Skip to content

Commit eeca2c0

Browse files
[ray integration] Initial Ray Integration with RayExecutor API (horovod#2218)
Signed-off-by: Richard Liaw <[email protected]> Co-authored-by: Travis Addair <[email protected]>
1 parent 9464e20 commit eeca2c0

15 files changed

+974
-26
lines changed

.buildkite/gen-pipeline.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ run_mpi_pytest() {
104104
local excluded_tests="| sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g'"
105105

106106
# Spark and Run test does not need to be executed with horovodrun, but we still run it below.
107-
local exclude_standalone_test="| sed 's/test_spark.py//g' | sed 's/test_run.py//g'"
107+
local exclude_standalone_test="| sed 's/test_spark.py//g' | sed 's/test_run.py//g' | sed 's/test_ray.py//g'"
108108
local standalone_tests="test_spark.py test_run.py"
109109

110110
# pytests have 4x GPU use cases and require a separate queue
@@ -209,8 +209,8 @@ run_gloo_pytest() {
209209
local excluded_tests="| sed 's/test_interactiverun.py//g' | sed 's/test_spark_keras.py//g' | sed 's/test_spark_torch.py//g'"
210210

211211
# Spark and Run test does not need to be executed with horovodrun, but we still run it below.
212-
local exclude_standalone_test="| sed 's/test_spark.py//g' | sed 's/test_run.py//g'"
213-
local standalone_tests="test_spark.py test_run.py"
212+
local exclude_standalone_test="| sed 's/test_spark.py//g' | sed 's/test_run.py//g' | sed 's/test_ray.py//g'"
213+
local standalone_tests="test_spark.py test_run.py test_ray.py"
214214

215215
run_test "${test}" "${queue}" \
216216
":pytest: Run PyTests (${test})" \

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
1818

1919
- Added `hvd.is_initialized()` method. ([#2020](https://github.com/horovod/horovod/pull/2020))
2020

21+
- Added Ray integration. ([#2218](https://github.com/horovod/horovod/pull/2218))
22+
2123
### Changed
2224

2325
- Moved `horovod.run.runner.run` to `horovod.run`. ([#2099](https://github.com/horovod/horovod/pull/2099))

Dockerfile.test.cpu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ RUN if [[ ${MPI_KIND} == "ONECCL" ]]; then \
169169
fi; \
170170
. /usr/local/oneccl/env/setvars.sh; \
171171
echo "pip install horovod, mpicxx is $(which mpicxx)"; \
172-
pip install -v $(ls /horovod/dist/horovod-*.tar.gz)[spark]; \
172+
pip install -v $(ls /horovod/dist/horovod-*.tar.gz)[spark,ray]; \
173173
else \
174-
pip install -v $(ls /horovod/dist/horovod-*.tar.gz)[spark]; \
174+
pip install -v $(ls /horovod/dist/horovod-*.tar.gz)[spark,ray]; \
175175
fi
176176

177177
# Prefetch Spark MNIST dataset.

Dockerfile.test.gpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ RUN if [[ ${MXNET_PACKAGE} == "mxnet-nightly" ]]; then \
123123
# Install Horovod.
124124
RUN cd /horovod && python setup.py sdist
125125
RUN ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && \
126-
bash -c "${HOROVOD_BUILD_FLAGS} HOROVOD_WITH_TENSORFLOW=1 HOROVOD_WITH_PYTORCH=1 HOROVOD_WITH_MXNET=1 pip install -v $(ls /horovod/dist/horovod-*.tar.gz)[spark]" && \
126+
bash -c "${HOROVOD_BUILD_FLAGS} HOROVOD_WITH_TENSORFLOW=1 HOROVOD_WITH_PYTORCH=1 HOROVOD_WITH_MXNET=1 pip install -v $(ls /horovod/dist/horovod-*.tar.gz)[spark,ray]" && \
127127
ldconfig
128128

129129
# Hack for compatibility of MNIST example with TensorFlow 1.1.0.

docs/api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ horovod.spark.common
5252
.. automodule:: horovod.spark.common.store
5353
:show-inheritance:
5454

55+
.. _horovod_ray_api:
56+
57+
horovod.ray
58+
-----------
59+
.. automodule:: horovod.ray
60+
5561
horovod.run
5662
-------------
5763
.. automodule:: horovod.run

docs/index.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,17 @@ Guides
118118
gpus_include
119119

120120
conda_include
121-
121+
122122
docker_include
123123

124124
spark_include
125125

126+
ray_include
127+
126128
lsf_include
127129

128130
tensor-fusion_include
129-
131+
130132
adasum_user_guide_include
131133

132134
timeline_include

docs/mocks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def _dummy():
5454
'pyspark.sql.functions',
5555
'pyspark.sql.types',
5656

57+
'ray',
58+
5759
'tensorflow',
5860
'tensorflow.python',
5961
'tensorflow.python.framework',

docs/ray.rst

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
.. inclusion-marker-start-do-not-remove
2+
3+
Horovod on Ray
4+
==============
5+
6+
``horovod.ray`` allows users to leverage Horovod on `a Ray cluster <https://docs.ray.io/en/latest/cluster/index.html>`_.
7+
8+
Currently, the Ray + Horovod integration provides a :ref:`RayExecutor API <horovod_ray_api>`.
9+
10+
.. note:: The Ray + Horovod integration currently only supports a Gloo backend.
11+
12+
Installation
13+
------------
14+
15+
Use the extra ``[ray]`` option to install Ray along with Horovod.
16+
17+
.. code-block:: bash
18+
19+
$ HOROVOD_WITH_GLOO=1 ... pip install 'horovod[ray]'
20+
21+
See the Ray documentation for `advanced installation instructions <https://docs.ray.io/en/latest/installation.html>`_.
22+
23+
24+
Horovod Ray Job
25+
---------------
26+
27+
The Horovod Ray integration offers a ``RayExecutor`` abstraction (:ref:`docs <horovod_ray_api>`),
28+
which is a wrapper over a group of `Ray actors (stateful processes) <https://docs.ray.io/en/latest/walkthrough.html#remote-classes-actors>`_.
29+
30+
.. code-block:: python
31+
32+
from horovod.ray import RayExecutor
33+
34+
# Start the Ray cluster or attach to an exisint Ray cluster.
35+
ray.init()
36+
37+
# Start num_hosts * num_slots actors on the cluster.
38+
executor = RayExecutor(
39+
setting, num_hosts=num_hosts, num_slots=num_slots, use_gpu=True)
40+
41+
# Launch the Ray actors on each machine.
42+
# This will launch `num_slots` actors on each machine, each with
43+
# 1 GPU allocated (set via CUDA VISIBLE DEVICES)
44+
executor.start()
45+
46+
47+
All actors will be part of the Horovod ring, so ``RayExecutor`` invocations will be able to support arbitrary Horovod collective operations.
48+
49+
Note that there is an implicit assumption on the cluster being homogenous in shape (i.e., all machines have the same number of slots available). This is simply
50+
an implementation detail and is not a fundamental limitation.
51+
52+
To actually execute a function, you can run the following:
53+
54+
.. code-block:: python
55+
56+
# In its simplest form, a function must take in a dummy variable
57+
def simple_fn(_):
58+
hvd.init()
59+
print("hvd rank", hvd.rank())
60+
return hvd.rank()
61+
62+
# Execute the function on all workers at once.
63+
result = executor.execute(simple_fn)
64+
# Check that the rank of all workers is unique
65+
assert len(set(result)) == hosts * num_slots
66+
67+
executor.shutdown()
68+
69+
70+
Execution
71+
~~~~~~~~~
72+
73+
A unique feature of Ray is its support for `stateful Actors <https://docs.ray.io/en/latest/walkthrough.html#remote-classes-actors>`_. This means that you can start arbitrary Python classes on each worker, easily supporting operations and calls where data is cached in memory.
74+
75+
.. code-block:: python
76+
77+
import torch
78+
from horovod.torch import hvd
79+
from horovod.ray import RayExecutor
80+
81+
class MyModel:
82+
def __init__(self, learning_rate):
83+
self.model = NeuralNet()
84+
optimizer = torch.optim.SGD(
85+
self.model.parameters(),
86+
lr=learning_rate,
87+
)
88+
self.optimizer = hvd.DistributedOptimizer(optimizer)
89+
90+
def get_weights(self):
91+
return dict(self.model.parameters())
92+
93+
def train(self):
94+
return train(self.model, self.optimizer)
95+
96+
97+
ray.init()
98+
executor = RayExecutor(...)
99+
executor.start(executable_cls=MyModel)
100+
for i in range(5):
101+
executor.execute(lambda worker: worker.train())
102+
103+
result = executor.execute(lambda worker: worker.get_weights())
104+
105+
# result will be N copies of the model weights
106+
assert all(isinstance(res, dict) for res in result)
107+
108+
109+
AWS: Cluster Launcher
110+
---------------------
111+
112+
You can also easily leverage the `Ray cluster launcher <https://docs.ray.io/en/latest/cluster/launcher.html>`_ to spin up cloud instances.
113+
114+
.. code-block:: yaml
115+
116+
# Save as `ray_cluster.yaml`
117+
118+
cluster_name: horovod-cluster
119+
provider: {type: aws, region: us-west-2}
120+
auth: {ssh_user: ubuntu}
121+
min_workers: 3
122+
max_workers: 3
123+
124+
# Deep Learning AMI (Ubuntu) Version 21.0
125+
head_node: {InstanceType: p3.2xlarge, ImageId: ami-0b294f219d14e6a82}
126+
worker_nodes: {InstanceType: p3.2xlarge, ImageId: ami-0b294f219d14e6a82}
127+
setup_commands: # Set up each node.
128+
- HOROVOD_WITH_GLOO=1 HOROVOD_GPU_OPERATIONS=NCCL pip install horovod[ray]
129+
130+
You can start the specified Ray cluster and monitor its status with:
131+
132+
.. code-block:: bash
133+
134+
$ ray up ray_cluster.yaml # starts the head node
135+
$ ray monitor ray_cluster.yaml # wait for worker nodes
136+
137+
Then, in your python script, make sure you add ``ray.init(address="auto")`` to connect
138+
to the distributed Ray cluster.
139+
140+
.. code-block:: diff
141+
142+
-ray.init()
143+
+ray.init(address="auto")
144+
145+
Then you can execute Ray scripts on the cluster:
146+
147+
.. code-block:: bash
148+
149+
$ ray submit ray_cluster.yaml <your_script.py>
150+
151+
# the above is is equivalent to
152+
$ ray attach ray_cluster.yaml # ssh
153+
ubuntu@ip-172-31-24-53:~$ python <your_script.py>
154+
155+
.. inclusion-marker-end-do-not-remove

docs/ray_include.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.. include:: ./ray.rst
2+
:start-after: inclusion-marker-start-do-not-remove
3+
:end-before: inclusion-marker-end-do-not-remove

examples/tensorflow2_mnist_ray.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import tensorflow as tf
2+
import horovod.tensorflow.keras as hvd
3+
import horovod.keras as hvd
4+
5+
import ray
6+
from horovod.ray import RayExecutor
7+
8+
9+
def train(num_epochs):
10+
# Horovod: initialize Horovod.
11+
hvd.init()
12+
13+
# Horovod: pin GPU to be used to process local rank (one GPU per process)
14+
gpus = tf.config.experimental.list_physical_devices('GPU')
15+
for gpu in gpus:
16+
tf.config.experimental.set_memory_growth(gpu, True)
17+
if gpus:
18+
tf.config.experimental.set_visible_devices(
19+
gpus[hvd.local_rank()], 'GPU')
20+
21+
(mnist_images, mnist_labels), _ = \
22+
tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())
23+
24+
dataset = tf.data.Dataset.from_tensor_slices(
25+
(tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
26+
tf.cast(mnist_labels, tf.int64))
27+
)
28+
dataset = dataset.repeat().shuffle(10000).batch(128)
29+
30+
mnist_model = tf.keras.Sequential([
31+
tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
32+
tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
33+
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
34+
tf.keras.layers.Dropout(0.25),
35+
tf.keras.layers.Flatten(),
36+
tf.keras.layers.Dense(128, activation='relu'),
37+
tf.keras.layers.Dropout(0.5),
38+
tf.keras.layers.Dense(10, activation='softmax')
39+
])
40+
41+
# Horovod: adjust learning rate based on number of GPUs.
42+
scaled_lr = 0.001 * hvd.size()
43+
opt = tf.optimizers.Adam(scaled_lr)
44+
45+
# Horovod: add Horovod DistributedOptimizer.
46+
opt = hvd.DistributedOptimizer(opt)
47+
48+
# Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
49+
# uses hvd.DistributedOptimizer() to compute gradients.
50+
mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
51+
optimizer=opt,
52+
metrics=['accuracy'],
53+
experimental_run_tf_function=False)
54+
55+
callbacks = [
56+
# Horovod: broadcast initial variable states from rank 0 to all other processes.
57+
# This is necessary to ensure consistent initialization of all workers when
58+
# training is started with random weights or restored from a checkpoint.
59+
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
60+
61+
# Horovod: average metrics among workers at the end of every epoch.
62+
#
63+
# Note: This callback must be in the list before the ReduceLROnPlateau,
64+
# TensorBoard or other metrics-based callbacks.
65+
hvd.callbacks.MetricAverageCallback(),
66+
67+
# Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
68+
# accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
69+
# the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
70+
hvd.callbacks.LearningRateWarmupCallback(
71+
warmup_epochs=3, initial_lr=scaled_lr, verbose=1),
72+
]
73+
74+
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
75+
if hvd.rank() == 0:
76+
callbacks.append(tf.keras.callbacks.ModelCheckpoint(
77+
'./checkpoint-{epoch}.h5'))
78+
79+
# Horovod: write logs on worker 0.
80+
verbose = 1 if hvd.rank() == 0 else 0
81+
82+
# Train the model.
83+
# Horovod: adjust number of steps based on number of GPUs.
84+
mnist_model.fit(dataset, steps_per_epoch=500 // hvd.size(),
85+
callbacks=callbacks, epochs=num_epochs, verbose=verbose)
86+
87+
88+
ray.init()
89+
settings = RayExecutor.create_settings(timeout_s=30)
90+
executor = RayExecutor(settings, num_hosts=1, num_slots=2, use_gpu=False)
91+
executor.start()
92+
executor.run(train, kwargs=dict(num_epochs=1))
93+
executor.shutdown()

0 commit comments

Comments
 (0)