Skip to content

Commit 3ad88aa

Browse files
authored
[Train] Test Train code snippets (#40432)
As a byproduct of the recent documentation rewrites, the Train docs contain several code snippets that aren't tested. This PR updates the snippets to test the ones that can be reasonably tested. --------- Signed-off-by: Balaji Veeramani <[email protected]>
1 parent 837ec26 commit 3ad88aa

16 files changed

+228
-112
lines changed

doc/BUILD

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ doctest(
312312
"source/serve/production-guide/fault-tolerance.md",
313313
"source/data/batch_inference.rst",
314314
"source/data/transforming-data.rst",
315-
"source/train/faq.rst",
316-
"source/train/user-guides/data-loading-preprocessing.rst",
315+
"source/train/**/*.rst",
316+
"source/train/**/*.md",
317317
"source/workflows/**/*.rst",
318318
"source/workflows/**/*.md",
319319
"source/rllib/**/*.rst",
@@ -326,6 +326,33 @@ doctest(
326326
)
327327

328328

329+
doctest(
330+
name="doctest[train]",
331+
files = glob(
332+
include=[
333+
"source/train/**/*.rst",
334+
"source/train/**/*.md"
335+
],
336+
exclude=[
337+
# GPU
338+
"source/train/user-guides/data-loading-preprocessing.rst",
339+
"source/train/user-guides/using-gpus.rst"
340+
]
341+
),
342+
tags = ["team:ml"]
343+
)
344+
345+
doctest(
346+
name="doctest[train]",
347+
files = [
348+
"source/train/user-guides/data-loading-preprocessing.rst",
349+
"source/train/user-guides/using-gpus.rst"
350+
],
351+
tags = ["team:ml"],
352+
gpu = True,
353+
)
354+
355+
329356
doctest(
330357
name="doctest[workflow]",
331358
files = glob(
@@ -362,10 +389,3 @@ doctest(
362389
gpu = True
363390
)
364391

365-
doctest(
366-
name="quarantine",
367-
files = [
368-
"source/train/user-guides/data-loading-preprocessing.rst",
369-
],
370-
tags = ["team:data"],
371-
)

doc/source/train/deepspeed.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ Code example
1010

1111
You only need to run your existing training code with a TorchTrainer. You can expect the final code to look like this:
1212

13-
.. code-block:: python
13+
.. testcode::
14+
:skipif: True
1415

1516
import deepspeed
1617
from deepspeed.accelerator import get_accelerator

doc/source/train/distributed-tensorflow-keras.rst

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ variable set up for you.
4646
The `MultiWorkerMirroredStrategy <https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy>`_
4747
enables synchronous distributed training. You *must* build and compile the ``Model`` within the scope of the strategy.
4848

49-
.. code-block:: python
49+
.. testcode::
50+
:skipif: True
5051

5152
with tf.distribute.MultiWorkerMirroredStrategy().scope():
5253
model = ... # build model
@@ -81,7 +82,12 @@ execute training. For distributed Tensorflow,
8182
use a :class:`~ray.train.tensorflow.TensorflowTrainer`
8283
that you can setup like this:
8384

84-
.. code-block:: python
85+
.. testcode::
86+
:hide:
87+
88+
train_func = lambda: None
89+
90+
.. testcode::
8591

8692
from ray.train import ScalingConfig
8793
from ray.train.tensorflow import TensorflowTrainer
@@ -95,7 +101,8 @@ that you can setup like this:
95101
To customize the backend setup, you can pass a
96102
:class:`~ray.train.tensorflow.TensorflowConfig`:
97103

98-
.. code-block:: python
104+
.. testcode::
105+
:skipif: True
99106

100107
from ray.train import ScalingConfig
101108
from ray.train.tensorflow import TensorflowTrainer, TensorflowConfig
@@ -116,7 +123,8 @@ Run a training function
116123
With a distributed training function and a Ray Train ``Trainer``, you are now
117124
ready to start training.
118125

119-
.. code-block:: python
126+
.. testcode::
127+
:skipif: True
120128

121129
trainer.fit()
122130

@@ -138,7 +146,7 @@ API for model training.
138146
`See this example <https://github.com/ray-project/ray/blob/master/python/ray/train/examples/tf/tune_tensorflow_autoencoder_example.py>`__
139147
for distributed data loading. The relevant parts are:
140148

141-
.. code-block:: python
149+
.. testcode::
142150

143151
import tensorflow as tf
144152
from ray import train
@@ -188,7 +196,7 @@ local log files. The logging also triggers :ref:`checkpoint bookkeeping <train-d
188196
The easiest way to report your results with Keras is by using the
189197
:class:`~ray.train.tensorflow.keras.ReportCheckpointCallback`:
190198

191-
.. code-block:: python
199+
.. testcode::
192200

193201
from ray.train.tensorflow.keras import ReportCheckpointCallback
194202

@@ -223,8 +231,9 @@ attribute.
223231
These concrete examples demonstrate how Ray Train appropriately saves checkpoints, model weights but not models, in distributed training.
224232

225233

226-
.. code-block:: python
234+
.. testcode::
227235

236+
import json
228237
import os
229238
import tempfile
230239

@@ -275,7 +284,7 @@ directory <train-log-dir>` of each run.
275284
Load checkpoints
276285
~~~~~~~~~~~~~~~~
277286

278-
.. code-block:: python
287+
.. testcode::
279288

280289
import os
281290
import tempfile

doc/source/train/examples/pytorch/torch_data_prefetch_benchmark/benchmark_example.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ Torch Data Prefetching Benchmark for Ray Train
88
We provide a benchmark example to show how the auto pipeline for host to device data transfer speeds up training on GPUs.
99
This functionality can be easily enabled by setting ``auto_transfer=True`` in :func:`train.torch.prepare_data_loader`.
1010

11-
.. code-block:: python
11+
.. testcode::
12+
:skipif: True
1213

1314
from torch.utils.data import DataLoader
1415
from ray import train

doc/source/train/getting-started-pytorch-lightning.rst

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ Quickstart
1717

1818
For reference, the final code is as follows:
1919

20-
.. code-block:: python
20+
.. testcode::
21+
:skipif: True
2122

2223
from ray.train.torch import TorchTrainer
2324
from ray.train import ScalingConfig
2425

2526
def train_func(config):
2627
# Your PyTorch Lightning training code here.
27-
28+
2829
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
2930
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
3031
result = trainer.fit()
@@ -39,7 +40,10 @@ Compare a PyTorch Lightning training script with and without Ray Train.
3940

4041
.. group-tab:: PyTorch Lightning
4142

42-
.. code-block:: python
43+
.. This snippet isn't tested because it doesn't use any Ray code.
44+
45+
.. testcode::
46+
:skipif: True
4347

4448
import torch
4549
from torchvision.models import resnet18
@@ -154,7 +158,8 @@ Set up a training function
154158
First, update your training code to support distributed training.
155159
Begin by wrapping your code in a :ref:`training function <train-overview-training-function>`:
156160

157-
.. code-block:: python
161+
.. testcode::
162+
:skipif: True
158163

159164
def train_func(config):
160165
# Your PyTorch Lightning training code here.
@@ -324,7 +329,7 @@ Outside of your training function, create a :class:`~ray.train.ScalingConfig` ob
324329
1. `num_workers` - The number of distributed training worker processes.
325330
2. `use_gpu` - Whether each worker should use a GPU (or CPU).
326331

327-
.. code-block:: python
332+
.. testcode::
328333

329334
from ray.train import ScalingConfig
330335
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
@@ -338,7 +343,15 @@ Launch a training job
338343
Tying this all together, you can now launch a distributed training job
339344
with a :class:`~ray.train.torch.TorchTrainer`.
340345

341-
.. code-block:: python
346+
.. testcode::
347+
:hide:
348+
349+
from ray.train import ScalingConfig
350+
351+
train_func = lambda: None
352+
scaling_config = ScalingConfig(num_workers=1)
353+
354+
.. testcode::
342355

343356
from ray.train.torch import TorchTrainer
344357

@@ -353,7 +366,7 @@ Access training results
353366
After training completes, Ray Train returns a :class:`~ray.train.Result` object, which contains
354367
information about the training run, including the metrics and checkpoints reported during training.
355368

356-
.. code-block:: python
369+
.. testcode::
357370

358371
result.metrics # The metrics reported during training.
359372
result.checkpoint # The latest checkpoint reported during training.
@@ -407,9 +420,11 @@ control over their native Lightning code.
407420

408421
.. group-tab:: (Deprecating) LightningTrainer
409422

423+
.. This snippet isn't tested because it raises a hard deprecation warning.
424+
425+
.. testcode::
426+
:skipif: True
410427

411-
.. code-block:: python
412-
413428
from ray.train.lightning import LightningConfigBuilder, LightningTrainer
414429

415430
config_builder = LightningConfigBuilder()
@@ -449,9 +464,13 @@ control over their native Lightning code.
449464

450465
.. group-tab:: (New API) TorchTrainer
451466

452-
.. code-block:: python
467+
.. This snippet isn't tested because it runs with 4 GPUs, and CI is only run with 1.
468+
469+
.. testcode::
470+
:skipif: True
453471

454472
import lightning.pytorch as pl
473+
from ray.air import CheckpointConfig, RunConfig
455474
from ray.train.torch import TorchTrainer
456475
from ray.train.lightning import (
457476
RayDDPStrategy,

doc/source/train/getting-started-pytorch.rst

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ Quickstart
1818

1919
For reference, the final code is as follows:
2020

21-
.. code-block:: python
21+
.. testcode::
22+
:skipif: True
2223

2324
from ray.train.torch import TorchTrainer
2425
from ray.train import ScalingConfig
@@ -40,7 +41,10 @@ Compare a PyTorch training script with and without Ray Train.
4041

4142
.. group-tab:: PyTorch
4243

43-
.. code-block:: python
44+
.. This snippet isn't tested because it doesn't use any Ray code.
45+
46+
.. testcode::
47+
:skipif: True
4448

4549
import tempfile
4650
import torch
@@ -138,7 +142,8 @@ Set up a training function
138142
First, update your training code to support distributed training.
139143
Begin by wrapping your code in a :ref:`training function <train-overview-training-function>`:
140144

141-
.. code-block:: python
145+
.. testcode::
146+
:skipif: True
142147

143148
def train_func(config):
144149
# Your PyTorch training code here.
@@ -212,8 +217,9 @@ See :ref:`data-ingest-torch`.
212217
Keep in mind that ``DataLoader`` takes in a ``batch_size`` which is the batch size for each worker.
213218
The global batch size can be calculated from the worker batch size (and vice-versa) with the following equation:
214219

215-
.. code-block:: python
216-
220+
.. testcode::
221+
:skipif: True
222+
217223
global_batch_size = worker_batch_size * ray.train.get_context().get_world_size()
218224

219225

@@ -248,7 +254,7 @@ Outside of your training function, create a :class:`~ray.train.ScalingConfig` ob
248254
1. :class:`num_workers <ray.train.ScalingConfig>` - The number of distributed training worker processes.
249255
2. :class:`use_gpu <ray.train.ScalingConfig>` - Whether each worker should use a GPU (or CPU).
250256

251-
.. code-block:: python
257+
.. testcode::
252258

253259
from ray.train import ScalingConfig
254260
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
@@ -262,7 +268,15 @@ Launch a training job
262268
Tying this all together, you can now launch a distributed training job
263269
with a :class:`~ray.train.torch.TorchTrainer`.
264270

265-
.. code-block:: python
271+
.. testcode::
272+
:hide:
273+
274+
from ray.train import ScalingConfig
275+
276+
train_func = lambda: None
277+
scaling_config = ScalingConfig(num_workers=1)
278+
279+
.. testcode::
266280

267281
from ray.train.torch import TorchTrainer
268282

@@ -275,7 +289,7 @@ Access training results
275289
After training completes, a :class:`~ray.train.Result` object is returned which contains
276290
information about the training run, including the metrics and checkpoints reported during training.
277291

278-
.. code-block:: python
292+
.. testcode::
279293

280294
result.metrics # The metrics reported during training.
281295
result.checkpoint # The latest checkpoint reported during training.

0 commit comments

Comments
 (0)