Skip to content

Commit f5df52f

Browse files
launcher: join workers as they exit (#429)
check worker exit status in the order they exit. This way failed workers can be discovered early, and the entire job terminated as soon as possible. Signed-off-by: yulu.jia <[email protected]>
1 parent 08034cc commit f5df52f

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

launcher/launch.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class PropagatingThread(threading.Thread):
1414
""" propagate exceptions to the parent's thread
1515
refer to https://stackoverflow.com/a/31614591/9601110
1616
"""
17+
def __init__(self, callback=None, idx=-1, **kwargs):
18+
super().__init__(**kwargs)
19+
self.callback = callback
20+
self.idx = idx
1721

1822
def run(self):
1923
self.exc = None
@@ -27,6 +31,8 @@ def run(self):
2731
self.ret = self._target(*self._args, **self._kwargs)
2832
except BaseException as e:
2933
self.exc = e
34+
if self.callback is not None:
35+
self.callback(self.idx)
3036

3137
def join(self):
3238
super(PropagatingThread, self).join()
@@ -204,6 +210,27 @@ def parse_num_range(core_list):
204210
ret.append([list(a) for a in temp])
205211
return ret
206212

213+
cv = threading.Condition(lock=threading.Lock())
214+
done_threads = []
215+
216+
def done_callback(idx):
217+
with cv:
218+
done_threads.append(idx)
219+
cv.notify()
220+
221+
def join_threads(threads):
222+
count = 0
223+
num = len(threads)
224+
while count < num:
225+
with cv:
226+
while not done_threads:
227+
cv.wait()
228+
idx = done_threads[-1]
229+
done_threads.pop()
230+
threads[idx].join()
231+
print("BytePS launcher: joined local rank ", idx)
232+
count += 1
233+
207234
def launch_bps():
208235
print("BytePS launching " + os.environ["DMLC_ROLE"])
209236
sys.stdout.flush()
@@ -228,16 +255,16 @@ def launch_bps():
228255
for i in range(local_size):
229256
command = ' '.join(sys.argv[1:])
230257
if bind_to_cores:
231-
t[i] = PropagatingThread(target=worker, args=[
232-
i, local_size, command, allocations[i]])
258+
t[i] = PropagatingThread(idx=i, callback=done_callback,
259+
target=worker,
260+
args=[i, local_size, command, allocations[i]])
233261
else:
234-
t[i] = PropagatingThread(target=worker, args=[
235-
i, local_size, command])
262+
t[i] = PropagatingThread(idx=i, callback=done_callback,
263+
target=worker, args=[i, local_size, command])
236264
t[i].daemon = True
237265
t[i].start()
238266

239-
for i in range(local_size):
240-
t[i].join()
267+
join_threads(t)
241268

242269
elif os.environ.get("BYTEPS_FORCE_DISTRIBUTED", "") == "1" or \
243270
int(os.environ.get("DMLC_NUM_WORKER", "1")) > 1:

0 commit comments

Comments
 (0)