@@ -58,7 +58,8 @@ struct GlooCache {
58
58
using value_type = std::tuple<
59
59
std::shared_ptr<algorithm_type>, // algorithm
60
60
std::shared_ptr<buffer_type>, // input buffer (nullptr if not used)
61
- std::shared_ptr<buffer_type> // output buffer (nullptr if not used)
61
+ std::shared_ptr<buffer_type>, // output buffer (nullptr if not used)
62
+ std::shared_ptr<std::mutex> // mutex to make algorithms run atomically
62
63
>;
63
64
64
65
GlooCache (rank_type rank, std::shared_ptr<::gloo::transport::Device> device,
@@ -88,6 +89,10 @@ struct GlooCache {
88
89
template <CollectiveType D, typename T, typename ... Args>
89
90
value_type getAlgorithm (THDGroup group_id, const DataChannel::Group& group,
90
91
Args... args) {
92
+ // We need to protect from race when two (or more) threads are trying to
93
+ // create same algorithm simultaneously.
94
+ std::lock_guard<std::mutex> lock (_mutex);
95
+
91
96
auto key = algorithm_spec<D, T>::key (group_id, args...);
92
97
auto it = _algorithms.find (key);
93
98
if (it == _algorithms.end ()) {
@@ -116,6 +121,8 @@ struct GlooCache {
116
121
std::shared_ptr<::gloo::transport::Device> _device;
117
122
std::shared_ptr<store_type> _store;
118
123
124
+ std::mutex _mutex;
125
+
119
126
std::unordered_map<key_type, value_type> _algorithms;
120
127
};
121
128
@@ -164,7 +171,8 @@ struct algorithm_spec<CollectiveType::ALL_GATHER, T> {
164
171
reinterpret_cast <T*>(output_buffer.get ()),
165
172
count),
166
173
input_buffer,
167
- output_buffer
174
+ output_buffer,
175
+ std::make_shared<std::mutex>()
168
176
);
169
177
}
170
178
};
@@ -192,7 +200,8 @@ struct algorithm_spec<CollectiveType::ALL_REDUCE, T> {
192
200
count,
193
201
THDToGlooReduceOp<T>(op)),
194
202
input_buffer,
195
- input_buffer // we get the result in same buffer
203
+ input_buffer, // we get the result in same buffer
204
+ std::make_shared<std::mutex>()
196
205
);
197
206
}
198
207
};
@@ -220,7 +229,8 @@ struct algorithm_spec<CollectiveType::BROADCAST, T> {
220
229
count,
221
230
src_rank),
222
231
input_buffer,
223
- input_buffer // we get the result in same buffer
232
+ input_buffer, // we get the result in same buffer
233
+ std::make_shared<std::mutex>()
224
234
);
225
235
}
226
236
};
@@ -239,7 +249,8 @@ struct algorithm_spec<CollectiveType::BARRIER, T> {
239
249
return std::make_tuple (
240
250
std::make_shared<::gloo::BarrierAllToAll>(context),
241
251
nullptr ,
242
- nullptr
252
+ nullptr ,
253
+ std::make_shared<std::mutex>()
243
254
);
244
255
}
245
256
};
0 commit comments