13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
16
- from horovod .run .common .service import driver_service
16
+ import os
17
+ import six
18
+ import sys
19
+
20
+ from socket import AF_INET
21
+ from psutil import net_if_addrs
17
22
23
+ from horovod .run .util import cache , lsf , threads
24
+ from horovod .run .common .service import driver_service
25
+ from horovod .run .common .util import codec , safe_shell_exec
26
+ from horovod .run .task import task_service
18
27
19
28
class HorovodRunDriverService (driver_service .BasicDriverService ):
20
29
NAME = 'horovodrun driver service'
@@ -33,3 +42,212 @@ def __init__(self, driver_addresses, key, verbose, match_intf=False):
33
42
key ,
34
43
verbose ,
35
44
match_intf = match_intf )
45
+
46
+
47
+ def _launch_task_servers (all_host_names , local_host_names , driver_addresses ,
48
+ settings ):
49
+ """
50
+ Executes the task server and service client task for registration on the
51
+ hosts.
52
+ :param all_host_names: list of addresses. for example,
53
+ ['worker-0','worker-1']
54
+ ['10.11.11.11', '10.11.11.12']
55
+ :type all_host_names: list(string)
56
+ :param local_host_names: names that are resolved to one of the addresses
57
+ of local hosts interfaces. For example,
58
+ set(['localhost', '127.0.0.1'])
59
+ :type local_host_names: set
60
+ :param driver_addresses: map of interfaces and their address and port for
61
+ the service. For example:
62
+ {
63
+ 'lo': [('127.0.0.1', 34588)],
64
+ 'docker0': [('172.122.10.1', 34588)],
65
+ 'eth0': [('11.111.33.73', 34588)]
66
+ }
67
+ :type driver_addresses: map
68
+ :param settings: the object that contains the setting for running horovod
69
+ :type settings: Horovod.run.common.util.settings.Settings
70
+ :return:
71
+ :rtype:
72
+ """
73
+
74
+ def _exec_command (command ):
75
+ host_output = six .StringIO ()
76
+ try :
77
+ exit_code = safe_shell_exec .execute (command ,
78
+ stdout = host_output ,
79
+ stderr = host_output )
80
+ if exit_code != 0 :
81
+ print (
82
+ 'Launching horovodrun task function was not '
83
+ 'successful:\n {host_output}'
84
+ .format (host_output = host_output .getvalue ()))
85
+ os ._exit (exit_code )
86
+ finally :
87
+ host_output .close ()
88
+ return exit_code
89
+
90
+ if settings .ssh_port :
91
+ ssh_port_arg = '-p {ssh_port}' .format (ssh_port = settings .ssh_port )
92
+ else :
93
+ ssh_port_arg = ''
94
+ args_list = []
95
+ for index in range (len (all_host_names )):
96
+ host_name = all_host_names [index ]
97
+ if host_name in local_host_names :
98
+ command = \
99
+ '{python} -m horovod.run.task_fn {index} ' \
100
+ '{driver_addresses} {settings}' \
101
+ .format (python = sys .executable ,
102
+ index = codec .dumps_base64 (index ),
103
+ driver_addresses = codec .dumps_base64 (driver_addresses ),
104
+ settings = codec .dumps_base64 (settings ))
105
+ else :
106
+ command = \
107
+ 'ssh -o StrictHostKeyChecking=no {host} {ssh_port_arg} ' \
108
+ '\' {python} -m horovod.run.task_fn {index} {driver_addresses}' \
109
+ ' {settings}\' ' \
110
+ .format (host = host_name ,
111
+ ssh_port_arg = ssh_port_arg ,
112
+ python = sys .executable ,
113
+ index = codec .dumps_base64 (index ),
114
+ driver_addresses = codec .dumps_base64 (driver_addresses ),
115
+ settings = codec .dumps_base64 (settings ))
116
+ args_list .append ([command ])
117
+ # Each thread will use ssh command to launch the server on one task. If an
118
+ # error occurs in one thread, entire process will be terminated. Otherwise,
119
+ # threads will keep running and ssh session -- and the the task server --
120
+ # will be bound to the thread. In case, the horovodrun process dies, all
121
+ # the ssh sessions and all the task servers will die as well.
122
+ threads .execute_function_multithreaded (_exec_command ,
123
+ args_list ,
124
+ block_until_all_done = False )
125
+
126
+
127
+ @cache .use_cache ()
128
+ def _driver_fn (all_host_names , local_host_names , settings ):
129
+ """
130
+ launches the service service, launches the task service on each worker and
131
+ have them register with the service service. Each worker probes all the
132
+ interfaces of the worker index + 1 (in a ring manner) and only keeps the
133
+ routed interfaces. Function returns the intersection of the set of all the
134
+ routed interfaces on all the workers.
135
+ :param all_host_names: list of addresses. for example,
136
+ ['worker-0','worker-1']
137
+ ['10.11.11.11', '10.11.11.12']
138
+ :type all_host_names: list(string)
139
+ :param local_host_names: host names that resolve into a local addresses.
140
+ :type local_host_names: set
141
+ :param settings: the object that contains the setting for running horovod
142
+ :type settings: Horovod.run.common.util.settings.Settings
143
+ :return: example: ['eth0', 'eth1']
144
+ :rtype: list[string]
145
+ """
146
+ # Launch a TCP server called service service on the host running
147
+ # horovodrun.
148
+ driver = HorovodRunDriverService (
149
+ settings .num_hosts , settings .key , settings .nic )
150
+ if settings .verbose >= 2 :
151
+ print ('Launched horovodrun server.' )
152
+ # Have all the workers register themselves with the service service.
153
+ _launch_task_servers (all_host_names , local_host_names ,
154
+ driver .addresses (), settings )
155
+ if settings .verbose >= 2 :
156
+ print ('Attempted to launch horovod task servers.' )
157
+ try :
158
+ # wait for all the hosts to register with the service service.
159
+ if settings .verbose >= 2 :
160
+ print ('Waiting for the hosts to acknowledge.' )
161
+ driver .wait_for_initial_registration (settings .timeout )
162
+ tasks = [
163
+ task_service .HorovodRunTaskClient (
164
+ index ,
165
+ driver .task_addresses_for_driver (index ),
166
+ settings .key ,
167
+ settings .verbose ) for index in range (
168
+ settings .num_hosts )]
169
+ # Notify all the drivers that the initial registration is complete.
170
+ for task in tasks :
171
+ task .notify_initial_registration_complete ()
172
+ if settings .verbose >= 2 :
173
+ print ('Notified all the hosts that the registration is complete.' )
174
+ # Each worker should probe the interfaces of the next worker in a ring
175
+ # manner and filter only the routed ones -- it should filter out
176
+ # interfaces that are not really connected to any external networks
177
+ # such as lo0 with address 127.0.0.1.
178
+ if settings .verbose >= 2 :
179
+ print ('Waiting for hosts to perform host-to-host '
180
+ 'interface checking.' )
181
+ driver .wait_for_task_to_task_address_updates (settings .timeout )
182
+ if settings .verbose >= 2 :
183
+ print ('Host-to-host interface checking successful.' )
184
+ # Determine a set of common interfaces for task-to-task communication.
185
+ common_intfs = set (driver .task_addresses_for_tasks (0 ).keys ())
186
+ for index in range (1 , settings .num_hosts ):
187
+ common_intfs .intersection_update (
188
+ driver .task_addresses_for_tasks (index ).keys ())
189
+ if not common_intfs :
190
+ raise Exception (
191
+ 'Unable to find a set of common task-to-task communication '
192
+ 'interfaces: %s'
193
+ % [(index , driver .task_addresses_for_tasks (index ))
194
+ for index in range (settings .num_hosts )])
195
+ return common_intfs
196
+ finally :
197
+ driver .shutdown ()
198
+
199
+
200
+ def _get_common_interfaces (settings , all_host_names , remote_host_names , fn_cache ):
201
+ '''
202
+ Find the set of common and routed interfaces on all the hosts.
203
+ :param settings: the object that contains the setting for running horovod
204
+ :type settings: Horovod.run.common.util.settings.Settings
205
+ :param all_host_names: list of the host names
206
+ :type all_host_names: list(string)
207
+ :param remote_host_names: list of the remote host names.
208
+ :type remote_host_names: list(string)
209
+ :param fn_cache: Cache storing the results of checks performed by horovodrun
210
+ :type fn_cache: Horovod.run.util.cache.Cache
211
+ :return: List of common interfaces
212
+ '''
213
+ # Skipping interface discovery for LSF cluster as it slows down considerably the job start
214
+ if lsf .LSFUtils .using_lsf ():
215
+ return None
216
+
217
+ if len (remote_host_names ) > 0 :
218
+ if settings .verbose >= 2 :
219
+ print ('Testing interfaces on all the hosts.' )
220
+
221
+ local_host_names = set (all_host_names ) - set (remote_host_names )
222
+ # Find the set of common, routed interfaces on all the hosts (remote
223
+ # and local) and specify it in the args to be used by NCCL. It is
224
+ # expected that the following function will find at least one interface
225
+ # otherwise, it will raise an exception.
226
+ common_intfs = _driver_fn (all_host_names , local_host_names ,
227
+ settings , fn_cache = fn_cache )
228
+
229
+ if settings .verbose >= 2 :
230
+ print ('Interfaces on all the hosts were successfully checked.' )
231
+ print ('Common interface found: ' + ' ' .join (common_intfs ))
232
+
233
+ else :
234
+ if settings .verbose >= 2 :
235
+ print ('All hosts are local, finding the interfaces '
236
+ 'with address 127.0.0.1' )
237
+ # If all the given hosts are local, find the interfaces with address
238
+ # 127.0.0.1
239
+ common_intfs = set ()
240
+ for iface , addrs in net_if_addrs ().items ():
241
+ if settings .nic and iface != settings .nic :
242
+ continue
243
+ for addr in addrs :
244
+ if addr .family == AF_INET and addr .address == '127.0.0.1' :
245
+ common_intfs .add (iface )
246
+ break
247
+
248
+ if len (common_intfs ) == 0 :
249
+ raise ValueError ('No interface is found for address 127.0.0.1.' )
250
+
251
+ if settings .verbose >= 2 :
252
+ print ('Local interface found ' + ' ' .join (common_intfs ))
253
+ return common_intfs
0 commit comments