Skip to content

Commit f4d519c

Browse files
authored
XLA HorovodAllreduce for tf.function(jit_compile=True) (horovod#3053)
* Add implementation of XLA HorovodAllreduce. It depends on TF2.6 because of the new CustomCallSchedule to give scheduling hints to HLOs, which is essentail to performance when lowering HorovodAllreduce into XLA. Signed-off-by: Trent Lo <[email protected]> * Fix a build break due to interface change in TF2.6. Signed-off-by: Trent Lo <[email protected]> * Implement customized XLA Op registrar as we want to make it an opt-in. Signed-off-by: Trent Lo <[email protected]> * Ran clang-format. Signed-off-by: Trent Lo <[email protected]> * Polish codes. Signed-off-by: Trent Lo <[email protected]> * Improve cmake for XLA. Signed-off-by: Trent Lo <[email protected]> * Polish comments. Signed-off-by: Trent Lo <[email protected]> * Minor polishing. Signed-off-by: Trent Lo <[email protected]> * Don't set alias for the `start` custom-call. XLA may have problem dealing with it. Signed-off-by: Trent Lo <[email protected]> * Add a unittest for XLA. Signed-off-by: Trent Lo <[email protected]> * Add process_id in XLA Ops. Signed-off-by: Trent Lo <[email protected]> * Add test_xla.py Signed-off-by: Trent Lo <[email protected]> * Embedd HOROVOD_ENABLE_XLA_OPS. Signed-off-by: Trent Lo <[email protected]> * Ran clang-format. Signed-off-by: Trent Lo <[email protected]> * autopep8 for python formatting. Signed-off-by: Trent Lo <[email protected]> * Add documentation for Horovod XLA Ops. Signed-off-by: Trent Lo <[email protected]> * Format docs/xla.rst. Signed-off-by: Trent Lo <[email protected]> * Automatically set HOROVOD_ENABLE_ASYNC_COMPLETION for xla ops. Signed-off-by: Trent Lo <[email protected]> * Add a link to XLA in summary.rst. Signed-off-by: Trent Lo <[email protected]> * Make title line long enough in xla.rst. Signed-off-by: Trent Lo <[email protected]> * Add xla into toctree. Signed-off-by: Trent Lo <[email protected]> * Compile XLA Horovod ops only for TF2.5+ Signed-off-by: Trent Lo <[email protected]> * Setting the default Cycle Time to 0 because the XLA runtime is sensitive to latencies. Signed-off-by: Trent Lo <[email protected]> * Skip XLA tests if TF is older than TF2.5. Signed-off-by: Trent Lo <[email protected]> * Don't use tf.function() as decorator. Signed-off-by: Trent Lo <[email protected]> * Remove a redundant test. Signed-off-by: Trent Lo <[email protected]> * xla::CustomCallSchedule requires TF2.6. Also, prefix with EARLIEST and LATEST with SCHEDULE_. Signed-off-by: Trent Lo <[email protected]> * Do not link _pywrap_tensorflow_internal.so if XLA is not enabled. Signed-off-by: Trent Lo <[email protected]>
1 parent 283549c commit f4d519c

File tree

12 files changed

+1317
-2
lines changed

12 files changed

+1317
-2
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ Supported frameworks
171171
See these pages for Horovod examples and best practices:
172172

173173
- `Horovod with TensorFlow <docs/tensorflow.rst>`_
174+
- `Horovod with XLA in Tensorflow <xla.rst>`_
174175
- `Horovod with Keras <docs/keras.rst>`_
175176
- `Horovod with PyTorch <docs/pytorch.rst>`_
176177
- `Horovod with MXNet <docs/mxnet.rst>`_

cmake/Modules/FindTensorflow.cmake

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ if (LEN EQUAL "4")
1919
list(GET Tensorflow_OUTPUT 0 Tensorflow_VERSION)
2020
list(GET Tensorflow_OUTPUT 1 Tensorflow_INCLUDE_DIRS)
2121
list(GET Tensorflow_OUTPUT 2 Tensorflow_LIBRARIES)
22-
string(REPLACE " " ";" Tensorflow_LIBRARIES "${Tensorflow_LIBRARIES}")
22+
string(REPLACE " " ";" Tensorflow_LIBRARIES_LIST "${Tensorflow_LIBRARIES}")
23+
list(GET Tensorflow_LIBRARIES_LIST 0 Tensorflow_LIB_PATH)
24+
if (Tensorflow_VERSION VERSION_GREATER_EQUAL "2.6")
25+
# XLA implementations are in _pywrap_tensorflow_internal.so
26+
set(Tensorflow_LIBRARIES "${Tensorflow_LIBRARIES} ${Tensorflow_LIB_PATH}/python/ -l:_pywrap_tensorflow_internal.so")
27+
endif()
28+
message("Tensorflow_LIBRARIES := ${Tensorflow_LIBRARIES}")
2329
list(GET Tensorflow_OUTPUT 3 Tensorflow_COMPILE_FLAGS)
2430
if("${Tensorflow_COMPILE_FLAGS}" MATCHES "-D_GLIBCXX_USE_CXX11_ABI=1")
2531
set(Tensorflow_CXX11 TRUE)

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ Guides
101101

102102
tensorflow
103103

104+
xla
105+
104106
keras
105107

106108
pytorch

docs/summary.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ Supported frameworks
163163
See these pages for Horovod examples and best practices:
164164

165165
- `Horovod with TensorFlow <tensorflow.rst>`_
166+
- `Horovod with XLA in Tensorflow <xla.rst>`_
166167
- `Horovod with Keras <keras.rst>`_
167168
- `Horovod with PyTorch <pytorch.rst>`_
168169
- `Horovod with MXNet <mxnet.rst>`_

docs/xla.rst

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
Horovod with XLA in Tensorflow
2+
===============================
3+
4+
Basic usage
5+
-----------
6+
7+
XLA Horovod ops can be enabled by setting ``HOROVOD_ENABLE_XLA_OPS = 1`` by controlling the registration of the ops to Tensorflow/XLA.
8+
9+
There are two main ways to enable XLA and they could work with Horovod in different ways:
10+
11+
For **Explicit compilation with tf.function(jit_compile=True)**:
12+
13+
.. code-block:: python
14+
15+
os.environ["HOROVOD_ENABLE_XLA_OPS"] = "1"
16+
17+
@tf.function(jit_compile=True)
18+
def compiled_hvd_allreduce(self, dtype, dim):
19+
tensor = self.random_uniform(
20+
[17] * dim, -100, 100, dtype=dtype)
21+
summed = hvd.allreduce(tensor, average=False)
22+
return summed
23+
24+
In this way, all the ops in the ``compiled_hvd_allreduce`` function are lowered into XLA per the compilation requirement. If the XLA Horovod ops are not enabled, XLA will report compilation errors.
25+
26+
27+
For **Auto-clustering**:
28+
29+
Auto-clustering is a convenient way to use XLA by simply setting ``TF_XLA_FLAGS=--tf_xla_auto_jit=2`` and the XLA JIT automatically selects ops in the Tensorflow graph to be lowered into XLA. In this mode, enabling XLA Horovod ops is optional, because the auto-clustering can work even if the Horovod ops are left to be run by Tensorflow (devices) while only parts of the graphs are lowered onto XLA (devices).
30+
31+
List of supported XLA Horovod ops
32+
---------------------------------
33+
34+
The supported op list is:
35+
36+
``HorovodAllreduce``
37+

horovod/common/common.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ namespace common {
137137
#define HOROVOD_DISABLE_NVTX_RANGES "HOROVOD_DISABLE_NVTX_RANGES"
138138
#define HOROVOD_ENABLE_ASYNC_COMPLETION "HOROVOD_ENABLE_ASYNC_COMPLETION"
139139
#define HOROVOD_DYNAMIC_PROCESS_SETS "HOROVOD_DYNAMIC_PROCESS_SETS"
140+
#define HOROVOD_ENABLE_XLA_OPS "HOROVOD_ENABLE_XLA_OPS"
140141

141142
// String constant for gloo interface.
142143
#define GLOO_DEFAULT_IFACE ""
@@ -153,7 +154,7 @@ namespace common {
153154
#define JOIN_TENSOR_NAME "join.noname"
154155

155156
// List of supported frameworks.
156-
enum Framework { TENSORFLOW, PYTORCH, MXNET };
157+
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA };
157158

158159
enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR, ABORTED, INVALID_ARGUMENT, IN_PROGRESS };
159160

@@ -228,6 +229,8 @@ const Status DUPLICATE_NAME_ERROR = Status::InvalidArgument(
228229

229230
class TensorShape {
230231
public:
232+
TensorShape() : shape_() {}
233+
TensorShape(std::vector<int64_t> vec) : shape_(vec) {}
231234
void AddDim(int64_t dim);
232235
void AppendShape(TensorShape& other);
233236

horovod/common/operations.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,14 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
494494

495495
// Override the cycle time.
496496
state.parameter_manager.SetCycleTimeMs(1);
497+
bool enable_xla_ops = false;
498+
common::SetBoolFromEnv(HOROVOD_ENABLE_XLA_OPS, enable_xla_ops, true);
499+
if (enable_xla_ops) {
500+
// Setting the default Cycle Time to 0 because the XLA runtime is sensitive
501+
// to latencies.
502+
state.parameter_manager.SetCycleTimeMs(0);
503+
}
504+
497505
auto horovod_cycle_time = std::getenv(HOROVOD_CYCLE_TIME);
498506
if (horovod_cycle_time != nullptr) {
499507
state.parameter_manager.SetCycleTimeMs(
@@ -563,6 +571,11 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
563571

564572
// Check if async completion should be enabled
565573
SetBoolFromEnv(HOROVOD_ENABLE_ASYNC_COMPLETION, state.enable_async_completion, true);
574+
if (enable_xla_ops) {
575+
// Enable async completion when XLA ops are enabled. Sine the XLA runtime is
576+
// single-threaded, async completion is essential to reduce host overhead.
577+
state.enable_async_completion = true;
578+
}
566579

567580
// Enable auto-tuning.
568581
auto horovod_autotune = std::getenv(HOROVOD_AUTOTUNE);

horovod/tensorflow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ set(Tensorflow_CXX11 ${Tensorflow_CXX11} PARENT_SCOPE)
5959

6060
# TF SOURCES
6161
list(APPEND TF_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/mpi_ops.cc")
62+
list(APPEND TF_SOURCES "${PROJECT_SOURCE_DIR}/horovod/tensorflow/xla_mpi_ops.cc")
6263

6364
# Create library
6465
set_output_dir()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
// Modifications copyright (C) 2017 Uber Technologies, Inc.
3+
// Modifications copyright Microsoft
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
// =============================================================================
17+
18+
include "horovod/common/wire/message.fbs";
19+
20+
namespace horovod.xla.wire;
21+
22+
table TensorShape {
23+
dims:[long];
24+
}
25+
26+
table CustomCallConfig {
27+
tensor_name:string;
28+
tensor_type:common.wire.DataType;
29+
input_shapes:[TensorShape];
30+
output_shapes:[TensorShape];
31+
32+
// Prescale and postscale factors
33+
prescale_factor:float;
34+
postscale_factor:float;
35+
36+
// Root rank is necessary for broadcast operation.
37+
root_rank:int;
38+
39+
// Reduce op.
40+
reduce_op:int;
41+
42+
process_set_id:int;
43+
}

0 commit comments

Comments
 (0)