Skip to content

Commit 8d2668c

Browse files
authored
Fix a bug in is_patched function (#66)
1 parent 70a56bd commit 8d2668c

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

.github/workflows/conda-package.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
name: Conda package
22

3-
on: push
3+
on:
4+
push:
5+
branches:
6+
- master
7+
pull_request:
48

59
permissions: read-all
610

mkl_umath/src/_patch.pyx

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
# cython: language_level=3
2828

2929
import mkl_umath._ufuncs as mu
30-
import numpy.core.umath as nu
3130

3231
cimport numpy as cnp
3332
import numpy as np
@@ -59,15 +58,15 @@ cdef class patch:
5958
self.functions_count = 0
6059
for umath in umaths:
6160
mkl_umath = getattr(mu, umath)
62-
self.functions_count = self.functions_count + mkl_umath.ntypes
61+
self.functions_count += mkl_umath.ntypes
6362

6463
self.functions = <function_info *> malloc(self.functions_count * sizeof(function_info))
6564

6665
func_number = 0
6766
for umath in umaths:
6867
patch_umath = getattr(mu, umath)
6968
c_patch_umath = <cnp.ufunc>patch_umath
70-
c_orig_umath = <cnp.ufunc>getattr(nu, umath)
69+
c_orig_umath = <cnp.ufunc>getattr(np, umath)
7170
nargs = c_patch_umath.nargs
7271
for pi in range(c_patch_umath.ntypes):
7372
oi = 0
@@ -103,7 +102,7 @@ cdef class patch:
103102
cdef int* signature
104103

105104
for func in self.functions_dict:
106-
np_umath = getattr(nu, func[0])
105+
np_umath = getattr(np, func[0])
107106
index = self.functions_dict[func]
108107
function = self.functions[index].patch_function
109108
signature = self.functions[index].signature
@@ -118,7 +117,7 @@ cdef class patch:
118117
cdef int* signature
119118

120119
for func in self.functions_dict:
121-
np_umath = getattr(nu, func[0])
120+
np_umath = getattr(np, func[0])
122121
index = self.functions_dict[func]
123122
function = self.functions[index].original_function
124123
signature = self.functions[index].signature
@@ -143,34 +142,97 @@ def _initialize_tls():
143142

144143

145144
def use_in_numpy():
146-
'''
145+
"""
147146
Enables using of mkl_umath in Numpy.
148-
'''
147+
148+
Examples
149+
--------
150+
>>> import mkl_umath, numpy as np
151+
>>> mkl_umath.is_patched()
152+
# False
153+
154+
>>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
155+
>>> mkl_umath.is_patched()
156+
# True
157+
158+
>>> mkl_umath.restore() # Disable mkl_umath in Numpy
159+
>>> mkl_umath.is_patched()
160+
# False
161+
162+
"""
149163
if not _is_tls_initialized():
150164
_initialize_tls()
151165
_tls.patch.do_patch()
152166

153167

154168
def restore():
155-
'''
169+
"""
156170
Disables using of mkl_umath in Numpy.
157-
'''
171+
172+
Examples
173+
--------
174+
>>> import mkl_umath, numpy as np
175+
>>> mkl_umath.is_patched()
176+
# False
177+
178+
>>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
179+
>>> mkl_umath.is_patched()
180+
# True
181+
182+
>>> mkl_umath.restore() # Disable mkl_umath in Numpy
183+
>>> mkl_umath.is_patched()
184+
# False
185+
186+
"""
158187
if not _is_tls_initialized():
159188
_initialize_tls()
160189
_tls.patch.do_unpatch()
161190

162191

163192
def is_patched():
164-
'''
193+
"""
165194
Returns whether Numpy has been patched with mkl_umath.
166-
'''
195+
196+
Examples
197+
--------
198+
>>> import mkl_umath, numpy as np
199+
>>> mkl_umath.is_patched()
200+
# False
201+
202+
>>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy
203+
>>> mkl_umath.is_patched()
204+
# True
205+
206+
>>> mkl_umath.restore() # Disable mkl_umath in Numpy
207+
>>> mkl_umath.is_patched()
208+
# False
209+
210+
"""
167211
if not _is_tls_initialized():
168212
_initialize_tls()
169-
_tls.patch.is_patched()
213+
return _tls.patch.is_patched()
170214

171215
from contextlib import ContextDecorator
172216

173217
class mkl_umath(ContextDecorator):
218+
"""
219+
Context manager and decorator to temporarily patch NumPy ufuncs
220+
with MKL-based implementations.
221+
222+
Examples
223+
--------
224+
>>> import mkl_umath, numpy as np
225+
>>> mkl_umath.is_patched()
226+
# False
227+
228+
>>> with mkl_umath.mkl_umath(): # Enable mkl_umath in Numpy
229+
>>> print(mkl_umath.is_patched())
230+
# True
231+
232+
>>> mkl_umath.is_patched()
233+
# False
234+
235+
"""
174236
def __enter__(self):
175237
use_in_numpy()
176238
return self

0 commit comments

Comments
 (0)