Skip to content

Commit 9297542

Browse files
author
Fan, Kai
committed
num_gpus checking
1 parent 379f542 commit 9297542

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

expert_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,9 @@ def __init__(self,
628628
reverse_target_vocab_table=None,
629629
scope=None):
630630

631+
self.devices = [x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"]
632+
assert len(self.devices) == hparams.num_gpus
633+
631634
self.iterator = iterator
632635
self.mode = mode
633636

@@ -765,8 +768,7 @@ def build_graph(self, hparams, scope=None):
765768
tgt_len_shards = approximate_split(self.tgt_sequence_length, hparams.num_gpus)
766769

767770
loss_shards = []
768-
devices = [x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"]
769-
for i, device in enumerate(devices):
771+
for i, device in enumerate(self.devices):
770772
with tf.name_scope("parallel_{}".format(i)):
771773
with tf.variable_scope(tf.get_variable_scope(), reuse=True if i > 0 else None):
772774
with tf.device(device):

0 commit comments

Comments
 (0)