Videre
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
7296
linedance-app/venv/lib/python3.12/site-packages/numba/np/arrayobj.py
Normal file
7296
linedance-app/venv/lib/python3.12/site-packages/numba/np/arrayobj.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
NumPy extensions.
|
||||
"""
|
||||
|
||||
from numba.np.arraymath import cross2d
|
||||
|
||||
|
||||
__all__ = [
|
||||
'cross2d'
|
||||
]
|
||||
2839
linedance-app/venv/lib/python3.12/site-packages/numba/np/linalg.py
Normal file
2839
linedance-app/venv/lib/python3.12/site-packages/numba/np/linalg.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,542 @@
|
||||
"""
|
||||
Implement the cmath module functions.
|
||||
"""
|
||||
|
||||
|
||||
import cmath
|
||||
import math
|
||||
|
||||
from numba.core.imputils import impl_ret_untracked
|
||||
from numba.core import types
|
||||
from numba.core.typing import signature
|
||||
from numba.cpython import mathimpl
|
||||
from numba.core.extending import overload
|
||||
|
||||
# registry = Registry('cmathimpl')
|
||||
# lower = registry.lower
|
||||
|
||||
|
||||
def is_nan(builder, z):
|
||||
return builder.fcmp_unordered('uno', z.real, z.imag)
|
||||
|
||||
def is_inf(builder, z):
|
||||
return builder.or_(mathimpl.is_inf(builder, z.real),
|
||||
mathimpl.is_inf(builder, z.imag))
|
||||
|
||||
def is_finite(builder, z):
|
||||
return builder.and_(mathimpl.is_finite(builder, z.real),
|
||||
mathimpl.is_finite(builder, z.imag))
|
||||
|
||||
|
||||
# @lower(cmath.isnan, types.Complex)
|
||||
def isnan_float_impl(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
res = is_nan(builder, z)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# @lower(cmath.isinf, types.Complex)
|
||||
def isinf_float_impl(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
res = is_inf(builder, z)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(cmath.isfinite, types.Complex)
|
||||
def isfinite_float_impl(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
res = is_finite(builder, z)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @overload(cmath.rect)
|
||||
def impl_cmath_rect(r, phi):
|
||||
if all([isinstance(typ, types.Float) for typ in [r, phi]]):
|
||||
def impl(r, phi):
|
||||
if not math.isfinite(phi):
|
||||
if not r:
|
||||
# cmath.rect(0, phi={inf, nan}) = 0
|
||||
return abs(r)
|
||||
if math.isinf(r):
|
||||
# cmath.rect(inf, phi={inf, nan}) = inf + j phi
|
||||
return complex(r, phi)
|
||||
real = math.cos(phi)
|
||||
imag = math.sin(phi)
|
||||
if real == 0. and math.isinf(r):
|
||||
# 0 * inf would return NaN, we want to keep 0 but xor the sign
|
||||
real /= r
|
||||
else:
|
||||
real *= r
|
||||
if imag == 0. and math.isinf(r):
|
||||
# ditto
|
||||
imag /= r
|
||||
else:
|
||||
imag *= r
|
||||
return complex(real, imag)
|
||||
return impl
|
||||
|
||||
|
||||
def intrinsic_complex_unary(inner_func):
|
||||
def wrapper(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
x = z.real
|
||||
y = z.imag
|
||||
# Same as above: math.isfinite() is unavailable on 2.x so we precompute
|
||||
# its value and pass it to the pure Python implementation.
|
||||
x_is_finite = mathimpl.is_finite(builder, x)
|
||||
y_is_finite = mathimpl.is_finite(builder, y)
|
||||
inner_sig = signature(sig.return_type,
|
||||
*(typ.underlying_float,) * 2 + (types.boolean,) * 2)
|
||||
res = context.compile_internal(builder, inner_func, inner_sig,
|
||||
(x, y, x_is_finite, y_is_finite))
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
return wrapper
|
||||
|
||||
|
||||
NAN = float('nan')
|
||||
INF = float('inf')
|
||||
|
||||
# @lower(cmath.exp, types.Complex)
|
||||
@intrinsic_complex_unary
|
||||
def exp_impl(x, y, x_is_finite, y_is_finite):
|
||||
"""cmath.exp(x + y j)"""
|
||||
if x_is_finite:
|
||||
if y_is_finite:
|
||||
c = math.cos(y)
|
||||
s = math.sin(y)
|
||||
r = math.exp(x)
|
||||
return complex(r * c, r * s)
|
||||
else:
|
||||
return complex(NAN, NAN)
|
||||
elif math.isnan(x):
|
||||
if y:
|
||||
return complex(x, x) # nan + j nan
|
||||
else:
|
||||
return complex(x, y) # nan + 0j
|
||||
elif x > 0.0:
|
||||
# x == +inf
|
||||
if y_is_finite:
|
||||
real = math.cos(y)
|
||||
imag = math.sin(y)
|
||||
# Avoid NaNs if math.cos(y) or math.sin(y) == 0
|
||||
# (e.g. cmath.exp(inf + 0j) == inf + 0j)
|
||||
if real != 0:
|
||||
real *= x
|
||||
if imag != 0:
|
||||
imag *= x
|
||||
return complex(real, imag)
|
||||
else:
|
||||
return complex(x, NAN)
|
||||
else:
|
||||
# x == -inf
|
||||
if y_is_finite:
|
||||
r = math.exp(x)
|
||||
c = math.cos(y)
|
||||
s = math.sin(y)
|
||||
return complex(r * c, r * s)
|
||||
else:
|
||||
r = 0
|
||||
return complex(r, r)
|
||||
|
||||
# @lower(cmath.log, types.Complex)
|
||||
@intrinsic_complex_unary
|
||||
def log_impl(x, y, x_is_finite, y_is_finite):
|
||||
"""cmath.log(x + y j)"""
|
||||
a = math.log(math.hypot(x, y))
|
||||
b = math.atan2(y, x)
|
||||
return complex(a, b)
|
||||
|
||||
|
||||
# @lower(cmath.log, types.Complex, types.Complex)
|
||||
def log_base_impl(context, builder, sig, args):
|
||||
"""cmath.log(z, base)"""
|
||||
[z, base] = args
|
||||
|
||||
def log_base(z, base):
|
||||
return cmath.log(z) / cmath.log(base)
|
||||
|
||||
res = context.compile_internal(builder, log_base, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
|
||||
# @overload(cmath.log10)
|
||||
def impl_cmath_log10(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
LN_10 = 2.302585092994045684
|
||||
|
||||
def log10_impl(z):
|
||||
"""cmath.log10(z)"""
|
||||
z = cmath.log(z)
|
||||
# This formula gives better results on +/-inf than cmath.log(z, 10)
|
||||
# See http://bugs.python.org/issue22544
|
||||
return complex(z.real / LN_10, z.imag / LN_10)
|
||||
|
||||
return log10_impl
|
||||
|
||||
|
||||
# @overload(cmath.phase)
|
||||
def phase_impl(x):
|
||||
"""cmath.phase(x + y j)"""
|
||||
|
||||
if not isinstance(x, types.Complex):
|
||||
return
|
||||
|
||||
def impl(x):
|
||||
return math.atan2(x.imag, x.real)
|
||||
return impl
|
||||
|
||||
|
||||
# @overload(cmath.polar)
|
||||
def polar_impl(x):
|
||||
if not isinstance(x, types.Complex):
|
||||
return
|
||||
|
||||
def impl(x):
|
||||
r, i = x.real, x.imag
|
||||
return math.hypot(r, i), math.atan2(i, r)
|
||||
return impl
|
||||
|
||||
|
||||
# @lower(cmath.sqrt, types.Complex)
|
||||
def sqrt_impl(context, builder, sig, args):
|
||||
# We risk spurious overflow for components >= FLT_MAX / (1 + sqrt(2)).
|
||||
|
||||
SQRT2 = 1.414213562373095048801688724209698079E0
|
||||
ONE_PLUS_SQRT2 = (1. + SQRT2)
|
||||
theargflt = sig.args[0].underlying_float
|
||||
# Get a type specific maximum value so scaling for overflow is based on that
|
||||
MAX = mathimpl.DBL_MAX if theargflt.bitwidth == 64 else mathimpl.FLT_MAX
|
||||
# THRES will be double precision, should not impact typing as it's just
|
||||
# used for comparison, there *may* be a few values near THRES which
|
||||
# deviate from e.g. NumPy due to rounding that occurs in the computation
|
||||
# of this value in the case of a 32bit argument.
|
||||
THRES = MAX / ONE_PLUS_SQRT2
|
||||
|
||||
def sqrt_impl(z):
|
||||
"""cmath.sqrt(z)"""
|
||||
# This is NumPy's algorithm, see npy_csqrt() in npy_math_complex.c.src
|
||||
a = z.real
|
||||
b = z.imag
|
||||
if a == 0.0 and b == 0.0:
|
||||
return complex(abs(b), b)
|
||||
if math.isinf(b):
|
||||
return complex(abs(b), b)
|
||||
if math.isnan(a):
|
||||
return complex(a, a)
|
||||
if math.isinf(a):
|
||||
if a < 0.0:
|
||||
return complex(abs(b - b), math.copysign(a, b))
|
||||
else:
|
||||
return complex(a, math.copysign(b - b, b))
|
||||
|
||||
# The remaining special case (b is NaN) is handled just fine by
|
||||
# the normal code path below.
|
||||
|
||||
# Scale to avoid overflow
|
||||
if abs(a) >= THRES or abs(b) >= THRES:
|
||||
a *= 0.25
|
||||
b *= 0.25
|
||||
scale = True
|
||||
else:
|
||||
scale = False
|
||||
# Algorithm 312, CACM vol 10, Oct 1967
|
||||
if a >= 0:
|
||||
t = math.sqrt((a + math.hypot(a, b)) * 0.5)
|
||||
real = t
|
||||
imag = b / (2 * t)
|
||||
else:
|
||||
t = math.sqrt((-a + math.hypot(a, b)) * 0.5)
|
||||
real = abs(b) / (2 * t)
|
||||
imag = math.copysign(t, b)
|
||||
# Rescale
|
||||
if scale:
|
||||
return complex(real * 2, imag)
|
||||
else:
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, sqrt_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
|
||||
# @lower(cmath.cos, types.Complex)
|
||||
def cos_impl(context, builder, sig, args):
|
||||
def cos_impl(z):
|
||||
"""cmath.cos(z) = cmath.cosh(z j)"""
|
||||
return cmath.cosh(complex(-z.imag, z.real))
|
||||
|
||||
res = context.compile_internal(builder, cos_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
# @overload(cmath.cosh)
|
||||
def impl_cmath_cosh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
def cosh_impl(z):
|
||||
"""cmath.cosh(z)"""
|
||||
x = z.real
|
||||
y = z.imag
|
||||
if math.isinf(x):
|
||||
if math.isnan(y):
|
||||
# x = +inf, y = NaN => cmath.cosh(x + y j) = inf + Nan * j
|
||||
real = abs(x)
|
||||
imag = y
|
||||
elif y == 0.0:
|
||||
# x = +inf, y = 0 => cmath.cosh(x + y j) = inf + 0j
|
||||
real = abs(x)
|
||||
imag = y
|
||||
else:
|
||||
real = math.copysign(x, math.cos(y))
|
||||
imag = math.copysign(x, math.sin(y))
|
||||
if x < 0.0:
|
||||
# x = -inf => negate imaginary part of result
|
||||
imag = -imag
|
||||
return complex(real, imag)
|
||||
return complex(math.cos(y) * math.cosh(x),
|
||||
math.sin(y) * math.sinh(x))
|
||||
return cosh_impl
|
||||
|
||||
|
||||
# @lower(cmath.sin, types.Complex)
|
||||
def sin_impl(context, builder, sig, args):
|
||||
def sin_impl(z):
|
||||
"""cmath.sin(z) = -j * cmath.sinh(z j)"""
|
||||
r = cmath.sinh(complex(-z.imag, z.real))
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, sin_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
# @overload(cmath.sinh)
|
||||
def impl_cmath_sinh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
def sinh_impl(z):
|
||||
"""cmath.sinh(z)"""
|
||||
x = z.real
|
||||
y = z.imag
|
||||
if math.isinf(x):
|
||||
if math.isnan(y):
|
||||
# x = +/-inf, y = NaN => cmath.sinh(x + y j) = x + NaN * j
|
||||
real = x
|
||||
imag = y
|
||||
else:
|
||||
real = math.cos(y)
|
||||
imag = math.sin(y)
|
||||
if real != 0.:
|
||||
real *= x
|
||||
if imag != 0.:
|
||||
imag *= abs(x)
|
||||
return complex(real, imag)
|
||||
return complex(math.cos(y) * math.sinh(x),
|
||||
math.sin(y) * math.cosh(x))
|
||||
return sinh_impl
|
||||
|
||||
|
||||
# @lower(cmath.tan, types.Complex)
|
||||
def tan_impl(context, builder, sig, args):
|
||||
def tan_impl(z):
|
||||
"""cmath.tan(z) = -j * cmath.tanh(z j)"""
|
||||
r = cmath.tanh(complex(-z.imag, z.real))
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, tan_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
|
||||
# @overload(cmath.tanh)
|
||||
def impl_cmath_tanh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
def tanh_impl(z):
|
||||
"""cmath.tanh(z)"""
|
||||
x = z.real
|
||||
y = z.imag
|
||||
if math.isinf(x):
|
||||
real = math.copysign(1., x)
|
||||
if math.isinf(y):
|
||||
imag = 0.
|
||||
else:
|
||||
imag = math.copysign(0., math.sin(2. * y))
|
||||
return complex(real, imag)
|
||||
# This is CPython's algorithm (see c_tanh() in cmathmodule.c).
|
||||
# XXX how to force float constants into single precision?
|
||||
tx = math.tanh(x)
|
||||
ty = math.tan(y)
|
||||
cx = 1. / math.cosh(x)
|
||||
txty = tx * ty
|
||||
denom = 1. + txty * txty
|
||||
return complex(
|
||||
tx * (1. + ty * ty) / denom,
|
||||
((ty / denom) * cx) * cx)
|
||||
|
||||
return tanh_impl
|
||||
|
||||
|
||||
# @lower(cmath.acos, types.Complex)
|
||||
def acos_impl(context, builder, sig, args):
|
||||
LN_4 = math.log(4)
|
||||
THRES = mathimpl.FLT_MAX / 4
|
||||
|
||||
def acos_impl(z):
|
||||
"""cmath.acos(z)"""
|
||||
# CPython's algorithm (see c_acos() in cmathmodule.c)
|
||||
if abs(z.real) > THRES or abs(z.imag) > THRES:
|
||||
# Avoid unnecessary overflow for large arguments
|
||||
# (also handles infinities gracefully)
|
||||
real = math.atan2(abs(z.imag), z.real)
|
||||
imag = math.copysign(
|
||||
math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4,
|
||||
-z.imag)
|
||||
return complex(real, imag)
|
||||
else:
|
||||
s1 = cmath.sqrt(complex(1. - z.real, -z.imag))
|
||||
s2 = cmath.sqrt(complex(1. + z.real, z.imag))
|
||||
real = 2. * math.atan2(s1.real, s2.real)
|
||||
imag = math.asinh(s2.real * s1.imag - s2.imag * s1.real)
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, acos_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
# @overload(cmath.acosh)
|
||||
def impl_cmath_acosh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
LN_4 = math.log(4)
|
||||
THRES = mathimpl.FLT_MAX / 4
|
||||
|
||||
def acosh_impl(z):
|
||||
"""cmath.acosh(z)"""
|
||||
# CPython's algorithm (see c_acosh() in cmathmodule.c)
|
||||
if abs(z.real) > THRES or abs(z.imag) > THRES:
|
||||
# Avoid unnecessary overflow for large arguments
|
||||
# (also handles infinities gracefully)
|
||||
real = math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4
|
||||
imag = math.atan2(z.imag, z.real)
|
||||
return complex(real, imag)
|
||||
else:
|
||||
s1 = cmath.sqrt(complex(z.real - 1., z.imag))
|
||||
s2 = cmath.sqrt(complex(z.real + 1., z.imag))
|
||||
real = math.asinh(s1.real * s2.real + s1.imag * s2.imag)
|
||||
imag = 2. * math.atan2(s1.imag, s2.real)
|
||||
return complex(real, imag)
|
||||
# Condensed formula (NumPy)
|
||||
#return cmath.log(z + cmath.sqrt(z + 1.) * cmath.sqrt(z - 1.))
|
||||
|
||||
return acosh_impl
|
||||
|
||||
|
||||
# @lower(cmath.asinh, types.Complex)
|
||||
def asinh_impl(context, builder, sig, args):
|
||||
LN_4 = math.log(4)
|
||||
THRES = mathimpl.FLT_MAX / 4
|
||||
|
||||
def asinh_impl(z):
|
||||
"""cmath.asinh(z)"""
|
||||
# CPython's algorithm (see c_asinh() in cmathmodule.c)
|
||||
if abs(z.real) > THRES or abs(z.imag) > THRES:
|
||||
real = math.copysign(
|
||||
math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4,
|
||||
z.real)
|
||||
imag = math.atan2(z.imag, abs(z.real))
|
||||
return complex(real, imag)
|
||||
else:
|
||||
s1 = cmath.sqrt(complex(1. + z.imag, -z.real))
|
||||
s2 = cmath.sqrt(complex(1. - z.imag, z.real))
|
||||
real = math.asinh(s1.real * s2.imag - s2.real * s1.imag)
|
||||
imag = math.atan2(z.imag, s1.real * s2.real - s1.imag * s2.imag)
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, asinh_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
# @lower(cmath.asin, types.Complex)
|
||||
def asin_impl(context, builder, sig, args):
|
||||
def asin_impl(z):
|
||||
"""cmath.asin(z) = -j * cmath.asinh(z j)"""
|
||||
r = cmath.asinh(complex(-z.imag, z.real))
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, asin_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
# @lower(cmath.atan, types.Complex)
|
||||
def atan_impl(context, builder, sig, args):
|
||||
def atan_impl(z):
|
||||
"""cmath.atan(z) = -j * cmath.atanh(z j)"""
|
||||
r = cmath.atanh(complex(-z.imag, z.real))
|
||||
if math.isinf(z.real) and math.isnan(z.imag):
|
||||
# XXX this is odd but necessary
|
||||
return complex(r.imag, r.real)
|
||||
else:
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, atan_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
# @lower(cmath.atanh, types.Complex)
|
||||
def atanh_impl(context, builder, sig, args):
|
||||
LN_4 = math.log(4)
|
||||
THRES_LARGE = math.sqrt(mathimpl.FLT_MAX / 4)
|
||||
THRES_SMALL = math.sqrt(mathimpl.FLT_MIN)
|
||||
PI_12 = math.pi / 2
|
||||
|
||||
def atanh_impl(z):
|
||||
"""cmath.atanh(z)"""
|
||||
# CPython's algorithm (see c_atanh() in cmathmodule.c)
|
||||
if z.real < 0.:
|
||||
# Reduce to case where z.real >= 0., using atanh(z) = -atanh(-z).
|
||||
negate = True
|
||||
z = -z
|
||||
else:
|
||||
negate = False
|
||||
|
||||
ay = abs(z.imag)
|
||||
if math.isnan(z.real) or z.real > THRES_LARGE or ay > THRES_LARGE:
|
||||
if math.isinf(z.imag):
|
||||
real = math.copysign(0., z.real)
|
||||
elif math.isinf(z.real):
|
||||
real = 0.
|
||||
else:
|
||||
# may be safe from overflow, depending on hypot's implementation...
|
||||
h = math.hypot(z.real * 0.5, z.imag * 0.5)
|
||||
real = z.real/4./h/h
|
||||
imag = -math.copysign(PI_12, -z.imag)
|
||||
elif z.real == 1. and ay < THRES_SMALL:
|
||||
# C99 standard says: atanh(1+/-0.) should be inf +/- 0j
|
||||
if ay == 0.:
|
||||
real = INF
|
||||
imag = z.imag
|
||||
else:
|
||||
real = -math.log(math.sqrt(ay) /
|
||||
math.sqrt(math.hypot(ay, 2.)))
|
||||
imag = math.copysign(math.atan2(2., -ay) / 2, z.imag)
|
||||
else:
|
||||
sqay = ay * ay
|
||||
zr1 = 1 - z.real
|
||||
real = math.log1p(4. * z.real / (zr1 * zr1 + sqay)) * 0.25
|
||||
imag = -math.atan2(-2. * z.imag,
|
||||
zr1 * (1 + z.real) - sqay) * 0.5
|
||||
|
||||
if math.isnan(z.imag):
|
||||
imag = NAN
|
||||
if negate:
|
||||
return complex(-real, -imag)
|
||||
else:
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, atanh_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
@@ -0,0 +1,454 @@
|
||||
"""
|
||||
Provide math calls that uses intrinsics or libc math functions.
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import llvmlite.ir
|
||||
from llvmlite.ir import Constant
|
||||
|
||||
from numba.core.imputils import impl_ret_untracked
|
||||
from numba.core import types, config, cgutils
|
||||
from numba.core.extending import overload
|
||||
from numba.core.typing import signature
|
||||
from numba.cpython.unsafe.numbers import trailing_zeros
|
||||
|
||||
|
||||
# registry = Registry('mathimpl')
|
||||
# lower = registry.lower
|
||||
|
||||
|
||||
# Helpers, shared with cmathimpl.
|
||||
_NP_FLT_FINFO = np.finfo(np.dtype('float32'))
|
||||
FLT_MAX = _NP_FLT_FINFO.max
|
||||
FLT_MIN = _NP_FLT_FINFO.tiny
|
||||
|
||||
_NP_DBL_FINFO = np.finfo(np.dtype('float64'))
|
||||
DBL_MAX = _NP_DBL_FINFO.max
|
||||
DBL_MIN = _NP_DBL_FINFO.tiny
|
||||
|
||||
FLOAT_ABS_MASK = 0x7fffffff
|
||||
FLOAT_SIGN_MASK = 0x80000000
|
||||
DOUBLE_ABS_MASK = 0x7fffffffffffffff
|
||||
DOUBLE_SIGN_MASK = 0x8000000000000000
|
||||
|
||||
|
||||
def is_nan(builder, val):
|
||||
"""
|
||||
Return a condition testing whether *val* is a NaN.
|
||||
"""
|
||||
return builder.fcmp_unordered('uno', val, val)
|
||||
|
||||
def is_inf(builder, val):
|
||||
"""
|
||||
Return a condition testing whether *val* is an infinite.
|
||||
"""
|
||||
pos_inf = Constant(val.type, float("+inf"))
|
||||
neg_inf = Constant(val.type, float("-inf"))
|
||||
isposinf = builder.fcmp_ordered('==', val, pos_inf)
|
||||
isneginf = builder.fcmp_ordered('==', val, neg_inf)
|
||||
return builder.or_(isposinf, isneginf)
|
||||
|
||||
def is_finite(builder, val):
|
||||
"""
|
||||
Return a condition testing whether *val* is a finite.
|
||||
"""
|
||||
# is_finite(x) <=> x - x != NaN
|
||||
val_minus_val = builder.fsub(val, val)
|
||||
return builder.fcmp_ordered('ord', val_minus_val, val_minus_val)
|
||||
|
||||
def f64_as_int64(builder, val):
|
||||
"""
|
||||
Bitcast a double into a 64-bit integer.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.DoubleType()
|
||||
return builder.bitcast(val, llvmlite.ir.IntType(64))
|
||||
|
||||
def int64_as_f64(builder, val):
|
||||
"""
|
||||
Bitcast a 64-bit integer into a double.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.IntType(64)
|
||||
return builder.bitcast(val, llvmlite.ir.DoubleType())
|
||||
|
||||
def f32_as_int32(builder, val):
|
||||
"""
|
||||
Bitcast a float into a 32-bit integer.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.FloatType()
|
||||
return builder.bitcast(val, llvmlite.ir.IntType(32))
|
||||
|
||||
def int32_as_f32(builder, val):
|
||||
"""
|
||||
Bitcast a 32-bit integer into a float.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.IntType(32)
|
||||
return builder.bitcast(val, llvmlite.ir.FloatType())
|
||||
|
||||
def negate_real(builder, val):
|
||||
"""
|
||||
Negate real number *val*, with proper handling of zeros.
|
||||
"""
|
||||
# The negative zero forces LLVM to handle signed zeros properly.
|
||||
return builder.fsub(Constant(val.type, -0.0), val)
|
||||
|
||||
def call_fp_intrinsic(builder, name, args):
|
||||
"""
|
||||
Call a LLVM intrinsic floating-point operation.
|
||||
"""
|
||||
mod = builder.module
|
||||
intr = mod.declare_intrinsic(name, [a.type for a in args])
|
||||
return builder.call(intr, args)
|
||||
|
||||
|
||||
def _unary_int_input_wrapper_impl(wrapped_impl):
|
||||
"""
|
||||
Return an implementation factory to convert the single integral input
|
||||
argument to a float64, then defer to the *wrapped_impl*.
|
||||
"""
|
||||
def implementer(context, builder, sig, args):
|
||||
val, = args
|
||||
input_type = sig.args[0]
|
||||
fpval = context.cast(builder, val, input_type, types.float64)
|
||||
inner_sig = signature(types.float64, types.float64)
|
||||
res = wrapped_impl(context, builder, inner_sig, (fpval,))
|
||||
return context.cast(builder, res, types.float64, sig.return_type)
|
||||
|
||||
return implementer
|
||||
|
||||
def unary_math_int_impl(fn, float_impl):
|
||||
impl = _unary_int_input_wrapper_impl(float_impl)
|
||||
# lower(fn, types.Integer)(impl)
|
||||
|
||||
def unary_math_intr(fn, intrcode):
|
||||
"""
|
||||
Implement the math function *fn* using the LLVM intrinsic *intrcode*.
|
||||
"""
|
||||
# @lower(fn, types.Float)
|
||||
def float_impl(context, builder, sig, args):
|
||||
res = call_fp_intrinsic(builder, intrcode, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
unary_math_int_impl(fn, float_impl)
|
||||
return float_impl
|
||||
|
||||
def unary_math_extern(fn, f32extern, f64extern, int_restype=False):
|
||||
"""
|
||||
Register implementations of Python function *fn* using the
|
||||
external function named *f32extern* and *f64extern* (for float32
|
||||
and float64 inputs, respectively).
|
||||
If *int_restype* is true, then the function's return value should be
|
||||
integral, otherwise floating-point.
|
||||
"""
|
||||
f_restype = types.int64 if int_restype else None
|
||||
|
||||
def float_impl(context, builder, sig, args):
|
||||
"""
|
||||
Implement *fn* for a types.Float input.
|
||||
"""
|
||||
[val] = args
|
||||
mod = builder.module
|
||||
input_type = sig.args[0]
|
||||
lty = context.get_value_type(input_type)
|
||||
func_name = {
|
||||
types.float32: f32extern,
|
||||
types.float64: f64extern,
|
||||
}[input_type]
|
||||
fnty = llvmlite.ir.FunctionType(lty, [lty])
|
||||
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
|
||||
res = builder.call(fn, (val,))
|
||||
res = context.cast(builder, res, input_type, sig.return_type)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# lower(fn, types.Float)(float_impl)
|
||||
|
||||
# Implement wrapper for integer inputs
|
||||
unary_math_int_impl(fn, float_impl)
|
||||
|
||||
return float_impl
|
||||
|
||||
|
||||
unary_math_intr(math.fabs, 'llvm.fabs')
|
||||
exp_impl = unary_math_intr(math.exp, 'llvm.exp')
|
||||
if sys.version_info >= (3, 11):
|
||||
exp2_impl = unary_math_intr(math.exp2, 'llvm.exp2')
|
||||
log_impl = unary_math_intr(math.log, 'llvm.log')
|
||||
log10_impl = unary_math_intr(math.log10, 'llvm.log10')
|
||||
sin_impl = unary_math_intr(math.sin, 'llvm.sin')
|
||||
cos_impl = unary_math_intr(math.cos, 'llvm.cos')
|
||||
|
||||
log1p_impl = unary_math_extern(math.log1p, "log1pf", "log1p")
|
||||
expm1_impl = unary_math_extern(math.expm1, "expm1f", "expm1")
|
||||
erf_impl = unary_math_extern(math.erf, "erff", "erf")
|
||||
erfc_impl = unary_math_extern(math.erfc, "erfcf", "erfc")
|
||||
|
||||
tan_impl = unary_math_extern(math.tan, "tanf", "tan")
|
||||
asin_impl = unary_math_extern(math.asin, "asinf", "asin")
|
||||
acos_impl = unary_math_extern(math.acos, "acosf", "acos")
|
||||
atan_impl = unary_math_extern(math.atan, "atanf", "atan")
|
||||
|
||||
asinh_impl = unary_math_extern(math.asinh, "asinhf", "asinh")
|
||||
acosh_impl = unary_math_extern(math.acosh, "acoshf", "acosh")
|
||||
atanh_impl = unary_math_extern(math.atanh, "atanhf", "atanh")
|
||||
sinh_impl = unary_math_extern(math.sinh, "sinhf", "sinh")
|
||||
cosh_impl = unary_math_extern(math.cosh, "coshf", "cosh")
|
||||
tanh_impl = unary_math_extern(math.tanh, "tanhf", "tanh")
|
||||
|
||||
log2_impl = unary_math_extern(math.log2, "log2f", "log2")
|
||||
ceil_impl = unary_math_extern(math.ceil, "ceilf", "ceil", True)
|
||||
floor_impl = unary_math_extern(math.floor, "floorf", "floor", True)
|
||||
|
||||
gamma_impl = unary_math_extern(math.gamma, "numba_gammaf", "numba_gamma") # work-around
|
||||
sqrt_impl = unary_math_extern(math.sqrt, "sqrtf", "sqrt")
|
||||
trunc_impl = unary_math_extern(math.trunc, "truncf", "trunc", True)
|
||||
lgamma_impl = unary_math_extern(math.lgamma, "lgammaf", "lgamma")
|
||||
|
||||
|
||||
# @lower(math.isnan, types.Float)
|
||||
def isnan_float_impl(context, builder, sig, args):
|
||||
[val] = args
|
||||
res = is_nan(builder, val)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# @lower(math.isnan, types.Integer)
|
||||
def isnan_int_impl(context, builder, sig, args):
|
||||
res = cgutils.false_bit
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(math.isinf, types.Float)
|
||||
def isinf_float_impl(context, builder, sig, args):
|
||||
[val] = args
|
||||
res = is_inf(builder, val)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# @lower(math.isinf, types.Integer)
|
||||
def isinf_int_impl(context, builder, sig, args):
|
||||
res = cgutils.false_bit
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(math.isfinite, types.Float)
|
||||
def isfinite_float_impl(context, builder, sig, args):
|
||||
[val] = args
|
||||
res = is_finite(builder, val)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(math.isfinite, types.Integer)
|
||||
def isfinite_int_impl(context, builder, sig, args):
|
||||
res = cgutils.true_bit
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(math.copysign, types.Float, types.Float)
|
||||
def copysign_float_impl(context, builder, sig, args):
|
||||
lty = args[0].type
|
||||
mod = builder.module
|
||||
fn = cgutils.get_or_insert_function(mod, llvmlite.ir.FunctionType(lty, (lty, lty)),
|
||||
'llvm.copysign.%s' % lty.intrinsic_name)
|
||||
res = builder.call(fn, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# @lower(math.frexp, types.Float)
|
||||
def frexp_impl(context, builder, sig, args):
|
||||
val, = args
|
||||
fltty = context.get_data_type(sig.args[0])
|
||||
intty = context.get_data_type(sig.return_type[1])
|
||||
expptr = cgutils.alloca_once(builder, intty, name='exp')
|
||||
fnty = llvmlite.ir.FunctionType(fltty, (fltty, llvmlite.ir.PointerType(intty)))
|
||||
fname = {
|
||||
"float": "numba_frexpf",
|
||||
"double": "numba_frexp",
|
||||
}[str(fltty)]
|
||||
fn = cgutils.get_or_insert_function(builder.module, fnty, fname)
|
||||
res = builder.call(fn, (val, expptr))
|
||||
res = cgutils.make_anonymous_struct(builder, (res, builder.load(expptr)))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(math.ldexp, types.Float, types.intc)
|
||||
def ldexp_impl(context, builder, sig, args):
|
||||
val, exp = args
|
||||
fltty, intty = map(context.get_data_type, sig.args)
|
||||
fnty = llvmlite.ir.FunctionType(fltty, (fltty, intty))
|
||||
fname = {
|
||||
"float": "numba_ldexpf",
|
||||
"double": "numba_ldexp",
|
||||
}[str(fltty)]
|
||||
fn = cgutils.insert_pure_function(builder.module, fnty, name=fname)
|
||||
res = builder.call(fn, (val, exp))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# @lower(math.atan2, types.int64, types.int64)
|
||||
def atan2_s64_impl(context, builder, sig, args):
|
||||
[y, x] = args
|
||||
y = builder.sitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.sitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
return atan2_float_impl(context, builder, fsig, (y, x))
|
||||
|
||||
# @lower(math.atan2, types.uint64, types.uint64)
|
||||
def atan2_u64_impl(context, builder, sig, args):
|
||||
[y, x] = args
|
||||
y = builder.uitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.uitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
return atan2_float_impl(context, builder, fsig, (y, x))
|
||||
|
||||
# @lower(math.atan2, types.Float, types.Float)
|
||||
def atan2_float_impl(context, builder, sig, args):
|
||||
assert len(args) == 2
|
||||
mod = builder.module
|
||||
ty = sig.args[0]
|
||||
lty = context.get_value_type(ty)
|
||||
func_name = {
|
||||
types.float32: "atan2f",
|
||||
types.float64: "atan2"
|
||||
}[ty]
|
||||
fnty = llvmlite.ir.FunctionType(lty, (lty, lty))
|
||||
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
|
||||
res = builder.call(fn, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# @lower(math.hypot, types.int64, types.int64)
|
||||
def hypot_s64_impl(context, builder, sig, args):
|
||||
[x, y] = args
|
||||
y = builder.sitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.sitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
res = hypot_float_impl(context, builder, fsig, (x, y))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(math.hypot, types.uint64, types.uint64)
|
||||
def hypot_u64_impl(context, builder, sig, args):
|
||||
[x, y] = args
|
||||
y = builder.sitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.sitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
res = hypot_float_impl(context, builder, fsig, (x, y))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# @lower(math.hypot, types.Float, types.Float)
|
||||
def hypot_float_impl(context, builder, sig, args):
|
||||
xty, yty = sig.args
|
||||
assert xty == yty == sig.return_type
|
||||
x, y = args
|
||||
|
||||
# Windows has alternate names for hypot/hypotf, see
|
||||
# https://msdn.microsoft.com/fr-fr/library/a9yb3dbt%28v=vs.80%29.aspx
|
||||
fname = {
|
||||
types.float32: "_hypotf" if sys.platform == 'win32' else "hypotf",
|
||||
types.float64: "_hypot" if sys.platform == 'win32' else "hypot",
|
||||
}[xty]
|
||||
plat_hypot = types.ExternalFunction(fname, sig)
|
||||
|
||||
if sys.platform == 'win32' and config.MACHINE_BITS == 32:
|
||||
inf = xty(float('inf'))
|
||||
|
||||
def hypot_impl(x, y):
|
||||
if math.isinf(x) or math.isinf(y):
|
||||
return inf
|
||||
return plat_hypot(x, y)
|
||||
else:
|
||||
def hypot_impl(x, y):
|
||||
return plat_hypot(x, y)
|
||||
|
||||
res = context.compile_internal(builder, hypot_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# @lower(math.radians, types.Float)
|
||||
def radians_float_impl(context, builder, sig, args):
|
||||
[x] = args
|
||||
coef = context.get_constant(sig.return_type, math.pi / 180)
|
||||
res = builder.fmul(x, coef)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
unary_math_int_impl(math.radians, radians_float_impl)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# @lower(math.degrees, types.Float)
|
||||
def degrees_float_impl(context, builder, sig, args):
|
||||
[x] = args
|
||||
coef = context.get_constant(sig.return_type, 180 / math.pi)
|
||||
res = builder.fmul(x, coef)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
unary_math_int_impl(math.degrees, degrees_float_impl)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# @lower(math.pow, types.Float, types.Float)
|
||||
# @lower(math.pow, types.Float, types.Integer)
|
||||
def pow_impl(context, builder, sig, args):
|
||||
impl = context.get_function(operator.pow, sig)
|
||||
return impl(builder, args)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _unsigned(T):
|
||||
"""Convert integer to unsigned integer of equivalent width."""
|
||||
pass
|
||||
|
||||
@overload(_unsigned)
|
||||
def _unsigned_impl(T):
|
||||
if T in types.unsigned_domain:
|
||||
return lambda T: T
|
||||
elif T in types.signed_domain:
|
||||
newT = getattr(types, 'uint{}'.format(T.bitwidth))
|
||||
return lambda T: newT(T)
|
||||
|
||||
|
||||
def gcd_impl(context, builder, sig, args):
|
||||
xty, yty = sig.args
|
||||
assert xty == yty == sig.return_type
|
||||
x, y = args
|
||||
|
||||
def gcd(a, b):
|
||||
"""
|
||||
Stein's algorithm, heavily cribbed from Julia implementation.
|
||||
"""
|
||||
T = type(a)
|
||||
if a == 0: return abs(b)
|
||||
if b == 0: return abs(a)
|
||||
za = trailing_zeros(a)
|
||||
zb = trailing_zeros(b)
|
||||
k = min(za, zb)
|
||||
# Uses np.*_shift instead of operators due to return types
|
||||
u = _unsigned(abs(np.right_shift(a, za)))
|
||||
v = _unsigned(abs(np.right_shift(b, zb)))
|
||||
while u != v:
|
||||
if u > v:
|
||||
u, v = v, u
|
||||
v -= u
|
||||
v = np.right_shift(v, trailing_zeros(v))
|
||||
r = np.left_shift(T(u), k)
|
||||
return r
|
||||
|
||||
res = context.compile_internal(builder, gcd, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# lower(math.gcd, types.Integer, types.Integer)(gcd_impl)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,840 @@
|
||||
"""
|
||||
Implementation of operations on numpy timedelta64.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
import llvmlite.ir
|
||||
from llvmlite.ir import Constant
|
||||
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.cgutils import create_constant_array
|
||||
from numba.core.imputils import (lower_builtin, lower_constant,
|
||||
impl_ret_untracked, lower_cast)
|
||||
from numba.np import npdatetime_helpers, numpy_support, npyfuncs
|
||||
from numba.extending import overload_method
|
||||
from numba.core.config import IS_32BITS
|
||||
from numba.core.errors import LoweringError
|
||||
|
||||
# datetime64 and timedelta64 use the same internal representation
|
||||
DATETIME64 = TIMEDELTA64 = llvmlite.ir.IntType(64)
|
||||
NAT = Constant(TIMEDELTA64, npdatetime_helpers.NAT)
|
||||
|
||||
TIMEDELTA_BINOP_SIG = (types.NPTimedelta,) * 2
|
||||
|
||||
|
||||
def scale_by_constant(builder, val, factor):
|
||||
"""
|
||||
Multiply *val* by the constant *factor*.
|
||||
"""
|
||||
return builder.mul(val, Constant(TIMEDELTA64, factor))
|
||||
|
||||
|
||||
def unscale_by_constant(builder, val, factor):
|
||||
"""
|
||||
Divide *val* by the constant *factor*.
|
||||
"""
|
||||
return builder.sdiv(val, Constant(TIMEDELTA64, factor))
|
||||
|
||||
|
||||
def add_constant(builder, val, const):
|
||||
"""
|
||||
Add constant *const* to *val*.
|
||||
"""
|
||||
return builder.add(val, Constant(TIMEDELTA64, const))
|
||||
|
||||
|
||||
def scale_timedelta(context, builder, val, srcty, destty):
|
||||
"""
|
||||
Scale the timedelta64 *val* from *srcty* to *destty*
|
||||
(both numba.types.NPTimedelta instances)
|
||||
"""
|
||||
factor = npdatetime_helpers.get_timedelta_conversion_factor(
|
||||
srcty.unit, destty.unit)
|
||||
if factor is None:
|
||||
# This can happen when using explicit output in a ufunc.
|
||||
msg = f"cannot convert timedelta64 from {srcty.unit} to {destty.unit}"
|
||||
raise LoweringError(msg)
|
||||
return scale_by_constant(builder, val, factor)
|
||||
|
||||
|
||||
def normalize_timedeltas(context, builder, left, right, leftty, rightty):
|
||||
"""
|
||||
Scale either *left* or *right* to the other's unit, in order to have
|
||||
homogeneous units.
|
||||
"""
|
||||
factor = npdatetime_helpers.get_timedelta_conversion_factor(
|
||||
leftty.unit, rightty.unit)
|
||||
if factor is not None:
|
||||
return scale_by_constant(builder, left, factor), right
|
||||
factor = npdatetime_helpers.get_timedelta_conversion_factor(
|
||||
rightty.unit, leftty.unit)
|
||||
if factor is not None:
|
||||
return left, scale_by_constant(builder, right, factor)
|
||||
# Typing should not let this happen, except on == and != operators
|
||||
raise RuntimeError("cannot normalize %r and %r" % (leftty, rightty))
|
||||
|
||||
|
||||
def alloc_timedelta_result(builder, name='ret'):
|
||||
"""
|
||||
Allocate a NaT-initialized datetime64 (or timedelta64) result slot.
|
||||
"""
|
||||
ret = cgutils.alloca_once(builder, TIMEDELTA64, name=name)
|
||||
builder.store(NAT, ret)
|
||||
return ret
|
||||
|
||||
|
||||
def alloc_boolean_result(builder, name='ret'):
|
||||
"""
|
||||
Allocate an uninitialized boolean result slot.
|
||||
"""
|
||||
ret = cgutils.alloca_once(builder, llvmlite.ir.IntType(1), name=name)
|
||||
return ret
|
||||
|
||||
|
||||
def is_not_nat(builder, val):
|
||||
"""
|
||||
Return a predicate which is true if *val* is not NaT.
|
||||
"""
|
||||
return builder.icmp_unsigned('!=', val, NAT)
|
||||
|
||||
|
||||
def are_not_nat(builder, vals):
|
||||
"""
|
||||
Return a predicate which is true if all of *vals* are not NaT.
|
||||
"""
|
||||
assert len(vals) >= 1
|
||||
pred = is_not_nat(builder, vals[0])
|
||||
for val in vals[1:]:
|
||||
pred = builder.and_(pred, is_not_nat(builder, val))
|
||||
return pred
|
||||
|
||||
|
||||
normal_year_months = create_constant_array(
|
||||
TIMEDELTA64,
|
||||
[31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
|
||||
leap_year_months = create_constant_array(
|
||||
TIMEDELTA64,
|
||||
[31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
|
||||
normal_year_months_acc = create_constant_array(
|
||||
TIMEDELTA64,
|
||||
[0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334])
|
||||
leap_year_months_acc = create_constant_array(
|
||||
TIMEDELTA64,
|
||||
[0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335])
|
||||
|
||||
|
||||
@lower_constant(types.NPDatetime)
|
||||
@lower_constant(types.NPTimedelta)
|
||||
def datetime_constant(context, builder, ty, pyval):
|
||||
return DATETIME64(pyval.astype(np.int64))
|
||||
|
||||
|
||||
# Arithmetic operators on timedelta64
|
||||
|
||||
@lower_builtin(operator.pos, types.NPTimedelta)
|
||||
def timedelta_pos_impl(context, builder, sig, args):
|
||||
res = args[0]
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.neg, types.NPTimedelta)
|
||||
def timedelta_neg_impl(context, builder, sig, args):
|
||||
res = builder.neg(args[0])
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(abs, types.NPTimedelta)
|
||||
def timedelta_abs_impl(context, builder, sig, args):
|
||||
val, = args
|
||||
ret = alloc_timedelta_result(builder)
|
||||
with builder.if_else(cgutils.is_scalar_neg(builder, val)) as (then, otherwise):
|
||||
with then:
|
||||
builder.store(builder.neg(val), ret)
|
||||
with otherwise:
|
||||
builder.store(val, ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
def timedelta_sign_impl(context, builder, sig, args):
|
||||
"""
|
||||
np.sign(timedelta64)
|
||||
"""
|
||||
val, = args
|
||||
ret = alloc_timedelta_result(builder)
|
||||
zero = Constant(TIMEDELTA64, 0)
|
||||
with builder.if_else(builder.icmp_signed('>', val, zero)
|
||||
) as (gt_zero, le_zero):
|
||||
with gt_zero:
|
||||
builder.store(Constant(TIMEDELTA64, 1), ret)
|
||||
with le_zero:
|
||||
with builder.if_else(builder.icmp_unsigned('==', val, zero)
|
||||
) as (eq_zero, lt_zero):
|
||||
with eq_zero:
|
||||
builder.store(Constant(TIMEDELTA64, 0), ret)
|
||||
with lt_zero:
|
||||
builder.store(Constant(TIMEDELTA64, -1), ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.add, *TIMEDELTA_BINOP_SIG)
|
||||
@lower_builtin(operator.iadd, *TIMEDELTA_BINOP_SIG)
|
||||
def timedelta_add_impl(context, builder, sig, args):
|
||||
[va, vb] = args
|
||||
[ta, tb] = sig.args
|
||||
ret = alloc_timedelta_result(builder)
|
||||
with cgutils.if_likely(builder, are_not_nat(builder, [va, vb])):
|
||||
va = scale_timedelta(context, builder, va, ta, sig.return_type)
|
||||
vb = scale_timedelta(context, builder, vb, tb, sig.return_type)
|
||||
builder.store(builder.add(va, vb), ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.sub, *TIMEDELTA_BINOP_SIG)
|
||||
@lower_builtin(operator.isub, *TIMEDELTA_BINOP_SIG)
|
||||
def timedelta_sub_impl(context, builder, sig, args):
|
||||
[va, vb] = args
|
||||
[ta, tb] = sig.args
|
||||
ret = alloc_timedelta_result(builder)
|
||||
with cgutils.if_likely(builder, are_not_nat(builder, [va, vb])):
|
||||
va = scale_timedelta(context, builder, va, ta, sig.return_type)
|
||||
vb = scale_timedelta(context, builder, vb, tb, sig.return_type)
|
||||
builder.store(builder.sub(va, vb), ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
def _timedelta_times_number(context, builder, td_arg, td_type,
|
||||
number_arg, number_type, return_type):
|
||||
ret = alloc_timedelta_result(builder)
|
||||
with cgutils.if_likely(builder, is_not_nat(builder, td_arg)):
|
||||
if isinstance(number_type, types.Float):
|
||||
val = builder.sitofp(td_arg, number_arg.type)
|
||||
val = builder.fmul(val, number_arg)
|
||||
val = _cast_to_timedelta(context, builder, val)
|
||||
else:
|
||||
val = builder.mul(td_arg, number_arg)
|
||||
# The scaling is required for ufunc np.multiply() with an explicit
|
||||
# output in a different unit.
|
||||
val = scale_timedelta(context, builder, val, td_type, return_type)
|
||||
builder.store(val, ret)
|
||||
return builder.load(ret)
|
||||
|
||||
|
||||
@lower_builtin(operator.mul, types.NPTimedelta, types.Integer)
|
||||
@lower_builtin(operator.imul, types.NPTimedelta, types.Integer)
|
||||
@lower_builtin(operator.mul, types.NPTimedelta, types.Float)
|
||||
@lower_builtin(operator.imul, types.NPTimedelta, types.Float)
|
||||
def timedelta_times_number(context, builder, sig, args):
|
||||
res = _timedelta_times_number(context, builder,
|
||||
args[0], sig.args[0], args[1], sig.args[1],
|
||||
sig.return_type)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.mul, types.Integer, types.NPTimedelta)
|
||||
@lower_builtin(operator.imul, types.Integer, types.NPTimedelta)
|
||||
@lower_builtin(operator.mul, types.Float, types.NPTimedelta)
|
||||
@lower_builtin(operator.imul, types.Float, types.NPTimedelta)
|
||||
def number_times_timedelta(context, builder, sig, args):
|
||||
res = _timedelta_times_number(context, builder,
|
||||
args[1], sig.args[1], args[0], sig.args[0],
|
||||
sig.return_type)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.truediv, types.NPTimedelta, types.Integer)
|
||||
@lower_builtin(operator.itruediv, types.NPTimedelta, types.Integer)
|
||||
@lower_builtin(operator.floordiv, types.NPTimedelta, types.Integer)
|
||||
@lower_builtin(operator.ifloordiv, types.NPTimedelta, types.Integer)
|
||||
@lower_builtin(operator.truediv, types.NPTimedelta, types.Float)
|
||||
@lower_builtin(operator.itruediv, types.NPTimedelta, types.Float)
|
||||
@lower_builtin(operator.floordiv, types.NPTimedelta, types.Float)
|
||||
@lower_builtin(operator.ifloordiv, types.NPTimedelta, types.Float)
|
||||
def timedelta_over_number(context, builder, sig, args):
|
||||
td_arg, number_arg = args
|
||||
number_type = sig.args[1]
|
||||
ret = alloc_timedelta_result(builder)
|
||||
ok = builder.and_(is_not_nat(builder, td_arg),
|
||||
builder.not_(cgutils.is_scalar_zero_or_nan(builder, number_arg)))
|
||||
with cgutils.if_likely(builder, ok):
|
||||
# Denominator is non-zero, non-NaN
|
||||
if isinstance(number_type, types.Float):
|
||||
val = builder.sitofp(td_arg, number_arg.type)
|
||||
val = builder.fdiv(val, number_arg)
|
||||
val = _cast_to_timedelta(context, builder, val)
|
||||
else:
|
||||
val = builder.sdiv(td_arg, number_arg)
|
||||
# The scaling is required for ufuncs np.*divide() with an explicit
|
||||
# output in a different unit.
|
||||
val = scale_timedelta(context, builder, val,
|
||||
sig.args[0], sig.return_type)
|
||||
builder.store(val, ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.truediv, *TIMEDELTA_BINOP_SIG)
|
||||
@lower_builtin(operator.itruediv, *TIMEDELTA_BINOP_SIG)
|
||||
def timedelta_over_timedelta(context, builder, sig, args):
|
||||
[va, vb] = args
|
||||
[ta, tb] = sig.args
|
||||
not_nan = are_not_nat(builder, [va, vb])
|
||||
ll_ret_type = context.get_value_type(sig.return_type)
|
||||
ret = cgutils.alloca_once(builder, ll_ret_type, name='ret')
|
||||
builder.store(Constant(ll_ret_type, float('nan')), ret)
|
||||
with cgutils.if_likely(builder, not_nan):
|
||||
va, vb = normalize_timedeltas(context, builder, va, vb, ta, tb)
|
||||
va = builder.sitofp(va, ll_ret_type)
|
||||
vb = builder.sitofp(vb, ll_ret_type)
|
||||
builder.store(builder.fdiv(va, vb), ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.floordiv, *TIMEDELTA_BINOP_SIG)
|
||||
def timedelta_floor_div_timedelta(context, builder, sig, args):
|
||||
[va, vb] = args
|
||||
[ta, tb] = sig.args
|
||||
ll_ret_type = context.get_value_type(sig.return_type)
|
||||
not_nan = are_not_nat(builder, [va, vb])
|
||||
ret = cgutils.alloca_once(builder, ll_ret_type, name='ret')
|
||||
zero = Constant(ll_ret_type, 0)
|
||||
one = Constant(ll_ret_type, 1)
|
||||
builder.store(zero, ret)
|
||||
with cgutils.if_likely(builder, not_nan):
|
||||
va, vb = normalize_timedeltas(context, builder, va, vb, ta, tb)
|
||||
# is the denominator zero or NaT?
|
||||
denom_ok = builder.not_(builder.icmp_signed('==', vb, zero))
|
||||
with cgutils.if_likely(builder, denom_ok):
|
||||
# is either arg negative?
|
||||
vaneg = builder.icmp_signed('<', va, zero)
|
||||
neg = builder.or_(vaneg, builder.icmp_signed('<', vb, zero))
|
||||
with builder.if_else(neg) as (then, otherwise):
|
||||
with then: # one or more value negative
|
||||
with builder.if_else(vaneg) as (negthen, negotherwise):
|
||||
with negthen:
|
||||
top = builder.sub(va, one)
|
||||
div = builder.sdiv(top, vb)
|
||||
builder.store(div, ret)
|
||||
with negotherwise:
|
||||
top = builder.add(va, one)
|
||||
div = builder.sdiv(top, vb)
|
||||
builder.store(div, ret)
|
||||
with otherwise:
|
||||
div = builder.sdiv(va, vb)
|
||||
builder.store(div, ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
def timedelta_mod_timedelta(context, builder, sig, args):
|
||||
# inspired by https://github.com/numpy/numpy/blob/fe8072a12d65e43bd2e0b0f9ad67ab0108cc54b3/numpy/core/src/umath/loops.c.src#L1424
|
||||
# alg is basically as `a % b`:
|
||||
# if a or b is NaT return NaT
|
||||
# elseif b is 0 return NaT
|
||||
# else pretend a and b are int and do pythonic int modulus
|
||||
|
||||
[va, vb] = args
|
||||
[ta, tb] = sig.args
|
||||
not_nan = are_not_nat(builder, [va, vb])
|
||||
ll_ret_type = context.get_value_type(sig.return_type)
|
||||
ret = alloc_timedelta_result(builder)
|
||||
builder.store(NAT, ret)
|
||||
zero = Constant(ll_ret_type, 0)
|
||||
with cgutils.if_likely(builder, not_nan):
|
||||
va, vb = normalize_timedeltas(context, builder, va, vb, ta, tb)
|
||||
# is the denominator zero or NaT?
|
||||
denom_ok = builder.not_(builder.icmp_signed('==', vb, zero))
|
||||
with cgutils.if_likely(builder, denom_ok):
|
||||
# is either arg negative?
|
||||
vapos = builder.icmp_signed('>', va, zero)
|
||||
vbpos = builder.icmp_signed('>', vb, zero)
|
||||
rem = builder.srem(va, vb)
|
||||
cond = builder.or_(builder.and_(vapos, vbpos),
|
||||
builder.icmp_signed('==', rem, zero))
|
||||
with builder.if_else(cond) as (then, otherwise):
|
||||
with then:
|
||||
builder.store(rem, ret)
|
||||
with otherwise:
|
||||
builder.store(builder.add(rem, vb), ret)
|
||||
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# Comparison operators on timedelta64
|
||||
|
||||
|
||||
def _create_timedelta_comparison_impl(ll_op, default_value):
|
||||
def impl(context, builder, sig, args):
|
||||
[va, vb] = args
|
||||
[ta, tb] = sig.args
|
||||
ret = alloc_boolean_result(builder)
|
||||
with builder.if_else(are_not_nat(builder, [va, vb])) as (then, otherwise):
|
||||
with then:
|
||||
try:
|
||||
norm_a, norm_b = normalize_timedeltas(
|
||||
context, builder, va, vb, ta, tb)
|
||||
except RuntimeError:
|
||||
# Cannot normalize units => the values are unequal (except if NaT)
|
||||
builder.store(default_value, ret)
|
||||
else:
|
||||
builder.store(builder.icmp_unsigned(ll_op, norm_a, norm_b), ret)
|
||||
with otherwise:
|
||||
# NaT ==/>=/>/</<= NaT is False
|
||||
# NaT != <anything, including NaT> is True
|
||||
if ll_op == '!=':
|
||||
builder.store(cgutils.true_bit, ret)
|
||||
else:
|
||||
builder.store(cgutils.false_bit, ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
def _create_timedelta_ordering_impl(ll_op):
|
||||
def impl(context, builder, sig, args):
|
||||
[va, vb] = args
|
||||
[ta, tb] = sig.args
|
||||
ret = alloc_boolean_result(builder)
|
||||
with builder.if_else(are_not_nat(builder, [va, vb])) as (then, otherwise):
|
||||
with then:
|
||||
norm_a, norm_b = normalize_timedeltas(
|
||||
context, builder, va, vb, ta, tb)
|
||||
builder.store(builder.icmp_signed(ll_op, norm_a, norm_b), ret)
|
||||
with otherwise:
|
||||
# NaT >=/>/</<= NaT is False
|
||||
builder.store(cgutils.false_bit, ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
timedelta_eq_timedelta_impl = _create_timedelta_comparison_impl(
|
||||
'==', cgutils.false_bit)
|
||||
timedelta_ne_timedelta_impl = _create_timedelta_comparison_impl(
|
||||
'!=', cgutils.true_bit)
|
||||
timedelta_lt_timedelta_impl = _create_timedelta_ordering_impl('<')
|
||||
timedelta_le_timedelta_impl = _create_timedelta_ordering_impl('<=')
|
||||
timedelta_gt_timedelta_impl = _create_timedelta_ordering_impl('>')
|
||||
timedelta_ge_timedelta_impl = _create_timedelta_ordering_impl('>=')
|
||||
|
||||
for op_, func in [(operator.eq, timedelta_eq_timedelta_impl),
|
||||
(operator.ne, timedelta_ne_timedelta_impl),
|
||||
(operator.lt, timedelta_lt_timedelta_impl),
|
||||
(operator.le, timedelta_le_timedelta_impl),
|
||||
(operator.gt, timedelta_gt_timedelta_impl),
|
||||
(operator.ge, timedelta_ge_timedelta_impl)]:
|
||||
lower_builtin(op_, *TIMEDELTA_BINOP_SIG)(func)
|
||||
|
||||
|
||||
# Arithmetic on datetime64
|
||||
|
||||
def is_leap_year(builder, year_val):
|
||||
"""
|
||||
Return a predicate indicating whether *year_val* (offset by 1970) is a
|
||||
leap year.
|
||||
"""
|
||||
actual_year = builder.add(year_val, Constant(DATETIME64, 1970))
|
||||
multiple_of_4 = cgutils.is_null(
|
||||
builder, builder.and_(actual_year, Constant(DATETIME64, 3)))
|
||||
not_multiple_of_100 = cgutils.is_not_null(
|
||||
builder, builder.srem(actual_year, Constant(DATETIME64, 100)))
|
||||
multiple_of_400 = cgutils.is_null(
|
||||
builder, builder.srem(actual_year, Constant(DATETIME64, 400)))
|
||||
return builder.and_(multiple_of_4,
|
||||
builder.or_(not_multiple_of_100, multiple_of_400))
|
||||
|
||||
|
||||
def year_to_days(builder, year_val):
|
||||
"""
|
||||
Given a year *year_val* (offset to 1970), return the number of days
|
||||
since the 1970 epoch.
|
||||
"""
|
||||
# The algorithm below is copied from Numpy's get_datetimestruct_days()
|
||||
# (src/multiarray/datetime.c)
|
||||
ret = cgutils.alloca_once(builder, TIMEDELTA64)
|
||||
# First approximation
|
||||
days = scale_by_constant(builder, year_val, 365)
|
||||
# Adjust for leap years
|
||||
with builder.if_else(cgutils.is_neg_int(builder, year_val)) \
|
||||
as (if_neg, if_pos):
|
||||
with if_pos:
|
||||
# At or after 1970:
|
||||
# 1968 is the closest leap year before 1970.
|
||||
# Exclude the current year, so add 1.
|
||||
from_1968 = add_constant(builder, year_val, 1)
|
||||
# Add one day for each 4 years
|
||||
p_days = builder.add(days,
|
||||
unscale_by_constant(builder, from_1968, 4))
|
||||
# 1900 is the closest previous year divisible by 100
|
||||
from_1900 = add_constant(builder, from_1968, 68)
|
||||
# Subtract one day for each 100 years
|
||||
p_days = builder.sub(p_days,
|
||||
unscale_by_constant(builder, from_1900, 100))
|
||||
# 1600 is the closest previous year divisible by 400
|
||||
from_1600 = add_constant(builder, from_1900, 300)
|
||||
# Add one day for each 400 years
|
||||
p_days = builder.add(p_days,
|
||||
unscale_by_constant(builder, from_1600, 400))
|
||||
builder.store(p_days, ret)
|
||||
with if_neg:
|
||||
# Before 1970:
|
||||
# NOTE `year_val` is negative, and so will be `from_1972` and `from_2000`.
|
||||
# 1972 is the closest later year after 1970.
|
||||
# Include the current year, so subtract 2.
|
||||
from_1972 = add_constant(builder, year_val, -2)
|
||||
# Subtract one day for each 4 years (`from_1972` is negative)
|
||||
n_days = builder.add(days,
|
||||
unscale_by_constant(builder, from_1972, 4))
|
||||
# 2000 is the closest later year divisible by 100
|
||||
from_2000 = add_constant(builder, from_1972, -28)
|
||||
# Add one day for each 100 years
|
||||
n_days = builder.sub(n_days,
|
||||
unscale_by_constant(builder, from_2000, 100))
|
||||
# 2000 is also the closest later year divisible by 400
|
||||
# Subtract one day for each 400 years
|
||||
n_days = builder.add(n_days,
|
||||
unscale_by_constant(builder, from_2000, 400))
|
||||
builder.store(n_days, ret)
|
||||
return builder.load(ret)
|
||||
|
||||
|
||||
def reduce_datetime_for_unit(builder, dt_val, src_unit, dest_unit):
|
||||
dest_unit_code = npdatetime_helpers.DATETIME_UNITS[dest_unit]
|
||||
src_unit_code = npdatetime_helpers.DATETIME_UNITS[src_unit]
|
||||
if dest_unit_code < 2 or src_unit_code >= 2:
|
||||
return dt_val, src_unit
|
||||
# Need to compute the day ordinal for *dt_val*
|
||||
if src_unit_code == 0:
|
||||
# Years to days
|
||||
year_val = dt_val
|
||||
days_val = year_to_days(builder, year_val)
|
||||
|
||||
else:
|
||||
# Months to days
|
||||
leap_array = cgutils.global_constant(builder, "leap_year_months_acc",
|
||||
leap_year_months_acc)
|
||||
normal_array = cgutils.global_constant(builder, "normal_year_months_acc",
|
||||
normal_year_months_acc)
|
||||
|
||||
days = cgutils.alloca_once(builder, TIMEDELTA64)
|
||||
|
||||
# First compute year number and month number
|
||||
year, month = cgutils.divmod_by_constant(builder, dt_val, 12)
|
||||
|
||||
# Then deduce the number of days
|
||||
with builder.if_else(is_leap_year(builder, year)) as (then, otherwise):
|
||||
with then:
|
||||
addend = builder.load(cgutils.gep(builder, leap_array,
|
||||
0, month, inbounds=True))
|
||||
builder.store(addend, days)
|
||||
with otherwise:
|
||||
addend = builder.load(cgutils.gep(builder, normal_array,
|
||||
0, month, inbounds=True))
|
||||
builder.store(addend, days)
|
||||
|
||||
days_val = year_to_days(builder, year)
|
||||
days_val = builder.add(days_val, builder.load(days))
|
||||
|
||||
if dest_unit_code == 2:
|
||||
# Need to scale back to weeks
|
||||
weeks, _ = cgutils.divmod_by_constant(builder, days_val, 7)
|
||||
return weeks, 'W'
|
||||
else:
|
||||
return days_val, 'D'
|
||||
|
||||
|
||||
def convert_datetime_for_arith(builder, dt_val, src_unit, dest_unit):
|
||||
"""
|
||||
Convert datetime *dt_val* from *src_unit* to *dest_unit*.
|
||||
"""
|
||||
# First partial conversion to days or weeks, if necessary.
|
||||
dt_val, dt_unit = reduce_datetime_for_unit(
|
||||
builder, dt_val, src_unit, dest_unit)
|
||||
# Then multiply by the remaining constant factor.
|
||||
dt_factor = npdatetime_helpers.get_timedelta_conversion_factor(dt_unit, dest_unit)
|
||||
if dt_factor is None:
|
||||
# This can happen when using explicit output in a ufunc.
|
||||
raise LoweringError("cannot convert datetime64 from %r to %r"
|
||||
% (src_unit, dest_unit))
|
||||
return scale_by_constant(builder, dt_val, dt_factor)
|
||||
|
||||
|
||||
def _datetime_timedelta_arith(ll_op_name):
|
||||
def impl(context, builder, dt_arg, dt_unit,
|
||||
td_arg, td_unit, ret_unit):
|
||||
ret = alloc_timedelta_result(builder)
|
||||
with cgutils.if_likely(builder, are_not_nat(builder, [dt_arg, td_arg])):
|
||||
dt_arg = convert_datetime_for_arith(builder, dt_arg,
|
||||
dt_unit, ret_unit)
|
||||
td_factor = npdatetime_helpers.get_timedelta_conversion_factor(
|
||||
td_unit, ret_unit)
|
||||
td_arg = scale_by_constant(builder, td_arg, td_factor)
|
||||
ret_val = getattr(builder, ll_op_name)(dt_arg, td_arg)
|
||||
builder.store(ret_val, ret)
|
||||
return builder.load(ret)
|
||||
return impl
|
||||
|
||||
|
||||
_datetime_plus_timedelta = _datetime_timedelta_arith('add')
|
||||
_datetime_minus_timedelta = _datetime_timedelta_arith('sub')
|
||||
|
||||
# datetime64 + timedelta64
|
||||
|
||||
|
||||
@lower_builtin(operator.add, types.NPDatetime, types.NPTimedelta)
|
||||
@lower_builtin(operator.iadd, types.NPDatetime, types.NPTimedelta)
|
||||
def datetime_plus_timedelta(context, builder, sig, args):
|
||||
dt_arg, td_arg = args
|
||||
dt_type, td_type = sig.args
|
||||
res = _datetime_plus_timedelta(context, builder,
|
||||
dt_arg, dt_type.unit,
|
||||
td_arg, td_type.unit,
|
||||
sig.return_type.unit)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.add, types.NPTimedelta, types.NPDatetime)
|
||||
@lower_builtin(operator.iadd, types.NPTimedelta, types.NPDatetime)
|
||||
def timedelta_plus_datetime(context, builder, sig, args):
|
||||
td_arg, dt_arg = args
|
||||
td_type, dt_type = sig.args
|
||||
res = _datetime_plus_timedelta(context, builder,
|
||||
dt_arg, dt_type.unit,
|
||||
td_arg, td_type.unit,
|
||||
sig.return_type.unit)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# datetime64 - timedelta64
|
||||
|
||||
|
||||
@lower_builtin(operator.sub, types.NPDatetime, types.NPTimedelta)
|
||||
@lower_builtin(operator.isub, types.NPDatetime, types.NPTimedelta)
|
||||
def datetime_minus_timedelta(context, builder, sig, args):
|
||||
dt_arg, td_arg = args
|
||||
dt_type, td_type = sig.args
|
||||
res = _datetime_minus_timedelta(context, builder,
|
||||
dt_arg, dt_type.unit,
|
||||
td_arg, td_type.unit,
|
||||
sig.return_type.unit)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# datetime64 - datetime64
|
||||
|
||||
|
||||
@lower_builtin(operator.sub, types.NPDatetime, types.NPDatetime)
|
||||
def datetime_minus_datetime(context, builder, sig, args):
|
||||
va, vb = args
|
||||
ta, tb = sig.args
|
||||
unit_a = ta.unit
|
||||
unit_b = tb.unit
|
||||
ret_unit = sig.return_type.unit
|
||||
ret = alloc_timedelta_result(builder)
|
||||
with cgutils.if_likely(builder, are_not_nat(builder, [va, vb])):
|
||||
va = convert_datetime_for_arith(builder, va, unit_a, ret_unit)
|
||||
vb = convert_datetime_for_arith(builder, vb, unit_b, ret_unit)
|
||||
ret_val = builder.sub(va, vb)
|
||||
builder.store(ret_val, ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# datetime64 comparisons
|
||||
|
||||
|
||||
def _create_datetime_comparison_impl(ll_op):
|
||||
def impl(context, builder, sig, args):
|
||||
va, vb = args
|
||||
ta, tb = sig.args
|
||||
unit_a = ta.unit
|
||||
unit_b = tb.unit
|
||||
ret_unit = npdatetime_helpers.get_best_unit(unit_a, unit_b)
|
||||
ret = alloc_boolean_result(builder)
|
||||
with builder.if_else(are_not_nat(builder, [va, vb])) as (then, otherwise):
|
||||
with then:
|
||||
norm_a = convert_datetime_for_arith(
|
||||
builder, va, unit_a, ret_unit)
|
||||
norm_b = convert_datetime_for_arith(
|
||||
builder, vb, unit_b, ret_unit)
|
||||
ret_val = builder.icmp_signed(ll_op, norm_a, norm_b)
|
||||
builder.store(ret_val, ret)
|
||||
with otherwise:
|
||||
if ll_op == '!=':
|
||||
ret_val = cgutils.true_bit
|
||||
else:
|
||||
ret_val = cgutils.false_bit
|
||||
builder.store(ret_val, ret)
|
||||
res = builder.load(ret)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
datetime_eq_datetime_impl = _create_datetime_comparison_impl('==')
|
||||
datetime_ne_datetime_impl = _create_datetime_comparison_impl('!=')
|
||||
datetime_lt_datetime_impl = _create_datetime_comparison_impl('<')
|
||||
datetime_le_datetime_impl = _create_datetime_comparison_impl('<=')
|
||||
datetime_gt_datetime_impl = _create_datetime_comparison_impl('>')
|
||||
datetime_ge_datetime_impl = _create_datetime_comparison_impl('>=')
|
||||
|
||||
for op, func in [(operator.eq, datetime_eq_datetime_impl),
|
||||
(operator.ne, datetime_ne_datetime_impl),
|
||||
(operator.lt, datetime_lt_datetime_impl),
|
||||
(operator.le, datetime_le_datetime_impl),
|
||||
(operator.gt, datetime_gt_datetime_impl),
|
||||
(operator.ge, datetime_ge_datetime_impl)]:
|
||||
lower_builtin(op, *[types.NPDatetime]*2)(func)
|
||||
|
||||
|
||||
########################################################################
|
||||
# datetime/timedelta fmax/fmin maximum/minimum support
|
||||
|
||||
def _gen_datetime_max_impl(NAT_DOMINATES):
|
||||
def datetime_max_impl(context, builder, sig, args):
|
||||
# note this could be optimizing relying on the actual value of NAT
|
||||
# but as NumPy doesn't rely on this, this seems more resilient
|
||||
in1, in2 = args
|
||||
in1_not_nat = is_not_nat(builder, in1)
|
||||
in2_not_nat = is_not_nat(builder, in2)
|
||||
in1_ge_in2 = builder.icmp_signed('>=', in1, in2)
|
||||
res = builder.select(in1_ge_in2, in1, in2)
|
||||
if NAT_DOMINATES:
|
||||
# NaT now dominates, like NaN
|
||||
in1, in2 = in2, in1
|
||||
res = builder.select(in1_not_nat, res, in2)
|
||||
res = builder.select(in2_not_nat, res, in1)
|
||||
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
return datetime_max_impl
|
||||
|
||||
datetime_maximum_impl = _gen_datetime_max_impl(True)
|
||||
datetime_fmax_impl = _gen_datetime_max_impl(False)
|
||||
|
||||
def _gen_datetime_min_impl(NAT_DOMINATES):
|
||||
def datetime_min_impl(context, builder, sig, args):
|
||||
# note this could be optimizing relying on the actual value of NAT
|
||||
# but as NumPy doesn't rely on this, this seems more resilient
|
||||
in1, in2 = args
|
||||
in1_not_nat = is_not_nat(builder, in1)
|
||||
in2_not_nat = is_not_nat(builder, in2)
|
||||
in1_le_in2 = builder.icmp_signed('<=', in1, in2)
|
||||
res = builder.select(in1_le_in2, in1, in2)
|
||||
if NAT_DOMINATES:
|
||||
# NaT now dominates, like NaN
|
||||
in1, in2 = in2, in1
|
||||
res = builder.select(in1_not_nat, res, in2)
|
||||
res = builder.select(in2_not_nat, res, in1)
|
||||
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
return datetime_min_impl
|
||||
|
||||
datetime_minimum_impl = _gen_datetime_min_impl(True)
|
||||
datetime_fmin_impl = _gen_datetime_min_impl(False)
|
||||
|
||||
def _gen_timedelta_max_impl(NAT_DOMINATES):
|
||||
def timedelta_max_impl(context, builder, sig, args):
|
||||
# note this could be optimizing relying on the actual value of NAT
|
||||
# but as NumPy doesn't rely on this, this seems more resilient
|
||||
in1, in2 = args
|
||||
in1_not_nat = is_not_nat(builder, in1)
|
||||
in2_not_nat = is_not_nat(builder, in2)
|
||||
in1_ge_in2 = builder.icmp_signed('>=', in1, in2)
|
||||
res = builder.select(in1_ge_in2, in1, in2)
|
||||
if NAT_DOMINATES:
|
||||
# NaT now dominates, like NaN
|
||||
in1, in2 = in2, in1
|
||||
res = builder.select(in1_not_nat, res, in2)
|
||||
res = builder.select(in2_not_nat, res, in1)
|
||||
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
return timedelta_max_impl
|
||||
|
||||
timedelta_maximum_impl = _gen_timedelta_max_impl(True)
|
||||
timedelta_fmax_impl = _gen_timedelta_max_impl(False)
|
||||
|
||||
def _gen_timedelta_min_impl(NAT_DOMINATES):
|
||||
def timedelta_min_impl(context, builder, sig, args):
|
||||
# note this could be optimizing relying on the actual value of NAT
|
||||
# but as NumPy doesn't rely on this, this seems more resilient
|
||||
in1, in2 = args
|
||||
in1_not_nat = is_not_nat(builder, in1)
|
||||
in2_not_nat = is_not_nat(builder, in2)
|
||||
in1_le_in2 = builder.icmp_signed('<=', in1, in2)
|
||||
res = builder.select(in1_le_in2, in1, in2)
|
||||
if NAT_DOMINATES:
|
||||
# NaT now dominates, like NaN
|
||||
in1, in2 = in2, in1
|
||||
res = builder.select(in1_not_nat, res, in2)
|
||||
res = builder.select(in2_not_nat, res, in1)
|
||||
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
return timedelta_min_impl
|
||||
|
||||
timedelta_minimum_impl = _gen_timedelta_min_impl(True)
|
||||
timedelta_fmin_impl = _gen_timedelta_min_impl(False)
|
||||
|
||||
def _cast_to_timedelta(context, builder, val):
|
||||
temp = builder.alloca(TIMEDELTA64)
|
||||
val_is_nan = builder.fcmp_unordered('uno', val, val)
|
||||
with builder.if_else(val_is_nan) as (
|
||||
then, els):
|
||||
with then:
|
||||
# NaN does not guarantee to cast to NAT.
|
||||
# We should store NAT explicitly.
|
||||
builder.store(NAT, temp)
|
||||
with els:
|
||||
builder.store(builder.fptosi(val, TIMEDELTA64), temp)
|
||||
return builder.load(temp)
|
||||
|
||||
|
||||
@lower_builtin(np.isnat, types.NPDatetime)
|
||||
@lower_builtin(np.isnat, types.NPTimedelta)
|
||||
def _np_isnat_impl(context, builder, sig, args):
|
||||
return npyfuncs.np_datetime_isnat_impl(context, builder, sig, args)
|
||||
|
||||
|
||||
@lower_cast(types.NPDatetime, types.Integer)
|
||||
@lower_cast(types.NPTimedelta, types.Integer)
|
||||
def _cast_npdatetime_int64(context, builder, fromty, toty, val):
|
||||
if toty.bitwidth != 64: # all date time types are 64 bit
|
||||
msg = f"Cannot cast {fromty} to {toty} as {toty} is not 64 bits wide."
|
||||
raise ValueError(msg)
|
||||
return val
|
||||
|
||||
|
||||
@overload_method(types.NPTimedelta, '__hash__')
|
||||
@overload_method(types.NPDatetime, '__hash__')
|
||||
def ol_hash_npdatetime(x):
|
||||
if numpy_support.numpy_version >= (2, 2)\
|
||||
and isinstance(x, types.NPTimedelta) and not x.unit:
|
||||
raise ValueError("Can't hash generic timedelta64")
|
||||
|
||||
if IS_32BITS:
|
||||
def impl(x):
|
||||
x = np.int64(x)
|
||||
if x < 2**31 - 1: # x < LONG_MAX
|
||||
y = np.int32(x)
|
||||
else:
|
||||
hi = (np.int64(x) & 0xffffffff00000000) >> 32
|
||||
lo = (np.int64(x) & 0x00000000ffffffff)
|
||||
y = np.int32(lo + (1000003) * hi)
|
||||
if y == -1:
|
||||
y = np.int32(-2)
|
||||
return y
|
||||
else:
|
||||
def impl(x):
|
||||
if np.int64(x) == -1:
|
||||
return np.int64(-2)
|
||||
return np.int64(x)
|
||||
return impl
|
||||
|
||||
|
||||
lower_builtin(npdatetime_helpers.datetime_minimum, types.NPDatetime, types.NPDatetime)(datetime_minimum_impl)
|
||||
lower_builtin(npdatetime_helpers.datetime_minimum, types.NPTimedelta, types.NPTimedelta)(datetime_minimum_impl)
|
||||
lower_builtin(npdatetime_helpers.datetime_maximum, types.NPDatetime, types.NPDatetime)(datetime_maximum_impl)
|
||||
lower_builtin(npdatetime_helpers.datetime_maximum, types.NPTimedelta, types.NPTimedelta)(datetime_maximum_impl)
|
||||
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Helper functions for np.timedelta64 and np.datetime64.
|
||||
For now, multiples-of-units (for example timedeltas expressed in tens
|
||||
of seconds) are not supported.
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
DATETIME_UNITS = {
|
||||
'Y': 0, # Years
|
||||
'M': 1, # Months
|
||||
'W': 2, # Weeks
|
||||
# Yes, there's a gap here
|
||||
'D': 4, # Days
|
||||
'h': 5, # Hours
|
||||
'm': 6, # Minutes
|
||||
's': 7, # Seconds
|
||||
'ms': 8, # Milliseconds
|
||||
'us': 9, # Microseconds
|
||||
'ns': 10, # Nanoseconds
|
||||
'ps': 11, # Picoseconds
|
||||
'fs': 12, # Femtoseconds
|
||||
'as': 13, # Attoseconds
|
||||
'': 14, # "generic", i.e. unit-less
|
||||
}
|
||||
|
||||
NAT = np.timedelta64('nat').astype(np.int64)
|
||||
|
||||
# NOTE: numpy has several inconsistent functions for timedelta casting:
|
||||
# - can_cast_timedelta64_{metadata,units}() disallows "safe" casting
|
||||
# to and from generic units
|
||||
# - cast_timedelta_to_timedelta() allows casting from (but not to)
|
||||
# generic units
|
||||
# - compute_datetime_metadata_greatest_common_divisor() allows casting from
|
||||
# generic units (used for promotion)
|
||||
|
||||
|
||||
def same_kind(src, dest):
|
||||
"""
|
||||
Whether the *src* and *dest* units are of the same kind.
|
||||
"""
|
||||
return (DATETIME_UNITS[src] < 5) == (DATETIME_UNITS[dest] < 5)
|
||||
|
||||
|
||||
def can_cast_timedelta_units(src, dest):
|
||||
# Mimic NumPy's "safe" casting and promotion
|
||||
# `dest` must be more precise than `src` and they must be compatible
|
||||
# for conversion.
|
||||
# XXX should we switch to enforcing "same-kind" for Numpy 1.10+ ?
|
||||
src = DATETIME_UNITS[src]
|
||||
dest = DATETIME_UNITS[dest]
|
||||
if src == dest:
|
||||
return True
|
||||
if src == 14:
|
||||
return True
|
||||
if src > dest:
|
||||
return False
|
||||
if dest == 14:
|
||||
# unit-less timedelta64 is not compatible with anything else
|
||||
return False
|
||||
if src <= 1 and dest > 1:
|
||||
# Cannot convert between months or years and other units
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Exact conversion factors from one unit to the immediately more precise one
|
||||
_factors = {
|
||||
0: (1, 12), # Years -> Months
|
||||
2: (4, 7), # Weeks -> Days
|
||||
4: (5, 24), # Days -> Hours
|
||||
5: (6, 60), # Hours -> Minutes
|
||||
6: (7, 60), # Minutes -> Seconds
|
||||
7: (8, 1000),
|
||||
8: (9, 1000),
|
||||
9: (10, 1000),
|
||||
10: (11, 1000),
|
||||
11: (12, 1000),
|
||||
12: (13, 1000),
|
||||
}
|
||||
|
||||
|
||||
def _get_conversion_multiplier(big_unit_code, small_unit_code):
|
||||
"""
|
||||
Return an integer multiplier allowing to convert from *big_unit_code*
|
||||
to *small_unit_code*.
|
||||
None is returned if the conversion is not possible through a
|
||||
simple integer multiplication.
|
||||
"""
|
||||
# Mimics get_datetime_units_factor() in NumPy's datetime.c,
|
||||
# with a twist to allow no-op conversion from generic units.
|
||||
if big_unit_code == 14:
|
||||
return 1
|
||||
c = big_unit_code
|
||||
factor = 1
|
||||
while c < small_unit_code:
|
||||
try:
|
||||
c, mult = _factors[c]
|
||||
except KeyError:
|
||||
# No possible conversion
|
||||
return None
|
||||
factor *= mult
|
||||
if c == small_unit_code:
|
||||
return factor
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_timedelta_conversion_factor(src_unit, dest_unit):
|
||||
"""
|
||||
Return an integer multiplier allowing to convert from timedeltas
|
||||
of *src_unit* to *dest_unit*.
|
||||
"""
|
||||
return _get_conversion_multiplier(DATETIME_UNITS[src_unit],
|
||||
DATETIME_UNITS[dest_unit])
|
||||
|
||||
|
||||
def get_datetime_timedelta_conversion(datetime_unit, timedelta_unit):
|
||||
"""
|
||||
Compute a possible conversion for combining *datetime_unit* and
|
||||
*timedelta_unit* (presumably for adding or subtracting).
|
||||
Return (result unit, integer datetime multiplier, integer timedelta
|
||||
multiplier). RuntimeError is raised if the combination is impossible.
|
||||
"""
|
||||
# XXX now unused (I don't know where / how Numpy uses this)
|
||||
dt_unit_code = DATETIME_UNITS[datetime_unit]
|
||||
td_unit_code = DATETIME_UNITS[timedelta_unit]
|
||||
if td_unit_code == 14 or dt_unit_code == 14:
|
||||
return datetime_unit, 1, 1
|
||||
if td_unit_code < 2 and dt_unit_code >= 2:
|
||||
# Cannot combine Y or M timedelta64 with a finer-grained datetime64
|
||||
raise RuntimeError("cannot combine datetime64(%r) and timedelta64(%r)"
|
||||
% (datetime_unit, timedelta_unit))
|
||||
dt_factor, td_factor = 1, 1
|
||||
|
||||
# If years or months, the datetime unit is first scaled to weeks or days,
|
||||
# then conversion continues below. This is the same algorithm as used
|
||||
# in Numpy's get_datetime_conversion_factor() (src/multiarray/datetime.c):
|
||||
# """Conversions between years/months and other units use
|
||||
# the factor averaged over the 400 year leap year cycle."""
|
||||
if dt_unit_code == 0:
|
||||
if td_unit_code >= 4:
|
||||
dt_factor = 97 + 400 * 365
|
||||
td_factor = 400
|
||||
dt_unit_code = 4
|
||||
elif td_unit_code == 2:
|
||||
dt_factor = 97 + 400 * 365
|
||||
td_factor = 400 * 7
|
||||
dt_unit_code = 2
|
||||
elif dt_unit_code == 1:
|
||||
if td_unit_code >= 4:
|
||||
dt_factor = 97 + 400 * 365
|
||||
td_factor = 400 * 12
|
||||
dt_unit_code = 4
|
||||
elif td_unit_code == 2:
|
||||
dt_factor = 97 + 400 * 365
|
||||
td_factor = 400 * 12 * 7
|
||||
dt_unit_code = 2
|
||||
|
||||
if td_unit_code >= dt_unit_code:
|
||||
factor = _get_conversion_multiplier(dt_unit_code, td_unit_code)
|
||||
assert factor is not None, (dt_unit_code, td_unit_code)
|
||||
return timedelta_unit, dt_factor * factor, td_factor
|
||||
else:
|
||||
factor = _get_conversion_multiplier(td_unit_code, dt_unit_code)
|
||||
assert factor is not None, (dt_unit_code, td_unit_code)
|
||||
return datetime_unit, dt_factor, td_factor * factor
|
||||
|
||||
|
||||
def combine_datetime_timedelta_units(datetime_unit, timedelta_unit):
|
||||
"""
|
||||
Return the unit result of combining *datetime_unit* with *timedelta_unit*
|
||||
(e.g. by adding or subtracting). None is returned if combining
|
||||
those units is forbidden.
|
||||
"""
|
||||
dt_unit_code = DATETIME_UNITS[datetime_unit]
|
||||
td_unit_code = DATETIME_UNITS[timedelta_unit]
|
||||
if dt_unit_code == 14:
|
||||
return timedelta_unit
|
||||
elif td_unit_code == 14:
|
||||
return datetime_unit
|
||||
if td_unit_code < 2 and dt_unit_code >= 2:
|
||||
return None
|
||||
if dt_unit_code > td_unit_code:
|
||||
return datetime_unit
|
||||
else:
|
||||
return timedelta_unit
|
||||
|
||||
|
||||
def get_best_unit(unit_a, unit_b):
|
||||
"""
|
||||
Get the best (i.e. finer-grained) of two units.
|
||||
"""
|
||||
a = DATETIME_UNITS[unit_a]
|
||||
b = DATETIME_UNITS[unit_b]
|
||||
if a == 14:
|
||||
return unit_b
|
||||
if b == 14:
|
||||
return unit_a
|
||||
if b > a:
|
||||
return unit_b
|
||||
return unit_a
|
||||
|
||||
|
||||
def datetime_minimum(a, b):
|
||||
pass
|
||||
|
||||
|
||||
def datetime_maximum(a, b):
|
||||
pass
|
||||
1704
linedance-app/venv/lib/python3.12/site-packages/numba/np/npyfuncs.py
Normal file
1704
linedance-app/venv/lib/python3.12/site-packages/numba/np/npyfuncs.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,876 @@
|
||||
"""
|
||||
Implementation of functions in the Numpy package.
|
||||
"""
|
||||
|
||||
|
||||
import math
|
||||
import sys
|
||||
import itertools
|
||||
from collections import namedtuple
|
||||
|
||||
import llvmlite.ir as ir
|
||||
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
from numba.np import arrayobj, ufunc_db, numpy_support
|
||||
from numba.np.ufunc.sigparse import parse_signature
|
||||
from numba.core.imputils import (Registry, impl_ret_new_ref, force_error_model, impl_ret_borrowed)
|
||||
from numba.core import typing, types, utils, cgutils, callconv, config
|
||||
from numba.np.numpy_support import (
|
||||
ufunc_find_matching_loop, select_array_wrapper, from_dtype, _ufunc_loop_sig
|
||||
)
|
||||
from numba.np.arrayobj import _getitem_array_generic
|
||||
from numba.core.typing import npydecl
|
||||
from numba.core.extending import overload, intrinsic
|
||||
|
||||
from numba.core import errors
|
||||
|
||||
registry = Registry('npyimpl')
|
||||
|
||||
|
||||
########################################################################
|
||||
|
||||
# In the way we generate code, ufuncs work with scalar as well as
|
||||
# with array arguments. The following helper classes help dealing
|
||||
# with scalar and array arguments in a regular way.
|
||||
#
|
||||
# In short, the classes provide a uniform interface. The interface
|
||||
# handles the indexing of as many dimensions as the array may have.
|
||||
# For scalars, all indexing is ignored and when the value is read,
|
||||
# the scalar is returned. For arrays code for actual indexing is
|
||||
# generated and reading performs the appropriate indirection.
|
||||
|
||||
class _ScalarIndexingHelper(object):
|
||||
def update_indices(self, loop_indices, name):
|
||||
pass
|
||||
|
||||
def as_values(self):
|
||||
pass
|
||||
|
||||
|
||||
class _ScalarHelper(object):
|
||||
"""Helper class to handle scalar arguments (and result).
|
||||
Note that store_data is only used when generating code for
|
||||
a scalar ufunc and to write the output value.
|
||||
|
||||
For loading, the value is directly used without having any
|
||||
kind of indexing nor memory backing it up. This is the use
|
||||
for input arguments.
|
||||
|
||||
For storing, a variable is created in the stack where the
|
||||
value will be written.
|
||||
|
||||
Note that it is not supported (as it is unneeded for our
|
||||
current use-cases) reading back a stored value. This class
|
||||
will always "load" the original value it got at its creation.
|
||||
"""
|
||||
def __init__(self, ctxt, bld, val, ty):
|
||||
self.context = ctxt
|
||||
self.builder = bld
|
||||
self.val = val
|
||||
self.base_type = ty
|
||||
intpty = ctxt.get_value_type(types.intp)
|
||||
self.shape = [ir.Constant(intpty, 1)]
|
||||
|
||||
lty = ctxt.get_data_type(ty) if ty != types.boolean else ir.IntType(1)
|
||||
self._ptr = cgutils.alloca_once(bld, lty)
|
||||
|
||||
def create_iter_indices(self):
|
||||
return _ScalarIndexingHelper()
|
||||
|
||||
def load_data(self, indices):
|
||||
return self.val
|
||||
|
||||
def store_data(self, indices, val):
|
||||
self.builder.store(val, self._ptr)
|
||||
|
||||
@property
|
||||
def return_val(self):
|
||||
return self.builder.load(self._ptr)
|
||||
|
||||
|
||||
class _ArrayIndexingHelper(namedtuple('_ArrayIndexingHelper',
|
||||
('array', 'indices'))):
|
||||
def update_indices(self, loop_indices, name):
|
||||
bld = self.array.builder
|
||||
intpty = self.array.context.get_value_type(types.intp)
|
||||
ONE = ir.Constant(ir.IntType(intpty.width), 1)
|
||||
|
||||
# we are only interested in as many inner dimensions as dimensions
|
||||
# the indexed array has (the outer dimensions are broadcast, so
|
||||
# ignoring the outer indices produces the desired result.
|
||||
indices = loop_indices[len(loop_indices) - len(self.indices):]
|
||||
for src, dst, dim in zip(indices, self.indices, self.array.shape):
|
||||
cond = bld.icmp_unsigned('>', dim, ONE)
|
||||
with bld.if_then(cond):
|
||||
bld.store(src, dst)
|
||||
|
||||
def as_values(self):
|
||||
"""
|
||||
The indexing helper is built using alloca for each value, so it
|
||||
actually contains pointers to the actual indices to load. Note
|
||||
that update_indices assumes the same. This method returns the
|
||||
indices as values
|
||||
"""
|
||||
bld = self.array.builder
|
||||
return [bld.load(index) for index in self.indices]
|
||||
|
||||
|
||||
class _ArrayHelper(namedtuple('_ArrayHelper', ('context', 'builder',
|
||||
'shape', 'strides', 'data',
|
||||
'layout', 'base_type', 'ndim',
|
||||
'return_val'))):
|
||||
"""Helper class to handle array arguments/result.
|
||||
It provides methods to generate code loading/storing specific
|
||||
items as well as support code for handling indices.
|
||||
"""
|
||||
def create_iter_indices(self):
|
||||
intpty = self.context.get_value_type(types.intp)
|
||||
ZERO = ir.Constant(ir.IntType(intpty.width), 0)
|
||||
|
||||
indices = []
|
||||
for i in range(self.ndim):
|
||||
x = cgutils.alloca_once(self.builder, ir.IntType(intpty.width))
|
||||
self.builder.store(ZERO, x)
|
||||
indices.append(x)
|
||||
return _ArrayIndexingHelper(self, indices)
|
||||
|
||||
def _load_effective_address(self, indices):
|
||||
return cgutils.get_item_pointer2(self.context,
|
||||
self.builder,
|
||||
data=self.data,
|
||||
shape=self.shape,
|
||||
strides=self.strides,
|
||||
layout=self.layout,
|
||||
inds=indices)
|
||||
|
||||
def load_data(self, indices):
|
||||
model = self.context.data_model_manager[self.base_type]
|
||||
ptr = self._load_effective_address(indices)
|
||||
return model.load_from_data_pointer(self.builder, ptr)
|
||||
|
||||
def store_data(self, indices, value):
|
||||
ctx = self.context
|
||||
bld = self.builder
|
||||
store_value = ctx.get_value_as_data(bld, self.base_type, value)
|
||||
assert ctx.get_data_type(self.base_type) == store_value.type
|
||||
bld.store(store_value, self._load_effective_address(indices))
|
||||
|
||||
|
||||
class _ArrayGUHelper(namedtuple('_ArrayHelper', ('context', 'builder',
|
||||
'shape', 'strides', 'data',
|
||||
'layout', 'base_type', 'ndim',
|
||||
'inner_arr_ty', 'is_input_arg'))):
|
||||
"""Helper class to handle array arguments/result.
|
||||
It provides methods to generate code loading/storing specific
|
||||
items as well as support code for handling indices.
|
||||
|
||||
Contrary to _ArrayHelper, this class can create a view to a subarray
|
||||
"""
|
||||
def create_iter_indices(self):
|
||||
intpty = self.context.get_value_type(types.intp)
|
||||
ZERO = ir.Constant(ir.IntType(intpty.width), 0)
|
||||
|
||||
indices = []
|
||||
for i in range(self.ndim - self.inner_arr_ty.ndim):
|
||||
x = cgutils.alloca_once(self.builder, ir.IntType(intpty.width))
|
||||
self.builder.store(ZERO, x)
|
||||
indices.append(x)
|
||||
return _ArrayIndexingHelper(self, indices)
|
||||
|
||||
def _load_effective_address(self, indices):
|
||||
context = self.context
|
||||
builder = self.builder
|
||||
arr_ty = types.Array(self.base_type, self.ndim, self.layout)
|
||||
arr = context.make_array(arr_ty)(context, builder, self.data)
|
||||
|
||||
return cgutils.get_item_pointer2(context,
|
||||
builder,
|
||||
data=arr.data,
|
||||
shape=self.shape,
|
||||
strides=self.strides,
|
||||
layout=self.layout,
|
||||
inds=indices)
|
||||
|
||||
def load_data(self, indices):
|
||||
context, builder = self.context, self.builder
|
||||
|
||||
if self.inner_arr_ty.ndim == 0 and self.is_input_arg:
|
||||
# scalar case for input arguments
|
||||
model = context.data_model_manager[self.base_type]
|
||||
ptr = self._load_effective_address(indices)
|
||||
return model.load_from_data_pointer(builder, ptr)
|
||||
elif self.inner_arr_ty.ndim == 0 and not self.is_input_arg:
|
||||
# Output arrays are handled as 1d with shape=(1,) when its
|
||||
# signature represents a scalar. For instance: "(n),(m) -> ()"
|
||||
intpty = context.get_value_type(types.intp)
|
||||
one = intpty(1)
|
||||
|
||||
fromty = types.Array(self.base_type, self.ndim, self.layout)
|
||||
toty = types.Array(self.base_type, 1, self.layout)
|
||||
itemsize = intpty(arrayobj.get_itemsize(context, fromty))
|
||||
|
||||
# create a view from the original ndarray to a 1d array
|
||||
arr_from = self.context.make_array(fromty)(context,
|
||||
builder,
|
||||
self.data)
|
||||
arr_to = self.context.make_array(toty)(context, builder)
|
||||
arrayobj.populate_array(
|
||||
arr_to,
|
||||
data=self._load_effective_address(indices),
|
||||
shape=cgutils.pack_array(builder, [one]),
|
||||
strides=cgutils.pack_array(builder, [itemsize]),
|
||||
itemsize=arr_from.itemsize,
|
||||
meminfo=arr_from.meminfo,
|
||||
parent=arr_from.parent)
|
||||
return arr_to._getvalue()
|
||||
else:
|
||||
# generic case
|
||||
# getitem n-dim array -> m-dim array, where N > M
|
||||
index_types = (types.int64,) * (self.ndim - self.inner_arr_ty.ndim)
|
||||
arrty = types.Array(self.base_type, self.ndim, self.layout)
|
||||
arr = self.context.make_array(arrty)(context, builder, self.data)
|
||||
res = _getitem_array_generic(context, builder,
|
||||
self.inner_arr_ty, arrty, arr,
|
||||
index_types, indices)
|
||||
# NOTE: don't call impl_ret_borrowed since the caller doesn't handle
|
||||
# references; but this is a borrow.
|
||||
return res
|
||||
|
||||
def guard_shape(self, loopshape):
|
||||
inner_ndim = self.inner_arr_ty.ndim
|
||||
def raise_impl(loop_shape, array_shape):
|
||||
# This would in fact be a test for broadcasting.
|
||||
# Broadcast would fail if, ignoring the core dimensions, the
|
||||
# remaining ones are different than indices given by loop shape.
|
||||
|
||||
remaining = len(array_shape) - inner_ndim
|
||||
_raise = (remaining > len(loop_shape))
|
||||
if not _raise:
|
||||
for i in range(remaining):
|
||||
_raise |= (array_shape[i] != loop_shape[i])
|
||||
if _raise:
|
||||
# Ideally we should call `np.broadcast_shapes` with loop and
|
||||
# array shapes. But since broadcasting is not supported here,
|
||||
# we just raise an error
|
||||
# TODO: check why raising a dynamic exception here fails
|
||||
raise ValueError('Loop and array shapes are incompatible')
|
||||
|
||||
context, builder = self.context, self.builder
|
||||
sig = types.none(
|
||||
types.UniTuple(types.intp, len(loopshape)),
|
||||
types.UniTuple(types.intp, len(self.shape)),
|
||||
)
|
||||
tup = (context.make_tuple(builder, sig.args[0], loopshape),
|
||||
context.make_tuple(builder, sig.args[1], self.shape))
|
||||
context.compile_internal(builder, raise_impl, sig, tup)
|
||||
|
||||
def guard_match_core_dims(self, other: '_ArrayGUHelper', ndims: int):
|
||||
# arguments with the same signature should match their core dimensions
|
||||
#
|
||||
# @guvectorize('(n,m), (n,m) -> (n)')
|
||||
# def foo(x, y, res):
|
||||
# ...
|
||||
#
|
||||
# x and y should have the same core (2D) dimensions
|
||||
def raise_impl(self_shape, other_shape):
|
||||
same = True
|
||||
a, b = len(self_shape) - ndims, len(other_shape) - ndims
|
||||
for i in range(ndims):
|
||||
same &= self_shape[a + i] == other_shape[b + i]
|
||||
if not same:
|
||||
# NumPy raises the following:
|
||||
# ValueError: gufunc: Input operand 1 has a mismatch in its
|
||||
# core dimension 0, with gufunc signature (n),(n) -> ()
|
||||
# (size 3 is different from 2)
|
||||
# But since we cannot raise a dynamic exception here, we just
|
||||
# (try) something meaninful
|
||||
msg = ('Operand has a mismatch in one of its core dimensions. '
|
||||
'Please, check if all arguments to a @guvectorize '
|
||||
'function have the same core dimensions.')
|
||||
raise ValueError(msg)
|
||||
|
||||
context, builder = self.context, self.builder
|
||||
sig = types.none(
|
||||
types.UniTuple(types.intp, len(self.shape)),
|
||||
types.UniTuple(types.intp, len(other.shape)),
|
||||
)
|
||||
tup = (context.make_tuple(builder, sig.args[0], self.shape),
|
||||
context.make_tuple(builder, sig.args[1], other.shape),)
|
||||
context.compile_internal(builder, raise_impl, sig, tup)
|
||||
|
||||
|
||||
def _prepare_argument(ctxt, bld, inp, tyinp, where='input operand'):
|
||||
"""returns an instance of the appropriate Helper (either
|
||||
_ScalarHelper or _ArrayHelper) class to handle the argument.
|
||||
using the polymorphic interface of the Helper classes, scalar
|
||||
and array cases can be handled with the same code"""
|
||||
|
||||
# first un-Optional Optionals
|
||||
if isinstance(tyinp, types.Optional):
|
||||
oty = tyinp
|
||||
tyinp = tyinp.type
|
||||
inp = ctxt.cast(bld, inp, oty, tyinp)
|
||||
|
||||
# then prepare the arg for a concrete instance
|
||||
if isinstance(tyinp, types.ArrayCompatible):
|
||||
ary = ctxt.make_array(tyinp)(ctxt, bld, inp)
|
||||
shape = cgutils.unpack_tuple(bld, ary.shape, tyinp.ndim)
|
||||
strides = cgutils.unpack_tuple(bld, ary.strides, tyinp.ndim)
|
||||
return _ArrayHelper(ctxt, bld, shape, strides, ary.data,
|
||||
tyinp.layout, tyinp.dtype, tyinp.ndim, inp)
|
||||
elif (types.unliteral(tyinp) in types.number_domain | {types.boolean}
|
||||
or isinstance(tyinp, types.scalars._NPDatetimeBase)):
|
||||
return _ScalarHelper(ctxt, bld, inp, tyinp)
|
||||
else:
|
||||
raise NotImplementedError('unsupported type for {0}: {1}'.format(where,
|
||||
str(tyinp)))
|
||||
|
||||
|
||||
_broadcast_onto_sig = types.intp(types.intp, types.CPointer(types.intp),
|
||||
types.intp, types.CPointer(types.intp))
|
||||
|
||||
def _broadcast_onto(src_ndim, src_shape, dest_ndim, dest_shape):
|
||||
'''Low-level utility function used in calculating a shape for
|
||||
an implicit output array. This function assumes that the
|
||||
destination shape is an LLVM pointer to a C-style array that was
|
||||
already initialized to a size of one along all axes.
|
||||
|
||||
Returns an integer value:
|
||||
>= 1 : Succeeded. Return value should equal the number of dimensions in
|
||||
the destination shape.
|
||||
0 : Failed to broadcast because source shape is larger than the
|
||||
destination shape (this case should be weeded out at type
|
||||
checking).
|
||||
< 0 : Failed to broadcast onto destination axis, at axis number ==
|
||||
-(return_value + 1).
|
||||
'''
|
||||
if src_ndim > dest_ndim:
|
||||
# This check should have been done during type checking, but
|
||||
# let's be defensive anyway...
|
||||
return 0
|
||||
else:
|
||||
src_index = 0
|
||||
dest_index = dest_ndim - src_ndim
|
||||
while src_index < src_ndim:
|
||||
src_dim_size = src_shape[src_index]
|
||||
dest_dim_size = dest_shape[dest_index]
|
||||
# Check to see if we've already mutated the destination
|
||||
# shape along this axis.
|
||||
if dest_dim_size != 1:
|
||||
# If we have mutated the destination shape already,
|
||||
# then the source axis size must either be one,
|
||||
# or the destination axis size.
|
||||
if src_dim_size != dest_dim_size and src_dim_size != 1:
|
||||
return -(dest_index + 1)
|
||||
elif src_dim_size != 1:
|
||||
# If the destination size is still its initial
|
||||
dest_shape[dest_index] = src_dim_size
|
||||
src_index += 1
|
||||
dest_index += 1
|
||||
return dest_index
|
||||
|
||||
def _build_array(context, builder, array_ty, input_types, inputs):
|
||||
"""Utility function to handle allocation of an implicit output array
|
||||
given the target context, builder, output array type, and a list of
|
||||
_ArrayHelper instances.
|
||||
"""
|
||||
# First, strip optional types, ufunc loops are typed on concrete types
|
||||
input_types = [x.type if isinstance(x, types.Optional) else x
|
||||
for x in input_types]
|
||||
|
||||
intp_ty = context.get_value_type(types.intp)
|
||||
def make_intp_const(val):
|
||||
return context.get_constant(types.intp, val)
|
||||
|
||||
ZERO = make_intp_const(0)
|
||||
ONE = make_intp_const(1)
|
||||
|
||||
src_shape = cgutils.alloca_once(builder, intp_ty, array_ty.ndim,
|
||||
"src_shape")
|
||||
dest_ndim = make_intp_const(array_ty.ndim)
|
||||
dest_shape = cgutils.alloca_once(builder, intp_ty, array_ty.ndim,
|
||||
"dest_shape")
|
||||
dest_shape_addrs = tuple(cgutils.gep_inbounds(builder, dest_shape, index)
|
||||
for index in range(array_ty.ndim))
|
||||
|
||||
# Initialize the destination shape with all ones.
|
||||
for dest_shape_addr in dest_shape_addrs:
|
||||
builder.store(ONE, dest_shape_addr)
|
||||
|
||||
# For each argument, try to broadcast onto the destination shape,
|
||||
# mutating along any axis where the argument shape is not one and
|
||||
# the destination shape is one.
|
||||
for arg_number, arg in enumerate(inputs):
|
||||
if not hasattr(arg, "ndim"): # Skip scalar arguments
|
||||
continue
|
||||
arg_ndim = make_intp_const(arg.ndim)
|
||||
for index in range(arg.ndim):
|
||||
builder.store(arg.shape[index],
|
||||
cgutils.gep_inbounds(builder, src_shape, index))
|
||||
arg_result = context.compile_internal(
|
||||
builder, _broadcast_onto, _broadcast_onto_sig,
|
||||
[arg_ndim, src_shape, dest_ndim, dest_shape])
|
||||
with cgutils.if_unlikely(builder,
|
||||
builder.icmp_signed('<', arg_result, ONE)):
|
||||
msg = "unable to broadcast argument %d to output array" % (
|
||||
arg_number,)
|
||||
|
||||
loc = errors.loc_info.get('loc', None)
|
||||
if loc is not None:
|
||||
msg += '\nFile "%s", line %d, ' % (loc.filename, loc.line)
|
||||
|
||||
context.call_conv.return_user_exc(builder, ValueError, (msg,))
|
||||
|
||||
real_array_ty = array_ty.as_array
|
||||
|
||||
dest_shape_tup = tuple(builder.load(dest_shape_addr)
|
||||
for dest_shape_addr in dest_shape_addrs)
|
||||
array_val = arrayobj._empty_nd_impl(context, builder, real_array_ty,
|
||||
dest_shape_tup)
|
||||
|
||||
# Get the best argument to call __array_wrap__ on
|
||||
array_wrapper_index = select_array_wrapper(input_types)
|
||||
array_wrapper_ty = input_types[array_wrapper_index]
|
||||
try:
|
||||
# __array_wrap__(source wrapped array, out array) -> out wrapped array
|
||||
array_wrap = context.get_function('__array_wrap__',
|
||||
array_ty(array_wrapper_ty, real_array_ty))
|
||||
except NotImplementedError:
|
||||
# If it's the same priority as a regular array, assume we
|
||||
# should use the allocated array unchanged.
|
||||
if array_wrapper_ty.array_priority != types.Array.array_priority:
|
||||
raise
|
||||
out_val = array_val._getvalue()
|
||||
else:
|
||||
wrap_args = (inputs[array_wrapper_index].return_val, array_val._getvalue())
|
||||
out_val = array_wrap(builder, wrap_args)
|
||||
|
||||
ndim = array_ty.ndim
|
||||
shape = cgutils.unpack_tuple(builder, array_val.shape, ndim)
|
||||
strides = cgutils.unpack_tuple(builder, array_val.strides, ndim)
|
||||
return _ArrayHelper(context, builder, shape, strides, array_val.data,
|
||||
array_ty.layout, array_ty.dtype, ndim,
|
||||
out_val)
|
||||
|
||||
# ufuncs either return a single result when nout == 1, else a tuple of results
|
||||
|
||||
def _unpack_output_types(ufunc, sig):
|
||||
if ufunc.nout == 1:
|
||||
return [sig.return_type]
|
||||
else:
|
||||
return list(sig.return_type)
|
||||
|
||||
|
||||
def _unpack_output_values(ufunc, builder, values):
|
||||
if ufunc.nout == 1:
|
||||
return [values]
|
||||
else:
|
||||
return cgutils.unpack_tuple(builder, values)
|
||||
|
||||
|
||||
def _pack_output_values(ufunc, context, builder, typ, values):
|
||||
if ufunc.nout == 1:
|
||||
return values[0]
|
||||
else:
|
||||
return context.make_tuple(builder, typ, values)
|
||||
|
||||
|
||||
def numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel_class):
|
||||
# This is the code generator that builds all the looping needed
|
||||
# to execute a numpy functions over several dimensions (including
|
||||
# scalar cases).
|
||||
#
|
||||
# context - the code generation context
|
||||
# builder - the code emitter
|
||||
# sig - signature of the ufunc
|
||||
# args - the args to the ufunc
|
||||
# ufunc - the ufunc itself
|
||||
# kernel_class - a code generating subclass of _Kernel that provides
|
||||
|
||||
arguments = [_prepare_argument(context, builder, arg, tyarg)
|
||||
for arg, tyarg in zip(args, sig.args)]
|
||||
|
||||
if len(arguments) < ufunc.nin:
|
||||
raise RuntimeError(
|
||||
"Not enough inputs to {}, expected {} got {}"
|
||||
.format(ufunc.__name__, ufunc.nin, len(arguments)))
|
||||
|
||||
for out_i, ret_ty in enumerate(_unpack_output_types(ufunc, sig)):
|
||||
if ufunc.nin + out_i >= len(arguments):
|
||||
# this out argument is not provided
|
||||
if isinstance(ret_ty, types.ArrayCompatible):
|
||||
output = _build_array(context, builder, ret_ty, sig.args, arguments)
|
||||
else:
|
||||
output = _prepare_argument(
|
||||
context, builder,
|
||||
ir.Constant(context.get_value_type(ret_ty), None), ret_ty)
|
||||
arguments.append(output)
|
||||
elif context.enable_nrt:
|
||||
# Incref the output
|
||||
context.nrt.incref(builder, ret_ty, args[ufunc.nin + out_i])
|
||||
|
||||
inputs = arguments[:ufunc.nin]
|
||||
outputs = arguments[ufunc.nin:]
|
||||
assert len(outputs) == ufunc.nout
|
||||
|
||||
outer_sig = _ufunc_loop_sig(
|
||||
[a.base_type for a in outputs],
|
||||
[a.base_type for a in inputs]
|
||||
)
|
||||
kernel = kernel_class(context, builder, outer_sig)
|
||||
intpty = context.get_value_type(types.intp)
|
||||
|
||||
indices = [inp.create_iter_indices() for inp in inputs]
|
||||
|
||||
# assume outputs are all the same size, which numpy requires
|
||||
|
||||
loopshape = outputs[0].shape
|
||||
|
||||
# count the number of C and F layout arrays, respectively
|
||||
input_layouts = [inp.layout for inp in inputs
|
||||
if isinstance(inp, _ArrayHelper)]
|
||||
num_c_layout = len([x for x in input_layouts if x == 'C'])
|
||||
num_f_layout = len([x for x in input_layouts if x == 'F'])
|
||||
|
||||
# Only choose F iteration order if more arrays are in F layout.
|
||||
# Default to C order otherwise.
|
||||
# This is a best effort for performance. NumPy has more fancy logic that
|
||||
# uses array iterators in non-trivial cases.
|
||||
if num_f_layout > num_c_layout:
|
||||
order = 'F'
|
||||
else:
|
||||
order = 'C'
|
||||
|
||||
with cgutils.loop_nest(builder, loopshape, intp=intpty, order=order) as loop_indices:
|
||||
vals_in = []
|
||||
for i, (index, arg) in enumerate(zip(indices, inputs)):
|
||||
index.update_indices(loop_indices, i)
|
||||
vals_in.append(arg.load_data(index.as_values()))
|
||||
|
||||
vals_out = _unpack_output_values(ufunc, builder, kernel.generate(*vals_in))
|
||||
for val_out, output in zip(vals_out, outputs):
|
||||
output.store_data(loop_indices, val_out)
|
||||
|
||||
out = _pack_output_values(ufunc, context, builder, sig.return_type, [o.return_val for o in outputs])
|
||||
return impl_ret_new_ref(context, builder, sig.return_type, out)
|
||||
|
||||
|
||||
def numpy_gufunc_kernel(context, builder, sig, args, ufunc, kernel_class):
|
||||
arguments = []
|
||||
expected_ndims = kernel_class.dufunc.expected_ndims()
|
||||
expected_ndims = expected_ndims[0] + expected_ndims[1]
|
||||
is_input = [True] * ufunc.nin + [False] * ufunc.nout
|
||||
for arg, ty, exp_ndim, is_inp in zip(args, sig.args, expected_ndims, is_input): # noqa: E501
|
||||
if isinstance(ty, types.ArrayCompatible):
|
||||
# Create an array helper that iteration returns a subarray
|
||||
# with ndim specified by "exp_ndim"
|
||||
arr = context.make_array(ty)(context, builder, arg)
|
||||
shape = cgutils.unpack_tuple(builder, arr.shape, ty.ndim)
|
||||
strides = cgutils.unpack_tuple(builder, arr.strides, ty.ndim)
|
||||
inner_arr_ty = ty.copy(ndim=exp_ndim)
|
||||
ndim = ty.ndim
|
||||
layout = ty.layout
|
||||
base_type = ty.dtype
|
||||
array_helper = _ArrayGUHelper(context, builder,
|
||||
shape, strides, arg,
|
||||
layout, base_type, ndim,
|
||||
inner_arr_ty, is_inp)
|
||||
arguments.append(array_helper)
|
||||
else:
|
||||
scalar_helper = _ScalarHelper(context, builder, arg, ty)
|
||||
arguments.append(scalar_helper)
|
||||
kernel = kernel_class(context, builder, sig)
|
||||
|
||||
layouts = [arg.layout for arg in arguments
|
||||
if isinstance(arg, _ArrayGUHelper)]
|
||||
num_c_layout = len([x for x in layouts if x == 'C'])
|
||||
num_f_layout = len([x for x in layouts if x == 'F'])
|
||||
|
||||
# Only choose F iteration order if more arrays are in F layout.
|
||||
# Default to C order otherwise.
|
||||
# This is a best effort for performance. NumPy has more fancy logic that
|
||||
# uses array iterators in non-trivial cases.
|
||||
if num_f_layout > num_c_layout:
|
||||
order = 'F'
|
||||
else:
|
||||
order = 'C'
|
||||
|
||||
outputs = arguments[ufunc.nin:]
|
||||
intpty = context.get_value_type(types.intp)
|
||||
indices = [inp.create_iter_indices() for inp in arguments]
|
||||
loopshape_ndim = outputs[0].ndim - outputs[0].inner_arr_ty.ndim
|
||||
loopshape = outputs[0].shape[ : loopshape_ndim]
|
||||
|
||||
_sig = parse_signature(ufunc.gufunc_builder.signature)
|
||||
for (idx_a, sig_a), (idx_b, sig_b) in itertools.combinations(
|
||||
zip(range(len(arguments)),
|
||||
_sig[0] + _sig[1]),
|
||||
r = 2
|
||||
):
|
||||
# For each pair of arguments, both inputs and outputs, must match their
|
||||
# inner dimensions if their signatures are the same.
|
||||
arg_a, arg_b = arguments[idx_a], arguments[idx_b]
|
||||
if sig_a == sig_b and \
|
||||
all(isinstance(x, _ArrayGUHelper) for x in (arg_a, arg_b)):
|
||||
arg_a, arg_b = arguments[idx_a], arguments[idx_b]
|
||||
arg_a.guard_match_core_dims(arg_b, len(sig_a))
|
||||
|
||||
for arg in arguments[:ufunc.nin]:
|
||||
if isinstance(arg, _ArrayGUHelper):
|
||||
arg.guard_shape(loopshape)
|
||||
|
||||
with cgutils.loop_nest(builder,
|
||||
loopshape,
|
||||
intp=intpty,
|
||||
order=order) as loop_indices:
|
||||
vals_in = []
|
||||
for i, (index, arg) in enumerate(zip(indices, arguments)):
|
||||
index.update_indices(loop_indices, i)
|
||||
vals_in.append(arg.load_data(index.as_values()))
|
||||
|
||||
kernel.generate(*vals_in)
|
||||
|
||||
|
||||
# Kernels are the code to be executed inside the multidimensional loop.
|
||||
class _Kernel(object):
|
||||
def __init__(self, context, builder, outer_sig):
|
||||
self.context = context
|
||||
self.builder = builder
|
||||
self.outer_sig = outer_sig
|
||||
|
||||
def cast(self, val, fromty, toty):
|
||||
"""Numpy uses cast semantics that are different from standard Python
|
||||
(for example, it does allow casting from complex to float).
|
||||
|
||||
This method acts as a patch to context.cast so that it allows
|
||||
complex to real/int casts.
|
||||
|
||||
"""
|
||||
if (isinstance(fromty, types.Complex) and
|
||||
not isinstance(toty, types.Complex)):
|
||||
# attempt conversion of the real part to the specified type.
|
||||
# note that NumPy issues a warning in this kind of conversions
|
||||
newty = fromty.underlying_float
|
||||
attr = self.context.get_getattr(fromty, 'real')
|
||||
val = attr(self.context, self.builder, fromty, val, 'real')
|
||||
fromty = newty
|
||||
# let the regular cast do the rest...
|
||||
|
||||
return self.context.cast(self.builder, val, fromty, toty)
|
||||
|
||||
def generate(self, *args):
|
||||
isig = self.inner_sig
|
||||
osig = self.outer_sig
|
||||
cast_args = [self.cast(val, inty, outty)
|
||||
for val, inty, outty in
|
||||
zip(args, osig.args, isig.args)]
|
||||
if self.cres.objectmode:
|
||||
func_type = self.context.call_conv.get_function_type(
|
||||
types.pyobject, [types.pyobject] * len(isig.args))
|
||||
else:
|
||||
func_type = self.context.call_conv.get_function_type(
|
||||
isig.return_type, isig.args)
|
||||
module = self.builder.block.function.module
|
||||
entry_point = cgutils.get_or_insert_function(
|
||||
module, func_type,
|
||||
self.cres.fndesc.llvm_func_name)
|
||||
entry_point.attributes.add("alwaysinline")
|
||||
|
||||
_, res = self.context.call_conv.call_function(
|
||||
self.builder, entry_point, isig.return_type, isig.args,
|
||||
cast_args)
|
||||
return self.cast(res, isig.return_type, osig.return_type)
|
||||
|
||||
|
||||
def _ufunc_db_function(ufunc):
|
||||
"""Use the ufunc loop type information to select the code generation
|
||||
function from the table provided by the dict_of_kernels. The dict
|
||||
of kernels maps the loop identifier to a function with the
|
||||
following signature: (context, builder, signature, args).
|
||||
|
||||
The loop type information has the form 'AB->C'. The letters to the
|
||||
left of '->' are the input types (specified as NumPy letter
|
||||
types). The letters to the right of '->' are the output
|
||||
types. There must be 'ufunc.nin' letters to the left of '->', and
|
||||
'ufunc.nout' letters to the right.
|
||||
|
||||
For example, a binary float loop resulting in a float, will have
|
||||
the following signature: 'ff->f'.
|
||||
|
||||
A given ufunc implements many loops. The list of loops implemented
|
||||
for a given ufunc can be accessed using the 'types' attribute in
|
||||
the ufunc object. The NumPy machinery selects the first loop that
|
||||
fits a given calling signature (in our case, what we call the
|
||||
outer_sig). This logic is mimicked by 'ufunc_find_matching_loop'.
|
||||
"""
|
||||
|
||||
class _KernelImpl(_Kernel):
|
||||
def __init__(self, context, builder, outer_sig):
|
||||
super(_KernelImpl, self).__init__(context, builder, outer_sig)
|
||||
loop = ufunc_find_matching_loop(
|
||||
ufunc, outer_sig.args + tuple(_unpack_output_types(ufunc, outer_sig)))
|
||||
self.fn = context.get_ufunc_info(ufunc).get(loop.ufunc_sig)
|
||||
self.inner_sig = _ufunc_loop_sig(loop.outputs, loop.inputs)
|
||||
|
||||
if self.fn is None:
|
||||
msg = "Don't know how to lower ufunc '{0}' for loop '{1}'"
|
||||
raise NotImplementedError(msg.format(ufunc.__name__, loop))
|
||||
|
||||
def generate(self, *args):
|
||||
isig = self.inner_sig
|
||||
osig = self.outer_sig
|
||||
|
||||
cast_args = [self.cast(val, inty, outty)
|
||||
for val, inty, outty in zip(args, osig.args,
|
||||
isig.args)]
|
||||
with force_error_model(self.context, 'numpy'):
|
||||
res = self.fn(self.context, self.builder, isig, cast_args)
|
||||
dmm = self.context.data_model_manager
|
||||
res = dmm[isig.return_type].from_return(self.builder, res)
|
||||
return self.cast(res, isig.return_type, osig.return_type)
|
||||
|
||||
return _KernelImpl
|
||||
|
||||
|
||||
################################################################################
|
||||
# Helper functions that register the ufuncs
|
||||
|
||||
def register_ufunc_kernel(ufunc, kernel, lower):
|
||||
def do_ufunc(context, builder, sig, args):
|
||||
return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel)
|
||||
|
||||
_any = types.Any
|
||||
in_args = (_any,) * ufunc.nin
|
||||
|
||||
# Add a lowering for each out argument that is missing.
|
||||
for n_explicit_out in range(ufunc.nout + 1):
|
||||
out_args = (types.Array,) * n_explicit_out
|
||||
lower(ufunc, *in_args, *out_args)(do_ufunc)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
def register_unary_operator_kernel(operator, ufunc, kernel, lower,
|
||||
inplace=False):
|
||||
assert not inplace # are there any inplace unary operators?
|
||||
def lower_unary_operator(context, builder, sig, args):
|
||||
return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel)
|
||||
_arr_kind = types.Array
|
||||
lower(operator, _arr_kind)(lower_unary_operator)
|
||||
|
||||
|
||||
def register_binary_operator_kernel(op, ufunc, kernel, lower, inplace=False):
|
||||
def lower_binary_operator(context, builder, sig, args):
|
||||
return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel)
|
||||
|
||||
def lower_inplace_operator(context, builder, sig, args):
|
||||
# The visible signature is (A, B) -> A
|
||||
# The implementation's signature (with explicit output)
|
||||
# is (A, B, A) -> A
|
||||
args = tuple(args) + (args[0],)
|
||||
sig = typing.signature(sig.return_type, *sig.args + (sig.args[0],))
|
||||
return numpy_ufunc_kernel(context, builder, sig, args, ufunc, kernel)
|
||||
|
||||
_any = types.Any
|
||||
_arr_kind = types.Array
|
||||
formal_sigs = [(_arr_kind, _arr_kind), (_any, _arr_kind), (_arr_kind, _any)]
|
||||
for sig in formal_sigs:
|
||||
if not inplace:
|
||||
lower(op, *sig)(lower_binary_operator)
|
||||
else:
|
||||
lower(op, *sig)(lower_inplace_operator)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Use the contents of ufunc_db to initialize the supported ufuncs
|
||||
|
||||
@registry.lower(operator.pos, types.Array)
|
||||
def array_positive_impl(context, builder, sig, args):
|
||||
'''Lowering function for +(array) expressions. Defined here
|
||||
(numba.targets.npyimpl) since the remaining array-operator
|
||||
lowering functions are also registered in this module.
|
||||
'''
|
||||
class _UnaryPositiveKernel(_Kernel):
|
||||
def generate(self, *args):
|
||||
[val] = args
|
||||
return val
|
||||
|
||||
return numpy_ufunc_kernel(context, builder, sig, args, np.positive,
|
||||
_UnaryPositiveKernel)
|
||||
|
||||
|
||||
def register_ufuncs(ufuncs, lower):
|
||||
kernels = {}
|
||||
for ufunc in ufuncs:
|
||||
db_func = _ufunc_db_function(ufunc)
|
||||
kernels[ufunc] = register_ufunc_kernel(ufunc, db_func, lower)
|
||||
|
||||
for _op_map in (npydecl.NumpyRulesUnaryArrayOperator._op_map,
|
||||
npydecl.NumpyRulesArrayOperator._op_map,
|
||||
):
|
||||
for operator, ufunc_name in _op_map.items():
|
||||
ufunc = getattr(np, ufunc_name)
|
||||
kernel = kernels[ufunc]
|
||||
if ufunc.nin == 1:
|
||||
register_unary_operator_kernel(operator, ufunc, kernel, lower)
|
||||
elif ufunc.nin == 2:
|
||||
register_binary_operator_kernel(operator, ufunc, kernel, lower)
|
||||
else:
|
||||
raise RuntimeError("There shouldn't be any non-unary or binary operators")
|
||||
|
||||
for _op_map in (npydecl.NumpyRulesInplaceArrayOperator._op_map,
|
||||
):
|
||||
for operator, ufunc_name in _op_map.items():
|
||||
ufunc = getattr(np, ufunc_name)
|
||||
kernel = kernels[ufunc]
|
||||
if ufunc.nin == 1:
|
||||
register_unary_operator_kernel(operator, ufunc, kernel, lower,
|
||||
inplace=True)
|
||||
elif ufunc.nin == 2:
|
||||
register_binary_operator_kernel(operator, ufunc, kernel, lower,
|
||||
inplace=True)
|
||||
else:
|
||||
raise RuntimeError("There shouldn't be any non-unary or binary operators")
|
||||
|
||||
|
||||
register_ufuncs(ufunc_db.get_ufuncs(), registry.lower)
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _make_dtype_object(typingctx, desc):
|
||||
"""Given a string or NumberClass description *desc*, returns the dtype object.
|
||||
"""
|
||||
def from_nb_type(nb_type):
|
||||
return_type = types.DType(nb_type)
|
||||
sig = return_type(desc)
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
# All dtype objects are dummy values in LLVM.
|
||||
# They only exist in the type level.
|
||||
return context.get_dummy_value()
|
||||
|
||||
return sig, codegen
|
||||
|
||||
if isinstance(desc, types.Literal):
|
||||
# Convert the str description into np.dtype then to numba type.
|
||||
nb_type = from_dtype(np.dtype(desc.literal_value))
|
||||
return from_nb_type(nb_type)
|
||||
elif isinstance(desc, types.functions.NumberClass):
|
||||
thestr = str(desc.dtype)
|
||||
# Convert the str description into np.dtype then to numba type.
|
||||
nb_type = from_dtype(np.dtype(thestr))
|
||||
return from_nb_type(nb_type)
|
||||
|
||||
@overload(np.dtype)
|
||||
def numpy_dtype(dtype):
|
||||
"""Provide an implementation so that numpy.dtype function can be lowered.
|
||||
"""
|
||||
if isinstance(dtype, (types.Literal, types.functions.NumberClass)):
|
||||
def imp(dtype):
|
||||
return _make_dtype_object(dtype)
|
||||
return imp
|
||||
else:
|
||||
raise errors.NumbaTypeError(
|
||||
'unknown dtype descriptor: {}'.format(dtype))
|
||||
@@ -0,0 +1,774 @@
|
||||
import collections
|
||||
import ctypes
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from numba.core import errors, types
|
||||
from numba.core.typing.templates import signature
|
||||
from numba.np import npdatetime_helpers
|
||||
from numba.core.errors import TypingError
|
||||
|
||||
# re-export
|
||||
from numba.core.cgutils import is_nonelike # noqa: F401
|
||||
|
||||
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:2]))
|
||||
|
||||
|
||||
FROM_DTYPE = {
|
||||
np.dtype('bool'): types.boolean,
|
||||
np.dtype('int8'): types.int8,
|
||||
np.dtype('int16'): types.int16,
|
||||
np.dtype('int32'): types.int32,
|
||||
np.dtype('int64'): types.int64,
|
||||
|
||||
np.dtype('uint8'): types.uint8,
|
||||
np.dtype('uint16'): types.uint16,
|
||||
np.dtype('uint32'): types.uint32,
|
||||
np.dtype('uint64'): types.uint64,
|
||||
|
||||
np.dtype('float32'): types.float32,
|
||||
np.dtype('float64'): types.float64,
|
||||
np.dtype('float16'): types.float16,
|
||||
np.dtype('complex64'): types.complex64,
|
||||
np.dtype('complex128'): types.complex128,
|
||||
|
||||
np.dtype(object): types.pyobject,
|
||||
}
|
||||
|
||||
|
||||
re_typestr = re.compile(r'[<>=\|]([a-z])(\d+)?$', re.I)
|
||||
re_datetimestr = re.compile(r'[<>=\|]([mM])8?(\[([a-z]+)\])?$', re.I)
|
||||
|
||||
sizeof_unicode_char = np.dtype('U1').itemsize
|
||||
|
||||
|
||||
def _from_str_dtype(dtype):
|
||||
m = re_typestr.match(dtype.str)
|
||||
if not m:
|
||||
raise errors.NumbaNotImplementedError(dtype)
|
||||
groups = m.groups()
|
||||
typecode = groups[0]
|
||||
if typecode == 'U':
|
||||
# unicode
|
||||
if dtype.byteorder not in '=|':
|
||||
raise errors.NumbaNotImplementedError("Does not support non-native "
|
||||
"byteorder")
|
||||
count = dtype.itemsize // sizeof_unicode_char
|
||||
assert count == int(groups[1]), "Unicode char size mismatch"
|
||||
return types.UnicodeCharSeq(count)
|
||||
|
||||
elif typecode == 'S':
|
||||
# char
|
||||
count = dtype.itemsize
|
||||
assert count == int(groups[1]), "Char size mismatch"
|
||||
return types.CharSeq(count)
|
||||
|
||||
else:
|
||||
raise errors.NumbaNotImplementedError(dtype)
|
||||
|
||||
|
||||
def _from_datetime_dtype(dtype):
|
||||
m = re_datetimestr.match(dtype.str)
|
||||
if not m:
|
||||
raise errors.NumbaNotImplementedError(dtype)
|
||||
groups = m.groups()
|
||||
typecode = groups[0]
|
||||
unit = groups[2] or ''
|
||||
if typecode == 'm':
|
||||
return types.NPTimedelta(unit)
|
||||
elif typecode == 'M':
|
||||
return types.NPDatetime(unit)
|
||||
else:
|
||||
raise errors.NumbaNotImplementedError(dtype)
|
||||
|
||||
|
||||
def from_dtype(dtype):
|
||||
"""
|
||||
Return a Numba Type instance corresponding to the given Numpy *dtype*.
|
||||
NumbaNotImplementedError is raised on unsupported Numpy dtypes.
|
||||
"""
|
||||
if type(dtype) is type and issubclass(dtype, np.generic):
|
||||
dtype = np.dtype(dtype)
|
||||
elif getattr(dtype, "fields", None) is not None:
|
||||
return from_struct_dtype(dtype)
|
||||
|
||||
try:
|
||||
return FROM_DTYPE[dtype]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
char = dtype.char
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if char in 'SU':
|
||||
return _from_str_dtype(dtype)
|
||||
if char in 'mM':
|
||||
return _from_datetime_dtype(dtype)
|
||||
if char in 'V' and dtype.subdtype is not None:
|
||||
subtype = from_dtype(dtype.subdtype[0])
|
||||
return types.NestedArray(subtype, dtype.shape)
|
||||
|
||||
raise errors.NumbaNotImplementedError(dtype)
|
||||
|
||||
|
||||
_as_dtype_letters = {
|
||||
types.NPDatetime: 'M8',
|
||||
types.NPTimedelta: 'm8',
|
||||
types.CharSeq: 'S',
|
||||
types.UnicodeCharSeq: 'U',
|
||||
}
|
||||
|
||||
|
||||
def as_dtype(nbtype):
|
||||
"""
|
||||
Return a numpy dtype instance corresponding to the given Numba type.
|
||||
NumbaNotImplementedError is if no correspondence is known.
|
||||
"""
|
||||
nbtype = types.unliteral(nbtype)
|
||||
if isinstance(nbtype, (types.Complex, types.Integer, types.Float)):
|
||||
return np.dtype(str(nbtype))
|
||||
if isinstance(nbtype, (types.Boolean)):
|
||||
return np.dtype('?')
|
||||
if isinstance(nbtype, (types.NPDatetime, types.NPTimedelta)):
|
||||
letter = _as_dtype_letters[type(nbtype)]
|
||||
if nbtype.unit:
|
||||
return np.dtype('%s[%s]' % (letter, nbtype.unit))
|
||||
else:
|
||||
return np.dtype(letter)
|
||||
if isinstance(nbtype, (types.CharSeq, types.UnicodeCharSeq)):
|
||||
letter = _as_dtype_letters[type(nbtype)]
|
||||
return np.dtype('%s%d' % (letter, nbtype.count))
|
||||
if isinstance(nbtype, types.Record):
|
||||
return as_struct_dtype(nbtype)
|
||||
if isinstance(nbtype, types.EnumMember):
|
||||
return as_dtype(nbtype.dtype)
|
||||
if isinstance(nbtype, types.npytypes.DType):
|
||||
return as_dtype(nbtype.dtype)
|
||||
if isinstance(nbtype, types.NumberClass):
|
||||
return as_dtype(nbtype.dtype)
|
||||
if isinstance(nbtype, types.NestedArray):
|
||||
spec = (as_dtype(nbtype.dtype), tuple(nbtype.shape))
|
||||
return np.dtype(spec)
|
||||
if isinstance(nbtype, types.PyObject):
|
||||
return np.dtype(object)
|
||||
|
||||
msg = f"{nbtype} cannot be represented as a NumPy dtype"
|
||||
raise errors.NumbaNotImplementedError(msg)
|
||||
|
||||
|
||||
def as_struct_dtype(rec):
|
||||
"""Convert Numba Record type to NumPy structured dtype
|
||||
"""
|
||||
assert isinstance(rec, types.Record)
|
||||
names = []
|
||||
formats = []
|
||||
offsets = []
|
||||
titles = []
|
||||
# Fill the fields if they are not a title.
|
||||
for k, t in rec.members:
|
||||
if not rec.is_title(k):
|
||||
names.append(k)
|
||||
formats.append(as_dtype(t))
|
||||
offsets.append(rec.offset(k))
|
||||
titles.append(rec.fields[k].title)
|
||||
|
||||
fields = {
|
||||
'names': names,
|
||||
'formats': formats,
|
||||
'offsets': offsets,
|
||||
'itemsize': rec.size,
|
||||
'titles': titles,
|
||||
}
|
||||
_check_struct_alignment(rec, fields)
|
||||
return np.dtype(fields, align=rec.aligned)
|
||||
|
||||
|
||||
def _check_struct_alignment(rec, fields):
|
||||
"""Check alignment compatibility with Numpy"""
|
||||
if rec.aligned:
|
||||
for k, dt in zip(fields['names'], fields['formats']):
|
||||
llvm_align = rec.alignof(k)
|
||||
npy_align = dt.alignment
|
||||
if llvm_align is not None and npy_align != llvm_align:
|
||||
msg = (
|
||||
'NumPy is using a different alignment ({}) '
|
||||
'than Numba/LLVM ({}) for {}. '
|
||||
'This is likely a NumPy bug.'
|
||||
)
|
||||
raise ValueError(msg.format(npy_align, llvm_align, dt))
|
||||
|
||||
|
||||
def map_arrayscalar_type(val):
|
||||
if isinstance(val, np.generic):
|
||||
# We can't blindly call np.dtype() as it loses information
|
||||
# on some types, e.g. datetime64 and timedelta64.
|
||||
dtype = val.dtype
|
||||
else:
|
||||
try:
|
||||
dtype = np.dtype(type(val))
|
||||
except TypeError:
|
||||
raise errors.NumbaNotImplementedError("no corresponding numpy "
|
||||
"dtype for %r" % type(val))
|
||||
return from_dtype(dtype)
|
||||
|
||||
|
||||
def is_array(val):
|
||||
return isinstance(val, np.ndarray)
|
||||
|
||||
|
||||
def map_layout(val):
|
||||
if val.flags['C_CONTIGUOUS']:
|
||||
layout = 'C'
|
||||
elif val.flags['F_CONTIGUOUS']:
|
||||
layout = 'F'
|
||||
else:
|
||||
layout = 'A'
|
||||
return layout
|
||||
|
||||
|
||||
def select_array_wrapper(inputs):
|
||||
"""
|
||||
Given the array-compatible input types to an operation (e.g. ufunc),
|
||||
select the appropriate input for wrapping the operation output,
|
||||
according to each input's __array_priority__.
|
||||
|
||||
An index into *inputs* is returned.
|
||||
"""
|
||||
max_prio = float('-inf')
|
||||
selected_index = None
|
||||
for index, ty in enumerate(inputs):
|
||||
# Ties are broken by choosing the first winner, as in Numpy
|
||||
if (isinstance(ty, types.ArrayCompatible) and
|
||||
ty.array_priority > max_prio):
|
||||
selected_index = index
|
||||
max_prio = ty.array_priority
|
||||
|
||||
assert selected_index is not None
|
||||
return selected_index
|
||||
|
||||
|
||||
def resolve_output_type(context, inputs, formal_output):
|
||||
"""
|
||||
Given the array-compatible input types to an operation (e.g. ufunc),
|
||||
and the operation's formal output type (a types.Array instance),
|
||||
resolve the actual output type using the typing *context*.
|
||||
|
||||
This uses a mechanism compatible with Numpy's __array_priority__ /
|
||||
__array_wrap__.
|
||||
"""
|
||||
selected_input = inputs[select_array_wrapper(inputs)]
|
||||
args = selected_input, formal_output
|
||||
sig = context.resolve_function_type('__array_wrap__', args, {})
|
||||
if sig is None:
|
||||
if selected_input.array_priority == types.Array.array_priority:
|
||||
# If it's the same priority as a regular array, assume we
|
||||
# should return the output unchanged.
|
||||
# (we can't define __array_wrap__ explicitly for types.Buffer,
|
||||
# as that would be inherited by most array-compatible objects)
|
||||
return formal_output
|
||||
raise errors.TypingError("__array_wrap__ failed for %s" % (args,))
|
||||
return sig.return_type
|
||||
|
||||
|
||||
def supported_ufunc_loop(ufunc, loop):
|
||||
"""Return whether the *loop* for the *ufunc* is supported -in nopython-.
|
||||
|
||||
*loop* should be a UFuncLoopSpec instance, and *ufunc* a numpy ufunc.
|
||||
|
||||
For ufuncs implemented using the ufunc_db, it is supported if the ufunc_db
|
||||
contains a lowering definition for 'loop' in the 'ufunc' entry.
|
||||
|
||||
For other ufuncs, it is type based. The loop will be considered valid if it
|
||||
only contains the following letter types: '?bBhHiIlLqQfd'. Note this is
|
||||
legacy and when implementing new ufuncs the ufunc_db should be preferred,
|
||||
as it allows for a more fine-grained incremental support.
|
||||
"""
|
||||
# NOTE: Assuming ufunc for the CPUContext
|
||||
from numba.np import ufunc_db
|
||||
loop_sig = loop.ufunc_sig
|
||||
try:
|
||||
# check if the loop has a codegen description in the
|
||||
# ufunc_db. If so, we can proceed.
|
||||
|
||||
# note that as of now not all ufuncs have an entry in the
|
||||
# ufunc_db
|
||||
supported_loop = loop_sig in ufunc_db.get_ufunc_info(ufunc)
|
||||
except KeyError:
|
||||
# for ufuncs not in ufunc_db, base the decision of whether the
|
||||
# loop is supported on its types
|
||||
loop_types = [x.char for x in loop.numpy_inputs + loop.numpy_outputs]
|
||||
supported_types = '?bBhHiIlLqQfd'
|
||||
# check if all the types involved in the ufunc loop are
|
||||
# supported in this mode
|
||||
supported_loop = all(t in supported_types for t in loop_types)
|
||||
|
||||
return supported_loop
|
||||
|
||||
|
||||
class UFuncLoopSpec(collections.namedtuple('_UFuncLoopSpec',
|
||||
('inputs', 'outputs', 'ufunc_sig'))):
|
||||
"""
|
||||
An object describing a ufunc loop's inner types. Properties:
|
||||
- inputs: the inputs' Numba types
|
||||
- outputs: the outputs' Numba types
|
||||
- ufunc_sig: the string representing the ufunc's type signature, in
|
||||
Numpy format (e.g. "ii->i")
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def numpy_inputs(self):
|
||||
return [as_dtype(x) for x in self.inputs]
|
||||
|
||||
@property
|
||||
def numpy_outputs(self):
|
||||
return [as_dtype(x) for x in self.outputs]
|
||||
|
||||
|
||||
def _ufunc_loop_sig(out_tys, in_tys):
|
||||
if len(out_tys) == 1:
|
||||
return signature(out_tys[0], *in_tys)
|
||||
else:
|
||||
return signature(types.Tuple(out_tys), *in_tys)
|
||||
|
||||
|
||||
def ufunc_can_cast(from_, to, has_mixed_inputs, casting='safe'):
|
||||
"""
|
||||
A variant of np.can_cast() that can allow casting any integer to
|
||||
any real or complex type, in case the operation has mixed-kind
|
||||
inputs.
|
||||
|
||||
For example we want `np.power(float32, int32)` to be computed using
|
||||
SP arithmetic and return `float32`.
|
||||
However, `np.sqrt(int32)` should use DP arithmetic and return `float64`.
|
||||
"""
|
||||
from_ = np.dtype(from_)
|
||||
to = np.dtype(to)
|
||||
if has_mixed_inputs and from_.kind in 'iu' and to.kind in 'cf':
|
||||
# Decide that all integers can cast to any real or complex type.
|
||||
return True
|
||||
return np.can_cast(from_, to, casting)
|
||||
|
||||
|
||||
def ufunc_find_matching_loop(ufunc, arg_types):
|
||||
"""Find the appropriate loop to be used for a ufunc based on the types
|
||||
of the operands
|
||||
|
||||
ufunc - The ufunc we want to check
|
||||
arg_types - The tuple of arguments to the ufunc, including any
|
||||
explicit output(s).
|
||||
return value - A UFuncLoopSpec identifying the loop, or None
|
||||
if no matching loop is found.
|
||||
"""
|
||||
|
||||
# Separate logical input from explicit output arguments
|
||||
input_types = arg_types[:ufunc.nin]
|
||||
output_types = arg_types[ufunc.nin:]
|
||||
assert (len(input_types) == ufunc.nin)
|
||||
|
||||
try:
|
||||
np_input_types = [as_dtype(x) for x in input_types]
|
||||
except errors.NumbaNotImplementedError:
|
||||
return None
|
||||
try:
|
||||
np_output_types = [as_dtype(x) for x in output_types]
|
||||
except errors.NumbaNotImplementedError:
|
||||
return None
|
||||
|
||||
# Whether the inputs are mixed integer / floating-point
|
||||
has_mixed_inputs = (
|
||||
any(dt.kind in 'iu' for dt in np_input_types) and
|
||||
any(dt.kind in 'cf' for dt in np_input_types))
|
||||
|
||||
def choose_types(numba_types, ufunc_letters):
|
||||
"""
|
||||
Return a list of Numba types representing *ufunc_letters*,
|
||||
except when the letter designates a datetime64 or timedelta64,
|
||||
in which case the type is taken from *numba_types*.
|
||||
"""
|
||||
assert len(ufunc_letters) >= len(numba_types)
|
||||
types = [tp if letter in 'mM' else from_dtype(np.dtype(letter))
|
||||
for tp, letter in zip(numba_types, ufunc_letters)]
|
||||
# Add missing types (presumably implicit outputs)
|
||||
types += [from_dtype(np.dtype(letter))
|
||||
for letter in ufunc_letters[len(numba_types):]]
|
||||
return types
|
||||
|
||||
def set_output_dt_units(inputs, outputs, ufunc_inputs, ufunc_name):
|
||||
"""
|
||||
Sets the output unit of a datetime type based on the input units
|
||||
|
||||
Timedelta is a special dtype that requires the time unit to be
|
||||
specified (day, month, etc). Not every operation with timedelta inputs
|
||||
leads to an output of timedelta output. However, for those that do,
|
||||
the unit of output must be inferred based on the units of the inputs.
|
||||
|
||||
At the moment this function takes care of two cases:
|
||||
a) where all inputs are timedelta with the same unit (mm), and
|
||||
therefore the output has the same unit.
|
||||
This case is used for arr.sum, and for arr1+arr2 where all arrays
|
||||
are timedeltas.
|
||||
If in the future this needs to be extended to a case with mixed units,
|
||||
the rules should be implemented in `npdatetime_helpers` and called
|
||||
from this function to set the correct output unit.
|
||||
b) where left operand is a timedelta, i.e. the "m?" case. This case
|
||||
is used for division, eg timedelta / int.
|
||||
|
||||
At the time of writing, Numba does not support addition of timedelta
|
||||
and other types, so this function does not consider the case "?m",
|
||||
i.e. where timedelta is the right operand to a non-timedelta left
|
||||
operand. To extend it in the future, just add another elif clause.
|
||||
"""
|
||||
def make_specific(outputs, unit):
|
||||
new_outputs = []
|
||||
for out in outputs:
|
||||
if isinstance(out, types.NPTimedelta) and out.unit == "":
|
||||
new_outputs.append(types.NPTimedelta(unit))
|
||||
else:
|
||||
new_outputs.append(out)
|
||||
return new_outputs
|
||||
|
||||
def make_datetime_specific(outputs, dt_unit, td_unit):
|
||||
new_outputs = []
|
||||
for out in outputs:
|
||||
if isinstance(out, types.NPDatetime) and out.unit == "":
|
||||
unit = npdatetime_helpers.combine_datetime_timedelta_units(
|
||||
dt_unit, td_unit)
|
||||
if unit is None:
|
||||
raise TypingError(f"ufunc '{ufunc_name}' is not " +
|
||||
"supported between " +
|
||||
f"datetime64[{dt_unit}] " +
|
||||
f"and timedelta64[{td_unit}]"
|
||||
)
|
||||
new_outputs.append(types.NPDatetime(unit))
|
||||
else:
|
||||
new_outputs.append(out)
|
||||
return new_outputs
|
||||
|
||||
if ufunc_inputs == 'mm':
|
||||
if all(inp.unit == inputs[0].unit for inp in inputs):
|
||||
# Case with operation on same units. Operations on different
|
||||
# units not adjusted for now but might need to be
|
||||
# added in the future
|
||||
unit = inputs[0].unit
|
||||
new_outputs = make_specific(outputs, unit)
|
||||
else:
|
||||
return outputs
|
||||
return new_outputs
|
||||
elif ufunc_inputs == 'mM':
|
||||
# case where the left operand has timedelta type
|
||||
# and the right operand has datetime
|
||||
td_unit = inputs[0].unit
|
||||
dt_unit = inputs[1].unit
|
||||
return make_datetime_specific(outputs, dt_unit, td_unit)
|
||||
|
||||
elif ufunc_inputs == 'Mm':
|
||||
# case where the right operand has timedelta type
|
||||
# and the left operand has datetime
|
||||
dt_unit = inputs[0].unit
|
||||
td_unit = inputs[1].unit
|
||||
return make_datetime_specific(outputs, dt_unit, td_unit)
|
||||
|
||||
elif ufunc_inputs[0] == 'm':
|
||||
# case where the left operand has timedelta type
|
||||
unit = inputs[0].unit
|
||||
new_outputs = make_specific(outputs, unit)
|
||||
return new_outputs
|
||||
|
||||
# In NumPy, the loops are evaluated from first to last. The first one
|
||||
# that is viable is the one used. One loop is viable if it is possible
|
||||
# to cast every input operand to the one expected by the ufunc.
|
||||
# Also under NumPy 1.10+ the output must be able to be cast back
|
||||
# to a close enough type ("same_kind").
|
||||
|
||||
for candidate in ufunc.types:
|
||||
ufunc_inputs = candidate[:ufunc.nin]
|
||||
ufunc_outputs = candidate[-ufunc.nout:] if ufunc.nout else []
|
||||
|
||||
if 'e' in ufunc_inputs:
|
||||
# Skip float16 arrays since we don't have implementation for them
|
||||
continue
|
||||
if 'O' in ufunc_inputs:
|
||||
# Skip object arrays
|
||||
continue
|
||||
found = True
|
||||
# Skip if any input or output argument is mismatching
|
||||
for outer, inner in zip(np_input_types, ufunc_inputs):
|
||||
# (outer is a dtype instance, inner is a type char)
|
||||
if outer.char in 'mM' or inner in 'mM':
|
||||
# For datetime64 and timedelta64, we want to retain
|
||||
# precise typing (i.e. the units); therefore we look for
|
||||
# an exact match.
|
||||
if outer.char != inner:
|
||||
found = False
|
||||
break
|
||||
elif not ufunc_can_cast(outer.char, inner,
|
||||
has_mixed_inputs, 'safe'):
|
||||
found = False
|
||||
break
|
||||
if found:
|
||||
# Can we cast the inner result to the outer result type?
|
||||
for outer, inner in zip(np_output_types, ufunc_outputs):
|
||||
if (outer.char not in 'mM' and not
|
||||
ufunc_can_cast(inner, outer.char,
|
||||
has_mixed_inputs, 'same_kind')):
|
||||
found = False
|
||||
break
|
||||
if found:
|
||||
# Found: determine the Numba types for the loop's inputs and
|
||||
# outputs.
|
||||
try:
|
||||
inputs = choose_types(input_types, ufunc_inputs)
|
||||
outputs = choose_types(output_types, ufunc_outputs)
|
||||
# if the left operand or both are timedeltas, or the first
|
||||
# argument is datetime and the second argument is timedelta,
|
||||
# then the output units need to be determined.
|
||||
if ufunc_inputs[0] == 'm' or ufunc_inputs == 'Mm':
|
||||
outputs = set_output_dt_units(inputs, outputs,
|
||||
ufunc_inputs, ufunc.__name__)
|
||||
|
||||
except errors.NumbaNotImplementedError:
|
||||
# One of the selected dtypes isn't supported by Numba
|
||||
# (e.g. float16), try other candidates
|
||||
continue
|
||||
else:
|
||||
return UFuncLoopSpec(inputs, outputs, candidate)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_aligned_struct(struct):
|
||||
return struct.isalignedstruct
|
||||
|
||||
|
||||
def from_struct_dtype(dtype):
|
||||
"""Convert a NumPy structured dtype to Numba Record type
|
||||
"""
|
||||
if dtype.hasobject:
|
||||
msg = "dtypes that contain object are not supported."
|
||||
raise errors.NumbaNotImplementedError(msg)
|
||||
|
||||
fields = []
|
||||
for name, info in dtype.fields.items():
|
||||
# *info* may have 3 element
|
||||
[elemdtype, offset] = info[:2]
|
||||
title = info[2] if len(info) == 3 else None
|
||||
|
||||
ty = from_dtype(elemdtype)
|
||||
infos = {
|
||||
'type': ty,
|
||||
'offset': offset,
|
||||
'title': title,
|
||||
}
|
||||
fields.append((name, infos))
|
||||
|
||||
# Note: dtype.alignment is not consistent.
|
||||
# It is different after passing into a recarray.
|
||||
# recarray(N, dtype=mydtype).dtype.alignment != mydtype.alignment
|
||||
size = dtype.itemsize
|
||||
aligned = _is_aligned_struct(dtype)
|
||||
|
||||
return types.Record(fields, size, aligned)
|
||||
|
||||
|
||||
def _get_bytes_buffer(ptr, nbytes):
|
||||
"""
|
||||
Get a ctypes array of *nbytes* starting at *ptr*.
|
||||
"""
|
||||
if isinstance(ptr, ctypes.c_void_p):
|
||||
ptr = ptr.value
|
||||
arrty = ctypes.c_byte * nbytes
|
||||
return arrty.from_address(ptr)
|
||||
|
||||
|
||||
def _get_array_from_ptr(ptr, nbytes, dtype):
|
||||
return np.frombuffer(_get_bytes_buffer(ptr, nbytes), dtype)
|
||||
|
||||
|
||||
def carray(ptr, shape, dtype=None):
|
||||
"""
|
||||
Return a Numpy array view over the data pointed to by *ptr* with the
|
||||
given *shape*, in C order. If *dtype* is given, it is used as the
|
||||
array's dtype, otherwise the array's dtype is inferred from *ptr*'s type.
|
||||
"""
|
||||
from numba.core.typing.ctypes_utils import from_ctypes
|
||||
|
||||
try:
|
||||
# Use ctypes parameter protocol if available
|
||||
ptr = ptr._as_parameter_
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Normalize dtype, to accept e.g. "int64" or np.int64
|
||||
if dtype is not None:
|
||||
dtype = np.dtype(dtype)
|
||||
|
||||
if isinstance(ptr, ctypes.c_void_p):
|
||||
if dtype is None:
|
||||
raise TypeError("explicit dtype required for void* argument")
|
||||
p = ptr
|
||||
elif isinstance(ptr, ctypes._Pointer):
|
||||
ptrty = from_ctypes(ptr.__class__)
|
||||
assert isinstance(ptrty, types.CPointer)
|
||||
ptr_dtype = as_dtype(ptrty.dtype)
|
||||
if dtype is not None and dtype != ptr_dtype:
|
||||
raise TypeError("mismatching dtype '%s' for pointer %s"
|
||||
% (dtype, ptr))
|
||||
dtype = ptr_dtype
|
||||
p = ctypes.cast(ptr, ctypes.c_void_p)
|
||||
else:
|
||||
raise TypeError("expected a ctypes pointer, got %r" % (ptr,))
|
||||
|
||||
nbytes = dtype.itemsize * np.prod(shape, dtype=np.intp)
|
||||
return _get_array_from_ptr(p, nbytes, dtype).reshape(shape)
|
||||
|
||||
|
||||
def farray(ptr, shape, dtype=None):
|
||||
"""
|
||||
Return a Numpy array view over the data pointed to by *ptr* with the
|
||||
given *shape*, in Fortran order. If *dtype* is given, it is used as the
|
||||
array's dtype, otherwise the array's dtype is inferred from *ptr*'s type.
|
||||
"""
|
||||
if not isinstance(shape, int):
|
||||
shape = shape[::-1]
|
||||
return carray(ptr, shape, dtype).T
|
||||
|
||||
|
||||
def is_contiguous(dims, strides, itemsize):
|
||||
"""Is the given shape, strides, and itemsize of C layout?
|
||||
|
||||
Note: The code is usable as a numba-compiled function
|
||||
"""
|
||||
nd = len(dims)
|
||||
# Check and skip 1s or 0s in inner dims
|
||||
innerax = nd - 1
|
||||
while innerax > -1 and dims[innerax] <= 1:
|
||||
innerax -= 1
|
||||
|
||||
# Early exit if all axis are 1s or 0s
|
||||
if innerax < 0:
|
||||
return True
|
||||
|
||||
# Check itemsize matches innermost stride
|
||||
if itemsize != strides[innerax]:
|
||||
return False
|
||||
|
||||
# Check and skip 1s or 0s in outer dims
|
||||
outerax = 0
|
||||
while outerax < innerax and dims[outerax] <= 1:
|
||||
outerax += 1
|
||||
|
||||
# Check remaining strides to be contiguous
|
||||
ax = innerax
|
||||
while ax > outerax:
|
||||
if strides[ax] * dims[ax] != strides[ax - 1]:
|
||||
return False
|
||||
ax -= 1
|
||||
return True
|
||||
|
||||
|
||||
def is_fortran(dims, strides, itemsize):
|
||||
"""Is the given shape, strides, and itemsize of F layout?
|
||||
|
||||
Note: The code is usable as a numba-compiled function
|
||||
"""
|
||||
nd = len(dims)
|
||||
# Check and skip 1s or 0s in inner dims
|
||||
firstax = 0
|
||||
while firstax < nd and dims[firstax] <= 1:
|
||||
firstax += 1
|
||||
|
||||
# Early exit if all axis are 1s or 0s
|
||||
if firstax >= nd:
|
||||
return True
|
||||
|
||||
# Check itemsize matches innermost stride
|
||||
if itemsize != strides[firstax]:
|
||||
return False
|
||||
|
||||
# Check and skip 1s or 0s in outer dims
|
||||
lastax = nd - 1
|
||||
while lastax > firstax and dims[lastax] <= 1:
|
||||
lastax -= 1
|
||||
|
||||
# Check remaining strides to be contiguous
|
||||
ax = firstax
|
||||
while ax < lastax:
|
||||
if strides[ax] * dims[ax] != strides[ax + 1]:
|
||||
return False
|
||||
ax += 1
|
||||
return True
|
||||
|
||||
|
||||
def type_can_asarray(arr):
|
||||
""" Returns True if the type of 'arr' is supported by the Numba `np.asarray`
|
||||
implementation, False otherwise.
|
||||
"""
|
||||
|
||||
ok = (types.Array, types.Sequence, types.Tuple, types.StringLiteral,
|
||||
types.Number, types.Boolean, types.containers.ListType)
|
||||
|
||||
return isinstance(arr, ok)
|
||||
|
||||
|
||||
def type_is_scalar(typ):
|
||||
""" Returns True if the type of 'typ' is a scalar type, according to
|
||||
NumPy rules. False otherwise.
|
||||
https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types
|
||||
"""
|
||||
|
||||
ok = (types.Boolean, types.Number, types.UnicodeType, types.StringLiteral,
|
||||
types.NPTimedelta, types.NPDatetime)
|
||||
return isinstance(typ, ok)
|
||||
|
||||
|
||||
def check_is_integer(v, name):
|
||||
"""Raises TypingError if the value is not an integer."""
|
||||
if not isinstance(v, (int, types.Integer)):
|
||||
raise TypingError('{} must be an integer'.format(name))
|
||||
|
||||
|
||||
def lt_floats(a, b):
|
||||
# Adapted from NumPy commit 717c7acf which introduced the behavior of
|
||||
# putting NaNs at the end.
|
||||
# The code is later moved to numpy/core/src/npysort/npysort_common.h
|
||||
# This info is gathered as of NumPy commit d8c09c50
|
||||
return a < b or (np.isnan(b) and not np.isnan(a))
|
||||
|
||||
|
||||
def lt_complex(a, b):
|
||||
if np.isnan(a.real):
|
||||
if np.isnan(b.real):
|
||||
if np.isnan(a.imag):
|
||||
return False
|
||||
else:
|
||||
if np.isnan(b.imag):
|
||||
return True
|
||||
else:
|
||||
return a.imag < b.imag
|
||||
else:
|
||||
return False
|
||||
|
||||
else:
|
||||
if np.isnan(b.real):
|
||||
return True
|
||||
else:
|
||||
if np.isnan(a.imag):
|
||||
if np.isnan(b.imag):
|
||||
return a.real < b.real
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if np.isnan(b.imag):
|
||||
return True
|
||||
else:
|
||||
if a.real < b.real:
|
||||
return True
|
||||
elif a.real == b.real:
|
||||
return a.imag < b.imag
|
||||
return False
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,223 @@
|
||||
from numba.extending import (models, register_model, type_callable,
|
||||
unbox, NativeValue, make_attribute_wrapper, box,
|
||||
lower_builtin)
|
||||
from numba.core import types, cgutils
|
||||
import warnings
|
||||
from numba.core.errors import NumbaExperimentalFeatureWarning, NumbaValueError
|
||||
from numpy.polynomial.polynomial import Polynomial
|
||||
from contextlib import ExitStack
|
||||
import numpy as np
|
||||
from llvmlite import ir
|
||||
|
||||
|
||||
@register_model(types.PolynomialType)
|
||||
class PolynomialModel(models.StructModel):
|
||||
def __init__(self, dmm, fe_type):
|
||||
members = [
|
||||
('coef', fe_type.coef),
|
||||
('domain', fe_type.domain),
|
||||
('window', fe_type.window)
|
||||
# Introduced in NumPy 1.24, maybe leave it out for now
|
||||
# ('symbol', types.string)
|
||||
]
|
||||
super(PolynomialModel, self).__init__(dmm, fe_type, members)
|
||||
|
||||
|
||||
@type_callable(Polynomial)
|
||||
def type_polynomial(context):
|
||||
def typer(coef, domain=None, window=None):
|
||||
default_domain = types.Array(types.int64, 1, 'C')
|
||||
double_domain = types.Array(types.double, 1, 'C')
|
||||
default_window = types.Array(types.int64, 1, 'C')
|
||||
double_window = types.Array(types.double, 1, 'C')
|
||||
double_coef = types.Array(types.double, 1, 'C')
|
||||
|
||||
warnings.warn("Polynomial class is experimental",
|
||||
category=NumbaExperimentalFeatureWarning)
|
||||
|
||||
if isinstance(coef, types.Array) and \
|
||||
all([a is None for a in (domain, window)]):
|
||||
if coef.ndim == 1:
|
||||
# If Polynomial(coef) is called, coef is cast to double dtype,
|
||||
# and domain and window are set to equal [-1, 1], i.e. have
|
||||
# integer dtype
|
||||
return types.PolynomialType(double_coef,
|
||||
default_domain,
|
||||
default_window,
|
||||
1)
|
||||
else:
|
||||
msg = 'Coefficient array is not 1-d'
|
||||
raise NumbaValueError(msg)
|
||||
elif all([isinstance(a, types.Array) for a in (coef, domain, window)]):
|
||||
if coef.ndim == 1:
|
||||
if all([a.ndim == 1 for a in (domain, window)]):
|
||||
# If Polynomial(coef, domain, window) is called, then coef,
|
||||
# domain and window are cast to double dtype
|
||||
return types.PolynomialType(double_coef,
|
||||
double_domain,
|
||||
double_window,
|
||||
3)
|
||||
else:
|
||||
msg = 'Coefficient array is not 1-d'
|
||||
raise NumbaValueError(msg)
|
||||
return typer
|
||||
|
||||
|
||||
make_attribute_wrapper(types.PolynomialType, 'coef', 'coef')
|
||||
make_attribute_wrapper(types.PolynomialType, 'domain', 'domain')
|
||||
make_attribute_wrapper(types.PolynomialType, 'window', 'window')
|
||||
# Introduced in NumPy 1.24, maybe leave it out for now
|
||||
# make_attribute_wrapper(types.PolynomialType, 'symbol', 'symbol')
|
||||
|
||||
|
||||
@lower_builtin(Polynomial, types.Array)
|
||||
def impl_polynomial1(context, builder, sig, args):
|
||||
|
||||
def to_double(arr):
|
||||
return np.asarray(arr, dtype=np.double)
|
||||
|
||||
def const_impl():
|
||||
return np.asarray([-1, 1])
|
||||
|
||||
typ = sig.return_type
|
||||
polynomial = cgutils.create_struct_proxy(typ)(context, builder)
|
||||
sig_coef = sig.args[0].copy(dtype=types.double)(sig.args[0])
|
||||
coef_cast = context.compile_internal(builder, to_double, sig_coef, args)
|
||||
sig_domain = sig.args[0].copy(dtype=types.intp)()
|
||||
sig_window = sig.args[0].copy(dtype=types.intp)()
|
||||
domain_cast = context.compile_internal(builder, const_impl, sig_domain, ())
|
||||
window_cast = context.compile_internal(builder, const_impl, sig_window, ())
|
||||
polynomial.coef = coef_cast
|
||||
polynomial.domain = domain_cast
|
||||
polynomial.window = window_cast
|
||||
|
||||
return polynomial._getvalue()
|
||||
|
||||
|
||||
@lower_builtin(Polynomial, types.Array, types.Array, types.Array)
|
||||
def impl_polynomial3(context, builder, sig, args):
|
||||
|
||||
def to_double(coef):
|
||||
return np.asarray(coef, dtype=np.double)
|
||||
|
||||
typ = sig.return_type
|
||||
polynomial = cgutils.create_struct_proxy(typ)(context, builder)
|
||||
|
||||
coef_sig = sig.args[0].copy(dtype=types.double)(sig.args[0])
|
||||
domain_sig = sig.args[1].copy(dtype=types.double)(sig.args[1])
|
||||
window_sig = sig.args[2].copy(dtype=types.double)(sig.args[2])
|
||||
coef_cast = context.compile_internal(builder,
|
||||
to_double, coef_sig,
|
||||
(args[0],))
|
||||
domain_cast = context.compile_internal(builder,
|
||||
to_double, domain_sig,
|
||||
(args[1],))
|
||||
window_cast = context.compile_internal(builder,
|
||||
to_double, window_sig,
|
||||
(args[2],))
|
||||
|
||||
domain_helper = context.make_helper(builder,
|
||||
domain_sig.return_type,
|
||||
value=domain_cast)
|
||||
window_helper = context.make_helper(builder,
|
||||
window_sig.return_type,
|
||||
value=window_cast)
|
||||
|
||||
i64 = ir.IntType(64)
|
||||
two = i64(2)
|
||||
|
||||
s1 = builder.extract_value(domain_helper.shape, 0)
|
||||
s2 = builder.extract_value(window_helper.shape, 0)
|
||||
pred1 = builder.icmp_signed('!=', s1, two)
|
||||
pred2 = builder.icmp_signed('!=', s2, two)
|
||||
|
||||
with cgutils.if_unlikely(builder, pred1):
|
||||
context.call_conv.return_user_exc(
|
||||
builder, ValueError,
|
||||
("Domain has wrong number of elements.",))
|
||||
|
||||
with cgutils.if_unlikely(builder, pred2):
|
||||
context.call_conv.return_user_exc(
|
||||
builder, ValueError,
|
||||
("Window has wrong number of elements.",))
|
||||
|
||||
polynomial.coef = coef_cast
|
||||
polynomial.domain = domain_helper._getvalue()
|
||||
polynomial.window = window_helper._getvalue()
|
||||
|
||||
return polynomial._getvalue()
|
||||
|
||||
|
||||
@unbox(types.PolynomialType)
|
||||
def unbox_polynomial(typ, obj, c):
|
||||
"""
|
||||
Convert a Polynomial object to a native polynomial structure.
|
||||
"""
|
||||
is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
|
||||
polynomial = cgutils.create_struct_proxy(typ)(c.context, c.builder)
|
||||
with ExitStack() as stack:
|
||||
natives = []
|
||||
for name in ("coef", "domain", "window"):
|
||||
attr = c.pyapi.object_getattr_string(obj, name)
|
||||
with cgutils.early_exit_if_null(c.builder, stack, attr):
|
||||
c.builder.store(cgutils.true_bit, is_error_ptr)
|
||||
t = getattr(typ, name)
|
||||
native = c.unbox(t, attr)
|
||||
c.pyapi.decref(attr)
|
||||
with cgutils.early_exit_if(c.builder, stack, native.is_error):
|
||||
c.builder.store(cgutils.true_bit, is_error_ptr)
|
||||
natives.append(native)
|
||||
|
||||
polynomial.coef = natives[0]
|
||||
polynomial.domain = natives[1]
|
||||
polynomial.window = natives[2]
|
||||
|
||||
return NativeValue(polynomial._getvalue(),
|
||||
is_error=c.builder.load(is_error_ptr))
|
||||
|
||||
|
||||
@box(types.PolynomialType)
|
||||
def box_polynomial(typ, val, c):
|
||||
"""
|
||||
Convert a native polynomial structure to a Polynomial object.
|
||||
"""
|
||||
ret_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
|
||||
fail_obj = c.pyapi.get_null_object()
|
||||
|
||||
with ExitStack() as stack:
|
||||
polynomial = cgutils.create_struct_proxy(typ)(c.context, c.builder,
|
||||
value=val)
|
||||
coef_obj = c.box(typ.coef, polynomial.coef)
|
||||
with cgutils.early_exit_if_null(c.builder, stack, coef_obj):
|
||||
c.builder.store(fail_obj, ret_ptr)
|
||||
|
||||
domain_obj = c.box(typ.domain, polynomial.domain)
|
||||
with cgutils.early_exit_if_null(c.builder, stack, domain_obj):
|
||||
c.builder.store(fail_obj, ret_ptr)
|
||||
|
||||
window_obj = c.box(typ.window, polynomial.window)
|
||||
with cgutils.early_exit_if_null(c.builder, stack, window_obj):
|
||||
c.builder.store(fail_obj, ret_ptr)
|
||||
|
||||
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Polynomial))
|
||||
with cgutils.early_exit_if_null(c.builder, stack, class_obj):
|
||||
c.pyapi.decref(coef_obj)
|
||||
c.pyapi.decref(domain_obj)
|
||||
c.pyapi.decref(window_obj)
|
||||
c.builder.store(fail_obj, ret_ptr)
|
||||
|
||||
if typ.n_args == 1:
|
||||
res1 = c.pyapi.call_function_objargs(class_obj, (coef_obj,))
|
||||
c.builder.store(res1, ret_ptr)
|
||||
else:
|
||||
res3 = c.pyapi.call_function_objargs(class_obj, (coef_obj,
|
||||
domain_obj,
|
||||
window_obj))
|
||||
c.builder.store(res3, ret_ptr)
|
||||
|
||||
c.pyapi.decref(coef_obj)
|
||||
c.pyapi.decref(domain_obj)
|
||||
c.pyapi.decref(window_obj)
|
||||
c.pyapi.decref(class_obj)
|
||||
|
||||
return c.builder.load(ret_ptr)
|
||||
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
Implementation of operations involving polynomials.
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
from numpy.polynomial import polynomial as poly
|
||||
from numpy.polynomial import polyutils as pu
|
||||
|
||||
from numba import literal_unroll
|
||||
from numba.core import types, errors
|
||||
from numba.core.extending import overload
|
||||
from numba.np.numpy_support import type_can_asarray, as_dtype, from_dtype
|
||||
|
||||
|
||||
@overload(np.roots)
|
||||
def roots_impl(p):
|
||||
|
||||
# cast int vectors to float cf. numpy, this is a bit dicey as
|
||||
# the roots could be complex which will fail anyway
|
||||
ty = getattr(p, 'dtype', p)
|
||||
if isinstance(ty, types.Integer):
|
||||
cast_t = np.float64
|
||||
else:
|
||||
cast_t = as_dtype(ty)
|
||||
|
||||
def roots_impl(p):
|
||||
# impl based on numpy:
|
||||
# https://github.com/numpy/numpy/blob/master/numpy/lib/polynomial.py
|
||||
|
||||
if len(p.shape) != 1:
|
||||
raise ValueError("Input must be a 1d array.")
|
||||
|
||||
non_zero = np.nonzero(p)[0]
|
||||
|
||||
if len(non_zero) == 0:
|
||||
return np.zeros(0, dtype=cast_t)
|
||||
|
||||
tz = len(p) - non_zero[-1] - 1
|
||||
|
||||
# pull out the coeffs selecting between possible zero pads
|
||||
p = p[int(non_zero[0]):int(non_zero[-1]) + 1]
|
||||
|
||||
n = len(p)
|
||||
if n > 1:
|
||||
# construct companion matrix, ensure fortran order
|
||||
# to give to eigvals, write to upper diag and then
|
||||
# transpose.
|
||||
A = np.diag(np.ones((n - 2,), cast_t), 1).T
|
||||
A[0, :] = -p[1:] / p[0] # normalize
|
||||
roots = np.linalg.eigvals(A)
|
||||
else:
|
||||
roots = np.zeros(0, dtype=cast_t)
|
||||
|
||||
# add in additional zeros on the end if needed
|
||||
if tz > 0:
|
||||
return np.hstack((roots, np.zeros(tz, dtype=cast_t)))
|
||||
else:
|
||||
return roots
|
||||
|
||||
return roots_impl
|
||||
|
||||
|
||||
@overload(pu.trimseq)
|
||||
def polyutils_trimseq(seq):
|
||||
if not type_can_asarray(seq):
|
||||
msg = 'The argument "seq" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if isinstance(seq, types.BaseTuple):
|
||||
msg = 'Unsupported type %r for argument "seq"'
|
||||
raise errors.TypingError(msg % (seq))
|
||||
|
||||
if np.ndim(seq) > 1:
|
||||
msg = 'Coefficient array is not 1-d'
|
||||
raise errors.NumbaValueError(msg)
|
||||
|
||||
def impl(seq):
|
||||
if len(seq) == 0:
|
||||
return seq
|
||||
else:
|
||||
for i in range(len(seq) - 1, -1, -1):
|
||||
if seq[i] != 0:
|
||||
break
|
||||
return seq[:i + 1]
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
@overload(pu.as_series)
|
||||
def polyutils_as_series(alist, trim=True):
|
||||
if not type_can_asarray(alist):
|
||||
msg = 'The argument "alist" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not isinstance(trim, (bool, types.Boolean)):
|
||||
msg = 'The argument "trim" must be boolean'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
res_dtype = np.float64
|
||||
|
||||
tuple_input = isinstance(alist, types.BaseTuple)
|
||||
list_input = isinstance(alist, types.List)
|
||||
if tuple_input:
|
||||
if np.any(np.array([np.ndim(a) > 1 for a in alist])):
|
||||
raise errors.NumbaValueError("Coefficient array is not 1-d")
|
||||
|
||||
res_dtype = _poly_result_dtype(*alist)
|
||||
|
||||
elif list_input:
|
||||
dt = as_dtype(_get_list_type(alist))
|
||||
res_dtype = np.result_type(dt, np.float64)
|
||||
|
||||
else:
|
||||
if np.ndim(alist) <= 2:
|
||||
res_dtype = np.result_type(res_dtype, as_dtype(alist.dtype))
|
||||
else:
|
||||
# If total dimension has ndim > 2, then coeff arrays are not 1D
|
||||
raise errors.NumbaValueError("Coefficient array is not 1-d")
|
||||
|
||||
def impl(alist, trim=True):
|
||||
if tuple_input:
|
||||
arrays = []
|
||||
for item in literal_unroll(alist):
|
||||
arrays.append(np.atleast_1d(np.asarray(item)).astype(res_dtype))
|
||||
|
||||
elif list_input:
|
||||
arrays = [np.atleast_1d(np.asarray(a)).astype(res_dtype)
|
||||
for a in alist]
|
||||
|
||||
else:
|
||||
alist_arr = np.asarray(alist)
|
||||
arrays = [np.atleast_1d(np.asarray(a)).astype(res_dtype)
|
||||
for a in alist_arr]
|
||||
|
||||
if min([a.size for a in arrays]) == 0:
|
||||
raise ValueError("Coefficient array is empty")
|
||||
|
||||
if trim:
|
||||
arrays = [pu.trimseq(a) for a in arrays]
|
||||
|
||||
ret = arrays
|
||||
return ret
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
def _get_list_type(l):
|
||||
# A helper function that takes a list (possibly nested) and returns its
|
||||
# dtype. Returns a Numba type.
|
||||
dt = l.dtype
|
||||
if (not isinstance(dt, types.Number)) and type_can_asarray(dt):
|
||||
return _get_list_type(dt)
|
||||
else:
|
||||
return dt
|
||||
|
||||
|
||||
def _poly_result_dtype(*args):
|
||||
# A helper function that takes a tuple of inputs and returns their result
|
||||
# dtype. Used for poly functions. Returns a NumPy dtype.
|
||||
res_dtype = np.float64
|
||||
for item in args:
|
||||
if isinstance(item, types.BaseTuple):
|
||||
s1 = item.types
|
||||
elif isinstance(item, types.List):
|
||||
s1 = [_get_list_type(item)]
|
||||
elif isinstance(item, types.Number):
|
||||
s1 = [item]
|
||||
elif isinstance(item, types.Array):
|
||||
s1 = [item.dtype]
|
||||
else:
|
||||
msg = 'Input dtype must be scalar'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
try:
|
||||
l = [as_dtype(t) for t in s1]
|
||||
l.append(res_dtype)
|
||||
res_dtype = (np.result_type(*l))
|
||||
except errors.NumbaNotImplementedError:
|
||||
msg = 'Input dtype must be scalar.'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
return from_dtype(res_dtype)
|
||||
|
||||
|
||||
@overload(poly.polyadd)
|
||||
def numpy_polyadd(c1, c2):
|
||||
if not type_can_asarray(c1):
|
||||
msg = 'The argument "c1" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not type_can_asarray(c2):
|
||||
msg = 'The argument "c2" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(c1, c2):
|
||||
arr1, arr2 = pu.as_series((c1, c2))
|
||||
diff = len(arr2) - len(arr1)
|
||||
if diff > 0:
|
||||
zr = np.zeros(diff)
|
||||
arr1 = np.concatenate((arr1, zr))
|
||||
if diff < 0:
|
||||
zr = np.zeros(-diff)
|
||||
arr2 = np.concatenate((arr2, zr))
|
||||
val = arr1 + arr2
|
||||
return pu.trimseq(val)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
@overload(poly.polysub)
|
||||
def numpy_polysub(c1, c2):
|
||||
if not type_can_asarray(c1):
|
||||
msg = 'The argument "c1" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not type_can_asarray(c2):
|
||||
msg = 'The argument "c2" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(c1, c2):
|
||||
arr1, arr2 = pu.as_series((c1, c2))
|
||||
diff = len(arr2) - len(arr1)
|
||||
if diff > 0:
|
||||
zr = np.zeros(diff)
|
||||
arr1 = np.concatenate((arr1, zr))
|
||||
if diff < 0:
|
||||
zr = np.zeros(-diff)
|
||||
arr2 = np.concatenate((arr2, zr))
|
||||
val = arr1 - arr2
|
||||
return pu.trimseq(val)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
@overload(poly.polymul)
|
||||
def numpy_polymul(c1, c2):
|
||||
if not type_can_asarray(c1):
|
||||
msg = 'The argument "c1" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not type_can_asarray(c2):
|
||||
msg = 'The argument "c2" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(c1, c2):
|
||||
arr1, arr2 = pu.as_series((c1, c2))
|
||||
val = np.convolve(arr1, arr2)
|
||||
return pu.trimseq(val)
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
@overload(poly.polyval, prefer_literal=True)
|
||||
def poly_polyval(x, c, tensor=True):
|
||||
if not type_can_asarray(x):
|
||||
msg = 'The argument "x" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not type_can_asarray(c):
|
||||
msg = 'The argument "c" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not isinstance(tensor, (bool, types.BooleanLiteral)):
|
||||
msg = 'The argument "tensor" must be boolean'
|
||||
raise errors.RequireLiteralValue(msg)
|
||||
|
||||
res_dtype = _poly_result_dtype(c, x)
|
||||
|
||||
# Simulate new_shape = (1,) * np.ndim(x) in the general case
|
||||
# If x is a number, new_shape is not used
|
||||
# If x is a tuple or a list, then it's 1d hence new_shape=(1,)
|
||||
x_nd_array = not isinstance(x, types.Number)
|
||||
new_shape = (1,)
|
||||
if isinstance(x, types.Array):
|
||||
# If x is a np.array, then take its dimension
|
||||
new_shape = (1,) * np.ndim(x)
|
||||
|
||||
if isinstance(tensor, bool):
|
||||
tensor_arg = tensor
|
||||
else:
|
||||
tensor_arg = tensor.literal_value
|
||||
|
||||
def impl(x, c, tensor=True):
|
||||
arr = np.asarray(c).astype(res_dtype)
|
||||
inputs = np.asarray(x).astype(res_dtype)
|
||||
if x_nd_array and tensor_arg:
|
||||
arr = arr.reshape(arr.shape + new_shape)
|
||||
|
||||
l = len(arr)
|
||||
y = arr[l - 1] + inputs * 0
|
||||
|
||||
for i in range(l - 1, 0, -1):
|
||||
y = arr[i - 1] + y * inputs
|
||||
|
||||
return y
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
@overload(poly.polyint)
|
||||
def poly_polyint(c, m=1):
|
||||
|
||||
if not type_can_asarray(c):
|
||||
msg = 'The argument "c" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not isinstance(m, (int, types.Integer)):
|
||||
msg = 'The argument "m" must be an integer'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
res_dtype = as_dtype(_poly_result_dtype(c))
|
||||
|
||||
if not np.issubdtype(res_dtype, np.number):
|
||||
msg = f'Input dtype must be scalar. Found {res_dtype} instead'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
is1D = ((np.ndim(c) == 1) or
|
||||
(isinstance(c, (types.List, types.BaseTuple))
|
||||
and isinstance(c.dtype, types.Number)))
|
||||
|
||||
def impl(c, m=1):
|
||||
c = np.asarray(c).astype(res_dtype)
|
||||
cdt = c.dtype
|
||||
for i in range(m):
|
||||
n = len(c)
|
||||
|
||||
tmp = np.empty((n + 1,) + c.shape[1:], dtype=cdt)
|
||||
tmp[0] = c[0] * 0
|
||||
tmp[1] = c[0]
|
||||
for j in range(1, n):
|
||||
tmp[j + 1] = c[j] / (j + 1)
|
||||
c = tmp
|
||||
if is1D:
|
||||
return pu.trimseq(c)
|
||||
else:
|
||||
return c
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
@overload(poly.polydiv)
|
||||
def numpy_polydiv(c1, c2):
|
||||
if not type_can_asarray(c1):
|
||||
msg = 'The argument "c1" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
if not type_can_asarray(c2):
|
||||
msg = 'The argument "c2" must be array-like'
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(c1, c2):
|
||||
arr1, arr2 = pu.as_series((c1, c2))
|
||||
if arr2[-1] == 0:
|
||||
raise ZeroDivisionError()
|
||||
|
||||
l1 = len(arr1)
|
||||
l2 = len(arr2)
|
||||
if l1 < l2:
|
||||
return arr1[:1] * 0, arr1
|
||||
elif l2 == 1:
|
||||
return arr1 / arr2[-1], arr1[:1] * 0
|
||||
else:
|
||||
dlen = l1 - l2
|
||||
scl = arr2[-1]
|
||||
arr2 = arr2[:-1] / scl
|
||||
i = dlen
|
||||
j = l1 - 1
|
||||
while i >= 0:
|
||||
arr1[i:j] -= arr2 * arr1[j]
|
||||
i -= 1
|
||||
j -= 1
|
||||
return arr1[j + 1:] / scl, pu.trimseq(arr1[:j + 1])
|
||||
|
||||
return impl
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,740 @@
|
||||
"""
|
||||
Algorithmic implementations for generating different types
|
||||
of random distributions.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from numba.core.extending import register_jitable
|
||||
from numba.np.random._constants import (wi_double, ki_double,
|
||||
ziggurat_nor_r, fi_double,
|
||||
wi_float, ki_float,
|
||||
ziggurat_nor_inv_r_f,
|
||||
ziggurat_nor_r_f, fi_float,
|
||||
we_double, ke_double,
|
||||
ziggurat_exp_r, fe_double,
|
||||
we_float, ke_float,
|
||||
ziggurat_exp_r_f, fe_float,
|
||||
INT64_MAX, ziggurat_nor_inv_r)
|
||||
from numba.np.random.generator_core import (next_double, next_float,
|
||||
next_uint32, next_uint64)
|
||||
from numba import float32, int64
|
||||
from numba.np.numpy_support import numpy_version
|
||||
# All of the following implementations are direct translations from:
|
||||
# https://github.com/numpy/numpy/blob/7cfef93c77599bd387ecc6a15d186c5a46024dac/numpy/random/src/distributions/distributions.c
|
||||
|
||||
|
||||
@register_jitable
|
||||
def np_log1p(x):
|
||||
return np.log1p(x)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def np_log1pf(x):
|
||||
return np.log1p(float32(x))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_rayleigh(bitgen, mode):
|
||||
return mode * np.sqrt(2.0 * random_standard_exponential(bitgen))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def np_expm1(x):
|
||||
return np.expm1(x)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_normal(bitgen):
|
||||
while 1:
|
||||
r = next_uint64(bitgen)
|
||||
idx = r & 0xff
|
||||
r >>= 8
|
||||
sign = r & 0x1
|
||||
rabs = (r >> 1) & 0x000fffffffffffff
|
||||
x = rabs * wi_double[idx]
|
||||
if (sign & 0x1):
|
||||
x = -x
|
||||
if rabs < ki_double[idx]:
|
||||
return x
|
||||
if idx == 0:
|
||||
while 1:
|
||||
xx = -ziggurat_nor_inv_r * np.log1p(-next_double(bitgen))
|
||||
yy = -np.log1p(-next_double(bitgen))
|
||||
if (yy + yy > xx * xx):
|
||||
if ((rabs >> 8) & 0x1):
|
||||
return -(ziggurat_nor_r + xx)
|
||||
else:
|
||||
return ziggurat_nor_r + xx
|
||||
else:
|
||||
if (((fi_double[idx - 1] - fi_double[idx]) *
|
||||
next_double(bitgen) + fi_double[idx]) <
|
||||
np.exp(-0.5 * x * x)):
|
||||
return x
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_normal_f(bitgen):
|
||||
while 1:
|
||||
r = next_uint32(bitgen)
|
||||
idx = r & 0xff
|
||||
sign = (r >> 8) & 0x1
|
||||
rabs = (r >> 9) & 0x0007fffff
|
||||
x = float32(float32(rabs) * wi_float[idx])
|
||||
if (sign & 0x1):
|
||||
x = -x
|
||||
if (rabs < ki_float[idx]):
|
||||
return x
|
||||
if (idx == 0):
|
||||
while 1:
|
||||
xx = float32(-ziggurat_nor_inv_r_f *
|
||||
np_log1pf(-next_float(bitgen)))
|
||||
yy = float32(-np_log1pf(-next_float(bitgen)))
|
||||
if (float32(yy + yy) > float32(xx * xx)):
|
||||
if ((rabs >> 8) & 0x1):
|
||||
return -float32(ziggurat_nor_r_f + xx)
|
||||
else:
|
||||
return float32(ziggurat_nor_r_f + xx)
|
||||
else:
|
||||
if (((fi_float[idx - 1] - fi_float[idx]) * next_float(bitgen) +
|
||||
fi_float[idx]) < float32(np.exp(-float32(0.5) * x * x))):
|
||||
return x
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_exponential(bitgen):
|
||||
while 1:
|
||||
ri = next_uint64(bitgen)
|
||||
ri >>= 3
|
||||
idx = ri & 0xFF
|
||||
ri >>= 8
|
||||
x = ri * we_double[idx]
|
||||
if (ri < ke_double[idx]):
|
||||
return x
|
||||
else:
|
||||
if idx == 0:
|
||||
return ziggurat_exp_r - np_log1p(-next_double(bitgen))
|
||||
elif ((fe_double[idx - 1] - fe_double[idx]) * next_double(bitgen) +
|
||||
fe_double[idx] < np.exp(-x)):
|
||||
return x
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_exponential_f(bitgen):
|
||||
while 1:
|
||||
ri = next_uint32(bitgen)
|
||||
ri >>= 1
|
||||
idx = ri & 0xFF
|
||||
ri >>= 8
|
||||
x = float32(float32(ri) * we_float[idx])
|
||||
if (ri < ke_float[idx]):
|
||||
return x
|
||||
else:
|
||||
if (idx == 0):
|
||||
return float32(ziggurat_exp_r_f -
|
||||
float32(np_log1pf(-next_float(bitgen))))
|
||||
elif ((fe_float[idx - 1] - fe_float[idx]) * next_float(bitgen) +
|
||||
fe_float[idx] < float32(np.exp(float32(-x)))):
|
||||
return x
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_exponential_inv(bitgen):
|
||||
return -np_log1p(-next_double(bitgen))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_exponential_inv_f(bitgen):
|
||||
return -np.log(float32(1.0) - next_float(bitgen))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_gamma(bitgen, shape):
|
||||
if (shape == 1.0):
|
||||
return random_standard_exponential(bitgen)
|
||||
elif (shape == 0.0):
|
||||
return 0.0
|
||||
elif (shape < 1.0):
|
||||
while 1:
|
||||
U = next_double(bitgen)
|
||||
V = random_standard_exponential(bitgen)
|
||||
if (U <= 1.0 - shape):
|
||||
X = pow(U, 1. / shape)
|
||||
if (X <= V):
|
||||
return X
|
||||
else:
|
||||
Y = -np.log((1 - U) / shape)
|
||||
X = pow(1.0 - shape + shape * Y, 1. / shape)
|
||||
if (X <= (V + Y)):
|
||||
return X
|
||||
else:
|
||||
b = shape - 1. / 3.
|
||||
c = 1. / np.sqrt(9 * b)
|
||||
while 1:
|
||||
while 1:
|
||||
X = random_standard_normal(bitgen)
|
||||
V = 1.0 + c * X
|
||||
if (V > 0.0):
|
||||
break
|
||||
|
||||
V = V * V * V
|
||||
U = next_double(bitgen)
|
||||
if (U < 1.0 - 0.0331 * (X * X) * (X * X)):
|
||||
return (b * V)
|
||||
|
||||
if (np.log(U) < 0.5 * X * X + b * (1. - V + np.log(V))):
|
||||
return (b * V)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_gamma_f(bitgen, shape):
|
||||
f32_one = float32(1.0)
|
||||
shape = float32(shape)
|
||||
if (shape == f32_one):
|
||||
return random_standard_exponential_f(bitgen)
|
||||
elif (shape == float32(0.0)):
|
||||
return float32(0.0)
|
||||
elif (shape < f32_one):
|
||||
while 1:
|
||||
U = next_float(bitgen)
|
||||
V = random_standard_exponential_f(bitgen)
|
||||
if (U <= f32_one - shape):
|
||||
X = float32(pow(U, float32(f32_one / shape)))
|
||||
if (X <= V):
|
||||
return X
|
||||
else:
|
||||
Y = float32(-np.log(float32((f32_one - U) / shape)))
|
||||
X = float32(pow(f32_one - shape + float32(shape * Y),
|
||||
float32(f32_one / shape)))
|
||||
if (X <= (V + Y)):
|
||||
return X
|
||||
else:
|
||||
b = shape - f32_one / float32(3.0)
|
||||
c = float32(f32_one / float32(np.sqrt(float32(9.0) * b)))
|
||||
while 1:
|
||||
while 1:
|
||||
X = float32(random_standard_normal_f(bitgen))
|
||||
V = float32(f32_one + c * X)
|
||||
if (V > float32(0.0)):
|
||||
break
|
||||
|
||||
V = float32(V * V * V)
|
||||
U = next_float(bitgen)
|
||||
if (U < f32_one - float32(0.0331) * (X * X) * (X * X)):
|
||||
return float32(b * V)
|
||||
|
||||
if (np.log(U) < float32(0.5) * X * X + b *
|
||||
(f32_one - V + np.log(V))):
|
||||
return float32(b * V)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_normal(bitgen, loc, scale):
|
||||
scaled_normal = scale * random_standard_normal(bitgen)
|
||||
return loc + scaled_normal
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_normal_f(bitgen, loc, scale):
|
||||
scaled_normal = float32(scale * random_standard_normal_f(bitgen))
|
||||
return float32(loc + scaled_normal)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_exponential(bitgen, scale):
|
||||
return scale * random_standard_exponential(bitgen)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_uniform(bitgen, lower, range):
|
||||
scaled_uniform = range * next_double(bitgen)
|
||||
return lower + scaled_uniform
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_gamma(bitgen, shape, scale):
|
||||
return scale * random_standard_gamma(bitgen, shape)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_gamma_f(bitgen, shape, scale):
|
||||
return float32(scale * random_standard_gamma_f(bitgen, shape))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_beta(bitgen, a, b):
|
||||
if a <= 1.0 and b <= 1.0:
|
||||
while 1:
|
||||
U = next_double(bitgen)
|
||||
V = next_double(bitgen)
|
||||
X = pow(U, 1.0 / a)
|
||||
Y = pow(V, 1.0 / b)
|
||||
XpY = X + Y
|
||||
if XpY <= 1.0 and XpY > 0.0:
|
||||
if (X + Y > 0):
|
||||
return X / XpY
|
||||
else:
|
||||
logX = np.log(U) / a
|
||||
logY = np.log(V) / b
|
||||
logM = min(logX, logY)
|
||||
logX -= logM
|
||||
logY -= logM
|
||||
|
||||
return np.exp(logX - np.log(np.exp(logX) + np.exp(logY)))
|
||||
else:
|
||||
Ga = random_standard_gamma(bitgen, a)
|
||||
Gb = random_standard_gamma(bitgen, b)
|
||||
return Ga / (Ga + Gb)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_chisquare(bitgen, df):
|
||||
return 2.0 * random_standard_gamma(bitgen, df / 2.0)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_f(bitgen, dfnum, dfden):
|
||||
return ((random_chisquare(bitgen, dfnum) * dfden) /
|
||||
(random_chisquare(bitgen, dfden) * dfnum))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_cauchy(bitgen):
|
||||
return random_standard_normal(bitgen) / random_standard_normal(bitgen)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_pareto(bitgen, a):
|
||||
return np_expm1(random_standard_exponential(bitgen) / a)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_weibull(bitgen, a):
|
||||
if (a == 0.0):
|
||||
return 0.0
|
||||
return pow(random_standard_exponential(bitgen), 1. / a)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_power(bitgen, a):
|
||||
return pow(-np_expm1(-random_standard_exponential(bitgen)), 1. / a)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_laplace(bitgen, loc, scale):
|
||||
U = next_double(bitgen)
|
||||
while U <= 0:
|
||||
U = next_double(bitgen)
|
||||
if (U >= 0.5):
|
||||
U = loc - scale * np.log(2.0 - U - U)
|
||||
elif (U > 0.0):
|
||||
U = loc + scale * np.log(U + U)
|
||||
return U
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_logistic(bitgen, loc, scale):
|
||||
U = next_double(bitgen)
|
||||
while U <= 0.0:
|
||||
U = next_double(bitgen)
|
||||
return loc + scale * np.log(U / (1.0 - U))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_lognormal(bitgen, mean, sigma):
|
||||
return np.exp(random_normal(bitgen, mean, sigma))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_standard_t(bitgen, df):
|
||||
num = random_standard_normal(bitgen)
|
||||
denom = random_standard_gamma(bitgen, df / 2)
|
||||
return np.sqrt(df / 2) * num / np.sqrt(denom)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_wald(bitgen, mean, scale):
|
||||
mu_2l = mean / (2 * scale)
|
||||
Y = random_standard_normal(bitgen)
|
||||
Y = mean * Y * Y
|
||||
X = mean + mu_2l * (Y - np.sqrt(4 * scale * Y + Y * Y))
|
||||
U = next_double(bitgen)
|
||||
if (U <= mean / (mean + X)):
|
||||
return X
|
||||
else:
|
||||
return mean * mean / X
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_geometric_search(bitgen, p):
|
||||
X = 1
|
||||
sum = prod = p
|
||||
q = 1.0 - p
|
||||
U = next_double(bitgen)
|
||||
while (U > sum):
|
||||
prod *= q
|
||||
sum += prod
|
||||
X = X + 1
|
||||
return X
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_geometric_inversion(bitgen, p):
|
||||
return np.ceil(-random_standard_exponential(bitgen) / np.log1p(-p))
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_geometric(bitgen, p):
|
||||
if (p >= 0.333333333333333333333333):
|
||||
return random_geometric_search(bitgen, p)
|
||||
else:
|
||||
return random_geometric_inversion(bitgen, p)
|
||||
|
||||
|
||||
if numpy_version < (2, 1):
|
||||
@register_jitable
|
||||
def random_zipf(bitgen, a):
|
||||
am1 = a - 1.0
|
||||
b = pow(2.0, am1)
|
||||
while 1:
|
||||
U = 1.0 - next_double(bitgen)
|
||||
V = next_double(bitgen)
|
||||
X = np.floor(pow(U, -1.0 / am1))
|
||||
if (X > INT64_MAX or X < 1.0):
|
||||
continue
|
||||
T = pow(1.0 + 1.0 / X, am1)
|
||||
if (V * X * (T - 1.0) / (b - 1.0) <= T / b):
|
||||
return X
|
||||
else:
|
||||
@register_jitable
|
||||
def random_zipf(bitgen, a):
|
||||
am1 = a - 1.0
|
||||
b = pow(2.0, am1)
|
||||
Umin = pow(INT64_MAX, -am1)
|
||||
while 1:
|
||||
U01 = next_double(bitgen)
|
||||
U = U01 * Umin + (1 - U01)
|
||||
V = next_double(bitgen)
|
||||
X = np.floor(pow(U, -1.0 / am1))
|
||||
if (X > INT64_MAX or X < 1.0):
|
||||
continue
|
||||
|
||||
T = pow(1.0 + 1.0 / X, am1)
|
||||
if (V * X * (T - 1.0) / (b - 1.0) <= T / b):
|
||||
return X
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_triangular(bitgen, left, mode,
|
||||
right):
|
||||
base = right - left
|
||||
leftbase = mode - left
|
||||
ratio = leftbase / base
|
||||
leftprod = leftbase * base
|
||||
rightprod = (right - mode) * base
|
||||
|
||||
U = next_double(bitgen)
|
||||
if (U <= ratio):
|
||||
return left + np.sqrt(U * leftprod)
|
||||
else:
|
||||
return right - np.sqrt((1.0 - U) * rightprod)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_loggam(x):
|
||||
a = [8.333333333333333e-02, -2.777777777777778e-03,
|
||||
7.936507936507937e-04, -5.952380952380952e-04,
|
||||
8.417508417508418e-04, -1.917526917526918e-03,
|
||||
6.410256410256410e-03, -2.955065359477124e-02,
|
||||
1.796443723688307e-01, -1.39243221690590e+00]
|
||||
|
||||
if ((x == 1.0) or (x == 2.0)):
|
||||
return 0.0
|
||||
elif (x < 7.0):
|
||||
n = int(7 - x)
|
||||
else:
|
||||
n = 0
|
||||
|
||||
x0 = x + n
|
||||
x2 = (1.0 / x0) * (1.0 / x0)
|
||||
# /* log(2 * M_PI) */
|
||||
lg2pi = 1.8378770664093453e+00
|
||||
gl0 = a[9]
|
||||
|
||||
for k in range(0, 9):
|
||||
gl0 *= x2
|
||||
gl0 += a[8 - k]
|
||||
|
||||
gl = gl0 / x0 + 0.5 * lg2pi + (x0 - 0.5) * np.log(x0) - x0
|
||||
if (x < 7.0):
|
||||
for k in range(1, n + 1):
|
||||
gl = gl - np.log(x0 - 1.0)
|
||||
x0 = x0 - 1.0
|
||||
|
||||
return gl
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_poisson_mult(bitgen, lam):
|
||||
enlam = np.exp(-lam)
|
||||
X = 0
|
||||
prod = 1.0
|
||||
while (1):
|
||||
U = next_double(bitgen)
|
||||
prod *= U
|
||||
if (prod > enlam):
|
||||
X += 1
|
||||
else:
|
||||
return X
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_poisson_ptrs(bitgen, lam):
|
||||
|
||||
slam = np.sqrt(lam)
|
||||
loglam = np.log(lam)
|
||||
b = 0.931 + 2.53 * slam
|
||||
a = -0.059 + 0.02483 * b
|
||||
invalpha = 1.1239 + 1.1328 / (b - 3.4)
|
||||
vr = 0.9277 - 3.6224 / (b - 2)
|
||||
|
||||
while (1):
|
||||
U = next_double(bitgen) - 0.5
|
||||
V = next_double(bitgen)
|
||||
us = 0.5 - np.fabs(U)
|
||||
k = int((2 * a / us + b) * U + lam + 0.43)
|
||||
if ((us >= 0.07) and (V <= vr)):
|
||||
return k
|
||||
|
||||
if ((k < 0) or ((us < 0.013) and (V > us))):
|
||||
continue
|
||||
|
||||
# /* log(V) == log(0.0) ok here */
|
||||
# /* if U==0.0 so that us==0.0, log is ok since always returns */
|
||||
if ((np.log(V) + np.log(invalpha) - np.log(a / (us * us) + b)) <=
|
||||
(-lam + k * loglam - random_loggam(k + 1))):
|
||||
return k
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_poisson(bitgen, lam):
|
||||
if (lam >= 10):
|
||||
return random_poisson_ptrs(bitgen, lam)
|
||||
elif (lam == 0):
|
||||
return 0
|
||||
else:
|
||||
return random_poisson_mult(bitgen, lam)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_negative_binomial(bitgen, n, p):
|
||||
Y = random_gamma(bitgen, n, (1 - p) / p)
|
||||
return random_poisson(bitgen, Y)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_noncentral_chisquare(bitgen, df, nonc):
|
||||
if np.isnan(nonc):
|
||||
return np.nan
|
||||
|
||||
if nonc == 0:
|
||||
return random_chisquare(bitgen, df)
|
||||
|
||||
if 1 < df:
|
||||
Chi2 = random_chisquare(bitgen, df - 1)
|
||||
n = random_standard_normal(bitgen) + np.sqrt(nonc)
|
||||
return Chi2 + n * n
|
||||
else:
|
||||
i = random_poisson(bitgen, nonc / 2.0)
|
||||
return random_chisquare(bitgen, df + 2 * i)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_noncentral_f(bitgen, dfnum, dfden, nonc):
|
||||
t = random_noncentral_chisquare(bitgen, dfnum, nonc) * dfden
|
||||
return t / (random_chisquare(bitgen, dfden) * dfnum)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_logseries(bitgen, p):
|
||||
r = np_log1p(-p)
|
||||
|
||||
while 1:
|
||||
V = next_double(bitgen)
|
||||
if (V >= p):
|
||||
return 1
|
||||
U = next_double(bitgen)
|
||||
q = -np.expm1(r * U)
|
||||
if (V <= q * q):
|
||||
result = int64(np.floor(1 + np.log(V) / np.log(q)))
|
||||
if result < 1 or V == 0.0:
|
||||
continue
|
||||
else:
|
||||
return result
|
||||
if (V >= q):
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_binomial_btpe(bitgen, n, p):
|
||||
r = min(p, 1.0 - p)
|
||||
q = 1.0 - r
|
||||
fm = n * r + r
|
||||
m = int(np.floor(fm))
|
||||
p1 = np.floor(2.195 * np.sqrt(n * r * q) - 4.6 * q) + 0.5
|
||||
xm = m + 0.5
|
||||
xl = xm - p1
|
||||
xr = xm + p1
|
||||
c = 0.134 + 20.5 / (15.3 + m)
|
||||
a = (fm - xl) / (fm - xl * r)
|
||||
laml = a * (1.0 + a / 2.0)
|
||||
a = (xr - fm) / (xr * q)
|
||||
lamr = a * (1.0 + a / 2.0)
|
||||
p2 = p1 * (1.0 + 2.0 * c)
|
||||
p3 = p2 + c / laml
|
||||
p4 = p3 + c / lamr
|
||||
|
||||
case = 10
|
||||
y = k = 0
|
||||
while 1:
|
||||
if case == 10:
|
||||
nrq = n * r * q
|
||||
u = next_double(bitgen) * p4
|
||||
v = next_double(bitgen)
|
||||
if (u > p1):
|
||||
case = 20
|
||||
continue
|
||||
y = int(np.floor(xm - p1 * v + u))
|
||||
case = 60
|
||||
continue
|
||||
elif case == 20:
|
||||
if (u > p2):
|
||||
case = 30
|
||||
continue
|
||||
x = xl + (u - p1) / c
|
||||
v = v * c + 1.0 - np.fabs(m - x + 0.5) / p1
|
||||
if (v > 1.0):
|
||||
case = 10
|
||||
continue
|
||||
y = int(np.floor(x))
|
||||
case = 50
|
||||
continue
|
||||
elif case == 30:
|
||||
if (u > p3):
|
||||
case = 40
|
||||
continue
|
||||
y = int(np.floor(xl + np.log(v) / laml))
|
||||
if ((y < 0) or (v == 0.0)):
|
||||
case = 10
|
||||
continue
|
||||
v = v * (u - p2) * laml
|
||||
case = 50
|
||||
continue
|
||||
elif case == 40:
|
||||
y = int(np.floor(xr - np.log(v) / lamr))
|
||||
if ((y > n) or (v == 0.0)):
|
||||
case = 10
|
||||
continue
|
||||
v = v * (u - p3) * lamr
|
||||
case = 50
|
||||
continue
|
||||
elif case == 50:
|
||||
k = abs(y - m)
|
||||
if ((k > 20) and (k < ((nrq) / 2.0 - 1))):
|
||||
case = 52
|
||||
continue
|
||||
s = r / q
|
||||
a = s * (n + 1)
|
||||
F = 1.0
|
||||
if (m < y):
|
||||
for i in range(m + 1, y + 1):
|
||||
F = F * (a / i - s)
|
||||
elif (m > y):
|
||||
for i in range(y + 1, m + 1):
|
||||
F = F / (a / i - s)
|
||||
if (v > F):
|
||||
case = 10
|
||||
continue
|
||||
case = 60
|
||||
continue
|
||||
elif case == 52:
|
||||
rho = (k / (nrq)) * \
|
||||
((k * (k / 3.0 + 0.625) + 0.16666666666666666) /
|
||||
nrq + 0.5)
|
||||
t = -k * k / (2 * nrq)
|
||||
A = np.log(v)
|
||||
if (A < (t - rho)):
|
||||
case = 60
|
||||
continue
|
||||
if (A > (t + rho)):
|
||||
case = 10
|
||||
continue
|
||||
x1 = y + 1
|
||||
f1 = m + 1
|
||||
z = n + 1 - m
|
||||
w = n - y + 1
|
||||
x2 = x1 * x1
|
||||
f2 = f1 * f1
|
||||
z2 = z * z
|
||||
w2 = w * w
|
||||
if (A > (xm * np.log(f1 / x1) + (n - m + 0.5) * np.log(z / w) +
|
||||
(y - m) * np.log(w * r / (x1 * q)) +
|
||||
(13680. - (462. - (132. - (99. - 140. / f2) / f2) / f2)
|
||||
/ f2) / f1 / 166320. +
|
||||
(13680. - (462. - (132. - (99. - 140. / z2) / z2) / z2)
|
||||
/ z2) / z / 166320. +
|
||||
(13680. - (462. - (132. - (99. - 140. / x2) / x2) / x2)
|
||||
/ x2) / x1 / 166320. +
|
||||
(13680. - (462. - (132. - (99. - 140. / w2) / w2) / w2)
|
||||
/ w2) / w / 66320.)):
|
||||
case = 10
|
||||
continue
|
||||
case = 60
|
||||
continue
|
||||
elif case == 60:
|
||||
if (p > 0.5):
|
||||
y = n - y
|
||||
return y
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_binomial_inversion(bitgen, n, p):
|
||||
q = 1.0 - p
|
||||
qn = np.exp(n * np.log(q))
|
||||
_np = n * p
|
||||
bound = min(n, _np + 10.0 * np.sqrt(_np * q + 1))
|
||||
|
||||
X = 0
|
||||
px = qn
|
||||
U = next_double(bitgen)
|
||||
while (U > px):
|
||||
X = X + 1
|
||||
if (X > bound):
|
||||
X = 0
|
||||
px = qn
|
||||
U = next_double(bitgen)
|
||||
else:
|
||||
U -= px
|
||||
px = ((n - X + 1) * p * px) / (X * q)
|
||||
|
||||
return X
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_binomial(bitgen, n, p):
|
||||
if ((n == 0) or (p == 0.0)):
|
||||
return 0
|
||||
|
||||
if (p <= 0.5):
|
||||
if (p * n <= 30.0):
|
||||
return random_binomial_inversion(bitgen, n, p)
|
||||
else:
|
||||
return random_binomial_btpe(bitgen, n, p)
|
||||
else:
|
||||
q = 1.0 - p
|
||||
if (q * n <= 30.0):
|
||||
return n - random_binomial_inversion(bitgen, n, q)
|
||||
else:
|
||||
return n - random_binomial_btpe(bitgen, n, q)
|
||||
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Core Implementations for Generator/BitGenerator Models.
|
||||
"""
|
||||
|
||||
from llvmlite import ir
|
||||
from numba.core import cgutils, types
|
||||
from numba.core.extending import (intrinsic, make_attribute_wrapper, models,
|
||||
overload, register_jitable,
|
||||
register_model)
|
||||
|
||||
|
||||
@register_model(types.NumPyRandomBitGeneratorType)
|
||||
class NumPyRngBitGeneratorModel(models.StructModel):
|
||||
def __init__(self, dmm, fe_type):
|
||||
members = [
|
||||
('parent', types.pyobject),
|
||||
('state_address', types.uintp),
|
||||
('state', types.uintp),
|
||||
('fnptr_next_uint64', types.uintp),
|
||||
('fnptr_next_uint32', types.uintp),
|
||||
('fnptr_next_double', types.uintp),
|
||||
('bit_generator', types.uintp),
|
||||
]
|
||||
super(NumPyRngBitGeneratorModel, self).__init__(dmm, fe_type, members)
|
||||
|
||||
|
||||
_bit_gen_type = types.NumPyRandomBitGeneratorType('bit_generator')
|
||||
|
||||
|
||||
@register_model(types.NumPyRandomGeneratorType)
|
||||
class NumPyRandomGeneratorTypeModel(models.StructModel):
|
||||
def __init__(self, dmm, fe_type):
|
||||
members = [
|
||||
('bit_generator', _bit_gen_type),
|
||||
('meminfo', types.MemInfoPointer(types.voidptr)),
|
||||
('parent', types.pyobject)
|
||||
]
|
||||
super(
|
||||
NumPyRandomGeneratorTypeModel,
|
||||
self).__init__(
|
||||
dmm,
|
||||
fe_type,
|
||||
members)
|
||||
|
||||
|
||||
# The Generator instances have a bit_generator attr
|
||||
make_attribute_wrapper(
|
||||
types.NumPyRandomGeneratorType,
|
||||
'bit_generator',
|
||||
'bit_generator')
|
||||
|
||||
|
||||
def _generate_next_binding(overloadable_function, return_type):
|
||||
"""
|
||||
Generate the overloads for "next_(some type)" functions.
|
||||
"""
|
||||
@intrinsic
|
||||
def intrin_NumPyRandomBitGeneratorType_next_ty(tyctx, inst):
|
||||
sig = return_type(inst)
|
||||
|
||||
def codegen(cgctx, builder, sig, llargs):
|
||||
name = overloadable_function.__name__
|
||||
struct_ptr = cgutils.create_struct_proxy(inst)(cgctx, builder,
|
||||
value=llargs[0])
|
||||
|
||||
# Get the 'state' and 'fnptr_next_(type)' members of the struct
|
||||
state = struct_ptr.state
|
||||
next_double_addr = getattr(struct_ptr, f'fnptr_{name}')
|
||||
|
||||
# LLVM IR types needed
|
||||
ll_void_ptr_t = cgctx.get_value_type(types.voidptr)
|
||||
ll_return_t = cgctx.get_value_type(return_type)
|
||||
ll_uintp_t = cgctx.get_value_type(types.uintp)
|
||||
|
||||
# Convert the stored Generator function address to a pointer
|
||||
next_fn_fnptr = builder.inttoptr(
|
||||
next_double_addr, ll_void_ptr_t)
|
||||
# Add the function to the module
|
||||
fnty = ir.FunctionType(ll_return_t, (ll_uintp_t,))
|
||||
next_fn = cgutils.get_or_insert_function(
|
||||
builder.module, fnty, name)
|
||||
# Bit cast the function pointer to the function type
|
||||
fnptr_as_fntype = builder.bitcast(next_fn_fnptr, next_fn.type)
|
||||
# call it with the "state" address as the arg
|
||||
ret = builder.call(fnptr_as_fntype, (state,))
|
||||
return ret
|
||||
return sig, codegen
|
||||
|
||||
@overload(overloadable_function)
|
||||
def ol_next_ty(bitgen):
|
||||
if isinstance(bitgen, types.NumPyRandomBitGeneratorType):
|
||||
def impl(bitgen):
|
||||
return intrin_NumPyRandomBitGeneratorType_next_ty(bitgen)
|
||||
return impl
|
||||
|
||||
|
||||
# Some function stubs for "next(some type)", these will be overloaded
|
||||
def next_double(bitgen):
|
||||
return bitgen.ctypes.next_double(bitgen.ctypes.state)
|
||||
|
||||
|
||||
def next_uint32(bitgen):
|
||||
return bitgen.ctypes.next_uint32(bitgen.ctypes.state)
|
||||
|
||||
|
||||
def next_uint64(bitgen):
|
||||
return bitgen.ctypes.next_uint64(bitgen.ctypes.state)
|
||||
|
||||
|
||||
_generate_next_binding(next_double, types.double)
|
||||
_generate_next_binding(next_uint32, types.uint32)
|
||||
_generate_next_binding(next_uint64, types.uint64)
|
||||
|
||||
|
||||
# See: https://github.com/numpy/numpy/pull/20314
|
||||
@register_jitable
|
||||
def next_float(bitgen):
|
||||
return types.float32(types.float32(next_uint32(bitgen) >> 8)
|
||||
* types.float32(1.0)
|
||||
/ types.float32(16777216.0))
|
||||
@@ -0,0 +1,971 @@
|
||||
"""
|
||||
Implementation of method overloads for Generator objects.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from numba.core import types
|
||||
from numba.core.extending import overload_method, register_jitable
|
||||
from numba.np.numpy_support import as_dtype, from_dtype
|
||||
from numba.np.random.generator_core import next_float, next_double
|
||||
from numba.np.numpy_support import is_nonelike
|
||||
from numba.core.errors import TypingError
|
||||
from numba.core.types.containers import Tuple, UniTuple
|
||||
from numba.np.random.distributions import \
|
||||
(random_standard_exponential_inv_f, random_standard_exponential_inv,
|
||||
random_standard_exponential, random_standard_normal_f,
|
||||
random_standard_gamma, random_standard_normal, random_uniform,
|
||||
random_standard_exponential_f, random_standard_gamma_f, random_normal,
|
||||
random_exponential, random_gamma, random_beta, random_power,
|
||||
random_f,random_chisquare,random_standard_cauchy,random_pareto,
|
||||
random_weibull, random_laplace, random_logistic,
|
||||
random_lognormal, random_rayleigh, random_standard_t, random_wald,
|
||||
random_geometric, random_zipf, random_triangular,
|
||||
random_poisson, random_negative_binomial, random_logseries,
|
||||
random_noncentral_chisquare, random_noncentral_f, random_binomial)
|
||||
from numba.np.random import random_methods
|
||||
|
||||
|
||||
def _get_proper_func(func_32, func_64, dtype, dist_name="the given"):
|
||||
"""
|
||||
Most of the standard NumPy distributions that accept dtype argument
|
||||
only support either np.float32 or np.float64 as dtypes.
|
||||
|
||||
This is a helper function that helps Numba select the proper underlying
|
||||
implementation according to provided dtype.
|
||||
"""
|
||||
if isinstance(dtype, types.Omitted):
|
||||
dtype = dtype.value
|
||||
|
||||
np_dt = dtype
|
||||
if isinstance(dtype, type):
|
||||
nb_dt = from_dtype(np.dtype(dtype))
|
||||
elif isinstance(dtype, types.NumberClass):
|
||||
nb_dt = dtype
|
||||
np_dt = as_dtype(nb_dt)
|
||||
|
||||
if np_dt not in [np.float32, np.float64]:
|
||||
raise TypingError("Argument dtype is not one of the" +
|
||||
" expected type(s): " +
|
||||
" np.float32 or np.float64")
|
||||
|
||||
if np_dt == np.float32:
|
||||
next_func = func_32
|
||||
else:
|
||||
next_func = func_64
|
||||
|
||||
return next_func, nb_dt
|
||||
|
||||
|
||||
def check_size(size):
|
||||
if not any([isinstance(size, UniTuple) and
|
||||
isinstance(size.dtype, types.Integer),
|
||||
isinstance(size, Tuple) and size.count == 0,
|
||||
isinstance(size, types.Integer)]):
|
||||
raise TypingError("Argument size is not one of the" +
|
||||
" expected type(s): " +
|
||||
" an integer, an empty tuple or a tuple of integers")
|
||||
|
||||
|
||||
def check_types(obj, type_list, arg_name):
|
||||
"""
|
||||
Check if given object is one of the provided types.
|
||||
If not raises an TypeError
|
||||
"""
|
||||
if isinstance(obj, types.Omitted):
|
||||
obj = obj.value
|
||||
|
||||
if not isinstance(type_list, (list, tuple)):
|
||||
type_list = [type_list]
|
||||
|
||||
if not any([isinstance(obj, _type) for _type in type_list]):
|
||||
raise TypingError(f"Argument {arg_name} is not one of the" +
|
||||
f" expected type(s): {type_list}")
|
||||
|
||||
|
||||
# Overload the Generator().integers()
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'integers')
|
||||
def NumPyRandomGeneratorType_integers(inst, low, high, size=None,
|
||||
dtype=np.int64, endpoint=False):
|
||||
check_types(low, [types.Integer,
|
||||
types.Boolean, bool, int], 'low')
|
||||
check_types(high, [types.Integer, types.Boolean,
|
||||
bool, int], 'high')
|
||||
check_types(endpoint, [types.Boolean, bool], 'endpoint')
|
||||
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if isinstance(dtype, types.Omitted):
|
||||
dtype = dtype.value
|
||||
|
||||
if isinstance(dtype, type):
|
||||
nb_dt = from_dtype(np.dtype(dtype))
|
||||
_dtype = dtype
|
||||
elif isinstance(dtype, types.NumberClass):
|
||||
nb_dt = dtype
|
||||
_dtype = as_dtype(nb_dt)
|
||||
else:
|
||||
raise TypingError("Argument dtype is not one of the" +
|
||||
" expected type(s): " +
|
||||
"np.int32, np.int64, np.int16, np.int8, "
|
||||
"np.uint32, np.uint64, np.uint16, np.uint8, "
|
||||
"np.bool_")
|
||||
|
||||
if _dtype == np.bool_:
|
||||
int_func = random_methods.random_bounded_bool_fill
|
||||
lower_bound = -1
|
||||
upper_bound = 2
|
||||
else:
|
||||
try:
|
||||
i_info = np.iinfo(_dtype)
|
||||
except ValueError:
|
||||
raise TypingError("Argument dtype is not one of the" +
|
||||
" expected type(s): " +
|
||||
"np.int32, np.int64, np.int16, np.int8, "
|
||||
"np.uint32, np.uint64, np.uint16, np.uint8, "
|
||||
"np.bool_")
|
||||
int_func = getattr(random_methods,
|
||||
f'random_bounded_uint{i_info.bits}_fill')
|
||||
lower_bound = i_info.min
|
||||
upper_bound = i_info.max
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, low, high, size=None,
|
||||
dtype=np.int64, endpoint=False):
|
||||
random_methods._randint_arg_check(low, high, endpoint,
|
||||
lower_bound, upper_bound)
|
||||
if not endpoint:
|
||||
high -= dtype(1)
|
||||
low = dtype(low)
|
||||
high = dtype(high)
|
||||
rng = high - low
|
||||
return int_func(inst.bit_generator, low, rng, 1, dtype)[0]
|
||||
else:
|
||||
low = dtype(low)
|
||||
high = dtype(high)
|
||||
rng = high - low
|
||||
return int_func(inst.bit_generator, low, rng, 1, dtype)[0]
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, low, high, size=None,
|
||||
dtype=np.int64, endpoint=False):
|
||||
random_methods._randint_arg_check(low, high, endpoint,
|
||||
lower_bound, upper_bound)
|
||||
if not endpoint:
|
||||
high -= dtype(1)
|
||||
low = dtype(low)
|
||||
high = dtype(high)
|
||||
rng = high - low
|
||||
return int_func(inst.bit_generator, low, rng, size, dtype)
|
||||
else:
|
||||
low = dtype(low)
|
||||
high = dtype(high)
|
||||
rng = high - low
|
||||
return int_func(inst.bit_generator, low, rng, size, dtype)
|
||||
return impl
|
||||
|
||||
|
||||
# The following `shuffle` implementation is a direct translation from:
|
||||
# https://github.com/numpy/numpy/blob/95e3e7f445407e4f355b23d6a9991d8774f0eb0c/numpy/random/_generator.pyx#L4578
|
||||
|
||||
# Overload the Generator().shuffle()
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'shuffle')
|
||||
def NumPyRandomGeneratorType_shuffle(inst, x, axis=0):
|
||||
check_types(x, [types.Array], 'x')
|
||||
check_types(axis, [int, types.Integer], 'axis')
|
||||
|
||||
def impl(inst, x, axis=0):
|
||||
if axis < 0:
|
||||
axis = axis + x.ndim
|
||||
if axis > x.ndim - 1 or axis < 0:
|
||||
raise IndexError("Axis is out of bounds for the given array")
|
||||
|
||||
z = np.swapaxes(x, 0, axis)
|
||||
buf = np.empty_like(z[0, ...])
|
||||
|
||||
for i in range(len(z) - 1, 0, -1):
|
||||
j = types.intp(random_methods.random_interval(inst.bit_generator,
|
||||
i))
|
||||
if i == j:
|
||||
continue
|
||||
buf[...] = z[j, ...]
|
||||
z[j, ...] = z[i, ...]
|
||||
z[i, ...] = buf
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
# The following `permutation` implementation is a direct translation from:
|
||||
# https://github.com/numpy/numpy/blob/95e3e7f445407e4f355b23d6a9991d8774f0eb0c/numpy/random/_generator.pyx#L4710
|
||||
# Overload the Generator().permutation()
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'permutation')
|
||||
def NumPyRandomGeneratorType_permutation(inst, x, axis=0):
|
||||
check_types(x, [types.Array, types.Integer], 'x')
|
||||
check_types(axis, [int, types.Integer], 'axis')
|
||||
|
||||
IS_INT = isinstance(x, types.Integer)
|
||||
|
||||
def impl(inst, x, axis=0):
|
||||
if IS_INT:
|
||||
new_arr = np.arange(x)
|
||||
# NumPy ignores the axis argument when x is an integer
|
||||
inst.shuffle(new_arr)
|
||||
else:
|
||||
new_arr = x.copy()
|
||||
inst.shuffle(new_arr, axis=axis)
|
||||
return new_arr
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().random()
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'random')
|
||||
def NumPyRandomGeneratorType_random(inst, size=None, dtype=np.float64):
|
||||
dist_func, nb_dt = _get_proper_func(next_float, next_double,
|
||||
dtype, "random")
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, size=None, dtype=np.float64):
|
||||
return nb_dt(dist_func(inst.bit_generator))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, size=None, dtype=np.float64):
|
||||
out = np.empty(size, dtype=dtype)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = dist_func(inst.bit_generator)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().standard_exponential() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'standard_exponential')
|
||||
def NumPyRandomGeneratorType_standard_exponential(inst, size=None,
|
||||
dtype=np.float64,
|
||||
method='zig'):
|
||||
check_types(method, [types.UnicodeType, str], 'method')
|
||||
dist_func_inv, nb_dt = _get_proper_func(
|
||||
random_standard_exponential_inv_f,
|
||||
random_standard_exponential_inv,
|
||||
dtype
|
||||
)
|
||||
|
||||
dist_func, nb_dt = _get_proper_func(random_standard_exponential_f,
|
||||
random_standard_exponential,
|
||||
dtype)
|
||||
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, size=None, dtype=np.float64, method='zig'):
|
||||
if method == 'zig':
|
||||
return nb_dt(dist_func(inst.bit_generator))
|
||||
elif method == 'inv':
|
||||
return nb_dt(dist_func_inv(inst.bit_generator))
|
||||
else:
|
||||
raise ValueError("Method must be either 'zig' or 'inv'")
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, size=None, dtype=np.float64, method='zig'):
|
||||
out = np.empty(size, dtype=dtype)
|
||||
out_f = out.flat
|
||||
if method == 'zig':
|
||||
for i in range(out.size):
|
||||
out_f[i] = dist_func(inst.bit_generator)
|
||||
elif method == 'inv':
|
||||
for i in range(out.size):
|
||||
out_f[i] = dist_func_inv(inst.bit_generator)
|
||||
else:
|
||||
raise ValueError("Method must be either 'zig' or 'inv'")
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().standard_normal() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'standard_normal')
|
||||
def NumPyRandomGeneratorType_standard_normal(inst, size=None, dtype=np.float64):
|
||||
dist_func, nb_dt = _get_proper_func(random_standard_normal_f,
|
||||
random_standard_normal,
|
||||
dtype)
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, size=None, dtype=np.float64):
|
||||
return nb_dt(dist_func(inst.bit_generator))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, size=None, dtype=np.float64):
|
||||
out = np.empty(size, dtype=dtype)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = dist_func(inst.bit_generator)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().standard_gamma() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'standard_gamma')
|
||||
def NumPyRandomGeneratorType_standard_gamma(inst, shape, size=None,
|
||||
dtype=np.float64):
|
||||
check_types(shape, [types.Float, types.Integer, int, float], 'shape')
|
||||
dist_func, nb_dt = _get_proper_func(random_standard_gamma_f,
|
||||
random_standard_gamma,
|
||||
dtype)
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, shape, size=None, dtype=np.float64):
|
||||
return nb_dt(dist_func(inst.bit_generator, shape))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, shape, size=None, dtype=np.float64):
|
||||
out = np.empty(size, dtype=dtype)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = dist_func(inst.bit_generator, shape)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().normal() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'normal')
|
||||
def NumPyRandomGeneratorType_normal(inst, loc=0.0, scale=1.0,
|
||||
size=None):
|
||||
check_types(loc, [types.Float, types.Integer, int, float], 'loc')
|
||||
check_types(scale, [types.Float, types.Integer, int, float], 'scale')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, loc=0.0, scale=1.0, size=None):
|
||||
return random_normal(inst.bit_generator, loc, scale)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, loc=0.0, scale=1.0, size=None):
|
||||
out = np.empty(size, dtype=np.float64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_normal(inst.bit_generator, loc, scale)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().uniform() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'uniform')
|
||||
def NumPyRandomGeneratorType_uniform(inst, low=0.0, high=1.0,
|
||||
size=None):
|
||||
check_types(low, [types.Float, types.Integer, int, float], 'low')
|
||||
check_types(high, [types.Float, types.Integer, int, float], 'high')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, low=0.0, high=1.0, size=None):
|
||||
return random_uniform(inst.bit_generator, low, high - low)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, low=0.0, high=1.0, size=None):
|
||||
out = np.empty(size, dtype=np.float64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_uniform(inst.bit_generator, low, high - low)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().exponential() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'exponential')
|
||||
def NumPyRandomGeneratorType_exponential(inst, scale=1.0, size=None):
|
||||
check_types(scale, [types.Float, types.Integer, int, float], 'scale')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, scale=1.0, size=None):
|
||||
return random_exponential(inst.bit_generator, scale)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, scale=1.0, size=None):
|
||||
out = np.empty(size, dtype=np.float64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_exponential(inst.bit_generator, scale)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().gamma() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'gamma')
|
||||
def NumPyRandomGeneratorType_gamma(inst, shape, scale=1.0, size=None):
|
||||
check_types(shape, [types.Float, types.Integer, int, float], 'shape')
|
||||
check_types(scale, [types.Float, types.Integer, int, float], 'scale')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, shape, scale=1.0, size=None):
|
||||
return random_gamma(inst.bit_generator, shape, scale)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, shape, scale=1.0, size=None):
|
||||
out = np.empty(size, dtype=np.float64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_gamma(inst.bit_generator, shape, scale)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().beta() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'beta')
|
||||
def NumPyRandomGeneratorType_beta(inst, a, b, size=None):
|
||||
check_types(a, [types.Float, types.Integer, int, float], 'a')
|
||||
check_types(b, [types.Float, types.Integer, int, float], 'b')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, a, b, size=None):
|
||||
return random_beta(inst.bit_generator, a, b)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, a, b, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_beta(inst.bit_generator, a, b)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().f() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'f')
|
||||
def NumPyRandomGeneratorType_f(inst, dfnum, dfden, size=None):
|
||||
check_types(dfnum, [types.Float, types.Integer, int, float], 'dfnum')
|
||||
check_types(dfden, [types.Float, types.Integer, int, float], 'dfden')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, dfnum, dfden, size=None):
|
||||
return random_f(inst.bit_generator, dfnum, dfden)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, dfnum, dfden, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_f(inst.bit_generator, dfnum, dfden)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
# Overload the Generator().chisquare() method
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'chisquare')
|
||||
def NumPyRandomGeneratorType_chisquare(inst, df, size=None):
|
||||
check_types(df, [types.Float, types.Integer, int, float], 'df')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, df, size=None):
|
||||
return random_chisquare(inst.bit_generator, df)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, df, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_chisquare(inst.bit_generator, df)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'standard_cauchy')
|
||||
def NumPyRandomGeneratorType_standard_cauchy(inst, size=None):
|
||||
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, size=None):
|
||||
return random_standard_cauchy(inst.bit_generator)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_standard_cauchy(inst.bit_generator)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'pareto')
|
||||
def NumPyRandomGeneratorType_pareto(inst, a, size=None):
|
||||
check_types(a, [types.Float, types.Integer, int, float], 'a')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, a, size=None):
|
||||
return random_pareto(inst.bit_generator, a)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, a, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_pareto(inst.bit_generator, a)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'weibull')
|
||||
def NumPyRandomGeneratorType_weibull(inst, a, size=None):
|
||||
check_types(a, [types.Float, types.Integer, int, float], 'a')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, a, size=None):
|
||||
return random_weibull(inst.bit_generator, a)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, a, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_weibull(inst.bit_generator, a)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'power')
|
||||
def NumPyRandomGeneratorType_power(inst, a, size=None):
|
||||
check_types(a, [types.Float, types.Integer, int, float], 'a')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, a, size=None):
|
||||
return random_power(inst.bit_generator, a)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, a, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_power(inst.bit_generator, a)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'laplace')
|
||||
def NumPyRandomGeneratorType_laplace(inst, loc=0.0, scale=1.0, size=None):
|
||||
check_types(loc, [types.Float, types.Integer, int, float], 'loc')
|
||||
check_types(scale, [types.Float, types.Integer, int, float], 'scale')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, loc=0.0, scale=1.0, size=None):
|
||||
return random_laplace(inst.bit_generator, loc, scale)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, loc=0.0, scale=1.0, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_laplace(inst.bit_generator, loc, scale)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'logistic')
|
||||
def NumPyRandomGeneratorType_logistic(inst, loc=0.0, scale=1.0, size=None):
|
||||
check_types(loc, [types.Float, types.Integer, int, float], 'loc')
|
||||
check_types(scale, [types.Float, types.Integer, int, float], 'scale')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, loc=0.0, scale=1.0, size=None):
|
||||
return random_logistic(inst.bit_generator, loc, scale)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, loc=0.0, scale=1.0, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_logistic(inst.bit_generator, loc, scale)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'lognormal')
|
||||
def NumPyRandomGeneratorType_lognormal(inst, mean=0.0, sigma=1.0, size=None):
|
||||
check_types(mean, [types.Float, types.Integer, int, float], 'mean')
|
||||
check_types(sigma, [types.Float, types.Integer, int, float], 'sigma')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, mean=0.0, sigma=1.0, size=None):
|
||||
return random_lognormal(inst.bit_generator, mean, sigma)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, mean=0.0, sigma=1.0, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_lognormal(inst.bit_generator, mean, sigma)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'rayleigh')
|
||||
def NumPyRandomGeneratorType_rayleigh(inst, scale=1.0, size=None):
|
||||
check_types(scale, [types.Float, types.Integer, int, float], 'scale')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, scale=1.0, size=None):
|
||||
return random_rayleigh(inst.bit_generator, scale)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, scale=1.0, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_rayleigh(inst.bit_generator, scale)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'standard_t')
|
||||
def NumPyRandomGeneratorType_standard_t(inst, df, size=None):
|
||||
check_types(df, [types.Float, types.Integer, int, float], 'df')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, df, size=None):
|
||||
return random_standard_t(inst.bit_generator, df)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, df, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_standard_t(inst.bit_generator, df)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'wald')
|
||||
def NumPyRandomGeneratorType_wald(inst, mean, scale, size=None):
|
||||
check_types(mean, [types.Float, types.Integer, int, float], 'mean')
|
||||
check_types(scale, [types.Float, types.Integer, int, float], 'scale')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, mean, scale, size=None):
|
||||
return random_wald(inst.bit_generator, mean, scale)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, mean, scale, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_wald(inst.bit_generator, mean, scale)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'geometric')
|
||||
def NumPyRandomGeneratorType_geometric(inst, p, size=None):
|
||||
check_types(p, [types.Float, types.Integer, int, float], 'p')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, p, size=None):
|
||||
return np.int64(random_geometric(inst.bit_generator, p))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, p, size=None):
|
||||
out = np.empty(size, dtype=np.int64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_geometric(inst.bit_generator, p)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'zipf')
|
||||
def NumPyRandomGeneratorType_zipf(inst, a, size=None):
|
||||
check_types(a, [types.Float, types.Integer, int, float], 'a')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, a, size=None):
|
||||
return np.int64(random_zipf(inst.bit_generator, a))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, a, size=None):
|
||||
out = np.empty(size, dtype=np.int64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_zipf(inst.bit_generator, a)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'triangular')
|
||||
def NumPyRandomGeneratorType_triangular(inst, left, mode, right, size=None):
|
||||
check_types(left, [types.Float, types.Integer, int, float], 'left')
|
||||
check_types(mode, [types.Float, types.Integer, int, float], 'mode')
|
||||
check_types(right, [types.Float, types.Integer, int, float], 'right')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, left, mode, right, size=None):
|
||||
return random_triangular(inst.bit_generator, left, mode, right)
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, left, mode, right, size=None):
|
||||
out = np.empty(size)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_triangular(inst.bit_generator,
|
||||
left, mode, right)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'poisson')
|
||||
def NumPyRandomGeneratorType_poisson(inst, lam , size=None):
|
||||
check_types(lam, [types.Float, types.Integer, int, float], 'lam')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, lam , size=None):
|
||||
return np.int64(random_poisson(inst.bit_generator, lam))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, lam , size=None):
|
||||
out = np.empty(size, dtype=np.int64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_poisson(inst.bit_generator, lam)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'negative_binomial')
|
||||
def NumPyRandomGeneratorType_negative_binomial(inst, n, p, size=None):
|
||||
check_types(n, [types.Float, types.Integer, int, float], 'n')
|
||||
check_types(p, [types.Float, types.Integer, int, float], 'p')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, n, p , size=None):
|
||||
return np.int64(random_negative_binomial(inst.bit_generator, n, p))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, n, p , size=None):
|
||||
out = np.empty(size, dtype=np.int64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_negative_binomial(inst.bit_generator, n, p)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'noncentral_chisquare')
|
||||
def NumPyRandomGeneratorType_noncentral_chisquare(inst, df, nonc, size=None):
|
||||
check_types(df, [types.Float, types.Integer, int, float], 'df')
|
||||
check_types(nonc, [types.Float, types.Integer, int, float], 'nonc')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
@register_jitable
|
||||
def check_arg_bounds(df, nonc):
|
||||
if df <= 0:
|
||||
raise ValueError("df <= 0")
|
||||
if nonc < 0:
|
||||
raise ValueError("nonc < 0")
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, df, nonc, size=None):
|
||||
check_arg_bounds(df, nonc)
|
||||
return np.float64(random_noncentral_chisquare(inst.bit_generator,
|
||||
df, nonc))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, df, nonc, size=None):
|
||||
check_arg_bounds(df, nonc)
|
||||
out = np.empty(size, dtype=np.float64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_noncentral_chisquare(inst.bit_generator,
|
||||
df, nonc)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'noncentral_f')
|
||||
def NumPyRandomGeneratorType_noncentral_f(inst, dfnum, dfden, nonc, size=None):
|
||||
check_types(dfnum, [types.Float, types.Integer, int, float], 'dfnum')
|
||||
check_types(dfden, [types.Float, types.Integer, int, float], 'dfden')
|
||||
check_types(nonc, [types.Float, types.Integer, int, float], 'nonc')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
@register_jitable
|
||||
def check_arg_bounds(dfnum, dfden, nonc):
|
||||
if dfnum <= 0:
|
||||
raise ValueError("dfnum <= 0")
|
||||
if dfden <= 0:
|
||||
raise ValueError("dfden <= 0")
|
||||
if nonc < 0:
|
||||
raise ValueError("nonc < 0")
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, dfnum, dfden, nonc, size=None):
|
||||
check_arg_bounds(dfnum, dfden, nonc)
|
||||
return np.float64(random_noncentral_f(inst.bit_generator,
|
||||
dfnum, dfden, nonc))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, dfnum, dfden, nonc, size=None):
|
||||
check_arg_bounds(dfnum, dfden, nonc)
|
||||
out = np.empty(size, dtype=np.float64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_noncentral_f(inst.bit_generator,
|
||||
dfnum, dfden, nonc)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'logseries')
|
||||
def NumPyRandomGeneratorType_logseries(inst, p, size=None):
|
||||
check_types(p, [types.Float, types.Integer, int, float], 'p')
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
@register_jitable
|
||||
def check_arg_bounds(p):
|
||||
if p < 0 or p >= 1 or np.isnan(p):
|
||||
raise ValueError("p < 0, p >= 1 or p is NaN")
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, p, size=None):
|
||||
check_arg_bounds(p)
|
||||
return np.int64(random_logseries(inst.bit_generator, p))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, p, size=None):
|
||||
check_arg_bounds(p)
|
||||
out = np.empty(size, dtype=np.int64)
|
||||
out_f = out.flat
|
||||
for i in range(out.size):
|
||||
out_f[i] = random_logseries(inst.bit_generator, p)
|
||||
return out
|
||||
return impl
|
||||
|
||||
|
||||
@overload_method(types.NumPyRandomGeneratorType, 'binomial')
|
||||
def NumPyRandomGeneratorType_binomial(inst, n, p, size=None):
|
||||
check_types(n, [types.Float, types.Integer, int, float], 'n')
|
||||
check_types(p, [types.Float, types.Integer, int, float], 'p')
|
||||
|
||||
if isinstance(size, types.Omitted):
|
||||
size = size.value
|
||||
|
||||
if is_nonelike(size):
|
||||
def impl(inst, n, p, size=None):
|
||||
return np.int64(random_binomial(inst.bit_generator, n, p))
|
||||
return impl
|
||||
else:
|
||||
check_size(size)
|
||||
|
||||
def impl(inst, n, p, size=None):
|
||||
out = np.empty(size, dtype=np.int64)
|
||||
for i in np.ndindex(size):
|
||||
out[i] = random_binomial(inst.bit_generator, n, p)
|
||||
return out
|
||||
return impl
|
||||
@@ -0,0 +1,365 @@
|
||||
import numpy as np
|
||||
|
||||
from numba import uint64, uint32, uint16, uint8
|
||||
from numba.core.extending import register_jitable
|
||||
|
||||
from numba.np.random._constants import (UINT32_MAX, UINT64_MAX,
|
||||
UINT16_MAX, UINT8_MAX)
|
||||
from numba.np.random.generator_core import next_uint32, next_uint64
|
||||
|
||||
# All following implementations are direct translations from:
|
||||
# https://github.com/numpy/numpy/blob/7cfef93c77599bd387ecc6a15d186c5a46024dac/numpy/random/src/distributions/distributions.c
|
||||
|
||||
|
||||
@register_jitable
|
||||
def gen_mask(max):
|
||||
mask = uint64(max)
|
||||
mask |= mask >> 1
|
||||
mask |= mask >> 2
|
||||
mask |= mask >> 4
|
||||
mask |= mask >> 8
|
||||
mask |= mask >> 16
|
||||
mask |= mask >> 32
|
||||
return mask
|
||||
|
||||
|
||||
@register_jitable
|
||||
def buffered_bounded_bool(bitgen, off, rng, bcnt, buf):
|
||||
if (rng == 0):
|
||||
return off, bcnt, buf
|
||||
if not bcnt:
|
||||
buf = next_uint32(bitgen)
|
||||
bcnt = 31
|
||||
else:
|
||||
buf >>= 1
|
||||
bcnt -= 1
|
||||
|
||||
return ((buf & 1) != 0), bcnt, buf
|
||||
|
||||
|
||||
@register_jitable
|
||||
def buffered_uint8(bitgen, bcnt, buf):
|
||||
if not bcnt:
|
||||
buf = next_uint32(bitgen)
|
||||
bcnt = 3
|
||||
else:
|
||||
buf >>= 8
|
||||
bcnt -= 1
|
||||
|
||||
return uint8(buf), bcnt, buf
|
||||
|
||||
|
||||
@register_jitable
|
||||
def buffered_uint16(bitgen, bcnt, buf):
|
||||
if not bcnt:
|
||||
buf = next_uint32(bitgen)
|
||||
bcnt = 1
|
||||
else:
|
||||
buf >>= 16
|
||||
bcnt -= 1
|
||||
|
||||
return uint16(buf), bcnt, buf
|
||||
|
||||
|
||||
# The following implementations use Lemire's algorithm:
|
||||
# https://arxiv.org/abs/1805.10941
|
||||
@register_jitable
|
||||
def buffered_bounded_lemire_uint8(bitgen, rng, bcnt, buf):
|
||||
"""
|
||||
Generates a random unsigned 8 bit integer bounded
|
||||
within a given interval using Lemire's rejection.
|
||||
|
||||
The buffer acts as storage for a 32 bit integer
|
||||
drawn from the associated BitGenerator so that
|
||||
multiple integers of smaller bitsize can be generated
|
||||
from a single draw of the BitGenerator.
|
||||
"""
|
||||
# Note: `rng` should not be 0xFF. When this happens `rng_excl` becomes
|
||||
# zero.
|
||||
rng_excl = uint8(rng) + uint8(1)
|
||||
|
||||
assert (rng != 0xFF)
|
||||
|
||||
# Generate a scaled random number.
|
||||
n, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
|
||||
m = uint16(n * rng_excl)
|
||||
|
||||
# Rejection sampling to remove any bias
|
||||
leftover = m & 0xFF
|
||||
|
||||
if (leftover < rng_excl):
|
||||
# `rng_excl` is a simple upper bound for `threshold`.
|
||||
threshold = ((uint8(UINT8_MAX) - rng) % rng_excl)
|
||||
|
||||
while (leftover < threshold):
|
||||
n, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
|
||||
m = uint16(n * rng_excl)
|
||||
leftover = m & 0xFF
|
||||
|
||||
return m >> 8, bcnt, buf
|
||||
|
||||
|
||||
@register_jitable
|
||||
def buffered_bounded_lemire_uint16(bitgen, rng, bcnt, buf):
|
||||
"""
|
||||
Generates a random unsigned 16 bit integer bounded
|
||||
within a given interval using Lemire's rejection.
|
||||
|
||||
The buffer acts as storage for a 32 bit integer
|
||||
drawn from the associated BitGenerator so that
|
||||
multiple integers of smaller bitsize can be generated
|
||||
from a single draw of the BitGenerator.
|
||||
"""
|
||||
# Note: `rng` should not be 0xFFFF. When this happens `rng_excl` becomes
|
||||
# zero.
|
||||
rng_excl = uint16(rng) + uint16(1)
|
||||
|
||||
assert (rng != 0xFFFF)
|
||||
|
||||
# Generate a scaled random number.
|
||||
n, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
|
||||
m = uint32(n * rng_excl)
|
||||
|
||||
# Rejection sampling to remove any bias
|
||||
leftover = m & 0xFFFF
|
||||
|
||||
if (leftover < rng_excl):
|
||||
# `rng_excl` is a simple upper bound for `threshold`.
|
||||
threshold = ((uint16(UINT16_MAX) - rng) % rng_excl)
|
||||
|
||||
while (leftover < threshold):
|
||||
n, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
|
||||
m = uint32(n * rng_excl)
|
||||
leftover = m & 0xFFFF
|
||||
|
||||
return m >> 16, bcnt, buf
|
||||
|
||||
|
||||
@register_jitable
|
||||
def buffered_bounded_lemire_uint32(bitgen, rng):
|
||||
"""
|
||||
Generates a random unsigned 32 bit integer bounded
|
||||
within a given interval using Lemire's rejection.
|
||||
"""
|
||||
rng_excl = uint32(rng) + uint32(1)
|
||||
|
||||
assert (rng != 0xFFFFFFFF)
|
||||
|
||||
# Generate a scaled random number.
|
||||
m = uint64(next_uint32(bitgen)) * uint64(rng_excl)
|
||||
|
||||
# Rejection sampling to remove any bias
|
||||
leftover = m & 0xFFFFFFFF
|
||||
|
||||
if (leftover < rng_excl):
|
||||
# `rng_excl` is a simple upper bound for `threshold`.
|
||||
threshold = (UINT32_MAX - rng) % rng_excl
|
||||
|
||||
while (leftover < threshold):
|
||||
m = uint64(next_uint32(bitgen)) * uint64(rng_excl)
|
||||
leftover = m & 0xFFFFFFFF
|
||||
|
||||
return (m >> 32)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def bounded_lemire_uint64(bitgen, rng):
|
||||
"""
|
||||
Generates a random unsigned 64 bit integer bounded
|
||||
within a given interval using Lemire's rejection.
|
||||
"""
|
||||
rng_excl = uint64(rng) + uint64(1)
|
||||
|
||||
assert (rng != 0xFFFFFFFFFFFFFFFF)
|
||||
|
||||
x = next_uint64(bitgen)
|
||||
|
||||
leftover = uint64(x) * uint64(rng_excl)
|
||||
|
||||
if (leftover < rng_excl):
|
||||
threshold = (UINT64_MAX - rng) % rng_excl
|
||||
|
||||
while (leftover < threshold):
|
||||
x = next_uint64(bitgen)
|
||||
leftover = uint64(x) * uint64(rng_excl)
|
||||
|
||||
x0 = x & uint64(0xFFFFFFFF)
|
||||
x1 = x >> 32
|
||||
rng_excl0 = rng_excl & uint64(0xFFFFFFFF)
|
||||
rng_excl1 = rng_excl >> 32
|
||||
w0 = x0 * rng_excl0
|
||||
t = x1 * rng_excl0 + (w0 >> 32)
|
||||
w1 = t & uint64(0xFFFFFFFF)
|
||||
w2 = t >> 32
|
||||
w1 += x0 * rng_excl1
|
||||
m1 = x1 * rng_excl1 + w2 + (w1 >> 32)
|
||||
|
||||
return m1
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_bounded_uint64_fill(bitgen, low, rng, size, dtype):
|
||||
"""
|
||||
Returns a new array of given size with 64 bit integers
|
||||
bounded by given interval.
|
||||
"""
|
||||
out = np.empty(size, dtype=dtype)
|
||||
if rng == 0:
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low
|
||||
elif rng <= 0xFFFFFFFF:
|
||||
if (rng == 0xFFFFFFFF):
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low + next_uint32(bitgen)
|
||||
else:
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low + buffered_bounded_lemire_uint32(bitgen, rng)
|
||||
|
||||
elif (rng == 0xFFFFFFFFFFFFFFFF):
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low + next_uint64(bitgen)
|
||||
else:
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low + bounded_lemire_uint64(bitgen, rng)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_bounded_uint32_fill(bitgen, low, rng, size, dtype):
|
||||
"""
|
||||
Returns a new array of given size with 32 bit integers
|
||||
bounded by given interval.
|
||||
"""
|
||||
out = np.empty(size, dtype=dtype)
|
||||
if rng == 0:
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low
|
||||
elif rng == 0xFFFFFFFF:
|
||||
# Lemire32 doesn't support rng = 0xFFFFFFFF.
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low + next_uint32(bitgen)
|
||||
else:
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low + buffered_bounded_lemire_uint32(bitgen, rng)
|
||||
return out
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_bounded_uint16_fill(bitgen, low, rng, size, dtype):
|
||||
"""
|
||||
Returns a new array of given size with 16 bit integers
|
||||
bounded by given interval.
|
||||
"""
|
||||
buf = 0
|
||||
bcnt = 0
|
||||
|
||||
out = np.empty(size, dtype=dtype)
|
||||
if rng == 0:
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low
|
||||
elif rng == 0xFFFF:
|
||||
# Lemire16 doesn't support rng = 0xFFFF.
|
||||
for i in np.ndindex(size):
|
||||
val, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
|
||||
out[i] = low + val
|
||||
|
||||
else:
|
||||
for i in np.ndindex(size):
|
||||
val, bcnt, buf = \
|
||||
buffered_bounded_lemire_uint16(bitgen, rng,
|
||||
bcnt, buf)
|
||||
out[i] = low + val
|
||||
return out
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_bounded_uint8_fill(bitgen, low, rng, size, dtype):
|
||||
"""
|
||||
Returns a new array of given size with 8 bit integers
|
||||
bounded by given interval.
|
||||
"""
|
||||
buf = 0
|
||||
bcnt = 0
|
||||
|
||||
out = np.empty(size, dtype=dtype)
|
||||
if rng == 0:
|
||||
for i in np.ndindex(size):
|
||||
out[i] = low
|
||||
elif rng == 0xFF:
|
||||
# Lemire8 doesn't support rng = 0xFF.
|
||||
for i in np.ndindex(size):
|
||||
val, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
|
||||
out[i] = low + val
|
||||
else:
|
||||
for i in np.ndindex(size):
|
||||
val, bcnt, buf = \
|
||||
buffered_bounded_lemire_uint8(bitgen, rng,
|
||||
bcnt, buf)
|
||||
out[i] = low + val
|
||||
return out
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_bounded_bool_fill(bitgen, low, rng, size, dtype):
|
||||
"""
|
||||
Returns a new array of given size with boolean values.
|
||||
"""
|
||||
buf = 0
|
||||
bcnt = 0
|
||||
out = np.empty(size, dtype=dtype)
|
||||
for i in np.ndindex(size):
|
||||
val, bcnt, buf = buffered_bounded_bool(bitgen, low, rng, bcnt, buf)
|
||||
out[i] = low + val
|
||||
return out
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _randint_arg_check(low, high, endpoint, lower_bound, upper_bound):
|
||||
"""
|
||||
Check that low and high are within the bounds
|
||||
for the given datatype.
|
||||
"""
|
||||
|
||||
if low < lower_bound:
|
||||
raise ValueError("low is out of bounds")
|
||||
|
||||
# This is being done to avoid high being accidentally
|
||||
# casted to int64/32 while subtracting 1 before
|
||||
# checking bounds, avoids overflow.
|
||||
if high > 0:
|
||||
high = uint64(high)
|
||||
if not endpoint:
|
||||
high -= uint64(1)
|
||||
upper_bound = uint64(upper_bound)
|
||||
if low > 0:
|
||||
low = uint64(low)
|
||||
if high > upper_bound:
|
||||
raise ValueError("high is out of bounds")
|
||||
if low > high: # -1 already subtracted, closed interval
|
||||
raise ValueError("low is greater than high in given interval")
|
||||
else:
|
||||
if high > upper_bound:
|
||||
raise ValueError("high is out of bounds")
|
||||
if low > high: # -1 already subtracted, closed interval
|
||||
raise ValueError("low is greater than high in given interval")
|
||||
|
||||
|
||||
@register_jitable
|
||||
def random_interval(bitgen, max_val):
|
||||
if (max_val == 0):
|
||||
return 0
|
||||
|
||||
max_val = uint64(max_val)
|
||||
mask = uint64(gen_mask(max_val))
|
||||
|
||||
if (max_val <= 0xffffffff):
|
||||
value = uint64(next_uint32(bitgen)) & mask
|
||||
while value > max_val:
|
||||
value = uint64(next_uint32(bitgen)) & mask
|
||||
else:
|
||||
value = next_uint64(bitgen) & mask
|
||||
while value > max_val:
|
||||
value = next_uint64(bitgen) & mask
|
||||
|
||||
return uint64(value)
|
||||
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from numba.np.ufunc.decorators import Vectorize, GUVectorize, vectorize, guvectorize
|
||||
from numba.np.ufunc._internal import PyUFunc_None, PyUFunc_Zero, PyUFunc_One
|
||||
from numba.np.ufunc import _internal, array_exprs
|
||||
from numba.np.ufunc.parallel import (threading_layer, get_num_threads,
|
||||
set_num_threads, get_thread_id,
|
||||
set_parallel_chunksize,
|
||||
get_parallel_chunksize)
|
||||
|
||||
|
||||
if hasattr(_internal, 'PyUFunc_ReorderableNone'):
|
||||
PyUFunc_ReorderableNone = _internal.PyUFunc_ReorderableNone
|
||||
del _internal, array_exprs
|
||||
|
||||
|
||||
def _init():
|
||||
|
||||
def init_cuda_vectorize():
|
||||
from numba.cuda.vectorizers import CUDAVectorize
|
||||
return CUDAVectorize
|
||||
|
||||
def init_cuda_guvectorize():
|
||||
from numba.cuda.vectorizers import CUDAGUFuncVectorize
|
||||
return CUDAGUFuncVectorize
|
||||
|
||||
Vectorize.target_registry.ondemand['cuda'] = init_cuda_vectorize
|
||||
GUVectorize.target_registry.ondemand['cuda'] = init_cuda_guvectorize
|
||||
|
||||
|
||||
_init()
|
||||
del _init
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,428 @@
|
||||
import ast
|
||||
from collections import defaultdict, OrderedDict
|
||||
import contextlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
from numba.core import types, targetconfig, ir, rewrites, compiler
|
||||
from numba.core.typing import npydecl
|
||||
from numba.np.ufunc.dufunc import DUFunc
|
||||
|
||||
|
||||
def _is_ufunc(func):
|
||||
return isinstance(func, (np.ufunc, DUFunc))
|
||||
|
||||
|
||||
@rewrites.register_rewrite('after-inference')
|
||||
class RewriteArrayExprs(rewrites.Rewrite):
|
||||
'''The RewriteArrayExprs class is responsible for finding array
|
||||
expressions in Numba intermediate representation code, and
|
||||
rewriting those expressions to a single operation that will expand
|
||||
into something similar to a ufunc call.
|
||||
'''
|
||||
def __init__(self, state, *args, **kws):
|
||||
super(RewriteArrayExprs, self).__init__(state, *args, **kws)
|
||||
# Install a lowering hook if we are using this rewrite.
|
||||
special_ops = state.targetctx.special_ops
|
||||
if 'arrayexpr' not in special_ops:
|
||||
special_ops['arrayexpr'] = _lower_array_expr
|
||||
|
||||
def match(self, func_ir, block, typemap, calltypes):
|
||||
"""
|
||||
Using typing and a basic block, search the basic block for array
|
||||
expressions.
|
||||
Return True when one or more matches were found, False otherwise.
|
||||
"""
|
||||
# We can trivially reject everything if there are no
|
||||
# calls in the type results.
|
||||
if len(calltypes) == 0:
|
||||
return False
|
||||
|
||||
self.crnt_block = block
|
||||
self.typemap = typemap
|
||||
# { variable name: IR assignment (of a function call or operator) }
|
||||
self.array_assigns = OrderedDict()
|
||||
# { variable name: IR assignment (of a constant) }
|
||||
self.const_assigns = {}
|
||||
|
||||
assignments = block.find_insts(ir.Assign)
|
||||
for instr in assignments:
|
||||
target_name = instr.target.name
|
||||
expr = instr.value
|
||||
# Does it assign an expression to an array variable?
|
||||
if (isinstance(expr, ir.Expr) and
|
||||
isinstance(typemap.get(target_name, None), types.Array)):
|
||||
self._match_array_expr(instr, expr, target_name)
|
||||
elif isinstance(expr, ir.Const):
|
||||
# Track constants since we might need them for an
|
||||
# array expression.
|
||||
self.const_assigns[target_name] = expr
|
||||
|
||||
return len(self.array_assigns) > 0
|
||||
|
||||
def _match_array_expr(self, instr, expr, target_name):
|
||||
"""
|
||||
Find whether the given assignment (*instr*) of an expression (*expr*)
|
||||
to variable *target_name* is an array expression.
|
||||
"""
|
||||
# We've matched a subexpression assignment to an
|
||||
# array variable. Now see if the expression is an
|
||||
# array expression.
|
||||
expr_op = expr.op
|
||||
array_assigns = self.array_assigns
|
||||
|
||||
if ((expr_op in ('unary', 'binop')) and (
|
||||
expr.fn in npydecl.supported_array_operators)):
|
||||
# It is an array operator that maps to a ufunc.
|
||||
# check that all args have internal types
|
||||
if all(self.typemap[var.name].is_internal
|
||||
for var in expr.list_vars()):
|
||||
array_assigns[target_name] = instr
|
||||
|
||||
elif ((expr_op == 'call') and (expr.func.name in self.typemap)):
|
||||
# It could be a match for a known ufunc call.
|
||||
func_type = self.typemap[expr.func.name]
|
||||
if isinstance(func_type, types.Function):
|
||||
func_key = func_type.typing_key
|
||||
if _is_ufunc(func_key):
|
||||
# If so, check whether an explicit output is passed.
|
||||
if not self._has_explicit_output(expr, func_key):
|
||||
# If not, match it as a (sub)expression.
|
||||
array_assigns[target_name] = instr
|
||||
|
||||
def _has_explicit_output(self, expr, func):
|
||||
"""
|
||||
Return whether the *expr* call to *func* (a ufunc) features an
|
||||
explicit output argument.
|
||||
"""
|
||||
nargs = len(expr.args) + len(expr.kws)
|
||||
if expr.vararg is not None:
|
||||
# XXX *args unsupported here, assume there may be an explicit
|
||||
# output
|
||||
return True
|
||||
return nargs > func.nin
|
||||
|
||||
def _get_array_operator(self, ir_expr):
|
||||
ir_op = ir_expr.op
|
||||
if ir_op in ('unary', 'binop'):
|
||||
return ir_expr.fn
|
||||
elif ir_op == 'call':
|
||||
return self.typemap[ir_expr.func.name].typing_key
|
||||
raise NotImplementedError(
|
||||
"Don't know how to find the operator for '{0}' expressions.".format(
|
||||
ir_op))
|
||||
|
||||
def _get_operands(self, ir_expr):
|
||||
'''Given a Numba IR expression, return the operands to the expression
|
||||
in order they appear in the expression.
|
||||
'''
|
||||
ir_op = ir_expr.op
|
||||
if ir_op == 'binop':
|
||||
return ir_expr.lhs, ir_expr.rhs
|
||||
elif ir_op == 'unary':
|
||||
return ir_expr.list_vars()
|
||||
elif ir_op == 'call':
|
||||
return ir_expr.args
|
||||
raise NotImplementedError(
|
||||
"Don't know how to find the operands for '{0}' expressions.".format(
|
||||
ir_op))
|
||||
|
||||
def _translate_expr(self, ir_expr):
|
||||
'''Translate the given expression from Numba IR to an array expression
|
||||
tree.
|
||||
'''
|
||||
ir_op = ir_expr.op
|
||||
if ir_op == 'arrayexpr':
|
||||
return ir_expr.expr
|
||||
operands_or_args = [self.const_assigns.get(op_var.name, op_var)
|
||||
for op_var in self._get_operands(ir_expr)]
|
||||
return self._get_array_operator(ir_expr), operands_or_args
|
||||
|
||||
def _handle_matches(self):
|
||||
'''Iterate over the matches, trying to find which instructions should
|
||||
be rewritten, deleted, or moved.
|
||||
'''
|
||||
replace_map = {}
|
||||
dead_vars = set()
|
||||
used_vars = defaultdict(int)
|
||||
for instr in self.array_assigns.values():
|
||||
expr = instr.value
|
||||
arr_inps = []
|
||||
arr_expr = self._get_array_operator(expr), arr_inps
|
||||
new_expr = ir.Expr(op='arrayexpr',
|
||||
loc=expr.loc,
|
||||
expr=arr_expr,
|
||||
ty=self.typemap[instr.target.name])
|
||||
new_instr = ir.Assign(new_expr, instr.target, instr.loc)
|
||||
replace_map[instr] = new_instr
|
||||
self.array_assigns[instr.target.name] = new_instr
|
||||
for operand in self._get_operands(expr):
|
||||
operand_name = operand.name
|
||||
if operand.is_temp and operand_name in self.array_assigns:
|
||||
child_assign = self.array_assigns[operand_name]
|
||||
child_expr = child_assign.value
|
||||
child_operands = child_expr.list_vars()
|
||||
for operand in child_operands:
|
||||
used_vars[operand.name] += 1
|
||||
arr_inps.append(self._translate_expr(child_expr))
|
||||
if child_assign.target.is_temp:
|
||||
dead_vars.add(child_assign.target.name)
|
||||
replace_map[child_assign] = None
|
||||
elif operand_name in self.const_assigns:
|
||||
arr_inps.append(self.const_assigns[operand_name])
|
||||
else:
|
||||
used_vars[operand.name] += 1
|
||||
arr_inps.append(operand)
|
||||
return replace_map, dead_vars, used_vars
|
||||
|
||||
def _get_final_replacement(self, replacement_map, instr):
|
||||
'''Find the final replacement instruction for a given initial
|
||||
instruction by chasing instructions in a map from instructions
|
||||
to replacement instructions.
|
||||
'''
|
||||
replacement = replacement_map[instr]
|
||||
while replacement in replacement_map:
|
||||
replacement = replacement_map[replacement]
|
||||
return replacement
|
||||
|
||||
def apply(self):
|
||||
'''When we've found array expressions in a basic block, rewrite that
|
||||
block, returning a new, transformed block.
|
||||
'''
|
||||
# Part 1: Figure out what instructions should be rewritten
|
||||
# based on the matches found.
|
||||
replace_map, dead_vars, used_vars = self._handle_matches()
|
||||
# Part 2: Using the information above, rewrite the target
|
||||
# basic block.
|
||||
result = self.crnt_block.copy()
|
||||
result.clear()
|
||||
delete_map = {}
|
||||
for instr in self.crnt_block.body:
|
||||
if isinstance(instr, ir.Assign):
|
||||
if instr in replace_map:
|
||||
replacement = self._get_final_replacement(
|
||||
replace_map, instr)
|
||||
if replacement:
|
||||
result.append(replacement)
|
||||
for var in replacement.value.list_vars():
|
||||
var_name = var.name
|
||||
if var_name in delete_map:
|
||||
result.append(delete_map.pop(var_name))
|
||||
if used_vars[var_name] > 0:
|
||||
used_vars[var_name] -= 1
|
||||
|
||||
else:
|
||||
result.append(instr)
|
||||
elif isinstance(instr, ir.Del):
|
||||
instr_value = instr.value
|
||||
if used_vars[instr_value] > 0:
|
||||
used_vars[instr_value] -= 1
|
||||
delete_map[instr_value] = instr
|
||||
elif instr_value not in dead_vars:
|
||||
result.append(instr)
|
||||
else:
|
||||
result.append(instr)
|
||||
if delete_map:
|
||||
for instr in delete_map.values():
|
||||
result.insert_before_terminator(instr)
|
||||
return result
|
||||
|
||||
|
||||
_unaryops = {
|
||||
operator.pos: ast.UAdd,
|
||||
operator.neg: ast.USub,
|
||||
operator.invert: ast.Invert,
|
||||
}
|
||||
|
||||
_binops = {
|
||||
operator.add: ast.Add,
|
||||
operator.sub: ast.Sub,
|
||||
operator.mul: ast.Mult,
|
||||
operator.truediv: ast.Div,
|
||||
operator.mod: ast.Mod,
|
||||
operator.or_: ast.BitOr,
|
||||
operator.rshift: ast.RShift,
|
||||
operator.xor: ast.BitXor,
|
||||
operator.lshift: ast.LShift,
|
||||
operator.and_: ast.BitAnd,
|
||||
operator.pow: ast.Pow,
|
||||
operator.floordiv: ast.FloorDiv,
|
||||
}
|
||||
|
||||
|
||||
_cmpops = {
|
||||
operator.eq: ast.Eq,
|
||||
operator.ne: ast.NotEq,
|
||||
operator.lt: ast.Lt,
|
||||
operator.le: ast.LtE,
|
||||
operator.gt: ast.Gt,
|
||||
operator.ge: ast.GtE,
|
||||
}
|
||||
|
||||
|
||||
def _arr_expr_to_ast(expr):
|
||||
'''Build a Python expression AST from an array expression built by
|
||||
RewriteArrayExprs.
|
||||
'''
|
||||
if isinstance(expr, tuple):
|
||||
op, arr_expr_args = expr
|
||||
ast_args = []
|
||||
env = {}
|
||||
for arg in arr_expr_args:
|
||||
ast_arg, child_env = _arr_expr_to_ast(arg)
|
||||
ast_args.append(ast_arg)
|
||||
env.update(child_env)
|
||||
if op in npydecl.supported_array_operators:
|
||||
if len(ast_args) == 2:
|
||||
if op in _binops:
|
||||
return ast.BinOp(
|
||||
ast_args[0], _binops[op](), ast_args[1]), env
|
||||
if op in _cmpops:
|
||||
return ast.Compare(
|
||||
ast_args[0], [_cmpops[op]()], [ast_args[1]]), env
|
||||
else:
|
||||
assert op in _unaryops
|
||||
return ast.UnaryOp(_unaryops[op](), ast_args[0]), env
|
||||
elif _is_ufunc(op):
|
||||
fn_name = "__ufunc_or_dufunc_{0}".format(
|
||||
hex(hash(op)).replace("-", "_"))
|
||||
fn_ast_name = ast.Name(fn_name, ast.Load())
|
||||
env[fn_name] = op # Stash the ufunc or DUFunc in the environment
|
||||
ast_call = ast.Call(fn_ast_name, ast_args, [])
|
||||
return ast_call, env
|
||||
elif isinstance(expr, ir.Var):
|
||||
return ast.Name(expr.name, ast.Load(),
|
||||
lineno=expr.loc.line,
|
||||
col_offset=expr.loc.col if expr.loc.col else 0), {}
|
||||
elif isinstance(expr, ir.Const):
|
||||
return ast.Constant(expr.value), {}
|
||||
raise NotImplementedError(
|
||||
"Don't know how to translate array expression '%r'" % (expr,))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _legalize_parameter_names(var_list):
|
||||
"""
|
||||
Legalize names in the variable list for use as a Python function's
|
||||
parameter names.
|
||||
"""
|
||||
var_map = OrderedDict()
|
||||
for var in var_list:
|
||||
old_name = var.name
|
||||
new_name = var.scope.redefine(old_name, loc=var.loc).name
|
||||
new_name = new_name.replace("$", "_").replace(".", "_")
|
||||
# Caller should ensure the names are unique
|
||||
if new_name in var_map:
|
||||
raise AssertionError(f"{new_name!r} not unique")
|
||||
var_map[new_name] = var, old_name
|
||||
var.name = new_name
|
||||
param_names = list(var_map)
|
||||
try:
|
||||
yield param_names
|
||||
finally:
|
||||
# Make sure the old names are restored, to avoid confusing
|
||||
# other parts of Numba (see issue #1466)
|
||||
for var, old_name in var_map.values():
|
||||
var.name = old_name
|
||||
|
||||
|
||||
class _EraseInvalidLineRanges(ast.NodeTransformer):
|
||||
def generic_visit(self, node: ast.AST) -> ast.AST:
|
||||
node = super().generic_visit(node)
|
||||
if hasattr(node, "lineno"):
|
||||
if getattr(node, "end_lineno", None) is not None:
|
||||
if node.lineno > node.end_lineno:
|
||||
del node.lineno
|
||||
del node.end_lineno
|
||||
return node
|
||||
|
||||
|
||||
def _fix_invalid_lineno_ranges(astree: ast.AST):
|
||||
"""Inplace fixes invalid lineno ranges.
|
||||
"""
|
||||
# Make sure lineno and end_lineno are present
|
||||
ast.fix_missing_locations(astree)
|
||||
# Delete invalid lineno ranges
|
||||
_EraseInvalidLineRanges().visit(astree)
|
||||
# Make sure lineno and end_lineno are present
|
||||
ast.fix_missing_locations(astree)
|
||||
|
||||
|
||||
def _lower_array_expr(lowerer, expr):
|
||||
'''Lower an array expression built by RewriteArrayExprs.
|
||||
'''
|
||||
expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
|
||||
expr_filename = expr.loc.filename
|
||||
expr_var_list = expr.list_vars()
|
||||
# The expression may use a given variable several times, but we
|
||||
# should only create one parameter for it.
|
||||
expr_var_unique = sorted(set(expr_var_list), key=lambda var: var.name)
|
||||
|
||||
# Arguments are the names external to the new closure
|
||||
expr_args = [var.name for var in expr_var_unique]
|
||||
|
||||
# 1. Create an AST tree from the array expression.
|
||||
with _legalize_parameter_names(expr_var_unique) as expr_params:
|
||||
ast_args = [ast.arg(param_name, None)
|
||||
for param_name in expr_params]
|
||||
# Parse a stub function to ensure the AST is populated with
|
||||
# reasonable defaults for the Python version.
|
||||
ast_module = ast.parse('def {0}(): return'.format(expr_name),
|
||||
expr_filename, 'exec')
|
||||
assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
|
||||
ast_fn = ast_module.body[0]
|
||||
ast_fn.args.args = ast_args
|
||||
ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
|
||||
_fix_invalid_lineno_ranges(ast_module)
|
||||
|
||||
# 2. Compile the AST module and extract the Python function.
|
||||
code_obj = compile(ast_module, expr_filename, 'exec')
|
||||
exec(code_obj, namespace)
|
||||
impl = namespace[expr_name]
|
||||
|
||||
# 3. Now compile a ufunc using the Python function as kernel.
|
||||
|
||||
context = lowerer.context
|
||||
builder = lowerer.builder
|
||||
outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_args))
|
||||
inner_sig_args = []
|
||||
for argty in outer_sig.args:
|
||||
if isinstance(argty, types.Optional):
|
||||
argty = argty.type
|
||||
if isinstance(argty, types.Array):
|
||||
inner_sig_args.append(argty.dtype)
|
||||
else:
|
||||
inner_sig_args.append(argty)
|
||||
inner_sig = outer_sig.return_type.dtype(*inner_sig_args)
|
||||
|
||||
flags = targetconfig.ConfigStack().top_or_none()
|
||||
flags = compiler.Flags() if flags is None else flags.copy() # make sure it's a clone or a fresh instance
|
||||
# Follow the Numpy error model. Note this also allows e.g. vectorizing
|
||||
# division (issue #1223).
|
||||
flags.error_model = 'numpy'
|
||||
cres = context.compile_subroutine(builder, impl, inner_sig, flags=flags,
|
||||
caching=False)
|
||||
|
||||
# Create kernel subclass calling our native function
|
||||
from numba.np import npyimpl
|
||||
|
||||
class ExprKernel(npyimpl._Kernel):
|
||||
def generate(self, *args):
|
||||
arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
|
||||
cast_args = [self.cast(val, inty, outty)
|
||||
for val, inty, outty in arg_zip]
|
||||
result = self.context.call_internal(
|
||||
builder, cres.fndesc, inner_sig, cast_args)
|
||||
return self.cast(result, inner_sig.return_type,
|
||||
self.outer_sig.return_type)
|
||||
|
||||
# create a fake ufunc object which is enough to trick numpy_ufunc_kernel
|
||||
ufunc = SimpleNamespace(nin=len(expr_args), nout=1, __name__=expr_name)
|
||||
ufunc.nargs = ufunc.nin + ufunc.nout
|
||||
|
||||
args = [lowerer.loadvar(name) for name in expr_args]
|
||||
return npyimpl.numpy_ufunc_kernel(
|
||||
context, builder, outer_sig, args, ufunc, ExprKernel)
|
||||
@@ -0,0 +1,208 @@
|
||||
import inspect
|
||||
|
||||
from numba.np.ufunc import _internal
|
||||
from numba.np.ufunc.parallel import ParallelUFuncBuilder, ParallelGUFuncBuilder
|
||||
|
||||
from numba.core.registry import DelayedRegistry
|
||||
from numba.np.ufunc import dufunc
|
||||
from numba.np.ufunc import gufunc
|
||||
|
||||
|
||||
class _BaseVectorize(object):
|
||||
|
||||
@classmethod
|
||||
def get_identity(cls, kwargs):
|
||||
return kwargs.pop('identity', None)
|
||||
|
||||
@classmethod
|
||||
def get_cache(cls, kwargs):
|
||||
return kwargs.pop('cache', False)
|
||||
|
||||
@classmethod
|
||||
def get_writable_args(cls, kwargs):
|
||||
return kwargs.pop('writable_args', ())
|
||||
|
||||
@classmethod
|
||||
def get_target_implementation(cls, kwargs):
|
||||
target = kwargs.pop('target', 'cpu')
|
||||
try:
|
||||
return cls.target_registry[target]
|
||||
except KeyError:
|
||||
raise ValueError("Unsupported target: %s" % target)
|
||||
|
||||
|
||||
class Vectorize(_BaseVectorize):
|
||||
target_registry = DelayedRegistry({'cpu': dufunc.DUFunc,
|
||||
'parallel': ParallelUFuncBuilder,})
|
||||
|
||||
def __new__(cls, func, **kws):
|
||||
identity = cls.get_identity(kws)
|
||||
cache = cls.get_cache(kws)
|
||||
imp = cls.get_target_implementation(kws)
|
||||
return imp(func, identity=identity, cache=cache, targetoptions=kws)
|
||||
|
||||
|
||||
class GUVectorize(_BaseVectorize):
|
||||
target_registry = DelayedRegistry({'cpu': gufunc.GUFunc,
|
||||
'parallel': ParallelGUFuncBuilder,})
|
||||
|
||||
def __new__(cls, func, signature, **kws):
|
||||
identity = cls.get_identity(kws)
|
||||
cache = cls.get_cache(kws)
|
||||
imp = cls.get_target_implementation(kws)
|
||||
writable_args = cls.get_writable_args(kws)
|
||||
if imp is gufunc.GUFunc:
|
||||
is_dyn = kws.pop('is_dynamic', False)
|
||||
return imp(func, signature, identity=identity, cache=cache,
|
||||
is_dynamic=is_dyn, targetoptions=kws,
|
||||
writable_args=writable_args)
|
||||
else:
|
||||
return imp(func, signature, identity=identity, cache=cache,
|
||||
targetoptions=kws, writable_args=writable_args)
|
||||
|
||||
|
||||
def vectorize(ftylist_or_function=(), **kws):
|
||||
"""vectorize(ftylist_or_function=(), target='cpu', identity=None, **kws)
|
||||
|
||||
A decorator that creates a NumPy ufunc object using Numba compiled
|
||||
code. When no arguments or only keyword arguments are given,
|
||||
vectorize will return a Numba dynamic ufunc (DUFunc) object, where
|
||||
compilation/specialization may occur at call-time.
|
||||
|
||||
Args
|
||||
-----
|
||||
ftylist_or_function: function or iterable
|
||||
|
||||
When the first argument is a function, signatures are dealt
|
||||
with at call-time.
|
||||
|
||||
When the first argument is an iterable of type signatures,
|
||||
which are either function type object or a string describing
|
||||
the function type, signatures are finalized at decoration
|
||||
time.
|
||||
|
||||
Keyword Args
|
||||
------------
|
||||
|
||||
target: str
|
||||
A string for code generation target. Default to "cpu".
|
||||
|
||||
identity: int, str, or None
|
||||
The identity (or unit) value for the element-wise function
|
||||
being implemented. Allowed values are None (the default), 0, 1,
|
||||
and "reorderable".
|
||||
|
||||
cache: bool
|
||||
Turns on caching.
|
||||
|
||||
|
||||
Returns
|
||||
--------
|
||||
|
||||
A NumPy universal function
|
||||
|
||||
Examples
|
||||
-------
|
||||
@vectorize(['float32(float32, float32)',
|
||||
'float64(float64, float64)'], identity=0)
|
||||
def sum(a, b):
|
||||
return a + b
|
||||
|
||||
@vectorize
|
||||
def sum(a, b):
|
||||
return a + b
|
||||
|
||||
@vectorize(identity=1)
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
|
||||
"""
|
||||
if isinstance(ftylist_or_function, str):
|
||||
# Common user mistake
|
||||
ftylist = [ftylist_or_function]
|
||||
elif inspect.isfunction(ftylist_or_function):
|
||||
return dufunc.DUFunc(ftylist_or_function, **kws)
|
||||
elif ftylist_or_function is not None:
|
||||
ftylist = ftylist_or_function
|
||||
|
||||
def wrap(func):
|
||||
vec = Vectorize(func, **kws)
|
||||
for sig in ftylist:
|
||||
vec.add(sig)
|
||||
if len(ftylist) > 0:
|
||||
vec.disable_compile()
|
||||
return vec.build_ufunc()
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def guvectorize(*args, **kwargs):
|
||||
"""guvectorize(ftylist, signature, target='cpu', identity=None, **kws)
|
||||
|
||||
A decorator to create NumPy generalized-ufunc object from Numba compiled
|
||||
code.
|
||||
|
||||
Args
|
||||
-----
|
||||
ftylist: iterable
|
||||
An iterable of type signatures, which are either
|
||||
function type object or a string describing the
|
||||
function type.
|
||||
|
||||
signature: str
|
||||
A NumPy generalized-ufunc signature.
|
||||
e.g. "(m, n), (n, p)->(m, p)"
|
||||
|
||||
identity: int, str, or None
|
||||
The identity (or unit) value for the element-wise function
|
||||
being implemented. Allowed values are None (the default), 0, 1,
|
||||
and "reorderable".
|
||||
|
||||
cache: bool
|
||||
Turns on caching.
|
||||
|
||||
writable_args: tuple
|
||||
a tuple of indices of input variables that are writable.
|
||||
|
||||
target: str
|
||||
A string for code generation target. Defaults to "cpu".
|
||||
|
||||
Returns
|
||||
--------
|
||||
|
||||
A NumPy generalized universal-function
|
||||
|
||||
Example
|
||||
-------
|
||||
@guvectorize(['void(int32[:,:], int32[:,:], int32[:,:])',
|
||||
'void(float32[:,:], float32[:,:], float32[:,:])'],
|
||||
'(x, y),(x, y)->(x, y)')
|
||||
def add_2d_array(a, b, c):
|
||||
for i in range(c.shape[0]):
|
||||
for j in range(c.shape[1]):
|
||||
c[i, j] = a[i, j] + b[i, j]
|
||||
|
||||
"""
|
||||
if len(args) == 1:
|
||||
ftylist = []
|
||||
signature = args[0]
|
||||
kwargs.setdefault('is_dynamic', True)
|
||||
elif len(args) == 2:
|
||||
ftylist = args[0]
|
||||
signature = args[1]
|
||||
else:
|
||||
raise TypeError('guvectorize() takes one or two positional arguments')
|
||||
|
||||
if isinstance(ftylist, str):
|
||||
# Common user mistake
|
||||
ftylist = [ftylist]
|
||||
|
||||
def wrap(func):
|
||||
guvec = GUVectorize(func, signature, **kwargs)
|
||||
for fty in ftylist:
|
||||
guvec.add(fty)
|
||||
if len(ftylist) > 0:
|
||||
guvec.disable_compile()
|
||||
return guvec.build_ufunc()
|
||||
|
||||
return wrap
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,325 @@
|
||||
from numba import typeof
|
||||
from numba.core import types
|
||||
from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
|
||||
from numba.np.ufunc.sigparse import parse_signature
|
||||
from numba.np.ufunc.ufunc_base import UfuncBase, UfuncLowererBase
|
||||
from numba.np.numpy_support import ufunc_find_matching_loop
|
||||
from numba.core import serialize, errors
|
||||
from numba.core.typing import npydecl
|
||||
from numba.core.typing.templates import signature, AbstractTemplate
|
||||
import functools
|
||||
|
||||
|
||||
def make_gufunc_kernel(_dufunc):
|
||||
from numba.np import npyimpl
|
||||
|
||||
class GUFuncKernel(npyimpl._Kernel):
|
||||
"""
|
||||
npyimpl._Kernel subclass responsible for lowering a gufunc kernel
|
||||
(element-wise function) inside a broadcast loop (which is
|
||||
generated by npyimpl.numpy_gufunc_kernel()).
|
||||
"""
|
||||
dufunc = _dufunc
|
||||
|
||||
def __init__(self, context, builder, outer_sig):
|
||||
super().__init__(context, builder, outer_sig)
|
||||
ewise_types = self.dufunc._get_ewise_dtypes(outer_sig.args)
|
||||
self.inner_sig, self.cres = self.dufunc.find_ewise_function(
|
||||
ewise_types)
|
||||
|
||||
def cast(self, val, fromty, toty):
|
||||
# Handle the case where "fromty" is an array and "toty" a scalar
|
||||
if isinstance(fromty, types.Array) and not \
|
||||
isinstance(toty, types.Array):
|
||||
return super().cast(val, fromty.dtype, toty)
|
||||
return super().cast(val, fromty, toty)
|
||||
|
||||
def generate(self, *args):
|
||||
if self.cres.objectmode:
|
||||
msg = ('Calling a guvectorize function in object mode is not '
|
||||
'supported yet.')
|
||||
raise errors.NumbaRuntimeError(msg)
|
||||
self.context.add_linking_libs((self.cres.library,))
|
||||
return super().generate(*args)
|
||||
|
||||
GUFuncKernel.__name__ += _dufunc.__name__
|
||||
return GUFuncKernel
|
||||
|
||||
|
||||
class GUFuncLowerer(UfuncLowererBase):
|
||||
'''Callable class responsible for lowering calls to a specific gufunc.
|
||||
'''
|
||||
def __init__(self, gufunc):
|
||||
from numba.np import npyimpl
|
||||
super().__init__(gufunc,
|
||||
make_gufunc_kernel,
|
||||
npyimpl.numpy_gufunc_kernel)
|
||||
|
||||
|
||||
class GUFunc(serialize.ReduceMixin, UfuncBase):
|
||||
"""
|
||||
Dynamic generalized universal function (GUFunc)
|
||||
intended to act like a normal Numpy gufunc, but capable
|
||||
of call-time (just-in-time) compilation of fast loops
|
||||
specialized to inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, py_func, signature, identity=None, cache=None,
|
||||
is_dynamic=False, targetoptions=None, writable_args=()):
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
self.ufunc = None
|
||||
self._frozen = False
|
||||
self._is_dynamic = is_dynamic
|
||||
self._identity = identity
|
||||
|
||||
# GUFunc cannot inherit from GUFuncBuilder because "identity"
|
||||
# is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder
|
||||
# object here
|
||||
self.gufunc_builder = GUFuncBuilder(
|
||||
py_func, signature, identity, cache, targetoptions, writable_args)
|
||||
|
||||
self.__name__ = self.gufunc_builder.py_func.__name__
|
||||
self.__doc__ = self.gufunc_builder.py_func.__doc__
|
||||
self._dispatcher = self.gufunc_builder.nb_func
|
||||
self._initialize(self._dispatcher)
|
||||
functools.update_wrapper(self, py_func)
|
||||
|
||||
def _initialize(self, dispatcher):
|
||||
self.build_ufunc()
|
||||
self._install_type()
|
||||
self._lower_me = GUFuncLowerer(self)
|
||||
self._install_cg()
|
||||
|
||||
def _reduce_states(self):
|
||||
gb = self.gufunc_builder
|
||||
dct = dict(
|
||||
py_func=gb.py_func,
|
||||
signature=gb.signature,
|
||||
identity=self._identity,
|
||||
cache=gb.cache,
|
||||
is_dynamic=self._is_dynamic,
|
||||
targetoptions=gb.targetoptions,
|
||||
writable_args=gb.writable_args,
|
||||
typesigs=gb._sigs,
|
||||
frozen=self._frozen,
|
||||
)
|
||||
return dct
|
||||
|
||||
@classmethod
|
||||
def _rebuild(cls, py_func, signature, identity, cache, is_dynamic,
|
||||
targetoptions, writable_args, typesigs, frozen):
|
||||
self = cls(py_func=py_func, signature=signature, identity=identity,
|
||||
cache=cache, is_dynamic=is_dynamic,
|
||||
targetoptions=targetoptions, writable_args=writable_args)
|
||||
for sig in typesigs:
|
||||
self.add(sig)
|
||||
self.build_ufunc()
|
||||
self._frozen = frozen
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return f"<numba._GUFunc '{self.__name__}'>"
|
||||
|
||||
def _install_type(self, typingctx=None):
|
||||
"""Constructs and installs a typing class for a gufunc object in the
|
||||
input typing context. If no typing context is given, then
|
||||
_install_type() installs into the typing context of the
|
||||
dispatcher object (should be same default context used by
|
||||
jit() and njit()).
|
||||
"""
|
||||
if typingctx is None:
|
||||
typingctx = self._dispatcher.targetdescr.typing_context
|
||||
_ty_cls = type('GUFuncTyping_' + self.__name__,
|
||||
(AbstractTemplate,),
|
||||
dict(key=self, generic=self._type_me))
|
||||
typingctx.insert_user_function(self, _ty_cls)
|
||||
|
||||
def add(self, fty):
|
||||
self.gufunc_builder.add(fty)
|
||||
|
||||
def build_ufunc(self):
|
||||
self.ufunc = self.gufunc_builder.build_ufunc()
|
||||
return self
|
||||
|
||||
def expected_ndims(self):
|
||||
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
||||
return (tuple(map(len, parsed_sig[0])), tuple(map(len, parsed_sig[1])))
|
||||
|
||||
def _type_me(self, argtys, kws):
|
||||
"""
|
||||
Implement AbstractTemplate.generic() for the typing class
|
||||
built by gufunc._install_type().
|
||||
|
||||
Return the call-site signature after either validating the
|
||||
element-wise signature or compiling for it.
|
||||
"""
|
||||
assert not kws
|
||||
ufunc = self.ufunc
|
||||
sig = self.gufunc_builder.signature
|
||||
inp_ndims, out_ndims = self.expected_ndims()
|
||||
ndims = inp_ndims + out_ndims
|
||||
|
||||
assert len(argtys), len(ndims)
|
||||
for idx, arg in enumerate(argtys):
|
||||
if isinstance(arg, types.Array) and arg.ndim < ndims[idx]:
|
||||
kind = "Input" if idx < len(inp_ndims) else "Output"
|
||||
i = idx if idx < len(inp_ndims) else idx - len(inp_ndims)
|
||||
msg = (
|
||||
f"{self.__name__}: {kind} operand {i} does not have "
|
||||
f"enough dimensions (has {arg.ndim}, gufunc core with "
|
||||
f"signature {sig} requires {ndims[idx]})")
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
_handle_inputs_result = npydecl.Numpy_rules_ufunc._handle_inputs(
|
||||
ufunc, argtys, kws)
|
||||
ewise_types, _, _, _ = _handle_inputs_result
|
||||
sig, _ = self.find_ewise_function(ewise_types)
|
||||
|
||||
if sig is None:
|
||||
# Matching element-wise signature was not found; must
|
||||
# compile.
|
||||
if self._frozen:
|
||||
msg = f"cannot call {self} with types {argtys}"
|
||||
raise errors.TypingError(msg)
|
||||
self._compile_for_argtys(ewise_types)
|
||||
# double check to ensure there is a match
|
||||
sig, _ = self.find_ewise_function(ewise_types)
|
||||
if sig == (None, None):
|
||||
msg = f"Fail to compile {self.__name__} with types {argtys}"
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
assert sig is not None
|
||||
|
||||
return signature(types.none, *argtys)
|
||||
|
||||
def _compile_for_argtys(self, argtys, return_type=None):
|
||||
# Compile a new guvectorize function! Use the gufunc signature
|
||||
# i.e. (n,m),(m)->(n)
|
||||
# plus ewise_types to build a numba function type
|
||||
fnty = self._get_function_type(*argtys)
|
||||
self.gufunc_builder.add(fnty)
|
||||
|
||||
def match_signature(self, ewise_types, sig):
|
||||
dtypes = self._get_ewise_dtypes(sig.args)
|
||||
return tuple(dtypes) == tuple(ewise_types)
|
||||
|
||||
@property
|
||||
def is_dynamic(self):
|
||||
return self._is_dynamic
|
||||
|
||||
def _get_ewise_dtypes(self, args):
|
||||
argtys = map(lambda arg: arg if isinstance(arg, types.Type) else
|
||||
typeof(arg), args)
|
||||
tys = []
|
||||
for argty in argtys:
|
||||
if isinstance(argty, types.Array):
|
||||
tys.append(argty.dtype)
|
||||
else:
|
||||
tys.append(argty)
|
||||
return tys
|
||||
|
||||
def _num_args_match(self, *args):
|
||||
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
||||
return len(args) == len(parsed_sig[0]) + len(parsed_sig[1])
|
||||
|
||||
def _get_function_type(self, *args):
|
||||
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
||||
# ewise_types is a list of [int32, int32, int32, ...]
|
||||
ewise_types = self._get_ewise_dtypes(args)
|
||||
|
||||
# first time calling the gufunc
|
||||
# generate a signature based on input arguments
|
||||
l = []
|
||||
for idx, sig_dim in enumerate(parsed_sig[0]):
|
||||
ndim = len(sig_dim)
|
||||
if ndim == 0: # append scalar
|
||||
l.append(ewise_types[idx])
|
||||
else:
|
||||
l.append(types.Array(ewise_types[idx], ndim, 'A'))
|
||||
|
||||
offset = len(parsed_sig[0])
|
||||
# add return type to signature
|
||||
for idx, sig_dim in enumerate(parsed_sig[1]):
|
||||
retty = ewise_types[idx + offset]
|
||||
ret_ndim = len(sig_dim) or 1 # small hack to return scalars
|
||||
l.append(types.Array(retty, ret_ndim, 'A'))
|
||||
|
||||
return types.none(*l)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# If compilation is disabled OR it is NOT a dynamic gufunc
|
||||
# call the underlying gufunc
|
||||
if self._frozen or not self.is_dynamic:
|
||||
# Do not unwrap the ufunc if the argument is a wrapper that will
|
||||
# potentially pickle the ufunc after it receives it in
|
||||
# __array_ufunc__. The same logic in theory should be replicated
|
||||
# for reduce(), outer(), etc., but they're not implemented in dask.
|
||||
if args and _is_array_wrapper(args[0]):
|
||||
return args[0].__array_ufunc__(
|
||||
self, "__call__", *args, **kwargs
|
||||
)
|
||||
else:
|
||||
return self.ufunc(*args, **kwargs)
|
||||
elif "out" in kwargs:
|
||||
# If "out" argument is supplied
|
||||
args += (kwargs.pop("out"),)
|
||||
|
||||
if self._num_args_match(*args) is False:
|
||||
# It is not allowed to call a dynamic gufunc without
|
||||
# providing all the arguments
|
||||
# see: https://github.com/numba/numba/pull/5938#discussion_r506429392 # noqa: E501
|
||||
msg = (
|
||||
f"Too few arguments for function '{self.__name__}'. "
|
||||
"Note that the pattern `out = gufunc(Arg1, Arg2, ..., ArgN)` "
|
||||
"is not allowed. Use `gufunc(Arg1, Arg2, ..., ArgN, out) "
|
||||
"instead.")
|
||||
raise TypeError(msg)
|
||||
|
||||
# at this point we know the gufunc is a dynamic one
|
||||
ewise = self._get_ewise_dtypes(args)
|
||||
if not (self.ufunc and ufunc_find_matching_loop(self.ufunc, ewise)):
|
||||
# A previous call (@njit -> @guvectorize) may have compiled a
|
||||
# version for the element-wise dtypes. In this case, we don't need
|
||||
# to compile it again, just build the (g)ufunc
|
||||
if not self.find_ewise_function(ewise) != (None, None):
|
||||
sig = self._get_function_type(*args)
|
||||
self.add(sig)
|
||||
self.build_ufunc()
|
||||
|
||||
return self.ufunc(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_array_wrapper(obj):
|
||||
"""Return True if obj wraps around numpy or another numpy-like library
|
||||
and is likely going to apply the ufunc to the wrapped array; False
|
||||
otherwise.
|
||||
|
||||
At the moment, this returns True for
|
||||
|
||||
- dask.array.Array
|
||||
- dask.dataframe.DataFrame
|
||||
- dask.dataframe.Series
|
||||
- xarray.DataArray
|
||||
- xarray.Dataset
|
||||
- xarray.Variable
|
||||
- pint.Quantity
|
||||
- other potential wrappers around dask array or dask dataframe
|
||||
|
||||
We may need to add other libraries that pickle ufuncs from their
|
||||
__array_ufunc__ method in the future.
|
||||
|
||||
Note that the below test is a lot more naive than
|
||||
`dask.base.is_dask_collection`
|
||||
(https://github.com/dask/dask/blob/5949e54bc04158d215814586a44d51e0eb4a964d/dask/base.py#L209-L249), # noqa: E501
|
||||
because it doesn't need to find out if we're actually dealing with
|
||||
a dask collection, only that we're dealing with a wrapper.
|
||||
Namely, it will return True for a pint.Quantity wrapping around a plain float, a
|
||||
numpy.ndarray, or a dask.array.Array, and it's OK because in all cases
|
||||
Quantity.__array_ufunc__ is going to forward the ufunc call inwards.
|
||||
"""
|
||||
return (
|
||||
not isinstance(obj, type)
|
||||
and hasattr(obj, "__dask_graph__")
|
||||
and hasattr(obj, "__array_ufunc__")
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,761 @@
|
||||
"""
|
||||
This file implements the code-generator for parallel-vectorize.
|
||||
|
||||
ParallelUFunc is the platform independent base class for generating
|
||||
the thread dispatcher. This thread dispatcher launches threads
|
||||
that execute the generated function of UFuncCore.
|
||||
UFuncCore is subclassed to specialize for the input/output types.
|
||||
The actual workload is invoked inside the function generated by UFuncCore.
|
||||
UFuncCore also defines a work-stealing mechanism that allows idle threads
|
||||
to steal works from other threads.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from threading import RLock as threadRLock
|
||||
from ctypes import CFUNCTYPE, c_int, CDLL, POINTER, c_uint
|
||||
|
||||
import numpy as np
|
||||
|
||||
import llvmlite.binding as ll
|
||||
from llvmlite import ir
|
||||
|
||||
from numba.np.numpy_support import as_dtype
|
||||
from numba.core import types, cgutils, config, errors
|
||||
from numba.core.typing import signature
|
||||
from numba.np.ufunc.wrappers import _wrapper_info
|
||||
from numba.np.ufunc import ufuncbuilder
|
||||
from numba.extending import overload, intrinsic
|
||||
|
||||
_IS_OSX = sys.platform.startswith('darwin')
|
||||
_IS_LINUX = sys.platform.startswith('linux')
|
||||
_IS_WINDOWS = sys.platform.startswith('win32')
|
||||
|
||||
|
||||
def get_thread_count():
|
||||
"""
|
||||
Gets the available thread count.
|
||||
"""
|
||||
t = config.NUMBA_NUM_THREADS
|
||||
if t < 1:
|
||||
raise ValueError("Number of threads specified must be > 0.")
|
||||
return t
|
||||
|
||||
|
||||
NUM_THREADS = get_thread_count()
|
||||
|
||||
|
||||
def build_gufunc_kernel(library, ctx, info, sig, inner_ndim):
|
||||
"""Wrap the original CPU ufunc/gufunc with a parallel dispatcher.
|
||||
This function will wrap gufuncs and ufuncs something like.
|
||||
|
||||
Args
|
||||
----
|
||||
ctx
|
||||
numba's codegen context
|
||||
|
||||
info: (library, env, name)
|
||||
inner function info
|
||||
|
||||
sig
|
||||
type signature of the gufunc
|
||||
|
||||
inner_ndim
|
||||
inner dimension of the gufunc (this is len(sig.args) in the case of a
|
||||
ufunc)
|
||||
|
||||
Returns
|
||||
-------
|
||||
wrapper_info : (library, env, name)
|
||||
The info for the gufunc wrapper.
|
||||
|
||||
Details
|
||||
-------
|
||||
|
||||
The kernel signature looks like this:
|
||||
|
||||
void kernel(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
|
||||
|
||||
args - the input arrays + output arrays
|
||||
dimensions - the dimensions of the arrays
|
||||
steps - the step size for the array (this is like sizeof(type))
|
||||
data - any additional data
|
||||
|
||||
The parallel backend then stages multiple calls to this kernel concurrently
|
||||
across a number of threads. Practically, for each item of work, the backend
|
||||
duplicates `dimensions` and adjusts the first entry to reflect the size of
|
||||
the item of work, it also forms up an array of pointers into the args for
|
||||
offsets to read/write from/to with respect to its position in the items of
|
||||
work. This allows the same kernel to be used for each item of work, with
|
||||
simply adjusted reads/writes/domain sizes and is safe by virtue of the
|
||||
domain partitioning.
|
||||
|
||||
NOTE: The execution backend is passed the requested thread count, but it can
|
||||
choose to ignore it (TBB)!
|
||||
"""
|
||||
assert isinstance(info, tuple) # guard against old usage
|
||||
# Declare types and function
|
||||
byte_t = ir.IntType(8)
|
||||
byte_ptr_t = ir.PointerType(byte_t)
|
||||
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
|
||||
|
||||
intp_t = ctx.get_value_type(types.intp)
|
||||
intp_ptr_t = ir.PointerType(intp_t)
|
||||
|
||||
fnty = ir.FunctionType(ir.VoidType(), [ir.PointerType(byte_ptr_t),
|
||||
ir.PointerType(intp_t),
|
||||
ir.PointerType(intp_t),
|
||||
byte_ptr_t])
|
||||
wrapperlib = ctx.codegen().create_library('parallelgufuncwrapper')
|
||||
mod = wrapperlib.create_ir_module('parallel.gufunc.wrapper')
|
||||
kernel_name = ".kernel.{}_{}".format(id(info.env), info.name)
|
||||
lfunc = ir.Function(mod, fnty, name=kernel_name)
|
||||
|
||||
bb_entry = lfunc.append_basic_block('')
|
||||
|
||||
# Function body starts
|
||||
builder = ir.IRBuilder(bb_entry)
|
||||
|
||||
args, dimensions, steps, data = lfunc.args
|
||||
|
||||
# Release the GIL (and ensure we have the GIL)
|
||||
# Note: numpy ufunc may not always release the GIL; thus,
|
||||
# we need to ensure we have the GIL.
|
||||
pyapi = ctx.get_python_api(builder)
|
||||
gil_state = pyapi.gil_ensure()
|
||||
thread_state = pyapi.save_thread()
|
||||
|
||||
def as_void_ptr(arg):
|
||||
return builder.bitcast(arg, byte_ptr_t)
|
||||
|
||||
# Array count depends on whether an "output" array is needed. In the case
|
||||
# of a void return type cf. gufunc it is the number of args, in the case of
|
||||
# a non-void return type cf. ufunc it is the number of args + 1 so as to
|
||||
# account for the output array.
|
||||
array_count = len(sig.args)
|
||||
if not isinstance(sig.return_type, types.NoneType):
|
||||
array_count += 1
|
||||
|
||||
parallel_for_ty = ir.FunctionType(ir.VoidType(),
|
||||
[byte_ptr_t] * 5 + [intp_t, ] * 3)
|
||||
parallel_for = cgutils.get_or_insert_function(mod, parallel_for_ty,
|
||||
'numba_parallel_for')
|
||||
|
||||
# Reference inner-function and link
|
||||
innerfunc_fnty = ir.FunctionType(
|
||||
ir.VoidType(),
|
||||
[byte_ptr_ptr_t, intp_ptr_t, intp_ptr_t, byte_ptr_t],
|
||||
)
|
||||
tmp_voidptr = cgutils.get_or_insert_function(mod, innerfunc_fnty,
|
||||
info.name,)
|
||||
wrapperlib.add_linking_library(info.library)
|
||||
|
||||
get_num_threads = cgutils.get_or_insert_function(
|
||||
builder.module,
|
||||
ir.FunctionType(ir.IntType(types.intp.bitwidth), []),
|
||||
"get_num_threads")
|
||||
|
||||
num_threads = builder.call(get_num_threads, [])
|
||||
|
||||
# Prepare call
|
||||
fnptr = builder.bitcast(tmp_voidptr, byte_ptr_t)
|
||||
innerargs = [as_void_ptr(x) for x
|
||||
in [args, dimensions, steps, data]]
|
||||
builder.call(parallel_for, [fnptr] + innerargs +
|
||||
[intp_t(x) for x in (inner_ndim, array_count)] + [num_threads])
|
||||
|
||||
# Release the GIL
|
||||
pyapi.restore_thread(thread_state)
|
||||
pyapi.gil_release(gil_state)
|
||||
|
||||
builder.ret_void()
|
||||
|
||||
wrapperlib.add_ir_module(mod)
|
||||
wrapperlib.add_linking_library(library)
|
||||
return _wrapper_info(library=wrapperlib, name=lfunc.name, env=info.env)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
class ParallelUFuncBuilder(ufuncbuilder.UFuncBuilder):
|
||||
def build(self, cres, sig):
|
||||
_launch_threads()
|
||||
|
||||
# Builder wrapper for ufunc entry point
|
||||
ctx = cres.target_context
|
||||
signature = cres.signature
|
||||
library = cres.library
|
||||
fname = cres.fndesc.llvm_func_name
|
||||
|
||||
info = build_ufunc_wrapper(library, ctx, fname, signature, cres)
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
# Get dtypes
|
||||
dtypenums = [np.dtype(a.name).num for a in signature.args]
|
||||
dtypenums.append(np.dtype(signature.return_type.name).num)
|
||||
keepalive = ()
|
||||
return dtypenums, ptr, keepalive
|
||||
|
||||
|
||||
def build_ufunc_wrapper(library, ctx, fname, signature, cres):
|
||||
innerfunc = ufuncbuilder.build_ufunc_wrapper(library, ctx, fname,
|
||||
signature, objmode=False,
|
||||
cres=cres)
|
||||
info = build_gufunc_kernel(library, ctx, innerfunc, signature,
|
||||
len(signature.args))
|
||||
return info
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ParallelGUFuncBuilder(ufuncbuilder.GUFuncBuilder):
|
||||
def __init__(self, py_func, signature, identity=None, cache=False,
|
||||
targetoptions=None, writable_args=()):
|
||||
# Force nopython mode
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
targetoptions.update(dict(nopython=True))
|
||||
super(
|
||||
ParallelGUFuncBuilder,
|
||||
self).__init__(
|
||||
py_func=py_func,
|
||||
signature=signature,
|
||||
identity=identity,
|
||||
cache=cache,
|
||||
targetoptions=targetoptions,
|
||||
writable_args=writable_args)
|
||||
|
||||
def build(self, cres):
|
||||
"""
|
||||
Returns (dtype numbers, function ptr, EnvironmentObject)
|
||||
"""
|
||||
_launch_threads()
|
||||
|
||||
# Build wrapper for ufunc entry point
|
||||
info = build_gufunc_wrapper(
|
||||
self.py_func, cres, self.sin, self.sout, cache=self.cache,
|
||||
is_parfors=False,
|
||||
)
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
env = info.env
|
||||
|
||||
# Get dtypes
|
||||
dtypenums = []
|
||||
for a in cres.signature.args:
|
||||
if isinstance(a, types.Array):
|
||||
ty = a.dtype
|
||||
else:
|
||||
ty = a
|
||||
dtypenums.append(as_dtype(ty).num)
|
||||
|
||||
return dtypenums, ptr, env
|
||||
|
||||
|
||||
# This is not a member of the ParallelGUFuncBuilder function because it is
|
||||
# called without an enclosing instance from parfors
|
||||
|
||||
def build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors):
|
||||
"""Build gufunc wrapper for the given arguments.
|
||||
The *is_parfors* is a boolean indicating whether the gufunc is being
|
||||
built for use as a ParFors kernel. This changes codegen and caching
|
||||
behavior.
|
||||
"""
|
||||
library = cres.library
|
||||
ctx = cres.target_context
|
||||
signature = cres.signature
|
||||
innerinfo = ufuncbuilder.build_gufunc_wrapper(
|
||||
py_func, cres, sin, sout, cache=cache, is_parfors=is_parfors,
|
||||
)
|
||||
sym_in = set(sym for term in sin for sym in term)
|
||||
sym_out = set(sym for term in sout for sym in term)
|
||||
inner_ndim = len(sym_in | sym_out)
|
||||
|
||||
info = build_gufunc_kernel(
|
||||
library, ctx, innerinfo, signature, inner_ndim,
|
||||
)
|
||||
return info
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_backend_init_thread_lock = threadRLock()
|
||||
|
||||
_windows = sys.platform.startswith('win32')
|
||||
|
||||
|
||||
class _nop(object):
|
||||
"""A no-op contextmanager
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_backend_init_process_lock = None
|
||||
|
||||
|
||||
def _set_init_process_lock():
|
||||
global _backend_init_process_lock
|
||||
try:
|
||||
# Force the use of an RLock in the case a fork was used to start the
|
||||
# process and thereby the init sequence, some of the threading backend
|
||||
# init sequences are not fork safe. Also, windows global mp locks seem
|
||||
# to be fine.
|
||||
with _backend_init_thread_lock: # protect part-initialized module access
|
||||
import multiprocessing
|
||||
if "fork" in multiprocessing.get_start_method() or _windows:
|
||||
ctx = multiprocessing.get_context()
|
||||
_backend_init_process_lock = ctx.RLock()
|
||||
else:
|
||||
_backend_init_process_lock = _nop()
|
||||
|
||||
except OSError as e:
|
||||
|
||||
# probably lack of /dev/shm for semaphore writes, warn the user
|
||||
msg = (
|
||||
"Could not obtain multiprocessing lock due to OS level error: %s\n"
|
||||
"A likely cause of this problem is '/dev/shm' is missing or "
|
||||
"read-only such that necessary semaphores cannot be written.\n"
|
||||
"*** The responsibility of ensuring multiprocessing safe access to "
|
||||
"this initialization sequence/module import is deferred to the "
|
||||
"user! ***\n"
|
||||
)
|
||||
warnings.warn(msg % str(e), errors.NumbaSystemWarning)
|
||||
|
||||
_backend_init_process_lock = _nop()
|
||||
|
||||
|
||||
_is_initialized = False
|
||||
|
||||
# this is set by _launch_threads
|
||||
_threading_layer = None
|
||||
|
||||
|
||||
def threading_layer():
|
||||
"""
|
||||
Get the name of the threading layer in use for parallel CPU targets
|
||||
"""
|
||||
if _threading_layer is None:
|
||||
raise ValueError("Threading layer is not initialized.")
|
||||
else:
|
||||
return _threading_layer
|
||||
|
||||
|
||||
def _check_tbb_version_compatible():
|
||||
"""
|
||||
Checks that if TBB is present it is of a compatible version.
|
||||
"""
|
||||
try:
|
||||
# first check that the TBB version is new enough
|
||||
if _IS_WINDOWS:
|
||||
libtbb_name = 'tbb12.dll'
|
||||
elif _IS_OSX:
|
||||
libtbb_name = 'libtbb.12.dylib'
|
||||
elif _IS_LINUX:
|
||||
libtbb_name = 'libtbb.so.12'
|
||||
else:
|
||||
raise ValueError("Unknown operating system")
|
||||
libtbb = CDLL(libtbb_name)
|
||||
version_func = libtbb.TBB_runtime_interface_version
|
||||
version_func.argtypes = []
|
||||
version_func.restype = c_int
|
||||
tbb_iface_ver = version_func()
|
||||
if tbb_iface_ver < 12060: # magic number from TBB
|
||||
msg = ("The TBB threading layer requires TBB "
|
||||
"version 2021 update 6 or later i.e., "
|
||||
"TBB_INTERFACE_VERSION >= 12060. Found "
|
||||
"TBB_INTERFACE_VERSION = %s. The TBB "
|
||||
"threading layer is disabled.") % tbb_iface_ver
|
||||
problem = errors.NumbaWarning(msg)
|
||||
warnings.warn(problem)
|
||||
raise ImportError("Problem with TBB. Reason: %s" % msg)
|
||||
except (ValueError, OSError) as e:
|
||||
# Translate as an ImportError for consistent error class use, this error
|
||||
# will never materialise
|
||||
raise ImportError("Problem with TBB. Reason: %s" % e)
|
||||
|
||||
|
||||
def _launch_threads():
|
||||
if not _backend_init_process_lock:
|
||||
_set_init_process_lock()
|
||||
|
||||
with _backend_init_process_lock:
|
||||
with _backend_init_thread_lock:
|
||||
global _is_initialized
|
||||
if _is_initialized:
|
||||
return
|
||||
|
||||
def select_known_backend(backend):
|
||||
"""
|
||||
Loads a specific threading layer backend based on string
|
||||
"""
|
||||
lib = None
|
||||
if backend.startswith("tbb"):
|
||||
try:
|
||||
# check if TBB is present and compatible
|
||||
_check_tbb_version_compatible()
|
||||
# now try and load the backend
|
||||
from numba.np.ufunc import tbbpool as lib
|
||||
except ImportError:
|
||||
pass
|
||||
elif backend.startswith("omp"):
|
||||
# TODO: Check that if MKL is present that it is a version
|
||||
# that understands GNU OMP might be present
|
||||
try:
|
||||
from numba.np.ufunc import omppool as lib
|
||||
except ImportError:
|
||||
pass
|
||||
elif backend.startswith("workqueue"):
|
||||
from numba.np.ufunc import workqueue as lib
|
||||
else:
|
||||
msg = "Unknown value specified for threading layer: %s"
|
||||
raise ValueError(msg % backend)
|
||||
return lib
|
||||
|
||||
def select_from_backends(backends):
|
||||
"""
|
||||
Selects from presented backends and returns the first working
|
||||
"""
|
||||
lib = None
|
||||
for backend in backends:
|
||||
lib = select_known_backend(backend)
|
||||
if lib is not None:
|
||||
break
|
||||
else:
|
||||
backend = ''
|
||||
return lib, backend
|
||||
|
||||
t = str(config.THREADING_LAYER).lower()
|
||||
namedbackends = config.THREADING_LAYER_PRIORITY
|
||||
if not (len(namedbackends) == 3 and
|
||||
set(namedbackends) == {'tbb', 'omp', 'workqueue'}):
|
||||
raise ValueError(
|
||||
"THREADING_LAYER_PRIORITY invalid: %s. "
|
||||
"It must be a permutation of "
|
||||
"{'tbb', 'omp', 'workqueue'}"
|
||||
% namedbackends
|
||||
)
|
||||
|
||||
lib = None
|
||||
err_helpers = dict()
|
||||
err_helpers['TBB'] = ("Intel TBB is required, try:\n"
|
||||
"$ conda/pip install tbb")
|
||||
err_helpers['OSX_OMP'] = ("Intel OpenMP is required, try:\n"
|
||||
"$ conda/pip install intel-openmp")
|
||||
requirements = []
|
||||
|
||||
def raise_with_hint(required):
|
||||
errmsg = "No threading layer could be loaded.\n%s"
|
||||
hintmsg = "HINT:\n%s"
|
||||
if len(required) == 0:
|
||||
hint = ''
|
||||
if len(required) == 1:
|
||||
hint = hintmsg % err_helpers[required[0]]
|
||||
if len(required) > 1:
|
||||
options = '\nOR\n'.join([err_helpers[x] for x in required])
|
||||
hint = hintmsg % ("One of:\n%s" % options)
|
||||
raise ValueError(errmsg % hint)
|
||||
|
||||
if t in namedbackends:
|
||||
# Try and load the specific named backend
|
||||
lib = select_known_backend(t)
|
||||
if not lib:
|
||||
# something is missing preventing a valid backend from
|
||||
# loading, set requirements for hinting
|
||||
if t == 'tbb':
|
||||
requirements.append('TBB')
|
||||
elif t == 'omp' and _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
libname = t
|
||||
elif t in ['threadsafe', 'forksafe', 'safe']:
|
||||
# User wants a specific behaviour...
|
||||
available = ['tbb']
|
||||
requirements.append('TBB')
|
||||
if t == "safe":
|
||||
# "safe" is TBB, which is fork and threadsafe everywhere
|
||||
pass
|
||||
elif t == "threadsafe":
|
||||
if _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
# omp is threadsafe everywhere
|
||||
available.append('omp')
|
||||
elif t == "forksafe":
|
||||
# everywhere apart from linux (GNU OpenMP) has a guaranteed
|
||||
# forksafe OpenMP, as OpenMP has better performance, prefer
|
||||
# this to workqueue
|
||||
if not _IS_LINUX:
|
||||
available.append('omp')
|
||||
if _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
# workqueue is forksafe everywhere
|
||||
available.append('workqueue')
|
||||
else: # unreachable
|
||||
msg = "No threading layer available for purpose %s"
|
||||
raise ValueError(msg % t)
|
||||
# select amongst available
|
||||
lib, libname = select_from_backends(available)
|
||||
elif t == 'default':
|
||||
# If default is supplied, try them in order, tbb, omp,
|
||||
# workqueue
|
||||
lib, libname = select_from_backends(namedbackends)
|
||||
if not lib:
|
||||
# set requirements for hinting
|
||||
requirements.append('TBB')
|
||||
if _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
else:
|
||||
msg = "The threading layer requested '%s' is unknown to Numba."
|
||||
raise ValueError(msg % t)
|
||||
|
||||
# No lib found, raise and hint
|
||||
if not lib:
|
||||
raise_with_hint(requirements)
|
||||
|
||||
ll.add_symbol('numba_parallel_for', lib.parallel_for)
|
||||
ll.add_symbol('do_scheduling_signed', lib.do_scheduling_signed)
|
||||
ll.add_symbol('do_scheduling_unsigned', lib.do_scheduling_unsigned)
|
||||
ll.add_symbol('allocate_sched', lib.allocate_sched)
|
||||
ll.add_symbol('deallocate_sched', lib.deallocate_sched)
|
||||
|
||||
launch_threads = CFUNCTYPE(None, c_int)(lib.launch_threads)
|
||||
launch_threads(NUM_THREADS)
|
||||
|
||||
_load_threading_functions(lib) # load late
|
||||
|
||||
# set library name so it can be queried
|
||||
global _threading_layer
|
||||
_threading_layer = libname
|
||||
_is_initialized = True
|
||||
|
||||
|
||||
def _load_threading_functions(lib):
|
||||
|
||||
ll.add_symbol('get_num_threads', lib.get_num_threads)
|
||||
ll.add_symbol('set_num_threads', lib.set_num_threads)
|
||||
ll.add_symbol('get_thread_id', lib.get_thread_id)
|
||||
|
||||
global _set_num_threads
|
||||
_set_num_threads = CFUNCTYPE(None, c_int)(lib.set_num_threads)
|
||||
_set_num_threads(NUM_THREADS)
|
||||
|
||||
global _get_num_threads
|
||||
_get_num_threads = CFUNCTYPE(c_int)(lib.get_num_threads)
|
||||
|
||||
global _get_thread_id
|
||||
_get_thread_id = CFUNCTYPE(c_int)(lib.get_thread_id)
|
||||
|
||||
ll.add_symbol('set_parallel_chunksize', lib.set_parallel_chunksize)
|
||||
ll.add_symbol('get_parallel_chunksize', lib.get_parallel_chunksize)
|
||||
ll.add_symbol('get_sched_size', lib.get_sched_size)
|
||||
global _set_parallel_chunksize
|
||||
_set_parallel_chunksize = CFUNCTYPE(c_uint,
|
||||
c_uint)(lib.set_parallel_chunksize)
|
||||
global _get_parallel_chunksize
|
||||
_get_parallel_chunksize = CFUNCTYPE(c_uint)(lib.get_parallel_chunksize)
|
||||
global _get_sched_size
|
||||
_get_sched_size = CFUNCTYPE(c_uint,
|
||||
c_uint,
|
||||
c_uint,
|
||||
POINTER(c_int),
|
||||
POINTER(c_int))(lib.get_sched_size)
|
||||
|
||||
|
||||
# Some helpers to make set_num_threads jittable
|
||||
|
||||
def gen_snt_check():
|
||||
from numba.core.config import NUMBA_NUM_THREADS
|
||||
msg = "The number of threads must be between 1 and %s" % NUMBA_NUM_THREADS
|
||||
|
||||
def snt_check(n):
|
||||
if n > NUMBA_NUM_THREADS or n < 1:
|
||||
raise ValueError(msg)
|
||||
return snt_check
|
||||
|
||||
|
||||
snt_check = gen_snt_check()
|
||||
|
||||
|
||||
@overload(snt_check)
|
||||
def ol_snt_check(n):
|
||||
return snt_check
|
||||
|
||||
|
||||
def set_num_threads(n):
|
||||
"""
|
||||
Set the number of threads to use for parallel execution.
|
||||
|
||||
By default, all :obj:`numba.config.NUMBA_NUM_THREADS` threads are used.
|
||||
|
||||
This functionality works by masking out threads that are not used.
|
||||
Therefore, the number of threads *n* must be less than or equal to
|
||||
:obj:`~.NUMBA_NUM_THREADS`, the total number of threads that are launched.
|
||||
See its documentation for more details.
|
||||
|
||||
This function can be used inside of a jitted function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n: The number of threads. Must be between 1 and NUMBA_NUM_THREADS.
|
||||
|
||||
See Also
|
||||
--------
|
||||
get_num_threads, numba.config.NUMBA_NUM_THREADS,
|
||||
numba.config.NUMBA_DEFAULT_NUM_THREADS, :envvar:`NUMBA_NUM_THREADS`
|
||||
|
||||
"""
|
||||
_launch_threads()
|
||||
if not isinstance(n, (int, np.integer)):
|
||||
raise TypeError("The number of threads specified must be an integer")
|
||||
snt_check(n)
|
||||
_set_num_threads(n)
|
||||
|
||||
|
||||
@overload(set_num_threads)
|
||||
def ol_set_num_threads(n):
|
||||
_launch_threads()
|
||||
if not isinstance(n, types.Integer):
|
||||
msg = "The number of threads specified must be an integer"
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(n):
|
||||
snt_check(n)
|
||||
_set_num_threads(n)
|
||||
return impl
|
||||
|
||||
|
||||
def get_num_threads():
|
||||
"""
|
||||
Get the number of threads used for parallel execution.
|
||||
|
||||
By default (if :func:`~.set_num_threads` is never called), all
|
||||
:obj:`numba.config.NUMBA_NUM_THREADS` threads are used.
|
||||
|
||||
This number is less than or equal to the total number of threads that are
|
||||
launched, :obj:`numba.config.NUMBA_NUM_THREADS`.
|
||||
|
||||
This function can be used inside of a jitted function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The number of threads.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_num_threads, numba.config.NUMBA_NUM_THREADS,
|
||||
numba.config.NUMBA_DEFAULT_NUM_THREADS, :envvar:`NUMBA_NUM_THREADS`
|
||||
|
||||
"""
|
||||
_launch_threads()
|
||||
num_threads = _get_num_threads()
|
||||
if num_threads <= 0:
|
||||
raise RuntimeError("Invalid number of threads. "
|
||||
"This likely indicates a bug in Numba. "
|
||||
"(thread_id=%s, num_threads=%s)" %
|
||||
(get_thread_id(), num_threads))
|
||||
return num_threads
|
||||
|
||||
|
||||
@overload(get_num_threads)
|
||||
def ol_get_num_threads():
|
||||
_launch_threads()
|
||||
|
||||
def impl():
|
||||
num_threads = _get_num_threads()
|
||||
if num_threads <= 0:
|
||||
print("Broken thread_id: ", get_thread_id())
|
||||
print("num_threads: ", num_threads)
|
||||
raise RuntimeError("Invalid number of threads. "
|
||||
"This likely indicates a bug in Numba.")
|
||||
return num_threads
|
||||
return impl
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _iget_num_threads(typingctx):
|
||||
_launch_threads()
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
mod = builder.module
|
||||
fnty = ir.FunctionType(cgutils.intp_t, [])
|
||||
fn = cgutils.get_or_insert_function(mod, fnty, "get_num_threads")
|
||||
return builder.call(fn, [])
|
||||
return signature(types.intp), codegen
|
||||
|
||||
|
||||
def get_thread_id():
|
||||
"""
|
||||
Returns a unique ID for each thread in the range 0 (inclusive)
|
||||
to :func:`~.get_num_threads` (exclusive).
|
||||
"""
|
||||
# Called from the interpreter directly, this should return 0
|
||||
# Called from a sequential JIT region, this should return 0
|
||||
# Called from a parallel JIT region, this should return 0..N
|
||||
# Called from objmode in a parallel JIT region, this should return 0..N
|
||||
_launch_threads()
|
||||
return _get_thread_id()
|
||||
|
||||
|
||||
@overload(get_thread_id)
|
||||
def ol_get_thread_id():
|
||||
_launch_threads()
|
||||
|
||||
def impl():
|
||||
return _iget_thread_id()
|
||||
return impl
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _iget_thread_id(typingctx):
|
||||
def codegen(context, builder, signature, args):
|
||||
mod = builder.module
|
||||
fnty = ir.FunctionType(cgutils.intp_t, [])
|
||||
fn = cgutils.get_or_insert_function(mod, fnty, "get_thread_id")
|
||||
return builder.call(fn, [])
|
||||
return signature(types.intp), codegen
|
||||
|
||||
|
||||
_DYLD_WORKAROUND_SET = 'NUMBA_DYLD_WORKAROUND' in os.environ
|
||||
_DYLD_WORKAROUND_VAL = int(os.environ.get('NUMBA_DYLD_WORKAROUND', 0))
|
||||
|
||||
if _DYLD_WORKAROUND_SET and _DYLD_WORKAROUND_VAL:
|
||||
_launch_threads()
|
||||
|
||||
|
||||
def set_parallel_chunksize(n):
|
||||
_launch_threads()
|
||||
if not isinstance(n, (int, np.integer)):
|
||||
raise TypeError("The parallel chunksize must be an integer")
|
||||
if n < 0:
|
||||
raise ValueError("chunksize must be greater than or equal to zero")
|
||||
return _set_parallel_chunksize(n)
|
||||
|
||||
|
||||
def get_parallel_chunksize():
|
||||
_launch_threads()
|
||||
return _get_parallel_chunksize()
|
||||
|
||||
|
||||
@overload(set_parallel_chunksize)
|
||||
def ol_set_parallel_chunksize(n):
|
||||
_launch_threads()
|
||||
if not isinstance(n, types.Integer):
|
||||
msg = "The parallel chunksize must be an integer"
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(n):
|
||||
if n < 0:
|
||||
raise ValueError("chunksize must be greater than or equal to zero")
|
||||
return _set_parallel_chunksize(n)
|
||||
return impl
|
||||
|
||||
|
||||
@overload(get_parallel_chunksize)
|
||||
def ol_get_parallel_chunksize():
|
||||
_launch_threads()
|
||||
|
||||
def impl():
|
||||
return _get_parallel_chunksize()
|
||||
return impl
|
||||
@@ -0,0 +1,63 @@
|
||||
import tokenize
|
||||
import string
|
||||
|
||||
|
||||
def parse_signature(sig):
|
||||
'''Parse generalized ufunc signature.
|
||||
|
||||
NOTE: ',' (COMMA) is a delimiter; not separator.
|
||||
This means trailing comma is legal.
|
||||
'''
|
||||
def stripws(s):
|
||||
return ''.join(c for c in s if c not in string.whitespace)
|
||||
|
||||
def tokenizer(src):
|
||||
def readline():
|
||||
yield src
|
||||
gen = readline()
|
||||
return tokenize.generate_tokens(lambda: next(gen))
|
||||
|
||||
def parse(src):
|
||||
tokgen = tokenizer(src)
|
||||
while True:
|
||||
tok = next(tokgen)
|
||||
if tok[1] == '(':
|
||||
symbols = []
|
||||
while True:
|
||||
tok = next(tokgen)
|
||||
if tok[1] == ')':
|
||||
break
|
||||
elif tok[0] == tokenize.NAME:
|
||||
symbols.append(tok[1])
|
||||
elif tok[1] == ',':
|
||||
continue
|
||||
else:
|
||||
raise ValueError('bad token in signature "%s"' % tok[1])
|
||||
yield tuple(symbols)
|
||||
tok = next(tokgen)
|
||||
if tok[1] == ',':
|
||||
continue
|
||||
elif tokenize.ISEOF(tok[0]):
|
||||
break
|
||||
elif tokenize.ISEOF(tok[0]):
|
||||
break
|
||||
else:
|
||||
raise ValueError('bad token in signature "%s"' % tok[1])
|
||||
|
||||
ins, _, outs = stripws(sig).partition('->')
|
||||
inputs = list(parse(ins))
|
||||
outputs = list(parse(outs))
|
||||
|
||||
# check that all output symbols are defined in the inputs
|
||||
isym = set()
|
||||
osym = set()
|
||||
for grp in inputs:
|
||||
isym |= set(grp)
|
||||
for grp in outputs:
|
||||
osym |= set(grp)
|
||||
|
||||
diff = osym.difference(isym)
|
||||
if diff:
|
||||
raise NameError('undefined output symbols: %s' % ','.join(sorted(diff)))
|
||||
|
||||
return inputs, outputs
|
||||
Binary file not shown.
@@ -0,0 +1,113 @@
|
||||
from numba.np import numpy_support
|
||||
from numba.core import types
|
||||
|
||||
|
||||
class UfuncLowererBase:
|
||||
'''Callable class responsible for lowering calls to a specific gufunc.
|
||||
'''
|
||||
def __init__(self, ufunc, make_kernel_fn, make_ufunc_kernel_fn):
|
||||
self.ufunc = ufunc
|
||||
self.make_ufunc_kernel_fn = make_ufunc_kernel_fn
|
||||
self.kernel = make_kernel_fn(ufunc)
|
||||
self.libs = []
|
||||
|
||||
def __call__(self, context, builder, sig, args):
|
||||
return self.make_ufunc_kernel_fn(context, builder, sig, args,
|
||||
self.ufunc, self.kernel)
|
||||
|
||||
|
||||
class UfuncBase:
|
||||
|
||||
@property
|
||||
def nin(self):
|
||||
return self.ufunc.nin
|
||||
|
||||
@property
|
||||
def nout(self):
|
||||
return self.ufunc.nout
|
||||
|
||||
@property
|
||||
def nargs(self):
|
||||
return self.ufunc.nargs
|
||||
|
||||
@property
|
||||
def ntypes(self):
|
||||
return self.ufunc.ntypes
|
||||
|
||||
@property
|
||||
def types(self):
|
||||
return self.ufunc.types
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.ufunc.identity
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
return self.ufunc.signature
|
||||
|
||||
@property
|
||||
def accumulate(self):
|
||||
return self.ufunc.accumulate
|
||||
|
||||
@property
|
||||
def at(self):
|
||||
return self.ufunc.at
|
||||
|
||||
@property
|
||||
def outer(self):
|
||||
return self.ufunc.outer
|
||||
|
||||
@property
|
||||
def reduce(self):
|
||||
return self.ufunc.reduce
|
||||
|
||||
@property
|
||||
def reduceat(self):
|
||||
return self.ufunc.reduceat
|
||||
|
||||
def disable_compile(self):
|
||||
"""
|
||||
Disable the compilation of new signatures at call time.
|
||||
"""
|
||||
# If disabling compilation then there must be at least one signature
|
||||
assert len(self._dispatcher.overloads) > 0
|
||||
self._frozen = True
|
||||
|
||||
def _install_cg(self, targetctx=None):
|
||||
"""
|
||||
Install an implementation function for a GUFunc/DUFunc object in the
|
||||
given target context. If no target context is given, then
|
||||
_install_cg() installs into the target context of the
|
||||
dispatcher object (should be same default context used by
|
||||
jit() and njit()).
|
||||
"""
|
||||
if targetctx is None:
|
||||
targetctx = self._dispatcher.targetdescr.target_context
|
||||
_any = types.Any
|
||||
_arr = types.Array
|
||||
# Either all outputs are explicit or none of them are
|
||||
sig0 = (_any,) * self.ufunc.nin + (_arr,) * self.ufunc.nout
|
||||
sig1 = (_any,) * self.ufunc.nin
|
||||
targetctx.insert_func_defn(
|
||||
[(self._lower_me, self, sig) for sig in (sig0, sig1)])
|
||||
|
||||
def find_ewise_function(self, ewise_types):
|
||||
"""
|
||||
Given a tuple of element-wise argument types, find a matching
|
||||
signature in the dispatcher.
|
||||
|
||||
Return a 2-tuple containing the matching signature, and
|
||||
compilation result. Will return two None's if no matching
|
||||
signature was found.
|
||||
"""
|
||||
if self._frozen:
|
||||
# If we cannot compile, coerce to the best matching loop
|
||||
loop = numpy_support.ufunc_find_matching_loop(self, ewise_types)
|
||||
if loop is None:
|
||||
return None, None
|
||||
ewise_types = tuple(loop.inputs + loop.outputs)[:len(ewise_types)]
|
||||
for sig, cres in self._dispatcher.overloads.items():
|
||||
if self.match_signature(ewise_types, sig):
|
||||
return sig, cres
|
||||
return None, None
|
||||
@@ -0,0 +1,444 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from numba.core import config, targetconfig
|
||||
from numba.core.decorators import jit
|
||||
from numba.core.descriptors import TargetDescriptor
|
||||
from numba.core.extending import is_jitted
|
||||
from numba.core.errors import NumbaDeprecationWarning
|
||||
from numba.core.options import TargetOptions, include_default_options
|
||||
from numba.core.registry import cpu_target
|
||||
from numba.core.target_extension import dispatcher_registry, target_registry
|
||||
from numba.core import utils, types, serialize, compiler, sigutils
|
||||
from numba.np.numpy_support import as_dtype
|
||||
from numba.np.ufunc import _internal
|
||||
from numba.np.ufunc.sigparse import parse_signature
|
||||
from numba.np.ufunc.wrappers import build_ufunc_wrapper, build_gufunc_wrapper
|
||||
from numba.core.caching import FunctionCache, NullCache
|
||||
from numba.core.compiler_lock import global_compiler_lock
|
||||
|
||||
|
||||
_options_mixin = include_default_options(
|
||||
"nopython",
|
||||
"forceobj",
|
||||
"boundscheck",
|
||||
"fastmath",
|
||||
"writable_args"
|
||||
)
|
||||
|
||||
|
||||
class UFuncTargetOptions(_options_mixin, TargetOptions):
|
||||
|
||||
def finalize(self, flags, options):
|
||||
if not flags.is_set("enable_pyobject"):
|
||||
flags.enable_pyobject = True
|
||||
|
||||
if not flags.is_set("enable_looplift"):
|
||||
flags.enable_looplift = True
|
||||
|
||||
flags.inherit_if_not_set("nrt", default=True)
|
||||
|
||||
if not flags.is_set("debuginfo"):
|
||||
flags.debuginfo = config.DEBUGINFO_DEFAULT
|
||||
|
||||
if not flags.is_set("boundscheck"):
|
||||
flags.boundscheck = flags.debuginfo
|
||||
|
||||
flags.enable_pyobject_looplift = True
|
||||
|
||||
flags.inherit_if_not_set("fastmath")
|
||||
|
||||
|
||||
class UFuncTarget(TargetDescriptor):
|
||||
options = UFuncTargetOptions
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('ufunc')
|
||||
|
||||
@property
|
||||
def typing_context(self):
|
||||
return cpu_target.typing_context
|
||||
|
||||
@property
|
||||
def target_context(self):
|
||||
return cpu_target.target_context
|
||||
|
||||
|
||||
ufunc_target = UFuncTarget()
|
||||
|
||||
|
||||
class UFuncDispatcher(serialize.ReduceMixin):
|
||||
"""
|
||||
An object handling compilation of various signatures for a ufunc.
|
||||
"""
|
||||
targetdescr = ufunc_target
|
||||
|
||||
def __init__(self, py_func, locals=None, targetoptions=None):
|
||||
if locals is None:
|
||||
locals = {}
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
self.py_func = py_func
|
||||
self.overloads = utils.UniqueDict()
|
||||
self.targetoptions = targetoptions
|
||||
self.locals = locals
|
||||
self.cache = NullCache()
|
||||
|
||||
def _reduce_states(self):
|
||||
"""
|
||||
NOTE: part of ReduceMixin protocol
|
||||
"""
|
||||
return dict(
|
||||
pyfunc=self.py_func,
|
||||
locals=self.locals,
|
||||
targetoptions=self.targetoptions,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _rebuild(cls, pyfunc, locals, targetoptions):
|
||||
"""
|
||||
NOTE: part of ReduceMixin protocol
|
||||
"""
|
||||
return cls(py_func=pyfunc, locals=locals, targetoptions=targetoptions)
|
||||
|
||||
def enable_caching(self):
|
||||
self.cache = FunctionCache(self.py_func)
|
||||
|
||||
def compile(self, sig, locals=None, **targetoptions):
|
||||
if locals is None:
|
||||
locals = {}
|
||||
locs = self.locals.copy()
|
||||
locs.update(locals)
|
||||
|
||||
topt = self.targetoptions.copy()
|
||||
topt.update(targetoptions)
|
||||
|
||||
flags = compiler.Flags()
|
||||
self.targetdescr.options.parse_as_flags(flags, topt)
|
||||
|
||||
flags.no_cpython_wrapper = True
|
||||
flags.error_model = "numpy"
|
||||
# Disable loop lifting
|
||||
# The feature requires a real
|
||||
# python function
|
||||
flags.enable_looplift = False
|
||||
|
||||
return self._compile_core(sig, flags, locals)
|
||||
|
||||
def _compile_core(self, sig, flags, locals):
|
||||
"""
|
||||
Trigger the compiler on the core function or load a previously
|
||||
compiled version from the cache. Returns the CompileResult.
|
||||
"""
|
||||
typingctx = self.targetdescr.typing_context
|
||||
targetctx = self.targetdescr.target_context
|
||||
|
||||
@contextmanager
|
||||
def store_overloads_on_success():
|
||||
# use to ensure overloads are stored on success
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
exists = self.overloads.get(cres.signature)
|
||||
if exists is None:
|
||||
self.overloads[cres.signature] = cres
|
||||
|
||||
# Use cache and compiler in a critical section
|
||||
with global_compiler_lock:
|
||||
with targetconfig.ConfigStack().enter(flags.copy()):
|
||||
with store_overloads_on_success():
|
||||
# attempt look up of existing
|
||||
cres = self.cache.load_overload(sig, targetctx)
|
||||
if cres is not None:
|
||||
return cres
|
||||
|
||||
# Compile
|
||||
args, return_type = sigutils.normalize_signature(sig)
|
||||
cres = compiler.compile_extra(typingctx, targetctx,
|
||||
self.py_func, args=args,
|
||||
return_type=return_type,
|
||||
flags=flags, locals=locals)
|
||||
|
||||
# cache lookup failed before so safe to save
|
||||
self.cache.save_overload(sig, cres)
|
||||
|
||||
return cres
|
||||
|
||||
|
||||
dispatcher_registry[target_registry['npyufunc']] = UFuncDispatcher
|
||||
|
||||
|
||||
# Utility functions
|
||||
|
||||
def _compile_element_wise_function(nb_func, targetoptions, sig):
|
||||
# Do compilation
|
||||
# Return CompileResult to test
|
||||
cres = nb_func.compile(sig, **targetoptions)
|
||||
args, return_type = sigutils.normalize_signature(sig)
|
||||
return cres, args, return_type
|
||||
|
||||
|
||||
def _finalize_ufunc_signature(cres, args, return_type):
|
||||
'''Given a compilation result, argument types, and a return type,
|
||||
build a valid Numba signature after validating that it doesn't
|
||||
violate the constraints for the compilation mode.
|
||||
'''
|
||||
if return_type is None:
|
||||
if cres.objectmode:
|
||||
# Object mode is used and return type is not specified
|
||||
raise TypeError("return type must be specified for object mode")
|
||||
else:
|
||||
return_type = cres.signature.return_type
|
||||
|
||||
assert return_type != types.pyobject
|
||||
return return_type(*args)
|
||||
|
||||
|
||||
def _build_element_wise_ufunc_wrapper(cres, signature):
|
||||
'''Build a wrapper for the ufunc loop entry point given by the
|
||||
compilation result object, using the element-wise signature.
|
||||
'''
|
||||
ctx = cres.target_context
|
||||
library = cres.library
|
||||
fname = cres.fndesc.llvm_func_name
|
||||
|
||||
with global_compiler_lock:
|
||||
info = build_ufunc_wrapper(library, ctx, fname, signature,
|
||||
cres.objectmode, cres)
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
# Get dtypes
|
||||
dtypenums = [as_dtype(a).num for a in signature.args]
|
||||
dtypenums.append(as_dtype(signature.return_type).num)
|
||||
return dtypenums, ptr, cres.environment
|
||||
|
||||
|
||||
_identities = {
|
||||
0: _internal.PyUFunc_Zero,
|
||||
1: _internal.PyUFunc_One,
|
||||
None: _internal.PyUFunc_None,
|
||||
"reorderable": _internal.PyUFunc_ReorderableNone,
|
||||
}
|
||||
|
||||
|
||||
def parse_identity(identity):
|
||||
"""
|
||||
Parse an identity value and return the corresponding low-level value
|
||||
for Numpy.
|
||||
"""
|
||||
try:
|
||||
identity = _identities[identity]
|
||||
except KeyError:
|
||||
raise ValueError("Invalid identity value %r" % (identity,))
|
||||
return identity
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _suppress_deprecation_warning_nopython_not_supplied():
|
||||
"""This suppresses the NumbaDeprecationWarning that occurs through the use
|
||||
of `jit` without the `nopython` kwarg. This use of `jit` occurs in a few
|
||||
places in the `{g,}ufunc` mechanism in Numba, predominantly to wrap the
|
||||
"kernel" function."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore',
|
||||
category=NumbaDeprecationWarning,
|
||||
message=(".*The 'nopython' keyword argument "
|
||||
"was not supplied*"),)
|
||||
yield
|
||||
|
||||
|
||||
# Class definitions
|
||||
|
||||
class _BaseUFuncBuilder(object):
|
||||
|
||||
def add(self, sig=None):
|
||||
if hasattr(self, 'targetoptions'):
|
||||
targetoptions = self.targetoptions
|
||||
else:
|
||||
targetoptions = self.nb_func.targetoptions
|
||||
cres, args, return_type = _compile_element_wise_function(
|
||||
self.nb_func, targetoptions, sig)
|
||||
sig = self._finalize_signature(cres, args, return_type)
|
||||
self._sigs.append(sig)
|
||||
self._cres[sig] = cres
|
||||
return cres
|
||||
|
||||
def disable_compile(self):
|
||||
"""
|
||||
Disable the compilation of new signatures at call time.
|
||||
"""
|
||||
# Override this for implementations that support lazy compilation
|
||||
|
||||
|
||||
class UFuncBuilder(_BaseUFuncBuilder):
|
||||
|
||||
def __init__(self, py_func, identity=None, cache=False, targetoptions=None):
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
if is_jitted(py_func):
|
||||
py_func = py_func.py_func
|
||||
self.py_func = py_func
|
||||
self.identity = parse_identity(identity)
|
||||
with _suppress_deprecation_warning_nopython_not_supplied():
|
||||
self.nb_func = jit(_target='npyufunc',
|
||||
cache=cache,
|
||||
**targetoptions)(py_func)
|
||||
self._sigs = []
|
||||
self._cres = {}
|
||||
|
||||
def _finalize_signature(self, cres, args, return_type):
|
||||
'''Slated for deprecation, use ufuncbuilder._finalize_ufunc_signature()
|
||||
instead.
|
||||
'''
|
||||
return _finalize_ufunc_signature(cres, args, return_type)
|
||||
|
||||
def build_ufunc(self):
|
||||
with global_compiler_lock:
|
||||
dtypelist = []
|
||||
ptrlist = []
|
||||
if not self.nb_func:
|
||||
raise TypeError("No definition")
|
||||
|
||||
# Get signature in the order they are added
|
||||
keepalive = []
|
||||
cres = None
|
||||
for sig in self._sigs:
|
||||
cres = self._cres[sig]
|
||||
dtypenums, ptr, env = self.build(cres, sig)
|
||||
dtypelist.append(dtypenums)
|
||||
ptrlist.append(int(ptr))
|
||||
keepalive.append((cres.library, env))
|
||||
|
||||
datlist = [None] * len(ptrlist)
|
||||
|
||||
if cres is None:
|
||||
argspec = inspect.getfullargspec(self.py_func)
|
||||
inct = len(argspec.args)
|
||||
else:
|
||||
inct = len(cres.signature.args)
|
||||
outct = 1
|
||||
|
||||
# Becareful that fromfunc does not provide full error checking yet.
|
||||
# If typenum is out-of-bound, we have nasty memory corruptions.
|
||||
# For instance, -1 for typenum will cause segfault.
|
||||
# If elements of type-list (2nd arg) is tuple instead,
|
||||
# there will also memory corruption. (Seems like code rewrite.)
|
||||
ufunc = _internal.fromfunc(
|
||||
self.py_func.__name__, self.py_func.__doc__,
|
||||
ptrlist, dtypelist, inct, outct, datlist,
|
||||
keepalive, self.identity,
|
||||
)
|
||||
|
||||
return ufunc
|
||||
|
||||
def build(self, cres, signature):
|
||||
'''Slated for deprecation, use
|
||||
ufuncbuilder._build_element_wise_ufunc_wrapper().
|
||||
'''
|
||||
return _build_element_wise_ufunc_wrapper(cres, signature)
|
||||
|
||||
|
||||
class GUFuncBuilder(_BaseUFuncBuilder):
|
||||
|
||||
# TODO handle scalar
|
||||
def __init__(self, py_func, signature, identity=None, cache=False,
|
||||
targetoptions=None, writable_args=()):
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
self.py_func = py_func
|
||||
self.identity = parse_identity(identity)
|
||||
with _suppress_deprecation_warning_nopython_not_supplied():
|
||||
self.nb_func = jit(_target='npyufunc', cache=cache)(py_func)
|
||||
self.signature = signature
|
||||
self.sin, self.sout = parse_signature(signature)
|
||||
self.targetoptions = targetoptions
|
||||
self.cache = cache
|
||||
self._sigs = []
|
||||
self._cres = {}
|
||||
|
||||
transform_arg = _get_transform_arg(py_func)
|
||||
self.writable_args = tuple([transform_arg(a) for a in writable_args])
|
||||
|
||||
def _finalize_signature(self, cres, args, return_type):
|
||||
if not cres.objectmode and cres.signature.return_type != types.void:
|
||||
raise TypeError("gufunc kernel must have void return type")
|
||||
|
||||
if return_type is None:
|
||||
return_type = types.void
|
||||
|
||||
return return_type(*args)
|
||||
|
||||
@global_compiler_lock
|
||||
def build_ufunc(self):
|
||||
type_list = []
|
||||
func_list = []
|
||||
if not self.nb_func:
|
||||
raise TypeError("No definition")
|
||||
|
||||
# Get signature in the order they are added
|
||||
keepalive = []
|
||||
for sig in self._sigs:
|
||||
cres = self._cres[sig]
|
||||
dtypenums, ptr, env = self.build(cres)
|
||||
type_list.append(dtypenums)
|
||||
func_list.append(int(ptr))
|
||||
keepalive.append((cres.library, env))
|
||||
|
||||
datalist = [None] * len(func_list)
|
||||
|
||||
nin = len(self.sin)
|
||||
nout = len(self.sout)
|
||||
|
||||
# Pass envs to fromfuncsig to bind to the lifetime of the ufunc object
|
||||
ufunc = _internal.fromfunc(
|
||||
self.py_func.__name__, self.py_func.__doc__,
|
||||
func_list, type_list, nin, nout, datalist,
|
||||
keepalive, self.identity, self.signature, self.writable_args
|
||||
)
|
||||
return ufunc
|
||||
|
||||
def build(self, cres):
|
||||
"""
|
||||
Returns (dtype numbers, function ptr, EnvironmentObject)
|
||||
"""
|
||||
# Builder wrapper for ufunc entry point
|
||||
signature = cres.signature
|
||||
info = build_gufunc_wrapper(
|
||||
self.py_func, cres, self.sin, self.sout,
|
||||
cache=self.cache, is_parfors=False,
|
||||
)
|
||||
|
||||
env = info.env
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
# Get dtypes
|
||||
dtypenums = []
|
||||
for a in signature.args:
|
||||
if isinstance(a, types.Array):
|
||||
ty = a.dtype
|
||||
else:
|
||||
ty = a
|
||||
dtypenums.append(as_dtype(ty).num)
|
||||
return dtypenums, ptr, env
|
||||
|
||||
|
||||
def _get_transform_arg(py_func):
|
||||
"""Return function that transform arg into index"""
|
||||
args = inspect.getfullargspec(py_func).args
|
||||
pos_by_arg = {arg: i for i, arg in enumerate(args)}
|
||||
|
||||
def transform_arg(arg):
|
||||
if isinstance(arg, int):
|
||||
return arg
|
||||
|
||||
try:
|
||||
return pos_by_arg[arg]
|
||||
except KeyError:
|
||||
msg = (f"Specified writable arg {arg} not found in arg list "
|
||||
f"{args} for function {py_func.__qualname__}")
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return transform_arg
|
||||
Binary file not shown.
@@ -0,0 +1,743 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from llvmlite.ir import Constant, IRBuilder
|
||||
from llvmlite import ir
|
||||
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.compiler_lock import global_compiler_lock
|
||||
from numba.core.caching import make_library_cache, NullCache
|
||||
|
||||
|
||||
_wrapper_info = namedtuple('_wrapper_info', ['library', 'env', 'name'])
|
||||
|
||||
|
||||
def _build_ufunc_loop_body(load, store, context, func, builder, arrays, out,
|
||||
offsets, store_offset, signature, pyapi, env):
|
||||
elems = load()
|
||||
|
||||
# Compute
|
||||
status, retval = context.call_conv.call_function(builder, func,
|
||||
signature.return_type,
|
||||
signature.args, elems)
|
||||
|
||||
# Store
|
||||
with builder.if_else(status.is_ok, likely=True) as (if_ok, if_error):
|
||||
with if_ok:
|
||||
store(retval)
|
||||
with if_error:
|
||||
gil = pyapi.gil_ensure()
|
||||
context.call_conv.raise_error(builder, pyapi, status)
|
||||
pyapi.gil_release(gil)
|
||||
|
||||
# increment indices
|
||||
for off, ary in zip(offsets, arrays):
|
||||
builder.store(builder.add(builder.load(off), ary.step), off)
|
||||
|
||||
builder.store(builder.add(builder.load(store_offset), out.step),
|
||||
store_offset)
|
||||
|
||||
return status.code
|
||||
|
||||
|
||||
def _build_ufunc_loop_body_objmode(load, store, context, func, builder,
|
||||
arrays, out, offsets, store_offset,
|
||||
signature, env, pyapi):
|
||||
elems = load()
|
||||
|
||||
# Compute
|
||||
_objargs = [types.pyobject] * len(signature.args)
|
||||
# We need to push the error indicator to avoid it messing with
|
||||
# the ufunc's execution. We restore it unless the ufunc raised
|
||||
# a new error.
|
||||
with pyapi.err_push(keep_new=True):
|
||||
status, retval = context.call_conv.call_function(builder, func,
|
||||
types.pyobject,
|
||||
_objargs, elems)
|
||||
# Release owned reference to arguments
|
||||
for elem in elems:
|
||||
pyapi.decref(elem)
|
||||
# NOTE: if an error occurred, it will be caught by the Numpy machinery
|
||||
|
||||
# Store
|
||||
store(retval)
|
||||
|
||||
# increment indices
|
||||
for off, ary in zip(offsets, arrays):
|
||||
builder.store(builder.add(builder.load(off), ary.step), off)
|
||||
|
||||
builder.store(builder.add(builder.load(store_offset), out.step),
|
||||
store_offset)
|
||||
|
||||
return status.code
|
||||
|
||||
|
||||
def build_slow_loop_body(context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi, env):
|
||||
def load():
|
||||
elems = [ary.load_direct(builder.load(off))
|
||||
for off, ary in zip(offsets, arrays)]
|
||||
return elems
|
||||
|
||||
def store(retval):
|
||||
out.store_direct(retval, builder.load(store_offset))
|
||||
|
||||
return _build_ufunc_loop_body(load, store, context, func, builder, arrays,
|
||||
out, offsets, store_offset, signature, pyapi,
|
||||
env=env)
|
||||
|
||||
|
||||
def build_obj_loop_body(context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi, envptr, env):
|
||||
env_body = context.get_env_body(builder, envptr)
|
||||
env_manager = pyapi.get_env_manager(env, env_body, envptr)
|
||||
|
||||
def load():
|
||||
# Load
|
||||
elems = [ary.load_direct(builder.load(off))
|
||||
for off, ary in zip(offsets, arrays)]
|
||||
# Box
|
||||
elems = [pyapi.from_native_value(t, v, env_manager)
|
||||
for v, t in zip(elems, signature.args)]
|
||||
return elems
|
||||
|
||||
def store(retval):
|
||||
is_ok = cgutils.is_not_null(builder, retval)
|
||||
# If an error is raised by the object mode ufunc, it will
|
||||
# simply get caught by the Numpy ufunc machinery.
|
||||
with builder.if_then(is_ok, likely=True):
|
||||
# Unbox
|
||||
native = pyapi.to_native_value(signature.return_type, retval)
|
||||
assert native.cleanup is None
|
||||
# Store
|
||||
out.store_direct(native.value, builder.load(store_offset))
|
||||
# Release owned reference
|
||||
pyapi.decref(retval)
|
||||
|
||||
return _build_ufunc_loop_body_objmode(load, store, context, func, builder,
|
||||
arrays, out, offsets, store_offset,
|
||||
signature, envptr, pyapi)
|
||||
|
||||
|
||||
def build_fast_loop_body(context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, ind, pyapi, env):
|
||||
def load():
|
||||
elems = [ary.load_aligned(ind)
|
||||
for ary in arrays]
|
||||
return elems
|
||||
|
||||
def store(retval):
|
||||
out.store_aligned(retval, ind)
|
||||
|
||||
return _build_ufunc_loop_body(load, store, context, func, builder, arrays,
|
||||
out, offsets, store_offset, signature, pyapi,
|
||||
env=env)
|
||||
|
||||
|
||||
def build_ufunc_wrapper(library, context, fname, signature, objmode, cres):
|
||||
"""
|
||||
Wrap the scalar function with a loop that iterates over the arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
(library, env, name)
|
||||
"""
|
||||
assert isinstance(fname, str)
|
||||
byte_t = ir.IntType(8)
|
||||
byte_ptr_t = ir.PointerType(byte_t)
|
||||
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
|
||||
intp_t = context.get_value_type(types.intp)
|
||||
intp_ptr_t = ir.PointerType(intp_t)
|
||||
|
||||
fnty = ir.FunctionType(ir.VoidType(), [byte_ptr_ptr_t, intp_ptr_t,
|
||||
intp_ptr_t, byte_ptr_t])
|
||||
|
||||
wrapperlib = context.codegen().create_library('ufunc_wrapper')
|
||||
wrapper_module = wrapperlib.create_ir_module('')
|
||||
if objmode:
|
||||
func_type = context.call_conv.get_function_type(
|
||||
types.pyobject, [types.pyobject] * len(signature.args))
|
||||
else:
|
||||
func_type = context.call_conv.get_function_type(
|
||||
signature.return_type, signature.args)
|
||||
|
||||
func = ir.Function(wrapper_module, func_type, name=fname)
|
||||
func.attributes.add("alwaysinline")
|
||||
|
||||
wrapper = ir.Function(wrapper_module, fnty, "__ufunc__." + func.name)
|
||||
arg_args, arg_dims, arg_steps, arg_data = wrapper.args
|
||||
arg_args.name = "args"
|
||||
arg_dims.name = "dims"
|
||||
arg_steps.name = "steps"
|
||||
arg_data.name = "data"
|
||||
|
||||
builder = IRBuilder(wrapper.append_basic_block("entry"))
|
||||
|
||||
# Prepare Environment
|
||||
envname = context.get_env_name(cres.fndesc)
|
||||
env = cres.environment
|
||||
envptr = builder.load(context.declare_env_global(builder.module, envname))
|
||||
|
||||
# Emit loop
|
||||
loopcount = builder.load(arg_dims, name="loopcount")
|
||||
|
||||
# Prepare inputs
|
||||
arrays = []
|
||||
for i, typ in enumerate(signature.args):
|
||||
arrays.append(UArrayArg(context, builder, arg_args, arg_steps, i, typ))
|
||||
|
||||
# Prepare output
|
||||
out = UArrayArg(context, builder, arg_args, arg_steps, len(arrays),
|
||||
signature.return_type)
|
||||
|
||||
# Setup indices
|
||||
offsets = []
|
||||
zero = context.get_constant(types.intp, 0)
|
||||
for _ in arrays:
|
||||
p = cgutils.alloca_once(builder, intp_t)
|
||||
offsets.append(p)
|
||||
builder.store(zero, p)
|
||||
|
||||
store_offset = cgutils.alloca_once(builder, intp_t)
|
||||
builder.store(zero, store_offset)
|
||||
|
||||
unit_strided = cgutils.true_bit
|
||||
for ary in arrays:
|
||||
unit_strided = builder.and_(unit_strided, ary.is_unit_strided)
|
||||
|
||||
pyapi = context.get_python_api(builder)
|
||||
if objmode:
|
||||
# General loop
|
||||
gil = pyapi.gil_ensure()
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t):
|
||||
build_obj_loop_body(
|
||||
context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi, envptr, env,
|
||||
)
|
||||
pyapi.gil_release(gil)
|
||||
builder.ret_void()
|
||||
|
||||
else:
|
||||
with builder.if_else(unit_strided) as (is_unit_strided, is_strided):
|
||||
with is_unit_strided:
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t) as loop:
|
||||
build_fast_loop_body(
|
||||
context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, loop.index, pyapi,
|
||||
env=envptr,
|
||||
)
|
||||
|
||||
with is_strided:
|
||||
# General loop
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t):
|
||||
build_slow_loop_body(
|
||||
context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi,
|
||||
env=envptr,
|
||||
)
|
||||
|
||||
builder.ret_void()
|
||||
del builder
|
||||
|
||||
# Link and finalize
|
||||
wrapperlib.add_ir_module(wrapper_module)
|
||||
wrapperlib.add_linking_library(library)
|
||||
return _wrapper_info(library=wrapperlib, env=env, name=wrapper.name)
|
||||
|
||||
|
||||
class UArrayArg(object):
|
||||
def __init__(self, context, builder, args, steps, i, fe_type):
|
||||
self.context = context
|
||||
self.builder = builder
|
||||
self.fe_type = fe_type
|
||||
offset = self.context.get_constant(types.intp, i)
|
||||
offseted_args = self.builder.load(builder.gep(args, [offset]))
|
||||
data_type = context.get_data_type(fe_type)
|
||||
self.dataptr = self.builder.bitcast(offseted_args,
|
||||
data_type.as_pointer())
|
||||
sizeof = self.context.get_abi_sizeof(data_type)
|
||||
self.abisize = self.context.get_constant(types.intp, sizeof)
|
||||
offseted_step = self.builder.gep(steps, [offset])
|
||||
self.step = self.builder.load(offseted_step)
|
||||
self.is_unit_strided = builder.icmp_unsigned('==',
|
||||
self.abisize, self.step)
|
||||
self.builder = builder
|
||||
|
||||
def load_direct(self, byteoffset):
|
||||
"""
|
||||
Generic load from the given *byteoffset*. load_aligned() is
|
||||
preferred if possible.
|
||||
"""
|
||||
ptr = cgutils.pointer_add(self.builder, self.dataptr, byteoffset)
|
||||
return self.context.unpack_value(self.builder, self.fe_type, ptr)
|
||||
|
||||
def load_aligned(self, ind):
|
||||
# Using gep() instead of explicit pointer addition helps LLVM
|
||||
# vectorize the loop.
|
||||
ptr = self.builder.gep(self.dataptr, [ind])
|
||||
return self.context.unpack_value(self.builder, self.fe_type, ptr)
|
||||
|
||||
def store_direct(self, value, byteoffset):
|
||||
ptr = cgutils.pointer_add(self.builder, self.dataptr, byteoffset)
|
||||
self.context.pack_value(self.builder, self.fe_type, value, ptr)
|
||||
|
||||
def store_aligned(self, value, ind):
|
||||
ptr = self.builder.gep(self.dataptr, [ind])
|
||||
self.context.pack_value(self.builder, self.fe_type, value, ptr)
|
||||
|
||||
|
||||
GufWrapperCache = make_library_cache('guf')
|
||||
|
||||
|
||||
class _GufuncWrapper(object):
|
||||
def __init__(self, py_func, cres, sin, sout, cache, is_parfors):
|
||||
"""
|
||||
The *is_parfors* argument is a boolean that indicates if the GUfunc
|
||||
being built is to be used as a ParFors kernel. If True, it disables
|
||||
the caching on the wrapper as a separate unit because it will be linked
|
||||
into the caller function and cached along with it.
|
||||
"""
|
||||
self.py_func = py_func
|
||||
self.cres = cres
|
||||
self.sin = sin
|
||||
self.sout = sout
|
||||
self.is_objectmode = self.signature.return_type == types.pyobject
|
||||
self.cache = (GufWrapperCache(py_func=self.py_func)
|
||||
if cache else NullCache())
|
||||
self.is_parfors = bool(is_parfors)
|
||||
|
||||
@property
|
||||
def library(self):
|
||||
return self.cres.library
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
return self.cres.target_context
|
||||
|
||||
@property
|
||||
def call_conv(self):
|
||||
return self.context.call_conv
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
return self.cres.signature
|
||||
|
||||
@property
|
||||
def fndesc(self):
|
||||
return self.cres.fndesc
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self.cres.environment
|
||||
|
||||
def _wrapper_function_type(self):
|
||||
byte_t = ir.IntType(8)
|
||||
byte_ptr_t = ir.PointerType(byte_t)
|
||||
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
|
||||
intp_t = self.context.get_value_type(types.intp)
|
||||
intp_ptr_t = ir.PointerType(intp_t)
|
||||
|
||||
fnty = ir.FunctionType(ir.VoidType(), [byte_ptr_ptr_t, intp_ptr_t,
|
||||
intp_ptr_t, byte_ptr_t])
|
||||
return fnty
|
||||
|
||||
def _build_wrapper(self, library, name):
|
||||
"""
|
||||
The LLVM IRBuilder code to create the gufunc wrapper.
|
||||
The *library* arg is the CodeLibrary to which the wrapper should
|
||||
be added. The *name* arg is the name of the wrapper function being
|
||||
created.
|
||||
"""
|
||||
intp_t = self.context.get_value_type(types.intp)
|
||||
fnty = self._wrapper_function_type()
|
||||
|
||||
wrapper_module = library.create_ir_module('_gufunc_wrapper')
|
||||
func_type = self.call_conv.get_function_type(self.fndesc.restype,
|
||||
self.fndesc.argtypes)
|
||||
fname = self.fndesc.llvm_func_name
|
||||
func = ir.Function(wrapper_module, func_type, name=fname)
|
||||
|
||||
func.attributes.add("alwaysinline")
|
||||
wrapper = ir.Function(wrapper_module, fnty, name)
|
||||
# The use of weak_odr linkage avoids the function being dropped due
|
||||
# to the order in which the wrappers and the user function are linked.
|
||||
wrapper.linkage = 'weak_odr'
|
||||
arg_args, arg_dims, arg_steps, arg_data = wrapper.args
|
||||
arg_args.name = "args"
|
||||
arg_dims.name = "dims"
|
||||
arg_steps.name = "steps"
|
||||
arg_data.name = "data"
|
||||
|
||||
builder = IRBuilder(wrapper.append_basic_block("entry"))
|
||||
loopcount = builder.load(arg_dims, name="loopcount")
|
||||
pyapi = self.context.get_python_api(builder)
|
||||
|
||||
# Unpack shapes
|
||||
unique_syms = set()
|
||||
for grp in (self.sin, self.sout):
|
||||
for syms in grp:
|
||||
unique_syms |= set(syms)
|
||||
|
||||
sym_map = {}
|
||||
for syms in self.sin:
|
||||
for s in syms:
|
||||
if s not in sym_map:
|
||||
sym_map[s] = len(sym_map)
|
||||
|
||||
sym_dim = {}
|
||||
for s, i in sym_map.items():
|
||||
sym_dim[s] = builder.load(builder.gep(arg_dims,
|
||||
[self.context.get_constant(
|
||||
types.intp,
|
||||
i + 1)]))
|
||||
|
||||
# Prepare inputs
|
||||
arrays = []
|
||||
step_offset = len(self.sin) + len(self.sout)
|
||||
for i, (typ, sym) in enumerate(zip(self.signature.args,
|
||||
self.sin + self.sout)):
|
||||
ary = GUArrayArg(self.context, builder, arg_args,
|
||||
arg_steps, i, step_offset, typ, sym, sym_dim)
|
||||
step_offset += len(sym)
|
||||
arrays.append(ary)
|
||||
|
||||
bbreturn = builder.append_basic_block('.return')
|
||||
|
||||
# Prologue
|
||||
self.gen_prologue(builder, pyapi)
|
||||
|
||||
# Loop
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t) as loop:
|
||||
args = [a.get_array_at_offset(loop.index) for a in arrays]
|
||||
innercall, error = self.gen_loop_body(builder, pyapi, func, args)
|
||||
# If error, escape
|
||||
cgutils.cbranch_or_continue(builder, error, bbreturn)
|
||||
|
||||
builder.branch(bbreturn)
|
||||
builder.position_at_end(bbreturn)
|
||||
|
||||
# Epilogue
|
||||
self.gen_epilogue(builder, pyapi)
|
||||
|
||||
builder.ret_void()
|
||||
|
||||
# Link
|
||||
library.add_ir_module(wrapper_module)
|
||||
library.add_linking_library(self.library)
|
||||
|
||||
def _compile_wrapper(self, wrapper_name):
|
||||
# Gufunc created by Parfors?
|
||||
if self.is_parfors:
|
||||
# No wrapper caching for parfors
|
||||
wrapperlib = self.context.codegen().create_library(str(self))
|
||||
# Build wrapper
|
||||
self._build_wrapper(wrapperlib, wrapper_name)
|
||||
# Non-parfors?
|
||||
else:
|
||||
# Use cache and compiler in a critical section
|
||||
wrapperlib = self.cache.load_overload(
|
||||
self.cres.signature, self.cres.target_context,
|
||||
)
|
||||
if wrapperlib is None:
|
||||
# Create library and enable caching
|
||||
wrapperlib = self.context.codegen().create_library(str(self))
|
||||
wrapperlib.enable_object_caching()
|
||||
# Build wrapper
|
||||
self._build_wrapper(wrapperlib, wrapper_name)
|
||||
# Cache
|
||||
self.cache.save_overload(self.cres.signature, wrapperlib)
|
||||
|
||||
return wrapperlib
|
||||
|
||||
@global_compiler_lock
|
||||
def build(self):
|
||||
wrapper_name = "__gufunc__." + self.fndesc.mangled_name
|
||||
wrapperlib = self._compile_wrapper(wrapper_name)
|
||||
return _wrapper_info(
|
||||
library=wrapperlib, env=self.env, name=wrapper_name,
|
||||
)
|
||||
|
||||
def gen_loop_body(self, builder, pyapi, func, args):
|
||||
status, retval = self.call_conv.call_function(
|
||||
builder, func, self.signature.return_type, self.signature.args,
|
||||
args)
|
||||
|
||||
with builder.if_then(status.is_error, likely=False):
|
||||
gil = pyapi.gil_ensure()
|
||||
self.context.call_conv.raise_error(builder, pyapi, status)
|
||||
pyapi.gil_release(gil)
|
||||
|
||||
return status.code, status.is_error
|
||||
|
||||
def gen_prologue(self, builder, pyapi):
|
||||
pass # Do nothing
|
||||
|
||||
def gen_epilogue(self, builder, pyapi):
|
||||
pass # Do nothing
|
||||
|
||||
|
||||
class _GufuncObjectWrapper(_GufuncWrapper):
|
||||
def gen_loop_body(self, builder, pyapi, func, args):
|
||||
innercall, error = _prepare_call_to_object_mode(self.context,
|
||||
builder, pyapi, func,
|
||||
self.signature,
|
||||
args)
|
||||
return innercall, error
|
||||
|
||||
def gen_prologue(self, builder, pyapi):
|
||||
# Acquire the GIL
|
||||
self.gil = pyapi.gil_ensure()
|
||||
|
||||
def gen_epilogue(self, builder, pyapi):
|
||||
# Release GIL
|
||||
pyapi.gil_release(self.gil)
|
||||
|
||||
|
||||
def build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors):
|
||||
signature = cres.signature
|
||||
wrapcls = (_GufuncObjectWrapper
|
||||
if signature.return_type == types.pyobject
|
||||
else _GufuncWrapper)
|
||||
return wrapcls(
|
||||
py_func, cres, sin, sout, cache, is_parfors=is_parfors,
|
||||
).build()
|
||||
|
||||
|
||||
def _prepare_call_to_object_mode(context, builder, pyapi, func,
|
||||
signature, args):
|
||||
mod = builder.module
|
||||
|
||||
bb_core_return = builder.append_basic_block('ufunc.core.return')
|
||||
|
||||
# Call to
|
||||
# PyObject* ndarray_new(int nd,
|
||||
# npy_intp *dims, /* shape */
|
||||
# npy_intp *strides,
|
||||
# void* data,
|
||||
# int type_num,
|
||||
# int itemsize)
|
||||
|
||||
ll_int = context.get_value_type(types.int32)
|
||||
ll_intp = context.get_value_type(types.intp)
|
||||
ll_intp_ptr = ir.PointerType(ll_intp)
|
||||
ll_voidptr = context.get_value_type(types.voidptr)
|
||||
ll_pyobj = context.get_value_type(types.pyobject)
|
||||
fnty = ir.FunctionType(ll_pyobj, [ll_int, ll_intp_ptr,
|
||||
ll_intp_ptr, ll_voidptr,
|
||||
ll_int, ll_int])
|
||||
|
||||
fn_array_new = cgutils.get_or_insert_function(mod, fnty,
|
||||
"numba_ndarray_new")
|
||||
|
||||
# Convert each llarray into pyobject
|
||||
error_pointer = cgutils.alloca_once(builder, ir.IntType(1), name='error')
|
||||
builder.store(cgutils.true_bit, error_pointer)
|
||||
|
||||
# The PyObject* arguments to the kernel function
|
||||
object_args = []
|
||||
object_pointers = []
|
||||
|
||||
for i, (arg, argty) in enumerate(zip(args, signature.args)):
|
||||
# Allocate NULL-initialized slot for this argument
|
||||
objptr = cgutils.alloca_once(builder, ll_pyobj, zfill=True)
|
||||
object_pointers.append(objptr)
|
||||
|
||||
if isinstance(argty, types.Array):
|
||||
# Special case arrays: we don't need full-blown NRT reflection
|
||||
# since the argument will be gone at the end of the kernel
|
||||
arycls = context.make_array(argty)
|
||||
array = arycls(context, builder, value=arg)
|
||||
|
||||
zero = Constant(ll_int, 0)
|
||||
|
||||
# Extract members of the llarray
|
||||
nd = Constant(ll_int, argty.ndim)
|
||||
dims = builder.gep(array._get_ptr_by_name('shape'), [zero, zero])
|
||||
strides = builder.gep(array._get_ptr_by_name('strides'),
|
||||
[zero, zero])
|
||||
data = builder.bitcast(array.data, ll_voidptr)
|
||||
dtype = np.dtype(str(argty.dtype))
|
||||
|
||||
# Prepare other info for reconstruction of the PyArray
|
||||
type_num = Constant(ll_int, dtype.num)
|
||||
itemsize = Constant(ll_int, dtype.itemsize)
|
||||
|
||||
# Call helper to reconstruct PyArray objects
|
||||
obj = builder.call(fn_array_new, [nd, dims, strides, data,
|
||||
type_num, itemsize])
|
||||
else:
|
||||
# Other argument types => use generic boxing
|
||||
obj = pyapi.from_native_value(argty, arg)
|
||||
|
||||
builder.store(obj, objptr)
|
||||
object_args.append(obj)
|
||||
|
||||
obj_is_null = cgutils.is_null(builder, obj)
|
||||
builder.store(obj_is_null, error_pointer)
|
||||
cgutils.cbranch_or_continue(builder, obj_is_null, bb_core_return)
|
||||
|
||||
# Call ufunc core function
|
||||
object_sig = [types.pyobject] * len(object_args)
|
||||
|
||||
status, retval = context.call_conv.call_function(
|
||||
builder, func, types.pyobject, object_sig,
|
||||
object_args)
|
||||
builder.store(status.is_error, error_pointer)
|
||||
|
||||
# Release returned object
|
||||
pyapi.decref(retval)
|
||||
|
||||
builder.branch(bb_core_return)
|
||||
# At return block
|
||||
builder.position_at_end(bb_core_return)
|
||||
|
||||
# Release argument objects
|
||||
for objptr in object_pointers:
|
||||
pyapi.decref(builder.load(objptr))
|
||||
|
||||
innercall = status.code
|
||||
return innercall, builder.load(error_pointer)
|
||||
|
||||
|
||||
class GUArrayArg(object):
|
||||
def __init__(self, context, builder, args, steps, i, step_offset,
|
||||
typ, syms, sym_dim):
|
||||
|
||||
self.context = context
|
||||
self.builder = builder
|
||||
|
||||
offset = context.get_constant(types.intp, i)
|
||||
|
||||
data = builder.load(builder.gep(args, [offset], name="data.ptr"),
|
||||
name="data")
|
||||
self.data = data
|
||||
|
||||
core_step_ptr = builder.gep(steps, [offset], name="core.step.ptr")
|
||||
core_step = builder.load(core_step_ptr)
|
||||
|
||||
if isinstance(typ, types.Array):
|
||||
as_scalar = not syms
|
||||
|
||||
# number of symbol in the shape spec should match the dimension
|
||||
# of the array type.
|
||||
if len(syms) != typ.ndim:
|
||||
if len(syms) == 0 and typ.ndim == 1:
|
||||
# This is an exception for handling scalar argument.
|
||||
# The type can be 1D array for scalar.
|
||||
# In the future, we may deprecate this exception.
|
||||
pass
|
||||
else:
|
||||
raise TypeError("type and shape signature mismatch for arg "
|
||||
"#{0}".format(i + 1))
|
||||
|
||||
ndim = typ.ndim
|
||||
shape = [sym_dim[s] for s in syms]
|
||||
strides = []
|
||||
|
||||
for j in range(ndim):
|
||||
stepptr = builder.gep(steps,
|
||||
[context.get_constant(types.intp,
|
||||
step_offset + j)],
|
||||
name="step.ptr")
|
||||
step = builder.load(stepptr)
|
||||
strides.append(step)
|
||||
|
||||
ldcls = (_ArrayAsScalarArgLoader
|
||||
if as_scalar
|
||||
else _ArrayArgLoader)
|
||||
|
||||
self._loader = ldcls(dtype=typ.dtype,
|
||||
ndim=ndim,
|
||||
core_step=core_step,
|
||||
as_scalar=as_scalar,
|
||||
shape=shape,
|
||||
strides=strides)
|
||||
else:
|
||||
# If typ is not an array
|
||||
if syms:
|
||||
raise TypeError("scalar type {0} given for non scalar "
|
||||
"argument #{1}".format(typ, i + 1))
|
||||
self._loader = _ScalarArgLoader(dtype=typ, stride=core_step)
|
||||
|
||||
def get_array_at_offset(self, ind):
|
||||
return self._loader.load(context=self.context, builder=self.builder,
|
||||
data=self.data, ind=ind)
|
||||
|
||||
|
||||
class _ScalarArgLoader(object):
|
||||
"""
|
||||
Handle GFunc argument loading where a scalar type is used in the core
|
||||
function.
|
||||
Note: It still has a stride because the input to the gufunc can be an array
|
||||
for this argument.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, stride):
|
||||
self.dtype = dtype
|
||||
self.stride = stride
|
||||
|
||||
def load(self, context, builder, data, ind):
|
||||
# Load at base + ind * stride
|
||||
data = builder.gep(data, [builder.mul(ind, self.stride)])
|
||||
dptr = builder.bitcast(data,
|
||||
context.get_data_type(self.dtype).as_pointer())
|
||||
return builder.load(dptr)
|
||||
|
||||
|
||||
class _ArrayArgLoader(object):
|
||||
"""
|
||||
Handle GUFunc argument loading where an array is expected.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, ndim, core_step, as_scalar, shape, strides):
|
||||
self.dtype = dtype
|
||||
self.ndim = ndim
|
||||
self.core_step = core_step
|
||||
self.as_scalar = as_scalar
|
||||
self.shape = shape
|
||||
self.strides = strides
|
||||
|
||||
def load(self, context, builder, data, ind):
|
||||
arytyp = types.Array(dtype=self.dtype, ndim=self.ndim, layout="A")
|
||||
arycls = context.make_array(arytyp)
|
||||
|
||||
array = arycls(context, builder)
|
||||
offseted_data = cgutils.pointer_add(builder,
|
||||
data,
|
||||
builder.mul(self.core_step,
|
||||
ind))
|
||||
|
||||
shape, strides = self._shape_and_strides(context, builder)
|
||||
|
||||
itemsize = context.get_abi_sizeof(context.get_data_type(self.dtype))
|
||||
context.populate_array(array,
|
||||
data=builder.bitcast(offseted_data,
|
||||
array.data.type),
|
||||
shape=shape,
|
||||
strides=strides,
|
||||
itemsize=context.get_constant(types.intp,
|
||||
itemsize),
|
||||
meminfo=None)
|
||||
|
||||
return array._getvalue()
|
||||
|
||||
def _shape_and_strides(self, context, builder):
|
||||
shape = cgutils.pack_array(builder, self.shape)
|
||||
strides = cgutils.pack_array(builder, self.strides)
|
||||
return shape, strides
|
||||
|
||||
|
||||
class _ArrayAsScalarArgLoader(_ArrayArgLoader):
|
||||
"""
|
||||
Handle GUFunc argument loading where the shape signature specifies
|
||||
a scalar "()" but a 1D array is used for the type of the core function.
|
||||
"""
|
||||
|
||||
def _shape_and_strides(self, context, builder):
|
||||
# Set shape and strides for a 1D size 1 array
|
||||
one = context.get_constant(types.intp, 1)
|
||||
zero = context.get_constant(types.intp, 0)
|
||||
shape = cgutils.pack_array(builder, [one])
|
||||
strides = cgutils.pack_array(builder, [zero])
|
||||
return shape, strides
|
||||
1211
linedance-app/venv/lib/python3.12/site-packages/numba/np/ufunc_db.py
Normal file
1211
linedance-app/venv/lib/python3.12/site-packages/numba/np/ufunc_db.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This file provides internal compiler utilities that support certain special
|
||||
operations with numpy.
|
||||
"""
|
||||
from numba.core import types, typing
|
||||
from numba.core.cgutils import unpack_tuple
|
||||
from numba.core.extending import intrinsic
|
||||
from numba.core.imputils import impl_ret_new_ref
|
||||
from numba.core.errors import RequireLiteralValue, TypingError
|
||||
|
||||
from numba.cpython.unsafe.tuple import tuple_setitem
|
||||
|
||||
|
||||
@intrinsic
|
||||
def empty_inferred(typingctx, shape):
|
||||
"""A version of numpy.empty whose dtype is inferred by the type system.
|
||||
|
||||
Expects `shape` to be a int-tuple.
|
||||
|
||||
There is special logic in the type-inferencer to handle the "refine"-ing
|
||||
of undefined dtype.
|
||||
"""
|
||||
from numba.np.arrayobj import _empty_nd_impl
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
# check that the return type is now defined
|
||||
arrty = signature.return_type
|
||||
assert arrty.is_precise()
|
||||
shapes = unpack_tuple(builder, args[0])
|
||||
# redirect implementation to np.empty
|
||||
res = _empty_nd_impl(context, builder, arrty, shapes)
|
||||
return impl_ret_new_ref(context, builder, arrty, res._getvalue())
|
||||
|
||||
# make function signature
|
||||
nd = len(shape)
|
||||
array_ty = types.Array(ndim=nd, layout='C', dtype=types.undefined)
|
||||
sig = array_ty(shape)
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def to_fixed_tuple(typingctx, array, length):
|
||||
"""Convert *array* into a tuple of *length*
|
||||
|
||||
Returns ``UniTuple(array.dtype, length)``
|
||||
|
||||
** Warning **
|
||||
- No boundchecking.
|
||||
If *length* is longer than *array.size*, the behavior is undefined.
|
||||
"""
|
||||
if not isinstance(length, types.IntegerLiteral):
|
||||
raise RequireLiteralValue('*length* argument must be a constant')
|
||||
|
||||
if array.ndim != 1:
|
||||
raise TypingError("Not supported on array.ndim={}".format(array.ndim))
|
||||
|
||||
# Determine types
|
||||
tuple_size = int(length.literal_value)
|
||||
tuple_type = types.UniTuple(dtype=array.dtype, count=tuple_size)
|
||||
sig = tuple_type(array, length)
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
def impl(array, length, empty_tuple):
|
||||
out = empty_tuple
|
||||
for i in range(length):
|
||||
out = tuple_setitem(out, i, array[i])
|
||||
return out
|
||||
|
||||
inner_argtypes = [signature.args[0], types.intp, tuple_type]
|
||||
inner_sig = typing.signature(tuple_type, *inner_argtypes)
|
||||
ll_idx_type = context.get_value_type(types.intp)
|
||||
# Allocate an empty tuple
|
||||
empty_tuple = context.get_constant_undef(tuple_type)
|
||||
inner_args = [args[0], ll_idx_type(tuple_size), empty_tuple]
|
||||
|
||||
res = context.compile_internal(builder, impl, inner_sig, inner_args)
|
||||
return res
|
||||
|
||||
return sig, codegen
|
||||
Reference in New Issue
Block a user