@@ -14,6 +14,10 @@ class PropagatingThread(threading.Thread):
14
14
""" propagate exceptions to the parent's thread
15
15
refer to https://stackoverflow.com/a/31614591/9601110
16
16
"""
17
+ def __init__ (self , callback = None , idx = - 1 , ** kwargs ):
18
+ super ().__init__ (** kwargs )
19
+ self .callback = callback
20
+ self .idx = idx
17
21
18
22
def run (self ):
19
23
self .exc = None
@@ -27,6 +31,8 @@ def run(self):
27
31
self .ret = self ._target (* self ._args , ** self ._kwargs )
28
32
except BaseException as e :
29
33
self .exc = e
34
+ if self .callback is not None :
35
+ self .callback (self .idx )
30
36
31
37
def join (self ):
32
38
super (PropagatingThread , self ).join ()
@@ -204,6 +210,27 @@ def parse_num_range(core_list):
204
210
ret .append ([list (a ) for a in temp ])
205
211
return ret
206
212
213
+ cv = threading .Condition (lock = threading .Lock ())
214
+ done_threads = []
215
+
216
+ def done_callback (idx ):
217
+ with cv :
218
+ done_threads .append (idx )
219
+ cv .notify ()
220
+
221
+ def join_threads (threads ):
222
+ count = 0
223
+ num = len (threads )
224
+ while count < num :
225
+ with cv :
226
+ while not done_threads :
227
+ cv .wait ()
228
+ idx = done_threads [- 1 ]
229
+ done_threads .pop ()
230
+ threads [idx ].join ()
231
+ print ("BytePS launcher: joined local rank " , idx )
232
+ count += 1
233
+
207
234
def launch_bps ():
208
235
print ("BytePS launching " + os .environ ["DMLC_ROLE" ])
209
236
sys .stdout .flush ()
@@ -228,16 +255,16 @@ def launch_bps():
228
255
for i in range (local_size ):
229
256
command = ' ' .join (sys .argv [1 :])
230
257
if bind_to_cores :
231
- t [i ] = PropagatingThread (target = worker , args = [
232
- i , local_size , command , allocations [i ]])
258
+ t [i ] = PropagatingThread (idx = i , callback = done_callback ,
259
+ target = worker ,
260
+ args = [i , local_size , command , allocations [i ]])
233
261
else :
234
- t [i ] = PropagatingThread (target = worker , args = [
235
- i , local_size , command ])
262
+ t [i ] = PropagatingThread (idx = i , callback = done_callback ,
263
+ target = worker , args = [ i , local_size , command ])
236
264
t [i ].daemon = True
237
265
t [i ].start ()
238
266
239
- for i in range (local_size ):
240
- t [i ].join ()
267
+ join_threads (t )
241
268
242
269
elif os .environ .get ("BYTEPS_FORCE_DISTRIBUTED" , "" ) == "1" or \
243
270
int (os .environ .get ("DMLC_NUM_WORKER" , "1" )) > 1 :
0 commit comments