Skip to content

Commit d347058

Browse files
author
wangguoteng.p
committed
polish
1 parent ca25b27 commit d347058

File tree

15 files changed

+185
-79
lines changed

15 files changed

+185
-79
lines changed

Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ WORKERS_COMMAND := $(if ${WORKERS},-n ${WORKERS} --dist=loadscope,)
1717
DURATIONS ?= 10
1818
DURATIONS_COMMAND := $(if ${DURATIONS},--durations=${DURATIONS},)
1919

20-
TIMEOUT_LIMIT ?= 300
2120

2221
docs:
2322
$(MAKE) -C ${DING_DIR}/docs html

codecov.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,3 @@ coverage:
66
target: auto
77
threshold: 0.5%
88
if_ci_failed: success #success, failure, error, ignore
9-
10-
# fix me
11-
# The unittests of the torchrpc module are tested by different runners and cannot be included
12-
# in the test_unittest's coverage report. To keep CI happy, we don't count torchrpc related coverage.
13-
ignore:
14-
- ./ding/framework/message_queue/torch_rpc.py
15-
- ./ding/framework/message_queue/tests/test_torch_rpc.py
16-
- ./ding/framework/message_queue/perfs/*

ding/data/shm_buffer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,7 @@ def __init__(
158158
self.copy_on_get = copy_on_get
159159
self.shape = shape
160160
self.device = device
161-
# We don't want the buffer to be involved in the computational graph
162-
with torch.no_grad():
163-
self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device)
161+
self.buffer = torch.zeros(reduce(lambda x, y: x * y, shape), dtype=ttype, device=self.device)
164162

165163
def fill(self, src_arr: Union[np.ndarray, torch.Tensor]) -> None:
166164
if self.ctype is np.ndarray:

ding/data/tests/test_shm_buffer.py

Lines changed: 91 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda
2-
from ding.compatibility import torch_ge_1121
3-
41
import pytest
52
import numpy as np
63
import timeit
74
import torch
85
import time
6+
from ding.data.shm_buffer import ShmBuffer, ShmBufferCuda
97

108

119
def subprocess_np_shm(shm_buf):
@@ -14,8 +12,9 @@ def subprocess_np_shm(shm_buf):
1412
print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res)))
1513

1614

17-
def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run):
18-
event_run.wait()
15+
def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_wait, event_fire, copy_on_get):
16+
event_wait.wait()
17+
event_wait.clear()
1918
rtensor = shm_buf_torch.get()
2019
assert isinstance(rtensor, torch.Tensor)
2120
assert rtensor.device == torch.device('cuda:0')
@@ -26,12 +25,25 @@ def subprocess_cuda_shared_tensor(shm_buf_np, shm_buf_torch, event_run):
2625
assert isinstance(rarray, np.ndarray)
2726
assert rarray.dtype == np.dtype(np.float32)
2827
assert rarray.dtype == np.dtype(np.float32)
28+
assert rtensor.sum() == 1024 * 1024
29+
30+
shm_buf_torch.fill(torch.zeros((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0')))
31+
shm_buf_np.fill(np.zeros((1024, 1024), dtype=np.float32))
32+
33+
event_fire.set()
34+
35+
if copy_on_get:
36+
event_wait.wait()
37+
shm_buf_torch.buffer[0] = 9.0
38+
shm_buf_np.buffer[0] = 9.0
39+
event_fire.set()
40+
41+
del shm_buf_np
42+
del shm_buf_torch
2943

30-
res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000)
31-
print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
32-
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000)
33-
print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
3444

45+
def subprocess_cuda_shared_tensor_case2(shm_buf_np, shm_buf_torch, event_wait):
46+
event_wait.wait()
3547
del shm_buf_np
3648
del shm_buf_torch
3749

@@ -49,42 +61,98 @@ def test_shm_buffer():
4961
@pytest.mark.benchmark
5062
@pytest.mark.cudatest
5163
# @pytest.mark.multiprocesstest
52-
def test_cuda_shm():
53-
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
64+
@pytest.mark.parametrize("copy_on_get", [True, False])
65+
def test_cuda_shm(copy_on_get):
66+
if torch.cuda.is_available():
5467
import torch.multiprocessing as mp
5568
ctx = mp.get_context('spawn')
5669

57-
event_run = ctx.Event()
58-
shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=True)
59-
shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=True)
60-
proc = ctx.Process(target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_run])
70+
event_fire, event_wait = ctx.Event(), ctx.Event()
71+
shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=copy_on_get)
72+
shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=copy_on_get)
73+
proc = ctx.Process(
74+
target=subprocess_cuda_shared_tensor, args=[shm_buf_np, shm_buf_torch, event_fire, event_wait, copy_on_get]
75+
)
6176
proc.start()
6277

63-
ltensor = torch.ones((1024, 1024), dtype=torch.float32).cuda(0 if torch.cuda.device_count() == 1 else 1)
64-
larray = np.random.rand(1024, 1024).astype(np.float32)
78+
ltensor = torch.ones((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0'))
79+
larray = np.ones((1024, 1024), dtype=np.float32)
6580
shm_buf_torch.fill(ltensor)
6681
shm_buf_np.fill(larray)
6782

68-
res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000)
69-
print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
70-
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000)
71-
print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
72-
7383
rtensor = shm_buf_torch.get()
7484
assert isinstance(rtensor, torch.Tensor)
7585
assert rtensor.device == torch.device('cuda:0')
7686
assert rtensor.shape == ltensor.shape
7787
assert rtensor.dtype == ltensor.dtype
88+
assert rtensor.sum().item() == 1024 * 1024
7889

7990
rarray = shm_buf_np.get()
8091
assert isinstance(rarray, np.ndarray)
8192
assert larray.shape == rarray.shape
8293
assert larray.dtype == rarray.dtype
94+
assert larray.sum() == 1024 * 1024
95+
96+
event_fire.set()
97+
event_wait.wait()
98+
event_wait.clear()
99+
rtensor = shm_buf_torch.get()
100+
assert isinstance(rtensor, torch.Tensor)
101+
assert rtensor.device == torch.device('cuda:0')
102+
assert rtensor.shape == ltensor.shape
103+
assert rtensor.dtype == ltensor.dtype
104+
assert rtensor.sum().item() == 0
105+
106+
rarray = shm_buf_np.get()
107+
assert isinstance(rarray, np.ndarray)
108+
assert rarray.shape == larray.shape
109+
assert rarray.dtype == larray.dtype
110+
assert rarray.sum() == 0
83111

84-
event_run.set()
112+
if copy_on_get:
113+
event_fire.set()
114+
event_wait.wait()
115+
assert shm_buf_torch.buffer[0].item() == 9.0
116+
assert shm_buf_np.buffer[0] == 9.0
85117

86118
# Keep producer process running until all consumers exits.
87119
proc.join()
88120

89121
del shm_buf_np
90122
del shm_buf_torch
123+
124+
125+
@pytest.mark.benchmark
126+
@pytest.mark.cudatest
127+
# @pytest.mark.multiprocesstest
128+
@pytest.mark.parametrize("copy_on_get", [True, False])
129+
def test_cudabuff_perf(copy_on_get):
130+
if torch.cuda.is_available():
131+
import torch.multiprocessing as mp
132+
ctx = mp.get_context('spawn')
133+
134+
event_fire, event_wait = ctx.Event(), ctx.Event()
135+
shm_buf_np = ShmBufferCuda(np.dtype(np.float32), shape=(1024, 1024), copy_on_get=copy_on_get)
136+
shm_buf_torch = ShmBufferCuda(torch.float32, shape=(1024, 1024), copy_on_get=copy_on_get)
137+
proc = ctx.Process(target=subprocess_cuda_shared_tensor_case2, args=[shm_buf_np, shm_buf_torch, event_fire])
138+
proc.start()
139+
140+
ltensor = torch.ones((1024, 1024), dtype=torch.float32, device=torch.device('cuda:0'))
141+
larray = np.ones((1024, 1024), dtype=np.float32)
142+
shm_buf_torch.fill(ltensor)
143+
shm_buf_np.fill(larray)
144+
145+
res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.fill(ltensor), repeat=5, number=1000)
146+
print("CUDA-shared-tensor (torch) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
147+
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.fill(larray), repeat=5, number=1000)
148+
print("CUDA-shared-tensor (numpy) Fill: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
149+
150+
res = timeit.repeat(lambda shm_buf_torch=shm_buf_torch: shm_buf_torch.get(), repeat=5, number=1000)
151+
print("CUDA-shared-tensor (torch) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
152+
res = timeit.repeat(lambda shm_buf_np=shm_buf_np: shm_buf_np.get(), repeat=5, number=1000)
153+
print("CUDA-shared-tensor (numpy) Get: mean: {:.4f}s, STD: {:.4f}s".format(np.mean(res), np.std(res)))
154+
event_fire.set()
155+
proc.join()
156+
157+
del shm_buf_np
158+
del shm_buf_torch

ding/entry/cli_ditask.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None:
5858
@click.option("--platform-spec", type=str, help="Platform specific configure.")
5959
@click.option("--platform", type=str, help="Platform type: slurm, k8s.")
6060
@click.option(
61-
"--mq-type",
62-
type=str,
63-
default="nng",
64-
help="Class type of message queue, i.e. nng, redis, torchrpc:cuda, torchrpc:cpu."
61+
"--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis, cuda, torchrpc:cpu."
6562
)
6663
@click.option("--redis-host", type=str, help="Redis host.")
6764
@click.option("--redis-port", type=int, help="Redis port.")
@@ -173,10 +170,10 @@ def _cli_ditask(
173170
node_ids = node_ids.split(",")
174171
node_ids = list(map(lambda i: int(i), node_ids))
175172
use_cuda = False
176-
if mq_type == "torchrpc:cuda" or mq_type == "torchrpc:cpu":
177-
mq_type, use_cuda = mq_type.split(":")
178-
if use_cuda == "cuda":
179-
use_cuda = True
173+
if mq_type == "cuda":
174+
mq_type, use_cuda = "torchrpc", True
175+
if mq_type == "torchrpc:cpu":
176+
mq_type, use_cuda = "torchrpc", False
180177
if local_cuda_devices:
181178
local_cuda_devices = local_cuda_devices.split(",")
182179
local_cuda_devices = list(map(lambda s: s.strip(), local_cuda_devices))

ding/envs/env_manager/subprocess_env_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ def __init__(
118118
if not self._auto_reset:
119119
assert not self._reset_inplace, "reset_inplace is unavailable when auto_reset=False."
120120

121+
if self._cfg.cuda_shared_memory and not self._cuda_shared_memory:
122+
logging.warning(
123+
"Option 'cuda_shared_memory' is true but 'shared_memory' is False, 'cuda_shared_memory'"
124+
" will not be used."
125+
)
126+
121127
def _create_state(self) -> None:
122128
r"""
123129
Overview:

ding/framework/message_queue/README.md

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,42 @@
11
# Notes on using torchrpc
22

3+
## Performance
4+
We conducted performance tests in a k8s environment equipped with A100-80GB and 200G HCA.
5+
6+
### Intra-node GPU-P2P performance
7+
8+
| test case(unit:ms) | 1.25 KB | 20.00 KB | 1.25 MB | 10.00 MB | 40.00 M | 640.00 M | 1.25GB |
9+
| ------------------ | ------- | -------- | ------- | -------- | ------- | -------- | -------- |
10+
| shm | 0.3605 | 0.352 | 0.9924 | 7.1229 | 47.9575 | 798.8635 | 1548.782 |
11+
| nccl-nvlink | 0.1969 | 0.1104 | 0.2162 | 0.3285 | 0.4532 | 3.3166 | 5.3828 |
12+
| cuda-shared-tensor | 0.5307 | 0.578 | 0.9643 | 0.5908 | 1.2449 | 5.3707 | 9.686 |
13+
14+
### Inter-node GPU-P2P performance
15+
16+
| test case(unit:ms) | 20.00 KB | 1.25 MB | 10.00 MB | 40.00 M | 640.00 M | 1.25GB | 2.50 GB |
17+
| ------------------------ | -------- | ------- | -------- | -------- | --------- | --------- | ---------- |
18+
| nng-TCP | 5.7353 | 9.6782 | 30.5187 | 172.9719 | 3450.7418 | 7083.6372 | 14072.1213 |
19+
| nccl-TCP | 0.0826 | 1.321 | 31.7813 | 128.0672 | 1259.72 | 2477.2957 | 5157.7578 |
20+
| nccl-IB | 0.0928 | 0.5618 | 2.1134 | 7.1768 | 120.131 | 260.2628 | 518.8091 |
21+
| nccl-GDR (PXN<->PXN) | 0.5541 | 45.601 | 9.3636 | 19.3071 | 108.11 | 280.0556 | 527.9732 |
22+
| torchrpc-TCP | 5.6691 | 5.4707 | 14.0155 | 39.4443 | 580.333 | 1154.0793 | 2297.3776 |
23+
| torchrpc-IB | 21.3884 | 4.4093 | 5.9105 | 22.3012 | 130.249 | 236.8084 | 477.2389 |
24+
| torchrpc-GDR (PXN<->PXN) | 20.5018 | 23.2081 | 15.6427 | 7.5357* | 48.7812 | 77.2657 | 143.4112 |
25+
26+
### Atari performance
27+
Performance of dizoo/atari/example/atari_dqn_dist_rdma.py
28+
- memory: "32Gi"
29+
- cpu: 16
30+
- gpu: A100
31+
32+
33+
| test case(unit:s) | avg |
34+
| ----------------- | ------- |
35+
| TCP-nng | 127.64 |
36+
| torchrpc-CP | 29.3906 |
37+
| torchrpc-IB | 28.7763 |
38+
39+
340
## Problems you may encounter
441

542
Message queue of Torchrpc uses [tensorpipe](https://github.com/pytorch/tensorpipe) as a communication backend, a high-performance modular tensor-p2p communication library. However, several tensorpipe defects have been found in the test, which may make it difficult for you to use it.
@@ -10,4 +47,8 @@ Tensorpipe is not container aware. Processes can find themselves on the same phy
1047

1148
### 2. RDMA and fork subprocess
1249

13-
Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support.
50+
Tensorpipe does not consider the case of calling [fork(2)](https://man7.org/linux/man-pages/man2/fork.2.html) when using RDMA. If the corresponding initialization measures are not performed when using RDMA, using fork will cause serious problems, refer to [here](https://www.rdmamojo.com/2012/05/24/ibv_fork_init/). Therefore, if you start ditask in the IB/RoCE network environment, please specify the environment variables `IBV_FORK_SAFE=1` and `RDMAV_FORK_SAFE=1` , so that ibverbs will automatically initialize fork support.
51+
52+
### 3. GPU direct RDMA
53+
54+
If you use torchrpc in an environment that supports GPU direct RDMA, if the size of the tensor transmitted in rpc is very small (less than 32B), segmentfault may occur. See [issue.](https://github.com/pytorch/pytorch/issues/57136) We are tracking this bug and hope it can be resolved eventually.

ding/framework/message_queue/perfs/perf_shm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from ditk import logging
44
from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType
55
from ding.envs.env_manager.subprocess_env_manager import ShmBufferContainer, ShmBuffer
6-
from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \
6+
from ding.utils.comm_perf_helper import tensor_size_beauty_print, \
77
dtype_2_byte, TENSOR_SIZE_LIST, print_timer_result_csv
8+
from ding.utils import byte_beauty_print
89

910
import torch
1011
import numpy as np
@@ -37,7 +38,7 @@ def cuda_shm_callback(payload: RecvPayload, buffers: Any):
3738
assert tensor.device == torch.device('cuda:1')
3839

3940

40-
class Recvier:
41+
class Receiver:
4142

4243
def step(self, idx: int, __start_time):
4344
return {"idx": idx, "start_time": __start_time}
@@ -56,7 +57,7 @@ def __init__(self, gpu_tensors, buffers, ctx, is_cuda_buffer):
5657
_shm_callback = shm_callback
5758
else:
5859
_shm_callback = cuda_shm_callback
59-
self.register(Recvier, shm_buffer=self.buffers, shm_callback=_shm_callback)
60+
self.register(Receiver, shm_buffer=self.buffers, shm_callback=_shm_callback)
6061
super().start_link()
6162

6263
def _send_recv_callback(self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None):

ding/framework/message_queue/perfs/perf_torchrpc_nccl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212

1313
from ding.utils.data.structure.lifo_deque import LifoDeque
1414
from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, RPCEvent
15-
from ding.utils.comm_perf_helper import tensor_size_beauty_print, byte_beauty_print, \
15+
from ding.utils.comm_perf_helper import tensor_size_beauty_print, \
1616
dtype_2_byte, DO_PERF, time_perf_avg, time_perf_once, print_timer_result_csv
17+
from ding.utils import byte_beauty_print
1718

1819
LENGTH = 5
1920
REPEAT = 2

ding/framework/middleware/functional/collector.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,15 @@ def _rollout(ctx: "OnlineRLContext"):
136136
# torchrpc currently uses "cuda:0" as the transmission device by default,
137137
# so all data on the cpu side is copied to "cuda:0" here. In fact this
138138
# copy is unnecessary, because torchrpc can support both cpu side and gpu
139-
# side data to communicate using RDMA, but mixing the two transfer types
140-
# will cause a bug, see issue:
141-
# Because we have copied the large payload "obs" and "next_obs" from the
142-
# collector's subprocess to "cuda:0" in advance, the copy operation here
143-
# will not have too much overhead.
139+
# side data to communicate using RDMA.
140+
# But we met a bug in unittest, see: https://github.com/pytorch/pytorch/issues/57136
141+
# We adopted some strategies to avoid bug.
142+
# 1. Try not to mix cpu and gpu arg in one rpc.
143+
# Because we have copied the large payload "obs" and "next_obs" from the
144+
# collector's subprocess to "cuda:0" in advance, the copy operation here
145+
# will not have too much overhead.
146+
# 2. Don't make tensor size too small when using gpu direct RDMA.
147+
144148
if use_cuda_shared_memory:
145149
transition = to_device(transition, "cuda:0")
146150
transitions.append(timestep.env_id, transition)
@@ -149,6 +153,5 @@ def _rollout(ctx: "OnlineRLContext"):
149153
env_episode_id[timestep.env_id] = current_id
150154
current_id += 1
151155
ctx.env_episode += 1
152-
# TODO log
153156

154157
return _rollout

0 commit comments

Comments
 (0)