@@ -137,16 +137,16 @@ void DataChannelGloo::allGatherT(std::vector<thpp::Tensor*>& output,
137
137
auto ret = _cache->getAlgorithm <CollectiveType::ALL_GATHER, T>(
138
138
group_id, _groups.at (group_id), tensor_bytes, all_tensor_bytes, input.numel ());
139
139
140
- std::memcpy (std::get< 1 > (ret).get (), input.data (), tensor_bytes);
140
+ std::memcpy (GlooCache::input_buffer (ret).get (), input.data (), tensor_bytes);
141
141
142
142
{
143
- std::lock_guard<std::mutex> lock (*std::get< 3 > (ret));
144
- std::get< 0 > (ret)->run ();
143
+ std::lock_guard<std::mutex> lock (*GlooCache::mutex (ret));
144
+ GlooCache::algorithm (ret)->run ();
145
145
}
146
146
147
147
for (std::size_t i = 0 ; i < output.size (); i++) {
148
148
std::memcpy (output.at (i)->data (),
149
- std::get< 2 > (ret).get () + (i * tensor_bytes),
149
+ GlooCache::output_buffer (ret).get () + (i * tensor_bytes),
150
150
tensor_bytes);
151
151
}
152
152
}
@@ -188,12 +188,12 @@ void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation,
188
188
auto ret = _cache->getAlgorithm <CollectiveType::ALL_REDUCE, T>(
189
189
group_id, _groups.at (group_id), tensor_bytes, t.numel (), operation);
190
190
191
- std::memcpy (std::get< 1 > (ret).get (), t.data (), tensor_bytes);
191
+ std::memcpy (GlooCache::input_buffer (ret).get (), t.data (), tensor_bytes);
192
192
{
193
- std::lock_guard<std::mutex> lock (*std::get< 3 > (ret));
194
- std::get< 0 > (ret)->run ();
193
+ std::lock_guard<std::mutex> lock (*GlooCache::mutex (ret));
194
+ GlooCache::algorithm (ret)->run ();
195
195
}
196
- std::memcpy (t.data (), std::get< 2 > (ret).get (), tensor_bytes);
196
+ std::memcpy (t.data (), GlooCache::output_buffer (ret).get (), tensor_bytes);
197
197
}
198
198
199
199
void DataChannelGloo::allReduce (thpp::Tensor& data, THDReduceOp operation,
@@ -219,15 +219,15 @@ void DataChannelGloo::broadcastT(thpp::Tensor& data, rank_type src_rank,
219
219
_groups.at (group_id).mustGetGroupRank (src_rank));
220
220
221
221
if (_rank == src_rank)
222
- std::memcpy (std::get< 1 > (ret).get (), data.data (), tensor_bytes);
222
+ std::memcpy (GlooCache::input_buffer (ret).get (), data.data (), tensor_bytes);
223
223
224
224
{
225
- std::lock_guard<std::mutex> lock (*std::get< 3 > (ret));
226
- std::get< 0 > (ret)->run ();
225
+ std::lock_guard<std::mutex> lock (*GlooCache::mutex (ret));
226
+ GlooCache::algorithm (ret)->run ();
227
227
}
228
228
229
229
if (_rank != src_rank)
230
- std::memcpy (data.data (), std::get< 2 > (ret).get (), tensor_bytes);
230
+ std::memcpy (data.data (), GlooCache::output_buffer (ret).get (), tensor_bytes);
231
231
}
232
232
233
233
@@ -278,8 +278,8 @@ void DataChannelGloo::barrier(THDGroup group_id) {
278
278
auto ret = _cache->getAlgorithm <CollectiveType::BARRIER, void >(
279
279
group_id, _groups.at (group_id));
280
280
{
281
- std::lock_guard<std::mutex> lock (*std::get< 3 > (ret));
282
- std::get< 0 > (ret)->run ();
281
+ std::lock_guard<std::mutex> lock (*GlooCache::mutex (ret));
282
+ GlooCache::algorithm (ret)->run ();
283
283
}
284
284
}
285
285
0 commit comments