Skip to content

Commit c19fbd3

Browse files
VirrageSapaszke
authored andcommitted
Update comments; Add inline accessors for value_type tuple in GlooCache
1 parent a17d96d commit c19fbd3

File tree

5 files changed

+47
-21
lines changed

5 files changed

+47
-21
lines changed

torch/lib/THD/base/data_channels/DataChannelGloo.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,16 @@ void DataChannelGloo::allGatherT(std::vector<thpp::Tensor*>& output,
137137
auto ret = _cache->getAlgorithm<CollectiveType::ALL_GATHER, T>(
138138
group_id, _groups.at(group_id), tensor_bytes, all_tensor_bytes, input.numel());
139139

140-
std::memcpy(std::get<1>(ret).get(), input.data(), tensor_bytes);
140+
std::memcpy(GlooCache::input_buffer(ret).get(), input.data(), tensor_bytes);
141141

142142
{
143-
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
144-
std::get<0>(ret)->run();
143+
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
144+
GlooCache::algorithm(ret)->run();
145145
}
146146

147147
for (std::size_t i = 0; i < output.size(); i++) {
148148
std::memcpy(output.at(i)->data(),
149-
std::get<2>(ret).get() + (i * tensor_bytes),
149+
GlooCache::output_buffer(ret).get() + (i * tensor_bytes),
150150
tensor_bytes);
151151
}
152152
}
@@ -188,12 +188,12 @@ void DataChannelGloo::allReduceT(thpp::Tensor& t, THDReduceOp operation,
188188
auto ret = _cache->getAlgorithm<CollectiveType::ALL_REDUCE, T>(
189189
group_id, _groups.at(group_id), tensor_bytes, t.numel(), operation);
190190

191-
std::memcpy(std::get<1>(ret).get(), t.data(), tensor_bytes);
191+
std::memcpy(GlooCache::input_buffer(ret).get(), t.data(), tensor_bytes);
192192
{
193-
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
194-
std::get<0>(ret)->run();
193+
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
194+
GlooCache::algorithm(ret)->run();
195195
}
196-
std::memcpy(t.data(), std::get<2>(ret).get(), tensor_bytes);
196+
std::memcpy(t.data(), GlooCache::output_buffer(ret).get(), tensor_bytes);
197197
}
198198

199199
void DataChannelGloo::allReduce(thpp::Tensor& data, THDReduceOp operation,
@@ -219,15 +219,15 @@ void DataChannelGloo::broadcastT(thpp::Tensor& data, rank_type src_rank,
219219
_groups.at(group_id).mustGetGroupRank(src_rank));
220220

221221
if (_rank == src_rank)
222-
std::memcpy(std::get<1>(ret).get(), data.data(), tensor_bytes);
222+
std::memcpy(GlooCache::input_buffer(ret).get(), data.data(), tensor_bytes);
223223

224224
{
225-
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
226-
std::get<0>(ret)->run();
225+
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
226+
GlooCache::algorithm(ret)->run();
227227
}
228228

229229
if (_rank != src_rank)
230-
std::memcpy(data.data(), std::get<2>(ret).get(), tensor_bytes);
230+
std::memcpy(data.data(), GlooCache::output_buffer(ret).get(), tensor_bytes);
231231
}
232232

233233

@@ -278,8 +278,8 @@ void DataChannelGloo::barrier(THDGroup group_id) {
278278
auto ret = _cache->getAlgorithm<CollectiveType::BARRIER, void>(
279279
group_id, _groups.at(group_id));
280280
{
281-
std::lock_guard<std::mutex> lock(*std::get<3>(ret));
282-
std::get<0>(ret)->run();
281+
std::lock_guard<std::mutex> lock(*GlooCache::mutex(ret));
282+
GlooCache::algorithm(ret)->run();
283283
}
284284
}
285285

torch/lib/THD/base/data_channels/DataChannelMPI.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ DataChannelMPI::~DataChannelMPI() {
9292

9393

9494
bool DataChannelMPI::init() {
95-
int* provided = NULL;
96-
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, provided);
97-
if (*provided != MPI_THREAD_MULTIPLE) {
98-
std::cerr << "MPI implementation does not support multiple threads."
99-
<< "Using same data channel in multiple thread can result in"
95+
int provided;
96+
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided);
97+
if (provided != MPI_THREAD_MULTIPLE) {
98+
std::cerr << "WARNING: MPI implementation does not support multiple threads. "
99+
<< "Using same data channel in multiple thread can result in "
100100
<< "wrong results or errors." << std::endl;
101101
}
102102

torch/lib/THD/base/data_channels/DataChannelTCP.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ struct DataChannelTCP : DataChannel {
8787
std::vector<Process> _processes; // Other processes in network
8888
std::unique_ptr<struct pollfd[]> _poll_events; // Events array for `poll`
8989

90-
std::mutex _mutex; // General mutex for methods - to make methods run atomically.
90+
// General mutex for methods - to protect access to the TCP data channel.
91+
std::mutex _mutex;
9192

9293
// Existing groups of processes and corresponding group ids
9394
std::unordered_map<THDGroup, DataChannel::Group> _groups;

torch/lib/THD/base/data_channels/GlooCache.hpp

+20-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct GlooCache {
5959
std::shared_ptr<algorithm_type>, // algorithm
6060
std::shared_ptr<buffer_type>, // input buffer (nullptr if not used)
6161
std::shared_ptr<buffer_type>, // output buffer (nullptr if not used)
62-
std::shared_ptr<std::mutex> // mutex to make algorithms run atomically
62+
std::shared_ptr<std::mutex> // mutex to protect same algorithm from running concurrently
6363
>;
6464

6565
GlooCache(rank_type rank, std::shared_ptr<::gloo::transport::Device> device,
@@ -72,6 +72,25 @@ struct GlooCache {
7272
GlooCache(GlooCache const&) = delete;
7373
void operator=(GlooCache const&) = delete;
7474

75+
76+
// Accessors for value_type tuple
77+
static inline std::shared_ptr<algorithm_type> algorithm(const value_type& t) {
78+
return std::get<0>(t);
79+
}
80+
81+
static inline std::shared_ptr<buffer_type> input_buffer(const value_type& t) {
82+
return std::get<1>(t);
83+
}
84+
85+
static inline std::shared_ptr<buffer_type> output_buffer(const value_type& t) {
86+
return std::get<2>(t);
87+
}
88+
89+
static inline std::shared_ptr<std::mutex> mutex(const value_type& t) {
90+
return std::get<3>(t);
91+
}
92+
93+
7594
std::shared_ptr<context_type> createContext(
7695
const DataChannel::Group& group,
7796
prefix_store_type& store

torch/lib/THD/test/data_channel_collectives.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ constexpr int BARRIER_WAIT_TIME = 200; // milliseconds
2828
std::vector<std::thread> g_all_workers;
2929
std::mutex g_mutex;
3030
std::string g_data_channel_type;
31+
std::unique_ptr<Barrier> g_barrier;
3132

3233

3334
void test_send_recv_tensor(std::shared_ptr<thd::DataChannel> data_channel) {
@@ -684,6 +685,8 @@ void init_gloo_master(int workers) {
684685

685686
assert(masterChannel->init());
686687
run_all_tests(masterChannel, workers);
688+
689+
g_barrier->wait();
687690
}
688691

689692
void init_gloo_worker(unsigned int id, int workers) {
@@ -695,6 +698,8 @@ void init_gloo_worker(unsigned int id, int workers) {
695698

696699
assert(worker_channel->init());
697700
run_all_tests(worker_channel, workers);
701+
702+
g_barrier->wait();
698703
}
699704
#endif // WITH_GLOO
700705

@@ -733,6 +738,7 @@ int main(int argc, char const *argv[]) {
733738
#ifdef WITH_GLOO
734739
g_data_channel_type = "gloo";
735740
for (auto workers : WORKERS_NUM) {
741+
g_barrier.reset(new Barrier(workers + 1));
736742
std::cout << "Gloo (workers: " << workers << "):" << std::endl;
737743
// start gloo master
738744
std::thread gloo_master_thread(init_gloo_master, workers);

0 commit comments

Comments
 (0)