3
3
import ctypes
4
4
import numpy as np
5
5
import torch
6
+ import torch .multiprocessing as mp
7
+ from functools import reduce
8
+ from ditk import logging
9
+ from abc import abstractmethod
6
10
7
11
_NTYPE_TO_CTYPE = {
8
12
np .bool_ : ctypes .c_bool ,
18
22
np .float64 : ctypes .c_double ,
19
23
}
20
24
25
+ # uint16, uint32, uint32
26
+ _NTYPE_TO_TTYPE = {
27
+ np .bool_ : torch .bool ,
28
+ np .uint8 : torch .uint8 ,
29
+ # np.uint16: torch.int16,
30
+ # np.uint32: torch.int32,
31
+ # np.uint64: torch.int64,
32
+ np .int8 : torch .uint8 ,
33
+ np .int16 : torch .int16 ,
34
+ np .int32 : torch .int32 ,
35
+ np .int64 : torch .int64 ,
36
+ np .float32 : torch .float32 ,
37
+ np .float64 : torch .float64 ,
38
+ }
39
+
40
+ _NOT_SUPPORT_NTYPE = {np .uint16 : torch .int16 , np .uint32 : torch .int32 , np .uint64 : torch .int64 }
41
+ _CONVERSION_TYPE = {np .uint16 : np .int16 , np .uint32 : np .int32 , np .uint64 : np .int64 }
42
+
43
+
44
+ class ShmBufferBase :
45
+
46
+ @abstractmethod
47
+ def fill (self , src_arr : Union [np .ndarray , torch .Tensor ]) -> None :
48
+ raise NotImplementedError
21
49
22
- class ShmBuffer ():
50
+ @abstractmethod
51
+ def get (self ) -> Union [np .ndarray , torch .Tensor ]:
52
+ raise NotImplementedError
53
+
54
+
55
+ class ShmBuffer (ShmBufferBase ):
23
56
"""
24
57
Overview:
25
58
Shared memory buffer to store numpy array.
@@ -78,6 +111,94 @@ def get(self) -> np.ndarray:
78
111
return data
79
112
80
113
114
+ class ShmBufferCuda (ShmBufferBase ):
115
+
116
+ def __init__ (
117
+ self ,
118
+ dtype : Union [torch .dtype , np .dtype ],
119
+ shape : Tuple [int ],
120
+ ctype : Optional [type ] = None ,
121
+ copy_on_get : bool = True ,
122
+ device : Optional [torch .device ] = torch .device ('cuda:0' )
123
+ ) -> None :
124
+ """
125
+ Overview:
126
+ Use torch.multiprocessing for shared tensor or ndaray between processes.
127
+ Arguments:
128
+ - dtype (Union[torch.dtype, np.dtype]): dtype of torch.tensor or numpy.ndarray.
129
+ - shape (Tuple[int]): Shape of torch.tensor or numpy.ndarray.
130
+ - ctype (type): Origin class type, e.g. np.ndarray, torch.Tensor.
131
+ - copy_on_get (bool, optional): Can be set to False only if the shared object
132
+ is a tenor, otherwise True.
133
+ - device (Optional[torch.device], optional): The GPU device where cuda-shared-tensor
134
+ is located, the default is cuda:0.
135
+
136
+ Raises:
137
+ RuntimeError: Unsupported share type by ShmBufferCuda.
138
+ """
139
+ if isinstance (dtype , np .dtype ): # it is type of gym.spaces.dtype
140
+ self .ctype = np .ndarray
141
+ dtype = dtype .type
142
+ if dtype in _NOT_SUPPORT_NTYPE .keys ():
143
+ logging .warning (
144
+ "Torch tensor unsupport numpy type {}, attempt to do a type conversion, which may lose precision." .
145
+ format (dtype )
146
+ )
147
+ ttype = _NOT_SUPPORT_NTYPE [dtype ]
148
+ self .dtype = _CONVERSION_TYPE [dtype ]
149
+ else :
150
+ ttype = _NTYPE_TO_TTYPE [dtype ]
151
+ self .dtype = dtype
152
+ elif isinstance (dtype , torch .dtype ):
153
+ self .ctype = torch .Tensor
154
+ ttype = dtype
155
+ else :
156
+ raise RuntimeError ("The dtype parameter only supports torch.dtype and np.dtype" )
157
+
158
+ self .copy_on_get = copy_on_get
159
+ self .shape = shape
160
+ self .device = device
161
+ # We don't want the buffer to be involved in the computational graph
162
+ with torch .no_grad ():
163
+ self .buffer = torch .zeros (reduce (lambda x , y : x * y , shape ), dtype = ttype , device = self .device )
164
+
165
+ def fill (self , src_arr : Union [np .ndarray , torch .Tensor ]) -> None :
166
+ if self .ctype is np .ndarray :
167
+ if src_arr .dtype .type != self .dtype :
168
+ logging .warning (
169
+ "Torch tensor unsupport numpy type {}, attempt to do a type conversion, which may lose precision." .
170
+ format (self .dtype )
171
+ )
172
+ src_arr = src_arr .astype (self .dtype )
173
+ tensor = torch .from_numpy (src_arr )
174
+ elif self .ctype is torch .Tensor :
175
+ tensor = src_arr
176
+ else :
177
+ raise RuntimeError ("Unsopport CUDA-shared-tensor input type:\" {}\" " .format (type (src_arr )))
178
+
179
+ # If the GPU-a and GPU-b are connected using nvlink, the copy is very fast.
180
+ with torch .no_grad ():
181
+ self .buffer .copy_ (tensor .view (tensor .numel ()))
182
+
183
+ def get (self ) -> Union [np .ndarray , torch .Tensor ]:
184
+ with torch .no_grad ():
185
+ if self .ctype is np .ndarray :
186
+ # Because ShmBufferCuda use CUDA memory exchanging data between processes.
187
+ # So copy_on_get is necessary for numpy arrays.
188
+ re = self .buffer .cpu ()
189
+ re = re .detach ().view (self .shape ).numpy ()
190
+ else :
191
+ if self .copy_on_get :
192
+ re = self .buffer .clone ().detach ().view (self .shape )
193
+ else :
194
+ re = self .buffer .view (self .shape )
195
+
196
+ return re
197
+
198
+ def __del__ (self ):
199
+ del self .buffer
200
+
201
+
81
202
class ShmBufferContainer (object ):
82
203
"""
83
204
Overview:
@@ -88,7 +209,8 @@ def __init__(
88
209
self ,
89
210
dtype : Union [Dict [Any , type ], type , np .dtype ],
90
211
shape : Union [Dict [Any , tuple ], tuple ],
91
- copy_on_get : bool = True
212
+ copy_on_get : bool = True ,
213
+ is_cuda_buffer : bool = False
92
214
) -> None :
93
215
"""
94
216
Overview:
@@ -98,11 +220,15 @@ def __init__(
98
220
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
99
221
multiple buffers; If `tuple`, use single buffer.
100
222
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
223
+ - is_cuda_buffer (:obj:`bool`): Whether to use pytorch CUDA shared tensor as the implementation of shm.
101
224
"""
102
225
if isinstance (shape , dict ):
103
- self ._data = {k : ShmBufferContainer (dtype [k ], v , copy_on_get ) for k , v in shape .items ()}
226
+ self ._data = {k : ShmBufferContainer (dtype [k ], v , copy_on_get , is_cuda_buffer ) for k , v in shape .items ()}
104
227
elif isinstance (shape , (tuple , list )):
105
- self ._data = ShmBuffer (dtype , shape , copy_on_get )
228
+ if not is_cuda_buffer :
229
+ self ._data = ShmBuffer (dtype , shape , copy_on_get )
230
+ else :
231
+ self ._data = ShmBufferCuda (dtype , shape , copy_on_get )
106
232
else :
107
233
raise RuntimeError ("not support shape: {}" .format (shape ))
108
234
self ._shape = shape
0 commit comments