27
27
# cython: language_level=3
28
28
29
29
import mkl_umath._ufuncs as mu
30
- import numpy.core.umath as nu
31
30
32
31
cimport numpy as cnp
33
32
import numpy as np
@@ -59,15 +58,15 @@ cdef class patch:
59
58
self .functions_count = 0
60
59
for umath in umaths:
61
60
mkl_umath = getattr (mu, umath)
62
- self .functions_count = self .functions_count + mkl_umath.ntypes
61
+ self .functions_count += mkl_umath.ntypes
63
62
64
63
self .functions = < function_info * > malloc(self .functions_count * sizeof(function_info))
65
64
66
65
func_number = 0
67
66
for umath in umaths:
68
67
patch_umath = getattr (mu, umath)
69
68
c_patch_umath = < cnp.ufunc> patch_umath
70
- c_orig_umath = < cnp.ufunc> getattr (nu , umath)
69
+ c_orig_umath = < cnp.ufunc> getattr (np , umath)
71
70
nargs = c_patch_umath.nargs
72
71
for pi in range (c_patch_umath.ntypes):
73
72
oi = 0
@@ -103,7 +102,7 @@ cdef class patch:
103
102
cdef int * signature
104
103
105
104
for func in self .functions_dict:
106
- np_umath = getattr (nu , func[0 ])
105
+ np_umath = getattr (np , func[0 ])
107
106
index = self .functions_dict[func]
108
107
function = self .functions[index].patch_function
109
108
signature = self .functions[index].signature
@@ -118,7 +117,7 @@ cdef class patch:
118
117
cdef int * signature
119
118
120
119
for func in self .functions_dict:
121
- np_umath = getattr (nu , func[0 ])
120
+ np_umath = getattr (np , func[0 ])
122
121
index = self .functions_dict[func]
123
122
function = self .functions[index].original_function
124
123
signature = self .functions[index].signature
@@ -143,34 +142,97 @@ def _initialize_tls():
143
142
144
143
145
144
def use_in_numpy ():
146
- '''
145
+ """
147
146
Enables using of mkl_umath in Numpy.
148
- '''
147
+
148
+ Examples
149
+ --------
150
+ >>> import mkl_umath, numpy as np
151
+ >>> mkl_umath.is_patched()
152
+ # False
153
+
154
+ >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
155
+ >>> mkl_umath.is_patched()
156
+ # True
157
+
158
+ >>> mkl_umath.restore() # Disable mkl_umath in Numpy
159
+ >>> mkl_umath.is_patched()
160
+ # False
161
+
162
+ """
149
163
if not _is_tls_initialized():
150
164
_initialize_tls()
151
165
_tls.patch.do_patch()
152
166
153
167
154
168
def restore ():
155
- '''
169
+ """
156
170
Disables using of mkl_umath in Numpy.
157
- '''
171
+
172
+ Examples
173
+ --------
174
+ >>> import mkl_umath, numpy as np
175
+ >>> mkl_umath.is_patched()
176
+ # False
177
+
178
+ >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
179
+ >>> mkl_umath.is_patched()
180
+ # True
181
+
182
+ >>> mkl_umath.restore() # Disable mkl_umath in Numpy
183
+ >>> mkl_umath.is_patched()
184
+ # False
185
+
186
+ """
158
187
if not _is_tls_initialized():
159
188
_initialize_tls()
160
189
_tls.patch.do_unpatch()
161
190
162
191
163
192
def is_patched ():
164
- '''
193
+ """
165
194
Returns whether Numpy has been patched with mkl_umath.
166
- '''
195
+
196
+ Examples
197
+ --------
198
+ >>> import mkl_umath, numpy as np
199
+ >>> mkl_umath.is_patched()
200
+ # False
201
+
202
+ >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
203
+ >>> mkl_umath.is_patched()
204
+ # True
205
+
206
+ >>> mkl_umath.restore() # Disable mkl_umath in Numpy
207
+ >>> mkl_umath.is_patched()
208
+ # False
209
+
210
+ """
167
211
if not _is_tls_initialized():
168
212
_initialize_tls()
169
- _tls.patch.is_patched()
213
+ return _tls.patch.is_patched()
170
214
171
215
from contextlib import ContextDecorator
172
216
173
217
class mkl_umath (ContextDecorator ):
218
+ """
219
+ Context manager and decorator to temporarily patch NumPy ufuncs
220
+ with MKL-based implementations.
221
+
222
+ Examples
223
+ --------
224
+ >>> import mkl_umath, numpy as np
225
+ >>> mkl_umath.is_patched()
226
+ # False
227
+
228
+ >>> with mkl_umath.mkl_umath(): # Enable mkl_umath in Numpy
229
+ >>> print(mkl_umath.is_patched())
230
+ # True
231
+
232
+ >>> mkl_umath.is_patched()
233
+ # False
234
+
235
+ """
174
236
def __enter__ (self ):
175
237
use_in_numpy()
176
238
return self
0 commit comments