Skip to content

Commit a17d96d

Browse files
VirrageSapaszke
authored andcommitted
Add multiple thread support for DataChannels
Previously, when using same data channel in multiple thread environment, one didn't have any guarantee that there won't be any deadlocks or even errors.
1 parent b7dcc29 commit a17d96d

File tree

5 files changed

+56
-10
lines changed

5 files changed

+56
-10
lines changed

torch/lib/THD/base/data_channels/DataChannelGloo.cpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ void DataChannelGloo::allGatherT(std::vector<thpp::Tensor*>& output,
138138
group_id, _groups.at(group_id), tensor_bytes, all_tensor_bytes, input.numel());
139139

140140
std::memcpy(std::get<1>(ret).get(), input.data(), tensor_bytes);
141-
std::get<0>(ret)->run();
141+
142+
{
143+
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
144+
std::get<0>(ret)->run();
145+
}
142146

143147
for (std::size_t i = 0; i < output.size(); i++) {
144148
std::memcpy(output.at(i)->data(),
@@ -185,7 +189,10 @@ void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation,
185189
group_id, _groups.at(group_id), tensor_bytes, t.numel(), operation);
186190

187191
std::memcpy(std::get<1>(ret).get(), t.data(), tensor_bytes);
188-
std::get<0>(ret)->run();
192+
{
193+
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
194+
std::get<0>(ret)->run();
195+
}
189196
std::memcpy(t.data(), std::get<2>(ret).get(), tensor_bytes);
190197
}
191198

@@ -214,7 +221,10 @@ void DataChannelGloo::broadcastT(thpp::Tensor& data, rank_type src_rank,
214221
if (_rank == src_rank)
215222
std::memcpy(std::get<1>(ret).get(), data.data(), tensor_bytes);
216223

217-
std::get<0>(ret)->run();
224+
{
225+
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
226+
std::get<0>(ret)->run();
227+
}
218228

219229
if (_rank != src_rank)
220230
std::memcpy(data.data(), std::get<2>(ret).get(), tensor_bytes);
@@ -267,7 +277,10 @@ void DataChannelGloo::barrier(THDGroup group_id) {
267277
RETURN_IF_NOT_IN_GROUP
268278
auto ret = _cache->getAlgorithm<CollectiveType::BARRIER, void>(
269279
group_id, _groups.at(group_id));
270-
std::get<0>(ret)->run();
280+
{
281+
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
282+
std::get<0>(ret)->run();
283+
}
271284
}
272285

273286

torch/lib/THD/base/data_channels/DataChannelMPI.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ DataChannelMPI::~DataChannelMPI() {
9292

9393

9494
bool DataChannelMPI::init() {
95-
MPI_Init(NULL, NULL);
95+
int* provided = NULL;
96+
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, provided);
97+
if (*provided != MPI_THREAD_MULTIPLE) {
98+
std::cerr << "MPI implementation does not support multiple threads."
99+
<< "Using same data channel in multiple thread can result in"
100+
<< "wrong results or errors." << std::endl;
101+
}
96102

97103
int rank, num_processes;
98104
MPI_Comm_size(MPI_COMM_WORLD, &num_processes);

torch/lib/THD/base/data_channels/DataChannelTCP.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ void DataChannelTCP::allGather(std::vector<thpp::Tensor*>& output,
293293
* efficient also for small data (< 512 KB).
294294
*/
295295

296+
std::lock_guard<std::mutex> lock(_mutex);
297+
296298
const auto& group = _groups.at(group_id);
297299
rank_type group_rank;
298300
bool exists;
@@ -325,6 +327,8 @@ void DataChannelTCP::allGather(std::vector<thpp::Tensor*>& output,
325327

326328
void DataChannelTCP::gather(std::vector<thpp::Tensor*>& output,
327329
thpp::Tensor& input, rank_type dst_rank, THDGroup group_id) {
330+
std::lock_guard<std::mutex> lock(_mutex);
331+
328332
const auto& group = _groups.at(group_id);
329333
bool exists;
330334

@@ -358,6 +362,8 @@ void DataChannelTCP::gather(std::vector<thpp::Tensor*>& output,
358362
void DataChannelTCP::scatter(std::vector<thpp::Tensor*>& input,
359363
thpp::Tensor& output, rank_type src_rank,
360364
THDGroup group_id) {
365+
std::lock_guard<std::mutex> lock(_mutex);
366+
361367
const auto& group = _groups.at(group_id);
362368
bool exists;
363369

@@ -404,6 +410,8 @@ void DataChannelTCP::allReduce(thpp::Tensor& data, THDReduceOp operation,
404410
* > https://github.com/pmodels/mpich/blob/master/src/mpi/coll/allreduce.c
405411
*/
406412

413+
std::lock_guard<std::mutex> lock(_mutex);
414+
407415
const auto& group = _groups.at(group_id);
408416
rank_type group_rank;
409417
bool exists;
@@ -471,6 +479,8 @@ void DataChannelTCP::reduce(thpp::Tensor& data, THDReduceOp operation,
471479
* order and direction of communication.
472480
*/
473481

482+
std::lock_guard<std::mutex> lock(_mutex);
483+
474484
const auto& group = _groups.at(group_id);
475485
rank_type group_rank;
476486
bool exists;
@@ -518,6 +528,8 @@ void DataChannelTCP::broadcast(thpp::Tensor& data, rank_type src_rank,
518528
* virtual ones where `virtual_rank` for `src_rank` is 0.
519529
*/
520530

531+
std::lock_guard<std::mutex> lock(_mutex);
532+
521533
const auto& group = _groups.at(group_id);
522534
rank_type group_rank;
523535
bool exists;
@@ -644,6 +656,8 @@ void DataChannelTCP::barrier(THDGroup group_id) {
644656
* we do recv asynchronously (thread), send byte and then wait for recv to complete.
645657
*/
646658

659+
std::lock_guard<std::mutex> lock(_mutex);
660+
647661
const auto& group = _groups.at(group_id);
648662
rank_type group_rank;
649663
bool exists;

torch/lib/THD/base/data_channels/DataChannelTCP.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ struct DataChannelTCP : DataChannel {
8787
std::vector<Process> _processes; // Other processes in network
8888
std::unique_ptr<struct pollfd[]> _poll_events; // Events array for `poll`
8989

90+
std::mutex _mutex; // General mutex for methods - to make methods run atomically.
91+
9092
// Existing groups of processes and corresponding group ids
9193
std::unordered_map<THDGroup, DataChannel::Group> _groups;
9294

torch/lib/THD/base/data_channels/GlooCache.hpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ struct GlooCache {
5858
using value_type = std::tuple<
5959
std::shared_ptr<algorithm_type>, // algorithm
6060
std::shared_ptr<buffer_type>, // input buffer (nullptr if not used)
61-
std::shared_ptr<buffer_type> // output buffer (nullptr if not used)
61+
std::shared_ptr<buffer_type>, // output buffer (nullptr if not used)
62+
std::shared_ptr<std::mutex> // mutex to make algorithms run atomically
6263
>;
6364

6465
GlooCache(rank_type rank, std::shared_ptr<::gloo::transport::Device> device,
@@ -88,6 +89,10 @@ struct GlooCache {
8889
template<CollectiveType D, typename T, typename... Args>
8990
value_type getAlgorithm(THDGroup group_id, const DataChannel::Group& group,
9091
Args... args) {
92+
// We need to protect from race when two (or more) threads are trying to
93+
// create same algorithm simultaneously.
94+
std::lock_guard<std::mutex> lock(_mutex);
95+
9196
auto key = algorithm_spec<D, T>::key(group_id, args...);
9297
auto it = _algorithms.find(key);
9398
if (it == _algorithms.end()) {
@@ -116,6 +121,8 @@ struct GlooCache {
116121
std::shared_ptr<::gloo::transport::Device> _device;
117122
std::shared_ptr<store_type> _store;
118123

124+
std::mutex _mutex;
125+
119126
std::unordered_map<key_type, value_type> _algorithms;
120127
};
121128

@@ -164,7 +171,8 @@ struct algorithm_spec<CollectiveType::ALL_GATHER, T> {
164171
reinterpret_cast<T*>(output_buffer.get()),
165172
count),
166173
input_buffer,
167-
output_buffer
174+
output_buffer,
175+
std::make_shared<std::mutex>()
168176
);
169177
}
170178
};
@@ -192,7 +200,8 @@ struct algorithm_spec<CollectiveType::ALL_REDUCE, T> {
192200
count,
193201
THDToGlooReduceOp<T>(op)),
194202
input_buffer,
195-
input_buffer // we get the result in same buffer
203+
input_buffer, // we get the result in same buffer
204+
std::make_shared<std::mutex>()
196205
);
197206
}
198207
};
@@ -220,7 +229,8 @@ struct algorithm_spec<CollectiveType::BROADCAST, T> {
220229
count,
221230
src_rank),
222231
input_buffer,
223-
input_buffer // we get the result in same buffer
232+
input_buffer, // we get the result in same buffer
233+
std::make_shared<std::mutex>()
224234
);
225235
}
226236
};
@@ -239,7 +249,8 @@ struct algorithm_spec<CollectiveType::BARRIER, T> {
239249
return std::make_tuple(
240250
std::make_shared<::gloo::BarrierAllToAll>(context),
241251
nullptr,
242-
nullptr
252+
nullptr,
253+
std::make_shared<std::mutex>()
243254
);
244255
}
245256
};

0 commit comments

Comments
 (0)