Videre
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,542 @@
|
||||
"""
|
||||
Implement the cmath module functions.
|
||||
"""
|
||||
|
||||
|
||||
import cmath
|
||||
import math
|
||||
|
||||
from numba.core.imputils import Registry, impl_ret_untracked
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.typing import signature
|
||||
from numba.cpython import builtins, mathimpl
|
||||
from numba.core.extending import overload
|
||||
|
||||
registry = Registry('cmathimpl')
|
||||
lower = registry.lower
|
||||
|
||||
|
||||
def is_nan(builder, z):
|
||||
return builder.fcmp_unordered('uno', z.real, z.imag)
|
||||
|
||||
def is_inf(builder, z):
|
||||
return builder.or_(mathimpl.is_inf(builder, z.real),
|
||||
mathimpl.is_inf(builder, z.imag))
|
||||
|
||||
def is_finite(builder, z):
|
||||
return builder.and_(mathimpl.is_finite(builder, z.real),
|
||||
mathimpl.is_finite(builder, z.imag))
|
||||
|
||||
|
||||
@lower(cmath.isnan, types.Complex)
|
||||
def isnan_float_impl(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
res = is_nan(builder, z)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower(cmath.isinf, types.Complex)
|
||||
def isinf_float_impl(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
res = is_inf(builder, z)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(cmath.isfinite, types.Complex)
|
||||
def isfinite_float_impl(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
res = is_finite(builder, z)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@overload(cmath.rect)
|
||||
def impl_cmath_rect(r, phi):
|
||||
if all([isinstance(typ, types.Float) for typ in [r, phi]]):
|
||||
def impl(r, phi):
|
||||
if not math.isfinite(phi):
|
||||
if not r:
|
||||
# cmath.rect(0, phi={inf, nan}) = 0
|
||||
return abs(r)
|
||||
if math.isinf(r):
|
||||
# cmath.rect(inf, phi={inf, nan}) = inf + j phi
|
||||
return complex(r, phi)
|
||||
real = math.cos(phi)
|
||||
imag = math.sin(phi)
|
||||
if real == 0. and math.isinf(r):
|
||||
# 0 * inf would return NaN, we want to keep 0 but xor the sign
|
||||
real /= r
|
||||
else:
|
||||
real *= r
|
||||
if imag == 0. and math.isinf(r):
|
||||
# ditto
|
||||
imag /= r
|
||||
else:
|
||||
imag *= r
|
||||
return complex(real, imag)
|
||||
return impl
|
||||
|
||||
|
||||
def intrinsic_complex_unary(inner_func):
|
||||
def wrapper(context, builder, sig, args):
|
||||
[typ] = sig.args
|
||||
[value] = args
|
||||
z = context.make_complex(builder, typ, value=value)
|
||||
x = z.real
|
||||
y = z.imag
|
||||
# Same as above: math.isfinite() is unavailable on 2.x so we precompute
|
||||
# its value and pass it to the pure Python implementation.
|
||||
x_is_finite = mathimpl.is_finite(builder, x)
|
||||
y_is_finite = mathimpl.is_finite(builder, y)
|
||||
inner_sig = signature(sig.return_type,
|
||||
*(typ.underlying_float,) * 2 + (types.boolean,) * 2)
|
||||
res = context.compile_internal(builder, inner_func, inner_sig,
|
||||
(x, y, x_is_finite, y_is_finite))
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
return wrapper
|
||||
|
||||
|
||||
NAN = float('nan')
|
||||
INF = float('inf')
|
||||
|
||||
@lower(cmath.exp, types.Complex)
|
||||
@intrinsic_complex_unary
|
||||
def exp_impl(x, y, x_is_finite, y_is_finite):
|
||||
"""cmath.exp(x + y j)"""
|
||||
if x_is_finite:
|
||||
if y_is_finite:
|
||||
c = math.cos(y)
|
||||
s = math.sin(y)
|
||||
r = math.exp(x)
|
||||
return complex(r * c, r * s)
|
||||
else:
|
||||
return complex(NAN, NAN)
|
||||
elif math.isnan(x):
|
||||
if y:
|
||||
return complex(x, x) # nan + j nan
|
||||
else:
|
||||
return complex(x, y) # nan + 0j
|
||||
elif x > 0.0:
|
||||
# x == +inf
|
||||
if y_is_finite:
|
||||
real = math.cos(y)
|
||||
imag = math.sin(y)
|
||||
# Avoid NaNs if math.cos(y) or math.sin(y) == 0
|
||||
# (e.g. cmath.exp(inf + 0j) == inf + 0j)
|
||||
if real != 0:
|
||||
real *= x
|
||||
if imag != 0:
|
||||
imag *= x
|
||||
return complex(real, imag)
|
||||
else:
|
||||
return complex(x, NAN)
|
||||
else:
|
||||
# x == -inf
|
||||
if y_is_finite:
|
||||
r = math.exp(x)
|
||||
c = math.cos(y)
|
||||
s = math.sin(y)
|
||||
return complex(r * c, r * s)
|
||||
else:
|
||||
r = 0
|
||||
return complex(r, r)
|
||||
|
||||
@lower(cmath.log, types.Complex)
|
||||
@intrinsic_complex_unary
|
||||
def log_impl(x, y, x_is_finite, y_is_finite):
|
||||
"""cmath.log(x + y j)"""
|
||||
a = math.log(math.hypot(x, y))
|
||||
b = math.atan2(y, x)
|
||||
return complex(a, b)
|
||||
|
||||
|
||||
@lower(cmath.log, types.Complex, types.Complex)
|
||||
def log_base_impl(context, builder, sig, args):
|
||||
"""cmath.log(z, base)"""
|
||||
[z, base] = args
|
||||
|
||||
def log_base(z, base):
|
||||
return cmath.log(z) / cmath.log(base)
|
||||
|
||||
res = context.compile_internal(builder, log_base, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
|
||||
@overload(cmath.log10)
|
||||
def impl_cmath_log10(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
LN_10 = 2.302585092994045684
|
||||
|
||||
def log10_impl(z):
|
||||
"""cmath.log10(z)"""
|
||||
z = cmath.log(z)
|
||||
# This formula gives better results on +/-inf than cmath.log(z, 10)
|
||||
# See http://bugs.python.org/issue22544
|
||||
return complex(z.real / LN_10, z.imag / LN_10)
|
||||
|
||||
return log10_impl
|
||||
|
||||
|
||||
@overload(cmath.phase)
|
||||
def phase_impl(x):
|
||||
"""cmath.phase(x + y j)"""
|
||||
|
||||
if not isinstance(x, types.Complex):
|
||||
return
|
||||
|
||||
def impl(x):
|
||||
return math.atan2(x.imag, x.real)
|
||||
return impl
|
||||
|
||||
|
||||
@overload(cmath.polar)
|
||||
def polar_impl(x):
|
||||
if not isinstance(x, types.Complex):
|
||||
return
|
||||
|
||||
def impl(x):
|
||||
r, i = x.real, x.imag
|
||||
return math.hypot(r, i), math.atan2(i, r)
|
||||
return impl
|
||||
|
||||
|
||||
@lower(cmath.sqrt, types.Complex)
|
||||
def sqrt_impl(context, builder, sig, args):
|
||||
# We risk spurious overflow for components >= FLT_MAX / (1 + sqrt(2)).
|
||||
|
||||
SQRT2 = 1.414213562373095048801688724209698079E0
|
||||
ONE_PLUS_SQRT2 = (1. + SQRT2)
|
||||
theargflt = sig.args[0].underlying_float
|
||||
# Get a type specific maximum value so scaling for overflow is based on that
|
||||
MAX = mathimpl.DBL_MAX if theargflt.bitwidth == 64 else mathimpl.FLT_MAX
|
||||
# THRES will be double precision, should not impact typing as it's just
|
||||
# used for comparison, there *may* be a few values near THRES which
|
||||
# deviate from e.g. NumPy due to rounding that occurs in the computation
|
||||
# of this value in the case of a 32bit argument.
|
||||
THRES = MAX / ONE_PLUS_SQRT2
|
||||
|
||||
def sqrt_impl(z):
|
||||
"""cmath.sqrt(z)"""
|
||||
# This is NumPy's algorithm, see npy_csqrt() in npy_math_complex.c.src
|
||||
a = z.real
|
||||
b = z.imag
|
||||
if a == 0.0 and b == 0.0:
|
||||
return complex(abs(b), b)
|
||||
if math.isinf(b):
|
||||
return complex(abs(b), b)
|
||||
if math.isnan(a):
|
||||
return complex(a, a)
|
||||
if math.isinf(a):
|
||||
if a < 0.0:
|
||||
return complex(abs(b - b), math.copysign(a, b))
|
||||
else:
|
||||
return complex(a, math.copysign(b - b, b))
|
||||
|
||||
# The remaining special case (b is NaN) is handled just fine by
|
||||
# the normal code path below.
|
||||
|
||||
# Scale to avoid overflow
|
||||
if abs(a) >= THRES or abs(b) >= THRES:
|
||||
a *= 0.25
|
||||
b *= 0.25
|
||||
scale = True
|
||||
else:
|
||||
scale = False
|
||||
# Algorithm 312, CACM vol 10, Oct 1967
|
||||
if a >= 0:
|
||||
t = math.sqrt((a + math.hypot(a, b)) * 0.5)
|
||||
real = t
|
||||
imag = b / (2 * t)
|
||||
else:
|
||||
t = math.sqrt((-a + math.hypot(a, b)) * 0.5)
|
||||
real = abs(b) / (2 * t)
|
||||
imag = math.copysign(t, b)
|
||||
# Rescale
|
||||
if scale:
|
||||
return complex(real * 2, imag)
|
||||
else:
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, sqrt_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
|
||||
@lower(cmath.cos, types.Complex)
|
||||
def cos_impl(context, builder, sig, args):
|
||||
def cos_impl(z):
|
||||
"""cmath.cos(z) = cmath.cosh(z j)"""
|
||||
return cmath.cosh(complex(-z.imag, z.real))
|
||||
|
||||
res = context.compile_internal(builder, cos_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
@overload(cmath.cosh)
|
||||
def impl_cmath_cosh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
def cosh_impl(z):
|
||||
"""cmath.cosh(z)"""
|
||||
x = z.real
|
||||
y = z.imag
|
||||
if math.isinf(x):
|
||||
if math.isnan(y):
|
||||
# x = +inf, y = NaN => cmath.cosh(x + y j) = inf + Nan * j
|
||||
real = abs(x)
|
||||
imag = y
|
||||
elif y == 0.0:
|
||||
# x = +inf, y = 0 => cmath.cosh(x + y j) = inf + 0j
|
||||
real = abs(x)
|
||||
imag = y
|
||||
else:
|
||||
real = math.copysign(x, math.cos(y))
|
||||
imag = math.copysign(x, math.sin(y))
|
||||
if x < 0.0:
|
||||
# x = -inf => negate imaginary part of result
|
||||
imag = -imag
|
||||
return complex(real, imag)
|
||||
return complex(math.cos(y) * math.cosh(x),
|
||||
math.sin(y) * math.sinh(x))
|
||||
return cosh_impl
|
||||
|
||||
|
||||
@lower(cmath.sin, types.Complex)
|
||||
def sin_impl(context, builder, sig, args):
|
||||
def sin_impl(z):
|
||||
"""cmath.sin(z) = -j * cmath.sinh(z j)"""
|
||||
r = cmath.sinh(complex(-z.imag, z.real))
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, sin_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
@overload(cmath.sinh)
|
||||
def impl_cmath_sinh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
def sinh_impl(z):
|
||||
"""cmath.sinh(z)"""
|
||||
x = z.real
|
||||
y = z.imag
|
||||
if math.isinf(x):
|
||||
if math.isnan(y):
|
||||
# x = +/-inf, y = NaN => cmath.sinh(x + y j) = x + NaN * j
|
||||
real = x
|
||||
imag = y
|
||||
else:
|
||||
real = math.cos(y)
|
||||
imag = math.sin(y)
|
||||
if real != 0.:
|
||||
real *= x
|
||||
if imag != 0.:
|
||||
imag *= abs(x)
|
||||
return complex(real, imag)
|
||||
return complex(math.cos(y) * math.sinh(x),
|
||||
math.sin(y) * math.cosh(x))
|
||||
return sinh_impl
|
||||
|
||||
|
||||
@lower(cmath.tan, types.Complex)
|
||||
def tan_impl(context, builder, sig, args):
|
||||
def tan_impl(z):
|
||||
"""cmath.tan(z) = -j * cmath.tanh(z j)"""
|
||||
r = cmath.tanh(complex(-z.imag, z.real))
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, tan_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
|
||||
@overload(cmath.tanh)
|
||||
def impl_cmath_tanh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
def tanh_impl(z):
|
||||
"""cmath.tanh(z)"""
|
||||
x = z.real
|
||||
y = z.imag
|
||||
if math.isinf(x):
|
||||
real = math.copysign(1., x)
|
||||
if math.isinf(y):
|
||||
imag = 0.
|
||||
else:
|
||||
imag = math.copysign(0., math.sin(2. * y))
|
||||
return complex(real, imag)
|
||||
# This is CPython's algorithm (see c_tanh() in cmathmodule.c).
|
||||
# XXX how to force float constants into single precision?
|
||||
tx = math.tanh(x)
|
||||
ty = math.tan(y)
|
||||
cx = 1. / math.cosh(x)
|
||||
txty = tx * ty
|
||||
denom = 1. + txty * txty
|
||||
return complex(
|
||||
tx * (1. + ty * ty) / denom,
|
||||
((ty / denom) * cx) * cx)
|
||||
|
||||
return tanh_impl
|
||||
|
||||
|
||||
@lower(cmath.acos, types.Complex)
|
||||
def acos_impl(context, builder, sig, args):
|
||||
LN_4 = math.log(4)
|
||||
THRES = mathimpl.FLT_MAX / 4
|
||||
|
||||
def acos_impl(z):
|
||||
"""cmath.acos(z)"""
|
||||
# CPython's algorithm (see c_acos() in cmathmodule.c)
|
||||
if abs(z.real) > THRES or abs(z.imag) > THRES:
|
||||
# Avoid unnecessary overflow for large arguments
|
||||
# (also handles infinities gracefully)
|
||||
real = math.atan2(abs(z.imag), z.real)
|
||||
imag = math.copysign(
|
||||
math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4,
|
||||
-z.imag)
|
||||
return complex(real, imag)
|
||||
else:
|
||||
s1 = cmath.sqrt(complex(1. - z.real, -z.imag))
|
||||
s2 = cmath.sqrt(complex(1. + z.real, z.imag))
|
||||
real = 2. * math.atan2(s1.real, s2.real)
|
||||
imag = math.asinh(s2.real * s1.imag - s2.imag * s1.real)
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, acos_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
@overload(cmath.acosh)
|
||||
def impl_cmath_acosh(z):
|
||||
if not isinstance(z, types.Complex):
|
||||
return
|
||||
|
||||
LN_4 = math.log(4)
|
||||
THRES = mathimpl.FLT_MAX / 4
|
||||
|
||||
def acosh_impl(z):
|
||||
"""cmath.acosh(z)"""
|
||||
# CPython's algorithm (see c_acosh() in cmathmodule.c)
|
||||
if abs(z.real) > THRES or abs(z.imag) > THRES:
|
||||
# Avoid unnecessary overflow for large arguments
|
||||
# (also handles infinities gracefully)
|
||||
real = math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4
|
||||
imag = math.atan2(z.imag, z.real)
|
||||
return complex(real, imag)
|
||||
else:
|
||||
s1 = cmath.sqrt(complex(z.real - 1., z.imag))
|
||||
s2 = cmath.sqrt(complex(z.real + 1., z.imag))
|
||||
real = math.asinh(s1.real * s2.real + s1.imag * s2.imag)
|
||||
imag = 2. * math.atan2(s1.imag, s2.real)
|
||||
return complex(real, imag)
|
||||
# Condensed formula (NumPy)
|
||||
#return cmath.log(z + cmath.sqrt(z + 1.) * cmath.sqrt(z - 1.))
|
||||
|
||||
return acosh_impl
|
||||
|
||||
|
||||
@lower(cmath.asinh, types.Complex)
|
||||
def asinh_impl(context, builder, sig, args):
|
||||
LN_4 = math.log(4)
|
||||
THRES = mathimpl.FLT_MAX / 4
|
||||
|
||||
def asinh_impl(z):
|
||||
"""cmath.asinh(z)"""
|
||||
# CPython's algorithm (see c_asinh() in cmathmodule.c)
|
||||
if abs(z.real) > THRES or abs(z.imag) > THRES:
|
||||
real = math.copysign(
|
||||
math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4,
|
||||
z.real)
|
||||
imag = math.atan2(z.imag, abs(z.real))
|
||||
return complex(real, imag)
|
||||
else:
|
||||
s1 = cmath.sqrt(complex(1. + z.imag, -z.real))
|
||||
s2 = cmath.sqrt(complex(1. - z.imag, z.real))
|
||||
real = math.asinh(s1.real * s2.imag - s2.real * s1.imag)
|
||||
imag = math.atan2(z.imag, s1.real * s2.real - s1.imag * s2.imag)
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, asinh_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
@lower(cmath.asin, types.Complex)
|
||||
def asin_impl(context, builder, sig, args):
|
||||
def asin_impl(z):
|
||||
"""cmath.asin(z) = -j * cmath.asinh(z j)"""
|
||||
r = cmath.asinh(complex(-z.imag, z.real))
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, asin_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
@lower(cmath.atan, types.Complex)
|
||||
def atan_impl(context, builder, sig, args):
|
||||
def atan_impl(z):
|
||||
"""cmath.atan(z) = -j * cmath.atanh(z j)"""
|
||||
r = cmath.atanh(complex(-z.imag, z.real))
|
||||
if math.isinf(z.real) and math.isnan(z.imag):
|
||||
# XXX this is odd but necessary
|
||||
return complex(r.imag, r.real)
|
||||
else:
|
||||
return complex(r.imag, -r.real)
|
||||
|
||||
res = context.compile_internal(builder, atan_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
|
||||
@lower(cmath.atanh, types.Complex)
|
||||
def atanh_impl(context, builder, sig, args):
|
||||
LN_4 = math.log(4)
|
||||
THRES_LARGE = math.sqrt(mathimpl.FLT_MAX / 4)
|
||||
THRES_SMALL = math.sqrt(mathimpl.FLT_MIN)
|
||||
PI_12 = math.pi / 2
|
||||
|
||||
def atanh_impl(z):
|
||||
"""cmath.atanh(z)"""
|
||||
# CPython's algorithm (see c_atanh() in cmathmodule.c)
|
||||
if z.real < 0.:
|
||||
# Reduce to case where z.real >= 0., using atanh(z) = -atanh(-z).
|
||||
negate = True
|
||||
z = -z
|
||||
else:
|
||||
negate = False
|
||||
|
||||
ay = abs(z.imag)
|
||||
if math.isnan(z.real) or z.real > THRES_LARGE or ay > THRES_LARGE:
|
||||
if math.isinf(z.imag):
|
||||
real = math.copysign(0., z.real)
|
||||
elif math.isinf(z.real):
|
||||
real = 0.
|
||||
else:
|
||||
# may be safe from overflow, depending on hypot's implementation...
|
||||
h = math.hypot(z.real * 0.5, z.imag * 0.5)
|
||||
real = z.real/4./h/h
|
||||
imag = -math.copysign(PI_12, -z.imag)
|
||||
elif z.real == 1. and ay < THRES_SMALL:
|
||||
# C99 standard says: atanh(1+/-0.) should be inf +/- 0j
|
||||
if ay == 0.:
|
||||
real = INF
|
||||
imag = z.imag
|
||||
else:
|
||||
real = -math.log(math.sqrt(ay) /
|
||||
math.sqrt(math.hypot(ay, 2.)))
|
||||
imag = math.copysign(math.atan2(2., -ay) / 2, z.imag)
|
||||
else:
|
||||
sqay = ay * ay
|
||||
zr1 = 1 - z.real
|
||||
real = math.log1p(4. * z.real / (zr1 * zr1 + sqay)) * 0.25
|
||||
imag = -math.atan2(-2. * z.imag,
|
||||
zr1 * (1 + z.real) - sqay) * 0.5
|
||||
|
||||
if math.isnan(z.imag):
|
||||
imag = NAN
|
||||
if negate:
|
||||
return complex(-real, -imag)
|
||||
else:
|
||||
return complex(real, imag)
|
||||
|
||||
res = context.compile_internal(builder, atanh_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig, res)
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Implementation of enums.
|
||||
"""
|
||||
import operator
|
||||
|
||||
from numba.core.imputils import (lower_builtin, lower_getattr,
|
||||
lower_getattr_generic, lower_cast,
|
||||
lower_constant, impl_ret_untracked)
|
||||
from numba.core import types
|
||||
from numba.core.extending import overload_method
|
||||
|
||||
|
||||
@lower_builtin(operator.eq, types.EnumMember, types.EnumMember)
|
||||
def enum_eq(context, builder, sig, args):
|
||||
tu, tv = sig.args
|
||||
u, v = args
|
||||
res = context.generic_compare(builder, operator.eq,
|
||||
(tu.dtype, tv.dtype), (u, v))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.is_, types.EnumMember, types.EnumMember)
|
||||
def enum_is(context, builder, sig, args):
|
||||
tu, tv = sig.args
|
||||
u, v = args
|
||||
if tu == tv:
|
||||
res = context.generic_compare(builder, operator.eq,
|
||||
(tu.dtype, tv.dtype), (u, v))
|
||||
else:
|
||||
res = context.get_constant(sig.return_type, False)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.ne, types.EnumMember, types.EnumMember)
|
||||
def enum_ne(context, builder, sig, args):
|
||||
tu, tv = sig.args
|
||||
u, v = args
|
||||
res = context.generic_compare(builder, operator.ne,
|
||||
(tu.dtype, tv.dtype), (u, v))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_getattr(types.EnumMember, 'value')
|
||||
def enum_value(context, builder, ty, val):
|
||||
return val
|
||||
|
||||
|
||||
@lower_cast(types.IntEnumMember, types.Integer)
|
||||
def int_enum_to_int(context, builder, fromty, toty, val):
|
||||
"""
|
||||
Convert an IntEnum member to its raw integer value.
|
||||
"""
|
||||
return context.cast(builder, val, fromty.dtype, toty)
|
||||
|
||||
|
||||
@lower_constant(types.EnumMember)
|
||||
def enum_constant(context, builder, ty, pyval):
|
||||
"""
|
||||
Return a LLVM constant representing enum member *pyval*.
|
||||
"""
|
||||
return context.get_constant_generic(builder, ty.dtype, pyval.value)
|
||||
|
||||
|
||||
@lower_getattr_generic(types.EnumClass)
|
||||
def enum_class_getattr(context, builder, ty, val, attr):
|
||||
"""
|
||||
Return an enum member by attribute name.
|
||||
"""
|
||||
member = getattr(ty.instance_class, attr)
|
||||
return context.get_constant_generic(builder, ty.dtype, member.value)
|
||||
|
||||
|
||||
@lower_builtin('static_getitem', types.EnumClass, types.StringLiteral)
|
||||
def enum_class_getitem(context, builder, sig, args):
|
||||
"""
|
||||
Return an enum member by index name.
|
||||
"""
|
||||
enum_cls_typ, idx = sig.args
|
||||
member = enum_cls_typ.instance_class[idx.literal_value]
|
||||
return context.get_constant_generic(builder, enum_cls_typ.dtype,
|
||||
member.value)
|
||||
|
||||
|
||||
@overload_method(types.IntEnumMember, '__hash__')
|
||||
def intenum_hash(val):
|
||||
# uses the hash of the value, for IntEnums this will be int.__hash__
|
||||
def hash_impl(val):
|
||||
return hash(val.value)
|
||||
return hash_impl
|
||||
@@ -0,0 +1,743 @@
|
||||
"""
|
||||
Hash implementations for Numba types
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import sys
|
||||
import ctypes
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
|
||||
import llvmlite.binding as ll
|
||||
from llvmlite import ir
|
||||
|
||||
from numba import literal_unroll
|
||||
from numba.core.extending import (
|
||||
overload, overload_method, intrinsic, register_jitable)
|
||||
from numba.core import errors
|
||||
from numba.core import types
|
||||
from numba.core.unsafe.bytes import grab_byte, grab_uint64_t
|
||||
from numba.cpython.randomimpl import (const_int, get_next_int, get_next_int32,
|
||||
get_state_ptr)
|
||||
|
||||
# This is Py_hash_t, which is a Py_ssize_t, which has sizeof(size_t):
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Include/pyport.h#L91-L96 # noqa: E501
|
||||
_hash_width = sys.hash_info.width
|
||||
_Py_hash_t = getattr(types, 'int%s' % _hash_width)
|
||||
_Py_uhash_t = getattr(types, 'uint%s' % _hash_width)
|
||||
|
||||
# Constants from CPython source, obtained by various means:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Include/pyhash.h # noqa: E501
|
||||
_PyHASH_INF = sys.hash_info.inf
|
||||
_PyHASH_NAN = sys.hash_info.nan
|
||||
_PyHASH_MODULUS = _Py_uhash_t(sys.hash_info.modulus)
|
||||
_PyHASH_BITS = 31 if types.intp.bitwidth == 32 else 61 # mersenne primes
|
||||
_PyHASH_MULTIPLIER = 0xf4243 # 1000003UL
|
||||
_PyHASH_IMAG = _PyHASH_MULTIPLIER
|
||||
_PyLong_SHIFT = sys.int_info.bits_per_digit
|
||||
_Py_HASH_CUTOFF = sys.hash_info.cutoff
|
||||
_Py_hashfunc_name = sys.hash_info.algorithm
|
||||
|
||||
|
||||
# This stub/overload pair are used to force branch pruning to remove the dead
|
||||
# branch based on the potential `None` type of the hash_func which works better
|
||||
# if the predicate for the prune in an ir.Arg. The obj is an arg to allow for
|
||||
# a custom error message.
|
||||
def _defer_hash(hash_func):
|
||||
pass
|
||||
|
||||
|
||||
@overload(_defer_hash)
|
||||
def ol_defer_hash(obj, hash_func):
|
||||
err_msg = f"unhashable type: '{obj}'"
|
||||
|
||||
def impl(obj, hash_func):
|
||||
if hash_func is None:
|
||||
raise TypeError(err_msg)
|
||||
else:
|
||||
return hash_func()
|
||||
return impl
|
||||
|
||||
|
||||
# hash(obj) is implemented by calling obj.__hash__()
|
||||
@overload(hash)
|
||||
def hash_overload(obj):
|
||||
attempt_generic_msg = ("No __hash__ is defined for object of type "
|
||||
f"'{obj}' and a generic hash() cannot be "
|
||||
"performed as there is no suitable object "
|
||||
"represention in Numba compiled code!")
|
||||
|
||||
def impl(obj):
|
||||
if hasattr(obj, '__hash__'):
|
||||
return _defer_hash(obj, getattr(obj, '__hash__'))
|
||||
else:
|
||||
raise TypeError(attempt_generic_msg)
|
||||
return impl
|
||||
|
||||
|
||||
@register_jitable
|
||||
def process_return(val):
|
||||
asint = _Py_hash_t(val)
|
||||
if (asint == int(-1)):
|
||||
asint = int(-2)
|
||||
return asint
|
||||
|
||||
|
||||
# This is a translation of CPython's _Py_HashDouble:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L34-L129 # noqa: E501
|
||||
# NOTE: In Python 3.10 hash of nan is now hash of the pointer to the PyObject
|
||||
# containing said nan. Numba cannot replicate this as there is no object, so it
|
||||
# elects to replicate the behaviour i.e. hash of nan is something "unique" which
|
||||
# satisfies https://bugs.python.org/issue43475.
|
||||
|
||||
@register_jitable(locals={'x': _Py_uhash_t,
|
||||
'y': _Py_uhash_t,
|
||||
'm': types.double,
|
||||
'e': types.intc,
|
||||
'sign': types.intc,
|
||||
'_PyHASH_MODULUS': _Py_uhash_t,
|
||||
'_PyHASH_BITS': types.intc})
|
||||
def _Py_HashDouble(v):
|
||||
if not np.isfinite(v):
|
||||
if (np.isinf(v)):
|
||||
if (v > 0):
|
||||
return _PyHASH_INF
|
||||
else:
|
||||
return -_PyHASH_INF
|
||||
else:
|
||||
# Python 3.10 does not use `_PyHASH_NAN`.
|
||||
# https://github.com/python/cpython/blob/2c4792264f9218692a1bd87398a60591f756b171/Python/pyhash.c#L102 # noqa: E501
|
||||
# Numba returns a pseudo-random number to reflect the spirit of the
|
||||
# change.
|
||||
x = _prng_random_hash()
|
||||
return process_return(x)
|
||||
|
||||
m, e = math.frexp(v)
|
||||
|
||||
sign = 1
|
||||
if (m < 0):
|
||||
sign = -1
|
||||
m = -m
|
||||
|
||||
# process 28 bits at a time; this should work well both for binary
|
||||
# and hexadecimal floating point.
|
||||
x = 0
|
||||
while (m):
|
||||
x = ((x << 28) & _PyHASH_MODULUS) | x >> (_PyHASH_BITS - 28)
|
||||
m *= 268435456.0 # /* 2**28 */
|
||||
e -= 28
|
||||
y = int(m) # /* pull out integer part */
|
||||
m -= y
|
||||
x += y
|
||||
if x >= _PyHASH_MODULUS:
|
||||
x -= _PyHASH_MODULUS
|
||||
# /* adjust for the exponent; first reduce it modulo _PyHASH_BITS */
|
||||
if e >= 0:
|
||||
e = e % _PyHASH_BITS
|
||||
else:
|
||||
e = _PyHASH_BITS - 1 - ((-1 - e) % _PyHASH_BITS)
|
||||
|
||||
x = ((x << e) & _PyHASH_MODULUS) | x >> (_PyHASH_BITS - e)
|
||||
|
||||
x = x * sign
|
||||
return process_return(x)
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _fpext(tyctx, val):
|
||||
def impl(cgctx, builder, signature, args):
|
||||
val = args[0]
|
||||
return builder.fpext(val, ir.DoubleType())
|
||||
sig = types.float64(types.float32)
|
||||
return sig, impl
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _prng_random_hash(tyctx):
|
||||
|
||||
def impl(cgctx, builder, signature, args):
|
||||
state_ptr = get_state_ptr(cgctx, builder, "internal")
|
||||
bits = const_int(_hash_width)
|
||||
|
||||
# Why not just use get_next_int() with the correct bitwidth?
|
||||
# get_next_int() always returns an i64, because the bitwidth it is
|
||||
# passed may not be a compile-time constant, so it needs to allocate
|
||||
# the largest unit of storage that may be required. Therefore, if the
|
||||
# hash width is 32, then we need to use get_next_int32() to ensure we
|
||||
# don't return a wider-than-expected hash, even if everything above
|
||||
# the low 32 bits would have been zero.
|
||||
if _hash_width == 32:
|
||||
value = get_next_int32(cgctx, builder, state_ptr)
|
||||
else:
|
||||
value = get_next_int(cgctx, builder, state_ptr, bits, False)
|
||||
|
||||
return value
|
||||
|
||||
sig = _Py_hash_t()
|
||||
return sig, impl
|
||||
|
||||
|
||||
# This is a translation of CPython's long_hash, but restricted to the numerical
|
||||
# domain reachable by int64/uint64 (i.e. no BigInt like support):
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/longobject.c#L2934-L2989 # noqa: E501
|
||||
# obdigit is a uint32_t which is typedef'd to digit
|
||||
# int32_t is typedef'd to sdigit
|
||||
|
||||
|
||||
@register_jitable(locals={'x': _Py_uhash_t,
|
||||
'p1': _Py_uhash_t,
|
||||
'p2': _Py_uhash_t,
|
||||
'p3': _Py_uhash_t,
|
||||
'p4': _Py_uhash_t,
|
||||
'_PyHASH_MODULUS': _Py_uhash_t,
|
||||
'_PyHASH_BITS': types.int32,
|
||||
'_PyLong_SHIFT': types.int32,})
|
||||
def _long_impl(val):
|
||||
# This function assumes val came from a long int repr with val being a
|
||||
# uint64_t this means having to split the input into PyLong_SHIFT size
|
||||
# chunks in an unsigned hash wide type, max numba can handle is a 64bit int
|
||||
|
||||
# mask to select low _PyLong_SHIFT bits
|
||||
_tmp_shift = 32 - _PyLong_SHIFT
|
||||
mask_shift = (~types.uint32(0x0)) >> _tmp_shift
|
||||
|
||||
# a 64bit wide max means Numba only needs 3 x 30 bit values max,
|
||||
# or 5 x 15 bit values max on 32bit platforms
|
||||
i = (64 // _PyLong_SHIFT) + 1
|
||||
|
||||
# alg as per hash_long
|
||||
x = 0
|
||||
p3 = (_PyHASH_BITS - _PyLong_SHIFT)
|
||||
for idx in range(i - 1, -1, -1):
|
||||
p1 = x << _PyLong_SHIFT
|
||||
p2 = p1 & _PyHASH_MODULUS
|
||||
p4 = x >> p3
|
||||
x = p2 | p4
|
||||
# the shift and mask splits out the `ob_digit` parts of a Long repr
|
||||
x += types.uint32((val >> idx * _PyLong_SHIFT) & mask_shift)
|
||||
if x >= _PyHASH_MODULUS:
|
||||
x -= _PyHASH_MODULUS
|
||||
return _Py_hash_t(x)
|
||||
|
||||
|
||||
# This has no CPython equivalent, CPython uses long_hash.
|
||||
@overload_method(types.Integer, '__hash__')
|
||||
@overload_method(types.Boolean, '__hash__')
|
||||
def int_hash(val):
|
||||
|
||||
_HASH_I64_MIN = -2 if sys.maxsize <= 2 ** 32 else -4
|
||||
_SIGNED_MIN = types.int64(-0x8000000000000000)
|
||||
|
||||
# Find a suitable type to hold a "big" value, i.e. iinfo(ty).min/max
|
||||
# this is to ensure e.g. int32.min is handled ok as it's abs() is its value
|
||||
_BIG = types.int64 if getattr(val, 'signed', False) else types.uint64
|
||||
|
||||
# this is a bit involved due to the CPython repr of ints
|
||||
def impl(val):
|
||||
# If the magnitude is under PyHASH_MODULUS, just return the
|
||||
# value val as the hash, couple of special cases if val == val:
|
||||
# 1. it's 0, in which case return 0
|
||||
# 2. it's signed int minimum value, return the value CPython computes
|
||||
# but Numba cannot as there's no type wide enough to hold the shifts.
|
||||
#
|
||||
# If the magnitude is greater than PyHASH_MODULUS then... if the value
|
||||
# is negative then negate it switch the sign on the hash once computed
|
||||
# and use the standard wide unsigned hash implementation
|
||||
val = _BIG(val)
|
||||
mag = abs(val)
|
||||
if mag < _PyHASH_MODULUS:
|
||||
if val == 0:
|
||||
ret = 0
|
||||
elif val == _SIGNED_MIN: # e.g. int64 min, -0x8000000000000000
|
||||
ret = _Py_hash_t(_HASH_I64_MIN)
|
||||
else:
|
||||
ret = _Py_hash_t(val)
|
||||
else:
|
||||
needs_negate = False
|
||||
if val < 0:
|
||||
val = -val
|
||||
needs_negate = True
|
||||
ret = _long_impl(val)
|
||||
if needs_negate:
|
||||
ret = -ret
|
||||
return process_return(ret)
|
||||
return impl
|
||||
|
||||
# This is a translation of CPython's float_hash:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/floatobject.c#L528-L532 # noqa: E501
|
||||
|
||||
|
||||
@overload_method(types.Float, '__hash__')
|
||||
def float_hash(val):
|
||||
if val.bitwidth == 64:
|
||||
def impl(val):
|
||||
hashed = _Py_HashDouble(val)
|
||||
return hashed
|
||||
else:
|
||||
def impl(val):
|
||||
# widen the 32bit float to 64bit
|
||||
fpextended = np.float64(_fpext(val))
|
||||
hashed = _Py_HashDouble(fpextended)
|
||||
return hashed
|
||||
return impl
|
||||
|
||||
# This is a translation of CPython's complex_hash:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/complexobject.c#L408-L428 # noqa: E501
|
||||
|
||||
|
||||
@overload_method(types.Complex, '__hash__')
|
||||
def complex_hash(val):
|
||||
def impl(val):
|
||||
hashreal = hash(val.real)
|
||||
hashimag = hash(val.imag)
|
||||
# Note: if the imaginary part is 0, hashimag is 0 now,
|
||||
# so the following returns hashreal unchanged. This is
|
||||
# important because numbers of different types that
|
||||
# compare equal must have the same hash value, so that
|
||||
# hash(x + 0*j) must equal hash(x).
|
||||
combined = hashreal + _PyHASH_IMAG * hashimag
|
||||
return process_return(combined)
|
||||
return impl
|
||||
|
||||
|
||||
# Python 3.8 strengthened its hash alg for tuples.
|
||||
# This is a translation of CPython's tuplehash for Python >=3.8
|
||||
# https://github.com/python/cpython/blob/b738237d6792acba85b1f6e6c8993a812c7fd815/Objects/tupleobject.c#L338-L391 # noqa: E501
|
||||
|
||||
# These consts are needed for this alg variant, they are from:
|
||||
# https://github.com/python/cpython/blob/b738237d6792acba85b1f6e6c8993a812c7fd815/Objects/tupleobject.c#L353-L363 # noqa: E501
|
||||
if _Py_uhash_t.bitwidth // 8 > 4:
|
||||
_PyHASH_XXPRIME_1 = _Py_uhash_t(11400714785074694791)
|
||||
_PyHASH_XXPRIME_2 = _Py_uhash_t(14029467366897019727)
|
||||
_PyHASH_XXPRIME_5 = _Py_uhash_t(2870177450012600261)
|
||||
|
||||
@register_jitable(locals={'x': types.uint64})
|
||||
def _PyHASH_XXROTATE(x):
|
||||
# Rotate left 31 bits
|
||||
return ((x << types.uint64(31)) | (x >> types.uint64(33)))
|
||||
else:
|
||||
_PyHASH_XXPRIME_1 = _Py_uhash_t(2654435761)
|
||||
_PyHASH_XXPRIME_2 = _Py_uhash_t(2246822519)
|
||||
_PyHASH_XXPRIME_5 = _Py_uhash_t(374761393)
|
||||
|
||||
@register_jitable(locals={'x': types.uint64})
|
||||
def _PyHASH_XXROTATE(x):
|
||||
# Rotate left 13 bits
|
||||
return ((x << types.uint64(13)) | (x >> types.uint64(19)))
|
||||
|
||||
|
||||
@register_jitable(locals={'acc': _Py_uhash_t, 'lane': _Py_uhash_t,
|
||||
'_PyHASH_XXPRIME_5': _Py_uhash_t,
|
||||
'_PyHASH_XXPRIME_1': _Py_uhash_t,
|
||||
'tl': _Py_uhash_t})
|
||||
def _tuple_hash(tup):
|
||||
tl = len(tup)
|
||||
acc = _PyHASH_XXPRIME_5
|
||||
for x in literal_unroll(tup):
|
||||
lane = hash(x)
|
||||
if lane == _Py_uhash_t(-1):
|
||||
return -1
|
||||
acc += lane * _PyHASH_XXPRIME_2
|
||||
acc = _PyHASH_XXROTATE(acc)
|
||||
acc *= _PyHASH_XXPRIME_1
|
||||
|
||||
acc += tl ^ (_PyHASH_XXPRIME_5 ^ _Py_uhash_t(3527539))
|
||||
|
||||
if acc == _Py_uhash_t(-1):
|
||||
return process_return(1546275796)
|
||||
|
||||
return process_return(acc)
|
||||
|
||||
|
||||
@overload_method(types.BaseTuple, '__hash__')
|
||||
def tuple_hash(val):
|
||||
def impl(val):
|
||||
return _tuple_hash(val)
|
||||
return impl
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# String/bytes hashing needs hashseed info, this is from:
|
||||
# https://stackoverflow.com/a/41088757
|
||||
# with thanks to Martijn Pieters
|
||||
#
|
||||
# Developer note:
|
||||
# CPython makes use of an internal "hashsecret" which is essentially a struct
|
||||
# containing some state that is set on CPython initialization and contains magic
|
||||
# numbers used particularly in unicode/string hashing. This code binds to the
|
||||
# Python runtime libraries in use by the current process and reads the
|
||||
# "hashsecret" state so that it can be used by Numba. As this is done at runtime
|
||||
# the behaviour and influence of the PYTHONHASHSEED environment variable is
|
||||
# accommodated.
|
||||
|
||||
from ctypes import ( # noqa
|
||||
c_size_t,
|
||||
c_ubyte,
|
||||
c_uint64,
|
||||
pythonapi,
|
||||
Structure,
|
||||
Union,
|
||||
) # noqa
|
||||
|
||||
|
||||
class FNV(Structure):
|
||||
_fields_ = [
|
||||
('prefix', c_size_t),
|
||||
('suffix', c_size_t)
|
||||
]
|
||||
|
||||
|
||||
class SIPHASH(Structure):
|
||||
_fields_ = [
|
||||
('k0', c_uint64),
|
||||
('k1', c_uint64),
|
||||
]
|
||||
|
||||
|
||||
class DJBX33A(Structure):
|
||||
_fields_ = [
|
||||
('padding', c_ubyte * 16),
|
||||
('suffix', c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class EXPAT(Structure):
|
||||
_fields_ = [
|
||||
('padding', c_ubyte * 16),
|
||||
('hashsalt', c_size_t),
|
||||
]
|
||||
|
||||
|
||||
class _Py_HashSecret_t(Union):
|
||||
_fields_ = [
|
||||
# ensure 24 bytes
|
||||
('uc', c_ubyte * 24),
|
||||
# two Py_hash_t for FNV
|
||||
('fnv', FNV),
|
||||
# two uint64 for SipHash24
|
||||
('siphash', SIPHASH),
|
||||
# a different (!) Py_hash_t for small string optimization
|
||||
('djbx33a', DJBX33A),
|
||||
('expat', EXPAT),
|
||||
]
|
||||
|
||||
|
||||
_hashsecret_entry = namedtuple('_hashsecret_entry', ['symbol', 'value'])
|
||||
|
||||
|
||||
# Only a few members are needed at present
|
||||
def _build_hashsecret():
|
||||
"""Read hash secret from the Python process
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : dict
|
||||
- keys are "djbx33a_suffix", "siphash_k0", siphash_k1".
|
||||
- values are the namedtuple[symbol:str, value:int]
|
||||
"""
|
||||
# Read hashsecret and inject it into the LLVM symbol map under the
|
||||
# prefix `_numba_hashsecret_`.
|
||||
pyhashsecret = _Py_HashSecret_t.in_dll(pythonapi, '_Py_HashSecret')
|
||||
info = {}
|
||||
|
||||
def inject(name, val):
|
||||
symbol_name = "_numba_hashsecret_{}".format(name)
|
||||
val = ctypes.c_uint64(val)
|
||||
addr = ctypes.addressof(val)
|
||||
ll.add_symbol(symbol_name, addr)
|
||||
info[name] = _hashsecret_entry(symbol=symbol_name, value=val)
|
||||
|
||||
inject('djbx33a_suffix', pyhashsecret.djbx33a.suffix)
|
||||
inject('siphash_k0', pyhashsecret.siphash.k0)
|
||||
inject('siphash_k1', pyhashsecret.siphash.k1)
|
||||
return info
|
||||
|
||||
|
||||
_hashsecret = _build_hashsecret()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
if _Py_hashfunc_name in ('siphash13', 'siphash24', 'fnv'):
|
||||
|
||||
# Check for use of the FNV hashing alg, warn users that it's not implemented
|
||||
# and functionality relying of properties derived from hashing will be fine
|
||||
# but hash values themselves are likely to be different.
|
||||
if _Py_hashfunc_name == 'fnv':
|
||||
msg = ("FNV hashing is not implemented in Numba. See PEP 456 "
|
||||
"https://www.python.org/dev/peps/pep-0456/ "
|
||||
"for rationale over not using FNV. Numba will continue to work, "
|
||||
"but hashes for built in types will be computed using "
|
||||
"siphash24. This will permit e.g. dictionaries to continue to "
|
||||
"behave as expected, however anything relying on the value of "
|
||||
"the hash opposed to hash as a derived property is likely to "
|
||||
"not work as expected.")
|
||||
warnings.warn(msg)
|
||||
|
||||
# This is a translation of CPython's siphash24 function:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L287-L413 # noqa: E501
|
||||
# and also, since Py 3.11, a translation of CPython's siphash13 function:
|
||||
# https://github.com/python/cpython/blob/9dda9020abcf0d51d59b283a89c58c8e1fb0f574/Python/pyhash.c#L376-L424
|
||||
# the only differences are in the use of SINGLE_ROUND in siphash13 vs.
|
||||
# DOUBLE_ROUND in siphash24, and that siphash13 has an extra "ROUND" applied
|
||||
# just before the final XORing of components to create the return value.
|
||||
|
||||
# /* *********************************************************************
|
||||
# <MIT License>
|
||||
# Copyright (c) 2013 Marek Majkowski <marek@popcount.org>
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
# </MIT License>
|
||||
|
||||
# Original location:
|
||||
# https://github.com/majek/csiphash/
|
||||
|
||||
# Solution inspired by code from:
|
||||
# Samuel Neves (supercop/crypto_auth/siphash24/little)
|
||||
#djb (supercop/crypto_auth/siphash24/little2)
|
||||
# Jean-Philippe Aumasson (https://131002.net/siphash/siphash24.c)
|
||||
|
||||
# Modified for Python by Christian Heimes:
|
||||
# - C89 / MSVC compatibility
|
||||
# - _rotl64() on Windows
|
||||
# - letoh64() fallback
|
||||
# */
|
||||
|
||||
@register_jitable(locals={'x': types.uint64,
|
||||
'b': types.uint64, })
|
||||
def _ROTATE(x, b):
|
||||
return types.uint64(((x) << (b)) | ((x) >> (types.uint64(64) - (b))))
|
||||
|
||||
@register_jitable(locals={'a': types.uint64,
|
||||
'b': types.uint64,
|
||||
'c': types.uint64,
|
||||
'd': types.uint64,
|
||||
's': types.uint64,
|
||||
't': types.uint64, })
|
||||
def _HALF_ROUND(a, b, c, d, s, t):
|
||||
a += b
|
||||
c += d
|
||||
b = _ROTATE(b, s) ^ a
|
||||
d = _ROTATE(d, t) ^ c
|
||||
a = _ROTATE(a, 32)
|
||||
return a, b, c, d
|
||||
|
||||
@register_jitable(locals={'v0': types.uint64,
|
||||
'v1': types.uint64,
|
||||
'v2': types.uint64,
|
||||
'v3': types.uint64, })
|
||||
def _SINGLE_ROUND(v0, v1, v2, v3):
|
||||
v0, v1, v2, v3 = _HALF_ROUND(v0, v1, v2, v3, 13, 16)
|
||||
v2, v1, v0, v3 = _HALF_ROUND(v2, v1, v0, v3, 17, 21)
|
||||
return v0, v1, v2, v3
|
||||
|
||||
@register_jitable(locals={'v0': types.uint64,
|
||||
'v1': types.uint64,
|
||||
'v2': types.uint64,
|
||||
'v3': types.uint64, })
|
||||
def _DOUBLE_ROUND(v0, v1, v2, v3):
|
||||
v0, v1, v2, v3 = _SINGLE_ROUND(v0, v1, v2, v3)
|
||||
v0, v1, v2, v3 = _SINGLE_ROUND(v0, v1, v2, v3)
|
||||
return v0, v1, v2, v3
|
||||
|
||||
def _gen_siphash(alg):
|
||||
if alg == 'siphash13':
|
||||
_ROUNDER = _SINGLE_ROUND
|
||||
_EXTRA_ROUND = True
|
||||
elif alg == 'siphash24':
|
||||
_ROUNDER = _DOUBLE_ROUND
|
||||
_EXTRA_ROUND = False
|
||||
else:
|
||||
assert 0, 'unreachable'
|
||||
|
||||
@register_jitable(locals={'v0': types.uint64,
|
||||
'v1': types.uint64,
|
||||
'v2': types.uint64,
|
||||
'v3': types.uint64,
|
||||
'b': types.uint64,
|
||||
'mi': types.uint64,
|
||||
't': types.uint64,
|
||||
'mask': types.uint64,
|
||||
'jmp': types.uint64,
|
||||
'ohexefef': types.uint64})
|
||||
def _siphash(k0, k1, src, src_sz):
|
||||
b = types.uint64(src_sz) << 56
|
||||
v0 = k0 ^ types.uint64(0x736f6d6570736575)
|
||||
v1 = k1 ^ types.uint64(0x646f72616e646f6d)
|
||||
v2 = k0 ^ types.uint64(0x6c7967656e657261)
|
||||
v3 = k1 ^ types.uint64(0x7465646279746573)
|
||||
|
||||
idx = 0
|
||||
while (src_sz >= 8):
|
||||
mi = grab_uint64_t(src, idx)
|
||||
idx += 1
|
||||
src_sz -= 8
|
||||
v3 ^= mi
|
||||
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
|
||||
v0 ^= mi
|
||||
|
||||
# this is the switch fallthrough:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L390-L400 # noqa: E501
|
||||
t = types.uint64(0x0)
|
||||
boffset = idx * 8
|
||||
ohexefef = types.uint64(0xff)
|
||||
if src_sz >= 7:
|
||||
jmp = (6 * 8)
|
||||
mask = ~types.uint64(ohexefef << jmp)
|
||||
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 6))
|
||||
<< jmp)
|
||||
if src_sz >= 6:
|
||||
jmp = (5 * 8)
|
||||
mask = ~types.uint64(ohexefef << jmp)
|
||||
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 5))
|
||||
<< jmp)
|
||||
if src_sz >= 5:
|
||||
jmp = (4 * 8)
|
||||
mask = ~types.uint64(ohexefef << jmp)
|
||||
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 4))
|
||||
<< jmp)
|
||||
if src_sz >= 4:
|
||||
t &= types.uint64(0xffffffff00000000)
|
||||
for i in range(4):
|
||||
jmp = i * 8
|
||||
mask = ~types.uint64(ohexefef << jmp)
|
||||
t = (t & mask) | (types.uint64(grab_byte(src, boffset + i))
|
||||
<< jmp)
|
||||
if src_sz >= 3:
|
||||
jmp = (2 * 8)
|
||||
mask = ~types.uint64(ohexefef << jmp)
|
||||
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 2))
|
||||
<< jmp)
|
||||
if src_sz >= 2:
|
||||
jmp = (1 * 8)
|
||||
mask = ~types.uint64(ohexefef << jmp)
|
||||
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 1))
|
||||
<< jmp)
|
||||
if src_sz >= 1:
|
||||
mask = ~(ohexefef)
|
||||
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 0)))
|
||||
|
||||
b |= t
|
||||
v3 ^= b
|
||||
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
|
||||
v0 ^= b
|
||||
v2 ^= ohexefef
|
||||
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
|
||||
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
|
||||
if _EXTRA_ROUND:
|
||||
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
|
||||
t = (v0 ^ v1) ^ (v2 ^ v3)
|
||||
return t
|
||||
|
||||
return _siphash
|
||||
|
||||
_siphash13 = _gen_siphash('siphash13')
|
||||
_siphash24 = _gen_siphash('siphash24')
|
||||
|
||||
_siphasher = _siphash13 if _Py_hashfunc_name == 'siphash13' else _siphash24
|
||||
|
||||
else:
|
||||
msg = "Unsupported hashing algorithm in use %s" % _Py_hashfunc_name
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _inject_hashsecret_read(tyctx, name):
|
||||
"""Emit code to load the hashsecret.
|
||||
"""
|
||||
if not isinstance(name, types.StringLiteral):
|
||||
raise errors.TypingError("requires literal string")
|
||||
|
||||
sym = _hashsecret[name.literal_value].symbol
|
||||
resty = types.uint64
|
||||
sig = resty(name)
|
||||
|
||||
def impl(cgctx, builder, sig, args):
|
||||
mod = builder.module
|
||||
try:
|
||||
# Search for existing global
|
||||
gv = mod.get_global(sym)
|
||||
except KeyError:
|
||||
# Inject the symbol if not already exist.
|
||||
gv = ir.GlobalVariable(mod, ir.IntType(64), name=sym)
|
||||
v = builder.load(gv)
|
||||
return v
|
||||
|
||||
return sig, impl
|
||||
|
||||
|
||||
def _load_hashsecret(name):
|
||||
return _hashsecret[name].value
|
||||
|
||||
|
||||
@overload(_load_hashsecret)
|
||||
def _impl_load_hashsecret(name):
|
||||
def imp(name):
|
||||
return _inject_hashsecret_read(name)
|
||||
return imp
|
||||
|
||||
|
||||
# This is a translation of CPythons's _Py_HashBytes:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L145-L191 # noqa: E501
|
||||
|
||||
|
||||
@register_jitable(locals={'_hash': _Py_uhash_t})
|
||||
def _Py_HashBytes(val, _len):
|
||||
if (_len == 0):
|
||||
return process_return(0)
|
||||
|
||||
if (_len < _Py_HASH_CUTOFF):
|
||||
# TODO: this branch needs testing, needs a CPython setup for it!
|
||||
# /* Optimize hashing of very small strings with inline DJBX33A. */
|
||||
_hash = _Py_uhash_t(5381) # /* DJBX33A starts with 5381 */
|
||||
for idx in range(_len):
|
||||
_hash = ((_hash << 5) + _hash) + np.uint8(grab_byte(val, idx))
|
||||
|
||||
_hash ^= _len
|
||||
_hash ^= _load_hashsecret('djbx33a_suffix')
|
||||
else:
|
||||
tmp = _siphasher(types.uint64(_load_hashsecret('siphash_k0')),
|
||||
types.uint64(_load_hashsecret('siphash_k1')),
|
||||
val, _len)
|
||||
_hash = process_return(tmp)
|
||||
return process_return(_hash)
|
||||
|
||||
# This is an approximate translation of CPython's unicode_hash:
|
||||
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/unicodeobject.c#L11635-L11663 # noqa: E501
|
||||
|
||||
|
||||
@overload_method(types.UnicodeType, '__hash__')
|
||||
def unicode_hash(val):
|
||||
from numba.cpython.unicode import _kind_to_byte_width
|
||||
|
||||
def impl(val):
|
||||
kindwidth = _kind_to_byte_width(val._kind)
|
||||
_len = len(val)
|
||||
# use the cache if possible
|
||||
current_hash = val._hash
|
||||
if current_hash != -1:
|
||||
return current_hash
|
||||
else:
|
||||
# cannot write hash value to cache in the unicode struct due to
|
||||
# pass by value on the struct making the struct member immutable
|
||||
return _Py_HashBytes(val._data, kindwidth * _len)
|
||||
|
||||
return impl
|
||||
@@ -0,0 +1,266 @@
|
||||
# A port of https://github.com/python/cpython/blob/e42b7051/Lib/heapq.py
|
||||
|
||||
|
||||
import heapq as hq
|
||||
|
||||
from numba.core import types
|
||||
from numba.core.errors import TypingError
|
||||
from numba.core.extending import overload, register_jitable
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftdown(heap, startpos, pos):
|
||||
newitem = heap[pos]
|
||||
|
||||
while pos > startpos:
|
||||
parentpos = (pos - 1) >> 1
|
||||
parent = heap[parentpos]
|
||||
if newitem < parent:
|
||||
heap[pos] = parent
|
||||
pos = parentpos
|
||||
continue
|
||||
break
|
||||
|
||||
heap[pos] = newitem
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftup(heap, pos):
|
||||
endpos = len(heap)
|
||||
startpos = pos
|
||||
newitem = heap[pos]
|
||||
|
||||
childpos = 2 * pos + 1
|
||||
while childpos < endpos:
|
||||
|
||||
rightpos = childpos + 1
|
||||
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
|
||||
childpos = rightpos
|
||||
|
||||
heap[pos] = heap[childpos]
|
||||
pos = childpos
|
||||
childpos = 2 * pos + 1
|
||||
|
||||
heap[pos] = newitem
|
||||
_siftdown(heap, startpos, pos)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftdown_max(heap, startpos, pos):
|
||||
newitem = heap[pos]
|
||||
|
||||
while pos > startpos:
|
||||
parentpos = (pos - 1) >> 1
|
||||
parent = heap[parentpos]
|
||||
if parent < newitem:
|
||||
heap[pos] = parent
|
||||
pos = parentpos
|
||||
continue
|
||||
break
|
||||
heap[pos] = newitem
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftup_max(heap, pos):
|
||||
endpos = len(heap)
|
||||
startpos = pos
|
||||
newitem = heap[pos]
|
||||
|
||||
childpos = 2 * pos + 1
|
||||
while childpos < endpos:
|
||||
|
||||
rightpos = childpos + 1
|
||||
if rightpos < endpos and not heap[rightpos] < heap[childpos]:
|
||||
childpos = rightpos
|
||||
|
||||
heap[pos] = heap[childpos]
|
||||
pos = childpos
|
||||
childpos = 2 * pos + 1
|
||||
|
||||
heap[pos] = newitem
|
||||
_siftdown_max(heap, startpos, pos)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def reversed_range(x):
|
||||
# analogous to reversed(range(x))
|
||||
return range(x - 1, -1, -1)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _heapify_max(x):
|
||||
n = len(x)
|
||||
|
||||
for i in reversed_range(n // 2):
|
||||
_siftup_max(x, i)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _heapreplace_max(heap, item):
|
||||
returnitem = heap[0]
|
||||
heap[0] = item
|
||||
_siftup_max(heap, 0)
|
||||
return returnitem
|
||||
|
||||
|
||||
def assert_heap_type(heap):
|
||||
if not isinstance(heap, (types.List, types.ListType)):
|
||||
raise TypingError('heap argument must be a list')
|
||||
|
||||
dt = heap.dtype
|
||||
if isinstance(dt, types.Complex):
|
||||
msg = ("'<' not supported between instances "
|
||||
"of 'complex' and 'complex'")
|
||||
raise TypingError(msg)
|
||||
|
||||
|
||||
def assert_item_type_consistent_with_heap_type(heap, item):
|
||||
if not heap.dtype == item:
|
||||
raise TypingError('heap type must be the same as item type')
|
||||
|
||||
|
||||
@overload(hq.heapify)
|
||||
def hq_heapify(x):
|
||||
assert_heap_type(x)
|
||||
|
||||
def hq_heapify_impl(x):
|
||||
n = len(x)
|
||||
for i in reversed_range(n // 2):
|
||||
_siftup(x, i)
|
||||
|
||||
return hq_heapify_impl
|
||||
|
||||
|
||||
@overload(hq.heappop)
|
||||
def hq_heappop(heap):
|
||||
assert_heap_type(heap)
|
||||
|
||||
def hq_heappop_impl(heap):
|
||||
lastelt = heap.pop()
|
||||
if heap:
|
||||
returnitem = heap[0]
|
||||
heap[0] = lastelt
|
||||
_siftup(heap, 0)
|
||||
return returnitem
|
||||
return lastelt
|
||||
|
||||
return hq_heappop_impl
|
||||
|
||||
|
||||
@overload(hq.heappush)
|
||||
def heappush(heap, item):
|
||||
assert_heap_type(heap)
|
||||
assert_item_type_consistent_with_heap_type(heap, item)
|
||||
|
||||
def hq_heappush_impl(heap, item):
|
||||
heap.append(item)
|
||||
_siftdown(heap, 0, len(heap) - 1)
|
||||
|
||||
return hq_heappush_impl
|
||||
|
||||
|
||||
@overload(hq.heapreplace)
|
||||
def heapreplace(heap, item):
|
||||
assert_heap_type(heap)
|
||||
assert_item_type_consistent_with_heap_type(heap, item)
|
||||
|
||||
def hq_heapreplace(heap, item):
|
||||
returnitem = heap[0]
|
||||
heap[0] = item
|
||||
_siftup(heap, 0)
|
||||
return returnitem
|
||||
|
||||
return hq_heapreplace
|
||||
|
||||
|
||||
@overload(hq.heappushpop)
|
||||
def heappushpop(heap, item):
|
||||
assert_heap_type(heap)
|
||||
assert_item_type_consistent_with_heap_type(heap, item)
|
||||
|
||||
def hq_heappushpop_impl(heap, item):
|
||||
if heap and heap[0] < item:
|
||||
item, heap[0] = heap[0], item
|
||||
_siftup(heap, 0)
|
||||
return item
|
||||
|
||||
return hq_heappushpop_impl
|
||||
|
||||
|
||||
def check_input_types(n, iterable):
|
||||
|
||||
if not isinstance(n, (types.Integer, types.Boolean)):
|
||||
raise TypingError("First argument 'n' must be an integer")
|
||||
# heapq also accepts 1.0 (but not 0.0, 2.0, 3.0...) but
|
||||
# this isn't replicated
|
||||
|
||||
if not isinstance(iterable, (types.Sequence, types.Array, types.ListType)):
|
||||
raise TypingError("Second argument 'iterable' must be iterable")
|
||||
|
||||
|
||||
@overload(hq.nsmallest)
|
||||
def nsmallest(n, iterable):
|
||||
check_input_types(n, iterable)
|
||||
|
||||
def hq_nsmallest_impl(n, iterable):
|
||||
|
||||
if n == 0:
|
||||
return [iterable[0] for _ in range(0)]
|
||||
elif n == 1:
|
||||
out = min(iterable)
|
||||
return [out]
|
||||
|
||||
size = len(iterable)
|
||||
if n >= size:
|
||||
return sorted(iterable)[:n]
|
||||
|
||||
it = iter(iterable)
|
||||
result = [(elem, i) for i, elem in zip(range(n), it)]
|
||||
|
||||
_heapify_max(result)
|
||||
top = result[0][0]
|
||||
order = n
|
||||
|
||||
for elem in it:
|
||||
if elem < top:
|
||||
_heapreplace_max(result, (elem, order))
|
||||
top, _order = result[0]
|
||||
order += 1
|
||||
result.sort()
|
||||
return [elem for (elem, order) in result]
|
||||
|
||||
return hq_nsmallest_impl
|
||||
|
||||
|
||||
@overload(hq.nlargest)
|
||||
def nlargest(n, iterable):
|
||||
check_input_types(n, iterable)
|
||||
|
||||
def hq_nlargest_impl(n, iterable):
|
||||
|
||||
if n == 0:
|
||||
return [iterable[0] for _ in range(0)]
|
||||
elif n == 1:
|
||||
out = max(iterable)
|
||||
return [out]
|
||||
|
||||
size = len(iterable)
|
||||
if n >= size:
|
||||
return sorted(iterable)[::-1][:n]
|
||||
|
||||
it = iter(iterable)
|
||||
result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
|
||||
|
||||
hq.heapify(result)
|
||||
top = result[0][0]
|
||||
order = -n
|
||||
|
||||
for elem in it:
|
||||
if top < elem:
|
||||
hq.heapreplace(result, (elem, order))
|
||||
top, _order = result[0]
|
||||
order -= 1
|
||||
result.sort(reverse=True)
|
||||
return [elem for (elem, order) in result]
|
||||
|
||||
return hq_nlargest_impl
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Implementation of various iterable and iterator types.
|
||||
"""
|
||||
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.imputils import (
|
||||
lower_builtin, iternext_impl, call_iternext, call_getiter,
|
||||
impl_ret_borrowed, impl_ret_new_ref, RefType)
|
||||
|
||||
|
||||
|
||||
@lower_builtin('getiter', types.IteratorType)
|
||||
def iterator_getiter(context, builder, sig, args):
|
||||
[it] = args
|
||||
return impl_ret_borrowed(context, builder, sig.return_type, it)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# builtin `enumerate` implementation
|
||||
|
||||
@lower_builtin(enumerate, types.IterableType)
|
||||
@lower_builtin(enumerate, types.IterableType, types.Integer)
|
||||
def make_enumerate_object(context, builder, sig, args):
|
||||
assert len(args) == 1 or len(args) == 2 # enumerate(it) or enumerate(it, start)
|
||||
srcty = sig.args[0]
|
||||
|
||||
if len(args) == 1:
|
||||
src = args[0]
|
||||
start_val = context.get_constant(types.intp, 0)
|
||||
elif len(args) == 2:
|
||||
src = args[0]
|
||||
start_val = context.cast(builder, args[1], sig.args[1], types.intp)
|
||||
|
||||
iterobj = call_getiter(context, builder, srcty, src)
|
||||
|
||||
enum = context.make_helper(builder, sig.return_type)
|
||||
|
||||
countptr = cgutils.alloca_once(builder, start_val.type)
|
||||
builder.store(start_val, countptr)
|
||||
|
||||
enum.count = countptr
|
||||
enum.iter = iterobj
|
||||
|
||||
res = enum._getvalue()
|
||||
return impl_ret_new_ref(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin('iternext', types.EnumerateType)
|
||||
@iternext_impl(RefType.NEW)
|
||||
def iternext_enumerate(context, builder, sig, args, result):
|
||||
[enumty] = sig.args
|
||||
[enum] = args
|
||||
|
||||
enum = context.make_helper(builder, enumty, value=enum)
|
||||
|
||||
count = builder.load(enum.count)
|
||||
ncount = builder.add(count, context.get_constant(types.intp, 1))
|
||||
builder.store(ncount, enum.count)
|
||||
|
||||
srcres = call_iternext(context, builder, enumty.source_type, enum.iter)
|
||||
is_valid = srcres.is_valid()
|
||||
result.set_valid(is_valid)
|
||||
|
||||
with builder.if_then(is_valid):
|
||||
srcval = srcres.yielded_value()
|
||||
result.yield_(context.make_tuple(builder, enumty.yield_type,
|
||||
[count, srcval]))
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# builtin `zip` implementation
|
||||
|
||||
@lower_builtin(zip, types.VarArg(types.Any))
|
||||
def make_zip_object(context, builder, sig, args):
|
||||
zip_type = sig.return_type
|
||||
|
||||
assert len(args) == len(zip_type.source_types)
|
||||
|
||||
zipobj = context.make_helper(builder, zip_type)
|
||||
|
||||
for i, (arg, srcty) in enumerate(zip(args, sig.args)):
|
||||
zipobj[i] = call_getiter(context, builder, srcty, arg)
|
||||
|
||||
res = zipobj._getvalue()
|
||||
return impl_ret_new_ref(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin('iternext', types.ZipType)
|
||||
@iternext_impl(RefType.NEW)
|
||||
def iternext_zip(context, builder, sig, args, result):
|
||||
[zip_type] = sig.args
|
||||
[zipobj] = args
|
||||
|
||||
zipobj = context.make_helper(builder, zip_type, value=zipobj)
|
||||
|
||||
if len(zipobj) == 0:
|
||||
# zip() is an empty iterator
|
||||
result.set_exhausted()
|
||||
return
|
||||
|
||||
p_ret_tup = cgutils.alloca_once(builder,
|
||||
context.get_value_type(zip_type.yield_type))
|
||||
p_is_valid = cgutils.alloca_once_value(builder, value=cgutils.true_bit)
|
||||
|
||||
for i, (iterobj, srcty) in enumerate(zip(zipobj, zip_type.source_types)):
|
||||
is_valid = builder.load(p_is_valid)
|
||||
# Avoid calling the remaining iternext if a iterator has been exhausted
|
||||
with builder.if_then(is_valid):
|
||||
srcres = call_iternext(context, builder, srcty, iterobj)
|
||||
is_valid = builder.and_(is_valid, srcres.is_valid())
|
||||
builder.store(is_valid, p_is_valid)
|
||||
val = srcres.yielded_value()
|
||||
ptr = cgutils.gep_inbounds(builder, p_ret_tup, 0, i)
|
||||
builder.store(val, ptr)
|
||||
|
||||
is_valid = builder.load(p_is_valid)
|
||||
result.set_valid(is_valid)
|
||||
|
||||
with builder.if_then(is_valid):
|
||||
result.yield_(builder.load(p_ret_tup))
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# generator implementation
|
||||
|
||||
@lower_builtin('iternext', types.Generator)
|
||||
@iternext_impl(RefType.BORROWED)
|
||||
def iternext_zip(context, builder, sig, args, result):
|
||||
genty, = sig.args
|
||||
gen, = args
|
||||
impl = context.get_generator_impl(genty)
|
||||
status, retval = impl(context, builder, sig, args)
|
||||
context.add_linking_libs(getattr(impl, 'libs', ()))
|
||||
|
||||
with cgutils.if_likely(builder, status.is_ok):
|
||||
result.set_valid(True)
|
||||
result.yield_(retval)
|
||||
with cgutils.if_unlikely(builder, status.is_stop_iteration):
|
||||
result.set_exhausted()
|
||||
with cgutils.if_unlikely(builder,
|
||||
builder.and_(status.is_error,
|
||||
builder.not_(status.is_stop_iteration))):
|
||||
context.call_conv.return_status_propagate(builder, status)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Provide math calls that uses intrinsics or libc math functions.
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import llvmlite.ir
|
||||
from llvmlite.ir import Constant
|
||||
|
||||
from numba.core.imputils import Registry, impl_ret_untracked
|
||||
from numba import typeof
|
||||
from numba.core import types, utils, config, cgutils
|
||||
from numba.core.extending import overload
|
||||
from numba.core.typing import signature
|
||||
from numba.cpython.unsafe.numbers import trailing_zeros
|
||||
|
||||
|
||||
registry = Registry('mathimpl')
|
||||
lower = registry.lower
|
||||
|
||||
|
||||
# Helpers, shared with cmathimpl.
|
||||
_NP_FLT_FINFO = np.finfo(np.dtype('float32'))
|
||||
FLT_MAX = _NP_FLT_FINFO.max
|
||||
FLT_MIN = _NP_FLT_FINFO.tiny
|
||||
|
||||
_NP_DBL_FINFO = np.finfo(np.dtype('float64'))
|
||||
DBL_MAX = _NP_DBL_FINFO.max
|
||||
DBL_MIN = _NP_DBL_FINFO.tiny
|
||||
|
||||
FLOAT_ABS_MASK = 0x7fffffff
|
||||
FLOAT_SIGN_MASK = 0x80000000
|
||||
DOUBLE_ABS_MASK = 0x7fffffffffffffff
|
||||
DOUBLE_SIGN_MASK = 0x8000000000000000
|
||||
|
||||
|
||||
def is_nan(builder, val):
|
||||
"""
|
||||
Return a condition testing whether *val* is a NaN.
|
||||
"""
|
||||
return builder.fcmp_unordered('uno', val, val)
|
||||
|
||||
def is_inf(builder, val):
|
||||
"""
|
||||
Return a condition testing whether *val* is an infinite.
|
||||
"""
|
||||
pos_inf = Constant(val.type, float("+inf"))
|
||||
neg_inf = Constant(val.type, float("-inf"))
|
||||
isposinf = builder.fcmp_ordered('==', val, pos_inf)
|
||||
isneginf = builder.fcmp_ordered('==', val, neg_inf)
|
||||
return builder.or_(isposinf, isneginf)
|
||||
|
||||
def is_finite(builder, val):
|
||||
"""
|
||||
Return a condition testing whether *val* is a finite.
|
||||
"""
|
||||
# is_finite(x) <=> x - x != NaN
|
||||
val_minus_val = builder.fsub(val, val)
|
||||
return builder.fcmp_ordered('ord', val_minus_val, val_minus_val)
|
||||
|
||||
def f64_as_int64(builder, val):
|
||||
"""
|
||||
Bitcast a double into a 64-bit integer.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.DoubleType()
|
||||
return builder.bitcast(val, llvmlite.ir.IntType(64))
|
||||
|
||||
def int64_as_f64(builder, val):
|
||||
"""
|
||||
Bitcast a 64-bit integer into a double.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.IntType(64)
|
||||
return builder.bitcast(val, llvmlite.ir.DoubleType())
|
||||
|
||||
def f32_as_int32(builder, val):
|
||||
"""
|
||||
Bitcast a float into a 32-bit integer.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.FloatType()
|
||||
return builder.bitcast(val, llvmlite.ir.IntType(32))
|
||||
|
||||
def int32_as_f32(builder, val):
|
||||
"""
|
||||
Bitcast a 32-bit integer into a float.
|
||||
"""
|
||||
assert val.type == llvmlite.ir.IntType(32)
|
||||
return builder.bitcast(val, llvmlite.ir.FloatType())
|
||||
|
||||
def negate_real(builder, val):
|
||||
"""
|
||||
Negate real number *val*, with proper handling of zeros.
|
||||
"""
|
||||
# The negative zero forces LLVM to handle signed zeros properly.
|
||||
return builder.fsub(Constant(val.type, -0.0), val)
|
||||
|
||||
def call_fp_intrinsic(builder, name, args):
|
||||
"""
|
||||
Call a LLVM intrinsic floating-point operation.
|
||||
"""
|
||||
mod = builder.module
|
||||
intr = mod.declare_intrinsic(name, [a.type for a in args])
|
||||
return builder.call(intr, args)
|
||||
|
||||
|
||||
def _unary_int_input_wrapper_impl(wrapped_impl):
|
||||
"""
|
||||
Return an implementation factory to convert the single integral input
|
||||
argument to a float64, then defer to the *wrapped_impl*.
|
||||
"""
|
||||
def implementer(context, builder, sig, args):
|
||||
val, = args
|
||||
input_type = sig.args[0]
|
||||
fpval = context.cast(builder, val, input_type, types.float64)
|
||||
inner_sig = signature(types.float64, types.float64)
|
||||
res = wrapped_impl(context, builder, inner_sig, (fpval,))
|
||||
return context.cast(builder, res, types.float64, sig.return_type)
|
||||
|
||||
return implementer
|
||||
|
||||
def unary_math_int_impl(fn, float_impl):
|
||||
impl = _unary_int_input_wrapper_impl(float_impl)
|
||||
lower(fn, types.Integer)(impl)
|
||||
|
||||
def unary_math_intr(fn, intrcode):
|
||||
"""
|
||||
Implement the math function *fn* using the LLVM intrinsic *intrcode*.
|
||||
"""
|
||||
@lower(fn, types.Float)
|
||||
def float_impl(context, builder, sig, args):
|
||||
res = call_fp_intrinsic(builder, intrcode, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
unary_math_int_impl(fn, float_impl)
|
||||
return float_impl
|
||||
|
||||
def unary_math_extern(fn, f32extern, f64extern, int_restype=False):
|
||||
"""
|
||||
Register implementations of Python function *fn* using the
|
||||
external function named *f32extern* and *f64extern* (for float32
|
||||
and float64 inputs, respectively).
|
||||
If *int_restype* is true, then the function's return value should be
|
||||
integral, otherwise floating-point.
|
||||
"""
|
||||
f_restype = types.int64 if int_restype else None
|
||||
|
||||
def float_impl(context, builder, sig, args):
|
||||
"""
|
||||
Implement *fn* for a types.Float input.
|
||||
"""
|
||||
[val] = args
|
||||
mod = builder.module
|
||||
input_type = sig.args[0]
|
||||
lty = context.get_value_type(input_type)
|
||||
func_name = {
|
||||
types.float32: f32extern,
|
||||
types.float64: f64extern,
|
||||
}[input_type]
|
||||
fnty = llvmlite.ir.FunctionType(lty, [lty])
|
||||
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
|
||||
res = builder.call(fn, (val,))
|
||||
res = context.cast(builder, res, input_type, sig.return_type)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
lower(fn, types.Float)(float_impl)
|
||||
|
||||
# Implement wrapper for integer inputs
|
||||
unary_math_int_impl(fn, float_impl)
|
||||
|
||||
return float_impl
|
||||
|
||||
|
||||
unary_math_intr(math.fabs, 'llvm.fabs')
|
||||
exp_impl = unary_math_intr(math.exp, 'llvm.exp')
|
||||
if sys.version_info >= (3, 11):
|
||||
exp2_impl = unary_math_intr(math.exp2, 'llvm.exp2')
|
||||
log_impl = unary_math_intr(math.log, 'llvm.log')
|
||||
log10_impl = unary_math_intr(math.log10, 'llvm.log10')
|
||||
log2_impl = unary_math_intr(math.log2, 'llvm.log2')
|
||||
sin_impl = unary_math_intr(math.sin, 'llvm.sin')
|
||||
cos_impl = unary_math_intr(math.cos, 'llvm.cos')
|
||||
|
||||
log1p_impl = unary_math_extern(math.log1p, "log1pf", "log1p")
|
||||
expm1_impl = unary_math_extern(math.expm1, "expm1f", "expm1")
|
||||
erf_impl = unary_math_extern(math.erf, "erff", "erf")
|
||||
erfc_impl = unary_math_extern(math.erfc, "erfcf", "erfc")
|
||||
|
||||
tan_impl = unary_math_extern(math.tan, "tanf", "tan")
|
||||
asin_impl = unary_math_extern(math.asin, "asinf", "asin")
|
||||
acos_impl = unary_math_extern(math.acos, "acosf", "acos")
|
||||
atan_impl = unary_math_extern(math.atan, "atanf", "atan")
|
||||
|
||||
asinh_impl = unary_math_extern(math.asinh, "asinhf", "asinh")
|
||||
acosh_impl = unary_math_extern(math.acosh, "acoshf", "acosh")
|
||||
atanh_impl = unary_math_extern(math.atanh, "atanhf", "atanh")
|
||||
sinh_impl = unary_math_extern(math.sinh, "sinhf", "sinh")
|
||||
cosh_impl = unary_math_extern(math.cosh, "coshf", "cosh")
|
||||
tanh_impl = unary_math_extern(math.tanh, "tanhf", "tanh")
|
||||
|
||||
log2_impl = unary_math_extern(math.log2, "log2f", "log2")
|
||||
ceil_impl = unary_math_extern(math.ceil, "ceilf", "ceil", True)
|
||||
floor_impl = unary_math_extern(math.floor, "floorf", "floor", True)
|
||||
|
||||
gamma_impl = unary_math_extern(math.gamma, "numba_gammaf", "numba_gamma") # work-around
|
||||
sqrt_impl = unary_math_extern(math.sqrt, "sqrtf", "sqrt")
|
||||
trunc_impl = unary_math_extern(math.trunc, "truncf", "trunc", True)
|
||||
lgamma_impl = unary_math_extern(math.lgamma, "lgammaf", "lgamma")
|
||||
|
||||
|
||||
@lower(math.isnan, types.Float)
|
||||
def isnan_float_impl(context, builder, sig, args):
|
||||
[val] = args
|
||||
res = is_nan(builder, val)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower(math.isnan, types.Integer)
|
||||
def isnan_int_impl(context, builder, sig, args):
|
||||
res = cgutils.false_bit
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(math.isinf, types.Float)
|
||||
def isinf_float_impl(context, builder, sig, args):
|
||||
[val] = args
|
||||
res = is_inf(builder, val)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower(math.isinf, types.Integer)
|
||||
def isinf_int_impl(context, builder, sig, args):
|
||||
res = cgutils.false_bit
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(math.isfinite, types.Float)
|
||||
def isfinite_float_impl(context, builder, sig, args):
|
||||
[val] = args
|
||||
res = is_finite(builder, val)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(math.isfinite, types.Integer)
|
||||
def isfinite_int_impl(context, builder, sig, args):
|
||||
res = cgutils.true_bit
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(math.copysign, types.Float, types.Float)
|
||||
def copysign_float_impl(context, builder, sig, args):
|
||||
lty = args[0].type
|
||||
mod = builder.module
|
||||
fn = cgutils.get_or_insert_function(mod, llvmlite.ir.FunctionType(lty, (lty, lty)),
|
||||
'llvm.copysign.%s' % lty.intrinsic_name)
|
||||
res = builder.call(fn, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@lower(math.frexp, types.Float)
|
||||
def frexp_impl(context, builder, sig, args):
|
||||
val, = args
|
||||
fltty = context.get_data_type(sig.args[0])
|
||||
intty = context.get_data_type(sig.return_type[1])
|
||||
expptr = cgutils.alloca_once(builder, intty, name='exp')
|
||||
fnty = llvmlite.ir.FunctionType(fltty, (fltty, llvmlite.ir.PointerType(intty)))
|
||||
fname = {
|
||||
"float": "numba_frexpf",
|
||||
"double": "numba_frexp",
|
||||
}[str(fltty)]
|
||||
fn = cgutils.get_or_insert_function(builder.module, fnty, fname)
|
||||
res = builder.call(fn, (val, expptr))
|
||||
res = cgutils.make_anonymous_struct(builder, (res, builder.load(expptr)))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(math.ldexp, types.Float, types.intc)
|
||||
def ldexp_impl(context, builder, sig, args):
|
||||
val, exp = args
|
||||
fltty, intty = map(context.get_data_type, sig.args)
|
||||
fnty = llvmlite.ir.FunctionType(fltty, (fltty, intty))
|
||||
fname = {
|
||||
"float": "numba_ldexpf",
|
||||
"double": "numba_ldexp",
|
||||
}[str(fltty)]
|
||||
fn = cgutils.insert_pure_function(builder.module, fnty, name=fname)
|
||||
res = builder.call(fn, (val, exp))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@lower(math.atan2, types.int64, types.int64)
|
||||
def atan2_s64_impl(context, builder, sig, args):
|
||||
[y, x] = args
|
||||
y = builder.sitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.sitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
return atan2_float_impl(context, builder, fsig, (y, x))
|
||||
|
||||
@lower(math.atan2, types.uint64, types.uint64)
|
||||
def atan2_u64_impl(context, builder, sig, args):
|
||||
[y, x] = args
|
||||
y = builder.uitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.uitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
return atan2_float_impl(context, builder, fsig, (y, x))
|
||||
|
||||
@lower(math.atan2, types.Float, types.Float)
|
||||
def atan2_float_impl(context, builder, sig, args):
|
||||
assert len(args) == 2
|
||||
mod = builder.module
|
||||
ty = sig.args[0]
|
||||
lty = context.get_value_type(ty)
|
||||
func_name = {
|
||||
types.float32: "atan2f",
|
||||
types.float64: "atan2"
|
||||
}[ty]
|
||||
fnty = llvmlite.ir.FunctionType(lty, (lty, lty))
|
||||
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
|
||||
res = builder.call(fn, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@lower(math.hypot, types.int64, types.int64)
|
||||
def hypot_s64_impl(context, builder, sig, args):
|
||||
[x, y] = args
|
||||
y = builder.sitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.sitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
res = hypot_float_impl(context, builder, fsig, (x, y))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(math.hypot, types.uint64, types.uint64)
|
||||
def hypot_u64_impl(context, builder, sig, args):
|
||||
[x, y] = args
|
||||
y = builder.sitofp(y, llvmlite.ir.DoubleType())
|
||||
x = builder.sitofp(x, llvmlite.ir.DoubleType())
|
||||
fsig = signature(types.float64, types.float64, types.float64)
|
||||
res = hypot_float_impl(context, builder, fsig, (x, y))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(math.hypot, types.Float, types.Float)
|
||||
def hypot_float_impl(context, builder, sig, args):
|
||||
xty, yty = sig.args
|
||||
assert xty == yty == sig.return_type
|
||||
x, y = args
|
||||
|
||||
# Windows has alternate names for hypot/hypotf, see
|
||||
# https://msdn.microsoft.com/fr-fr/library/a9yb3dbt%28v=vs.80%29.aspx
|
||||
fname = {
|
||||
types.float32: "_hypotf" if sys.platform == 'win32' else "hypotf",
|
||||
types.float64: "_hypot" if sys.platform == 'win32' else "hypot",
|
||||
}[xty]
|
||||
plat_hypot = types.ExternalFunction(fname, sig)
|
||||
|
||||
if sys.platform == 'win32' and config.MACHINE_BITS == 32:
|
||||
inf = xty(float('inf'))
|
||||
|
||||
def hypot_impl(x, y):
|
||||
if math.isinf(x) or math.isinf(y):
|
||||
return inf
|
||||
return plat_hypot(x, y)
|
||||
else:
|
||||
def hypot_impl(x, y):
|
||||
return plat_hypot(x, y)
|
||||
|
||||
res = context.compile_internal(builder, hypot_impl, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@lower(math.radians, types.Float)
|
||||
def radians_float_impl(context, builder, sig, args):
|
||||
[x] = args
|
||||
coef = context.get_constant(sig.return_type, math.pi / 180)
|
||||
res = builder.fmul(x, coef)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
unary_math_int_impl(math.radians, radians_float_impl)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@lower(math.degrees, types.Float)
|
||||
def degrees_float_impl(context, builder, sig, args):
|
||||
[x] = args
|
||||
coef = context.get_constant(sig.return_type, 180 / math.pi)
|
||||
res = builder.fmul(x, coef)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
unary_math_int_impl(math.degrees, degrees_float_impl)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@lower(math.pow, types.Float, types.Float)
|
||||
@lower(math.pow, types.Float, types.Integer)
|
||||
def pow_impl(context, builder, sig, args):
|
||||
impl = context.get_function(operator.pow, sig)
|
||||
return impl(builder, args)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@lower(math.nextafter, types.Float, types.Float)
|
||||
def nextafter_impl(context, builder, sig, args):
|
||||
assert len(args) == 2
|
||||
ty = sig.args[0]
|
||||
lty = context.get_value_type(ty)
|
||||
func_name = {
|
||||
types.float32: "nextafterf",
|
||||
types.float64: "nextafter"
|
||||
}[ty]
|
||||
fnty = llvmlite.ir.FunctionType(lty, (lty, lty))
|
||||
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
|
||||
res = builder.call(fn, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _unsigned(T):
|
||||
"""Convert integer to unsigned integer of equivalent width."""
|
||||
pass
|
||||
|
||||
@overload(_unsigned)
|
||||
def _unsigned_impl(T):
|
||||
if T in types.unsigned_domain:
|
||||
return lambda T: T
|
||||
elif T in types.signed_domain:
|
||||
newT = getattr(types, 'uint{}'.format(T.bitwidth))
|
||||
return lambda T: newT(T)
|
||||
|
||||
|
||||
def gcd_impl(context, builder, sig, args):
|
||||
xty, yty = sig.args
|
||||
assert xty == yty == sig.return_type
|
||||
x, y = args
|
||||
|
||||
def gcd(a, b):
|
||||
"""
|
||||
Stein's algorithm, heavily cribbed from Julia implementation.
|
||||
"""
|
||||
T = type(a)
|
||||
if a == 0: return abs(b)
|
||||
if b == 0: return abs(a)
|
||||
za = trailing_zeros(a)
|
||||
zb = trailing_zeros(b)
|
||||
k = min(za, zb)
|
||||
# Uses np.*_shift instead of operators due to return types
|
||||
u = _unsigned(abs(np.right_shift(a, za)))
|
||||
v = _unsigned(abs(np.right_shift(b, zb)))
|
||||
while u != v:
|
||||
if u > v:
|
||||
u, v = v, u
|
||||
v -= u
|
||||
v = np.right_shift(v, trailing_zeros(v))
|
||||
r = np.left_shift(T(u), k)
|
||||
return r
|
||||
|
||||
res = context.compile_internal(builder, gcd, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
lower(math.gcd, types.Integer, types.Integer)(gcd_impl)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
This file implements print functionality for the CPU.
|
||||
"""
|
||||
from numba.core import types, typing, cgutils
|
||||
from numba.core.imputils import Registry, impl_ret_untracked
|
||||
|
||||
registry = Registry('printimpl')
|
||||
lower = registry.lower
|
||||
|
||||
|
||||
# NOTE: the current implementation relies on CPython API even in
|
||||
# nopython mode.
|
||||
|
||||
|
||||
@lower("print_item", types.Literal)
|
||||
def print_item_impl(context, builder, sig, args):
|
||||
"""
|
||||
Print a single constant value.
|
||||
"""
|
||||
ty, = sig.args
|
||||
val = ty.literal_value
|
||||
|
||||
pyapi = context.get_python_api(builder)
|
||||
|
||||
strobj = pyapi.unserialize(pyapi.serialize_object(val))
|
||||
pyapi.print_object(strobj)
|
||||
pyapi.decref(strobj)
|
||||
|
||||
res = context.get_dummy_value()
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower("print_item", types.Any)
|
||||
def print_item_impl(context, builder, sig, args):
|
||||
"""
|
||||
Print a single native value by boxing it in a Python object and
|
||||
invoking the Python interpreter's print routine.
|
||||
"""
|
||||
ty, = sig.args
|
||||
val, = args
|
||||
|
||||
pyapi = context.get_python_api(builder)
|
||||
env_manager = context.get_env_manager(builder)
|
||||
|
||||
if context.enable_nrt:
|
||||
context.nrt.incref(builder, ty, val)
|
||||
|
||||
obj = pyapi.from_native_value(ty, val, env_manager)
|
||||
with builder.if_else(cgutils.is_not_null(builder, obj), likely=True) as (if_ok, if_error):
|
||||
with if_ok:
|
||||
pyapi.print_object(obj)
|
||||
pyapi.decref(obj)
|
||||
with if_error:
|
||||
cstr = context.insert_const_string(builder.module,
|
||||
"the print() function")
|
||||
strobj = pyapi.string_from_string(cstr)
|
||||
pyapi.err_write_unraisable(strobj)
|
||||
pyapi.decref(strobj)
|
||||
|
||||
res = context.get_dummy_value()
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower(print, types.VarArg(types.Any))
|
||||
def print_varargs_impl(context, builder, sig, args):
|
||||
"""
|
||||
A entire print() call.
|
||||
"""
|
||||
pyapi = context.get_python_api(builder)
|
||||
gil = pyapi.gil_ensure()
|
||||
|
||||
for i, (argtype, argval) in enumerate(zip(sig.args, args)):
|
||||
signature = typing.signature(types.none, argtype)
|
||||
imp = context.get_function("print_item", signature)
|
||||
imp(builder, [argval])
|
||||
if i < len(args) - 1:
|
||||
pyapi.print_string(' ')
|
||||
pyapi.print_string('\n')
|
||||
|
||||
pyapi.gil_release(gil)
|
||||
res = context.get_dummy_value()
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Implementation of the range object for fixed-size integers.
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
from numba import prange
|
||||
from numba.core import types, cgutils, errors, config
|
||||
from numba.core.imputils import (lower_builtin, lower_cast,
|
||||
iterator_impl, impl_ret_untracked)
|
||||
from numba.core.typing import signature
|
||||
from numba.core.extending import intrinsic, overload, overload_attribute, register_jitable
|
||||
from numba.parfors.parfor import internal_prange
|
||||
|
||||
def make_range_iterator(typ):
|
||||
"""
|
||||
Return the Structure representation of the given *typ* (an
|
||||
instance of types.RangeIteratorType).
|
||||
"""
|
||||
return cgutils.create_struct_proxy(typ)
|
||||
|
||||
|
||||
def make_range_impl(int_type, range_state_type, range_iter_type):
|
||||
RangeState = cgutils.create_struct_proxy(range_state_type)
|
||||
|
||||
@lower_builtin(range, int_type)
|
||||
@lower_builtin(prange, int_type)
|
||||
@lower_builtin(internal_prange, int_type)
|
||||
def range1_impl(context, builder, sig, args):
|
||||
"""
|
||||
range(stop: int) -> range object
|
||||
"""
|
||||
[stop] = args
|
||||
state = RangeState(context, builder)
|
||||
state.start = context.get_constant(int_type, 0)
|
||||
state.stop = stop
|
||||
state.step = context.get_constant(int_type, 1)
|
||||
return impl_ret_untracked(context,
|
||||
builder,
|
||||
range_state_type,
|
||||
state._getvalue())
|
||||
|
||||
@lower_builtin(range, int_type, int_type)
|
||||
@lower_builtin(prange, int_type, int_type)
|
||||
@lower_builtin(internal_prange, int_type, int_type)
|
||||
def range2_impl(context, builder, sig, args):
|
||||
"""
|
||||
range(start: int, stop: int) -> range object
|
||||
"""
|
||||
start, stop = args
|
||||
state = RangeState(context, builder)
|
||||
state.start = start
|
||||
state.stop = stop
|
||||
state.step = context.get_constant(int_type, 1)
|
||||
return impl_ret_untracked(context,
|
||||
builder,
|
||||
range_state_type,
|
||||
state._getvalue())
|
||||
|
||||
@lower_builtin(range, int_type, int_type, int_type)
|
||||
@lower_builtin(prange, int_type, int_type, int_type)
|
||||
@lower_builtin(internal_prange, int_type, int_type, int_type)
|
||||
def range3_impl(context, builder, sig, args):
|
||||
"""
|
||||
range(start: int, stop: int, step: int) -> range object
|
||||
"""
|
||||
[start, stop, step] = args
|
||||
state = RangeState(context, builder)
|
||||
state.start = start
|
||||
state.stop = stop
|
||||
state.step = step
|
||||
return impl_ret_untracked(context,
|
||||
builder,
|
||||
range_state_type,
|
||||
state._getvalue())
|
||||
|
||||
@lower_builtin(len, range_state_type)
|
||||
def range_len(context, builder, sig, args):
|
||||
"""
|
||||
len(range)
|
||||
"""
|
||||
(value,) = args
|
||||
state = RangeState(context, builder, value)
|
||||
res = RangeIter.from_range_state(context, builder, state)
|
||||
return impl_ret_untracked(context, builder, int_type, builder.load(res.count))
|
||||
|
||||
@lower_builtin('getiter', range_state_type)
|
||||
def getiter_range32_impl(context, builder, sig, args):
|
||||
"""
|
||||
range.__iter__
|
||||
"""
|
||||
(value,) = args
|
||||
state = RangeState(context, builder, value)
|
||||
res = RangeIter.from_range_state(context, builder, state)._getvalue()
|
||||
return impl_ret_untracked(context, builder, range_iter_type, res)
|
||||
|
||||
@iterator_impl(range_state_type, range_iter_type)
|
||||
class RangeIter(make_range_iterator(range_iter_type)):
|
||||
|
||||
@classmethod
|
||||
def from_range_state(cls, context, builder, state):
|
||||
"""
|
||||
Create a RangeIter initialized from the given RangeState *state*.
|
||||
"""
|
||||
self = cls(context, builder)
|
||||
start = state.start
|
||||
stop = state.stop
|
||||
step = state.step
|
||||
|
||||
startptr = cgutils.alloca_once(builder, start.type)
|
||||
builder.store(start, startptr)
|
||||
|
||||
countptr = cgutils.alloca_once(builder, start.type)
|
||||
|
||||
self.iter = startptr
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
self.count = countptr
|
||||
|
||||
diff = builder.sub(stop, start)
|
||||
zero = context.get_constant(int_type, 0)
|
||||
one = context.get_constant(int_type, 1)
|
||||
pos_diff = builder.icmp_signed('>', diff, zero)
|
||||
pos_step = builder.icmp_signed('>', step, zero)
|
||||
sign_differs = builder.xor(pos_diff, pos_step)
|
||||
zero_step = builder.icmp_unsigned('==', step, zero)
|
||||
|
||||
with cgutils.if_unlikely(builder, zero_step):
|
||||
# step shouldn't be zero
|
||||
context.call_conv.return_user_exc(builder, ValueError,
|
||||
("range() arg 3 must not be zero",))
|
||||
|
||||
with builder.if_else(sign_differs) as (then, orelse):
|
||||
with then:
|
||||
builder.store(zero, self.count)
|
||||
|
||||
with orelse:
|
||||
rem = builder.srem(diff, step)
|
||||
rem = builder.select(pos_diff, rem, builder.neg(rem))
|
||||
uneven = builder.icmp_signed('>', rem, zero)
|
||||
newcount = builder.add(builder.sdiv(diff, step),
|
||||
builder.select(uneven, one, zero))
|
||||
builder.store(newcount, self.count)
|
||||
|
||||
return self
|
||||
|
||||
def iternext(self, context, builder, result):
|
||||
zero = context.get_constant(int_type, 0)
|
||||
countptr = self.count
|
||||
count = builder.load(countptr)
|
||||
is_valid = builder.icmp_signed('>', count, zero)
|
||||
result.set_valid(is_valid)
|
||||
|
||||
with builder.if_then(is_valid):
|
||||
value = builder.load(self.iter)
|
||||
result.yield_(value)
|
||||
one = context.get_constant(int_type, 1)
|
||||
|
||||
builder.store(builder.sub(count, one, flags=["nsw"]), countptr)
|
||||
builder.store(builder.add(value, self.step), self.iter)
|
||||
|
||||
range_impl_map = {
|
||||
types.int32 : (types.range_state32_type, types.range_iter32_type),
|
||||
types.int64 : (types.range_state64_type, types.range_iter64_type),
|
||||
types.uint64 : (types.unsigned_range_state64_type, types.unsigned_range_iter64_type)
|
||||
}
|
||||
|
||||
for int_type, state_types in range_impl_map.items():
|
||||
make_range_impl(int_type, *state_types)
|
||||
|
||||
@lower_cast(types.RangeType, types.RangeType)
|
||||
def range_to_range(context, builder, fromty, toty, val):
|
||||
olditems = cgutils.unpack_tuple(builder, val, 3)
|
||||
items = [context.cast(builder, v, fromty.dtype, toty.dtype)
|
||||
for v in olditems]
|
||||
return cgutils.make_anonymous_struct(builder, items)
|
||||
|
||||
|
||||
def make_range_attr(index, attribute):
|
||||
@intrinsic
|
||||
def rangetype_attr_getter(typingctx, a):
|
||||
if isinstance(a, types.RangeType):
|
||||
def codegen(context, builder, sig, args):
|
||||
(val,) = args
|
||||
items = cgutils.unpack_tuple(builder, val, 3)
|
||||
return impl_ret_untracked(context, builder, sig.return_type,
|
||||
items[index])
|
||||
return signature(a.dtype, a), codegen
|
||||
|
||||
@overload_attribute(types.RangeType, attribute)
|
||||
def range_attr(rnge):
|
||||
def get(rnge):
|
||||
return rangetype_attr_getter(rnge)
|
||||
return get
|
||||
|
||||
|
||||
@register_jitable
|
||||
def impl_contains_helper(robj, val):
|
||||
if robj.step > 0 and (val < robj.start or val >= robj.stop):
|
||||
return False
|
||||
elif robj.step < 0 and (val <= robj.stop or val > robj.start):
|
||||
return False
|
||||
|
||||
return ((val - robj.start) % robj.step) == 0
|
||||
|
||||
|
||||
@overload(operator.contains)
|
||||
def impl_contains(robj, val):
|
||||
def impl_false(robj, val):
|
||||
return False
|
||||
|
||||
if not isinstance(robj, types.RangeType):
|
||||
return
|
||||
|
||||
elif isinstance(val, (types.Integer, types.Boolean)):
|
||||
return impl_contains_helper
|
||||
|
||||
elif isinstance(val, types.Float):
|
||||
def impl(robj, val):
|
||||
if val % 1 != 0:
|
||||
return False
|
||||
else:
|
||||
return impl_contains_helper(robj, int(val))
|
||||
return impl
|
||||
|
||||
elif isinstance(val, types.Complex):
|
||||
def impl(robj, val):
|
||||
if val.imag != 0:
|
||||
return False
|
||||
elif val.real % 1 != 0:
|
||||
return False
|
||||
else:
|
||||
return impl_contains_helper(robj, int(val.real))
|
||||
return impl
|
||||
|
||||
elif not isinstance(val, types.Number):
|
||||
return impl_false
|
||||
|
||||
|
||||
for ix, attr in enumerate(('start', 'stop', 'step')):
|
||||
make_range_attr(index=ix, attribute=attr)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Implement slices and various slice computations.
|
||||
"""
|
||||
|
||||
from itertools import zip_longest
|
||||
|
||||
from llvmlite import ir
|
||||
from numba.core import cgutils, types, typing, utils
|
||||
from numba.core.imputils import (impl_ret_borrowed, impl_ret_new_ref,
|
||||
impl_ret_untracked, iternext_impl,
|
||||
lower_builtin, lower_cast, lower_constant,
|
||||
lower_getattr)
|
||||
|
||||
|
||||
def fix_index(builder, idx, size):
|
||||
"""
|
||||
Fix negative index by adding *size* to it. Positive
|
||||
indices are left untouched.
|
||||
"""
|
||||
is_negative = builder.icmp_signed('<', idx, ir.Constant(size.type, 0))
|
||||
wrapped_index = builder.add(idx, size)
|
||||
return builder.select(is_negative, wrapped_index, idx)
|
||||
|
||||
|
||||
def fix_slice(builder, slice, size):
|
||||
"""
|
||||
Fix *slice* start and stop to be valid (inclusive and exclusive, resp)
|
||||
indexing bounds for a sequence of the given *size*.
|
||||
"""
|
||||
# See PySlice_GetIndicesEx()
|
||||
zero = ir.Constant(size.type, 0)
|
||||
minus_one = ir.Constant(size.type, -1)
|
||||
|
||||
def fix_bound(bound_name, lower_repl, upper_repl):
|
||||
bound = getattr(slice, bound_name)
|
||||
bound = fix_index(builder, bound, size)
|
||||
# Store value
|
||||
setattr(slice, bound_name, bound)
|
||||
# Still negative? => clamp to lower_repl
|
||||
underflow = builder.icmp_signed('<', bound, zero)
|
||||
with builder.if_then(underflow, likely=False):
|
||||
setattr(slice, bound_name, lower_repl)
|
||||
# Greater than size? => clamp to upper_repl
|
||||
overflow = builder.icmp_signed('>=', bound, size)
|
||||
with builder.if_then(overflow, likely=False):
|
||||
setattr(slice, bound_name, upper_repl)
|
||||
|
||||
with builder.if_else(cgutils.is_neg_int(builder, slice.step)) as (if_neg_step, if_pos_step):
|
||||
with if_pos_step:
|
||||
# < 0 => 0; >= size => size
|
||||
fix_bound('start', zero, size)
|
||||
fix_bound('stop', zero, size)
|
||||
with if_neg_step:
|
||||
# < 0 => -1; >= size => size - 1
|
||||
lower = minus_one
|
||||
upper = builder.add(size, minus_one)
|
||||
fix_bound('start', lower, upper)
|
||||
fix_bound('stop', lower, upper)
|
||||
|
||||
|
||||
def get_slice_length(builder, slicestruct):
|
||||
"""
|
||||
Given a slice, compute the number of indices it spans, i.e. the
|
||||
number of iterations that for_range_slice() will execute.
|
||||
|
||||
Pseudo-code:
|
||||
assert step != 0
|
||||
if step > 0:
|
||||
if stop <= start:
|
||||
return 0
|
||||
else:
|
||||
return (stop - start - 1) // step + 1
|
||||
else:
|
||||
if stop >= start:
|
||||
return 0
|
||||
else:
|
||||
return (stop - start + 1) // step + 1
|
||||
|
||||
(see PySlice_GetIndicesEx() in CPython)
|
||||
"""
|
||||
start = slicestruct.start
|
||||
stop = slicestruct.stop
|
||||
step = slicestruct.step
|
||||
one = ir.Constant(start.type, 1)
|
||||
zero = ir.Constant(start.type, 0)
|
||||
|
||||
is_step_negative = cgutils.is_neg_int(builder, step)
|
||||
delta = builder.sub(stop, start)
|
||||
|
||||
# Nominal case
|
||||
pos_dividend = builder.sub(delta, one)
|
||||
neg_dividend = builder.add(delta, one)
|
||||
dividend = builder.select(is_step_negative, neg_dividend, pos_dividend)
|
||||
nominal_length = builder.add(one, builder.sdiv(dividend, step))
|
||||
|
||||
# Catch zero length
|
||||
is_zero_length = builder.select(is_step_negative,
|
||||
builder.icmp_signed('>=', delta, zero),
|
||||
builder.icmp_signed('<=', delta, zero))
|
||||
|
||||
# Clamp to 0 if is_zero_length
|
||||
return builder.select(is_zero_length, zero, nominal_length)
|
||||
|
||||
|
||||
def get_slice_bounds(builder, slicestruct):
|
||||
"""
|
||||
Return the [lower, upper) indexing bounds of a slice.
|
||||
"""
|
||||
start = slicestruct.start
|
||||
stop = slicestruct.stop
|
||||
zero = start.type(0)
|
||||
one = start.type(1)
|
||||
# This is a bit pessimal, e.g. it will return [1, 5) instead
|
||||
# of [1, 4) for `1:5:2`
|
||||
is_step_negative = builder.icmp_signed('<', slicestruct.step, zero)
|
||||
lower = builder.select(is_step_negative,
|
||||
builder.add(stop, one), start)
|
||||
upper = builder.select(is_step_negative,
|
||||
builder.add(start, one), stop)
|
||||
return lower, upper
|
||||
|
||||
|
||||
def fix_stride(builder, slice, stride):
|
||||
"""
|
||||
Fix the given stride for the slice's step.
|
||||
"""
|
||||
return builder.mul(slice.step, stride)
|
||||
|
||||
def guard_invalid_slice(context, builder, typ, slicestruct):
|
||||
"""
|
||||
Guard against *slicestruct* having a zero step (and raise ValueError).
|
||||
"""
|
||||
if typ.has_step:
|
||||
cgutils.guard_null(context, builder, slicestruct.step,
|
||||
(ValueError, "slice step cannot be zero"))
|
||||
|
||||
|
||||
def get_defaults(context):
|
||||
"""
|
||||
Get the default values for a slice's members:
|
||||
(start for positive step, start for negative step,
|
||||
stop for positive step, stop for negative step, step)
|
||||
"""
|
||||
maxint = (1 << (context.address_size - 1)) - 1
|
||||
return (0, maxint, maxint, - maxint - 1, 1)
|
||||
|
||||
|
||||
#---------------------------------------------------------------------------
|
||||
# The slice structure
|
||||
|
||||
@lower_builtin(slice, types.VarArg(types.Any))
|
||||
def slice_constructor_impl(context, builder, sig, args):
|
||||
(
|
||||
default_start_pos,
|
||||
default_start_neg,
|
||||
default_stop_pos,
|
||||
default_stop_neg,
|
||||
default_step,
|
||||
) = [context.get_constant(types.intp, x) for x in get_defaults(context)]
|
||||
|
||||
slice_args = [None] * 3
|
||||
|
||||
# Fetch non-None arguments
|
||||
if len(args) == 1 and sig.args[0] is not types.none:
|
||||
slice_args[1] = args[0]
|
||||
else:
|
||||
for i, (ty, val) in enumerate(zip(sig.args, args)):
|
||||
if ty is not types.none:
|
||||
slice_args[i] = val
|
||||
|
||||
# Fill omitted arguments
|
||||
def get_arg_value(i, default):
|
||||
val = slice_args[i]
|
||||
if val is None:
|
||||
return default
|
||||
else:
|
||||
return val
|
||||
|
||||
step = get_arg_value(2, default_step)
|
||||
is_step_negative = builder.icmp_signed('<', step,
|
||||
context.get_constant(types.intp, 0))
|
||||
default_stop = builder.select(is_step_negative,
|
||||
default_stop_neg, default_stop_pos)
|
||||
default_start = builder.select(is_step_negative,
|
||||
default_start_neg, default_start_pos)
|
||||
stop = get_arg_value(1, default_stop)
|
||||
start = get_arg_value(0, default_start)
|
||||
|
||||
ty = sig.return_type
|
||||
sli = context.make_helper(builder, sig.return_type)
|
||||
sli.start = start
|
||||
sli.stop = stop
|
||||
sli.step = step
|
||||
|
||||
res = sli._getvalue()
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_getattr(types.SliceType, "start")
|
||||
def slice_start_impl(context, builder, typ, value):
|
||||
sli = context.make_helper(builder, typ, value)
|
||||
return sli.start
|
||||
|
||||
@lower_getattr(types.SliceType, "stop")
|
||||
def slice_stop_impl(context, builder, typ, value):
|
||||
sli = context.make_helper(builder, typ, value)
|
||||
return sli.stop
|
||||
|
||||
@lower_getattr(types.SliceType, "step")
|
||||
def slice_step_impl(context, builder, typ, value):
|
||||
if typ.has_step:
|
||||
sli = context.make_helper(builder, typ, value)
|
||||
return sli.step
|
||||
else:
|
||||
return context.get_constant(types.intp, 1)
|
||||
|
||||
|
||||
@lower_builtin("slice.indices", types.SliceType, types.Integer)
|
||||
def slice_indices(context, builder, sig, args):
|
||||
length = args[1]
|
||||
sli = context.make_helper(builder, sig.args[0], args[0])
|
||||
|
||||
with builder.if_then(cgutils.is_neg_int(builder, length), likely=False):
|
||||
context.call_conv.return_user_exc(
|
||||
builder, ValueError,
|
||||
("length should not be negative",)
|
||||
)
|
||||
with builder.if_then(cgutils.is_scalar_zero(builder, sli.step), likely=False):
|
||||
context.call_conv.return_user_exc(
|
||||
builder, ValueError,
|
||||
("slice step cannot be zero",)
|
||||
)
|
||||
|
||||
fix_slice(builder, sli, length)
|
||||
|
||||
return context.make_tuple(
|
||||
builder,
|
||||
sig.return_type,
|
||||
(sli.start, sli.stop, sli.step)
|
||||
)
|
||||
|
||||
|
||||
def make_slice_from_constant(context, builder, ty, pyval):
|
||||
sli = context.make_helper(builder, ty)
|
||||
lty = context.get_value_type(types.intp)
|
||||
|
||||
(
|
||||
default_start_pos,
|
||||
default_start_neg,
|
||||
default_stop_pos,
|
||||
default_stop_neg,
|
||||
default_step,
|
||||
) = [context.get_constant(types.intp, x) for x in get_defaults(context)]
|
||||
|
||||
step = pyval.step
|
||||
if step is None:
|
||||
step_is_neg = False
|
||||
step = default_step
|
||||
else:
|
||||
step_is_neg = step < 0
|
||||
step = lty(step)
|
||||
|
||||
start = pyval.start
|
||||
if start is None:
|
||||
if step_is_neg:
|
||||
start = default_start_neg
|
||||
else:
|
||||
start = default_start_pos
|
||||
else:
|
||||
start = lty(start)
|
||||
|
||||
stop = pyval.stop
|
||||
if stop is None:
|
||||
if step_is_neg:
|
||||
stop = default_stop_neg
|
||||
else:
|
||||
stop = default_stop_pos
|
||||
else:
|
||||
stop = lty(stop)
|
||||
|
||||
sli.start = start
|
||||
sli.stop = stop
|
||||
sli.step = step
|
||||
|
||||
return sli._getvalue()
|
||||
|
||||
|
||||
@lower_constant(types.SliceType)
|
||||
def constant_slice(context, builder, ty, pyval):
|
||||
if isinstance(ty, types.Literal):
|
||||
typ = ty.literal_type
|
||||
else:
|
||||
typ = ty
|
||||
|
||||
return make_slice_from_constant(context, builder, typ, pyval)
|
||||
|
||||
|
||||
@lower_cast(types.misc.SliceLiteral, types.SliceType)
|
||||
def cast_from_literal(context, builder, fromty, toty, val):
|
||||
return make_slice_from_constant(
|
||||
context, builder, toty, fromty.literal_value,
|
||||
)
|
||||
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Implementation of tuple objects
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
from numba.core.imputils import (lower_builtin, lower_getattr_generic,
|
||||
lower_cast, lower_constant, iternext_impl,
|
||||
impl_ret_borrowed, impl_ret_untracked,
|
||||
RefType)
|
||||
from numba.core import typing, types, cgutils
|
||||
from numba.core.extending import overload_method, overload, intrinsic
|
||||
|
||||
|
||||
@lower_builtin(types.NamedTupleClass, types.VarArg(types.Any))
|
||||
def namedtuple_constructor(context, builder, sig, args):
|
||||
# A namedtuple has the same representation as a regular tuple
|
||||
# the arguments need casting (lower_cast) from the types in the ctor args
|
||||
# to those in the ctor return type, this is to handle cases such as a
|
||||
# literal present in the args, but a type present in the return type.
|
||||
newargs = []
|
||||
for i, arg in enumerate(args):
|
||||
casted = context.cast(builder, arg, sig.args[i], sig.return_type[i])
|
||||
newargs.append(casted)
|
||||
res = context.make_tuple(builder, sig.return_type, tuple(newargs))
|
||||
# The tuple's contents are borrowed
|
||||
return impl_ret_borrowed(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin(operator.add, types.BaseTuple, types.BaseTuple)
|
||||
def tuple_add(context, builder, sig, args):
|
||||
left, right = [cgutils.unpack_tuple(builder, x) for x in args]
|
||||
res = context.make_tuple(builder, sig.return_type, left + right)
|
||||
# The tuple's contents are borrowed
|
||||
return impl_ret_borrowed(context, builder, sig.return_type, res)
|
||||
|
||||
def tuple_cmp_ordered(context, builder, op, sig, args):
|
||||
tu, tv = sig.args
|
||||
u, v = args
|
||||
res = cgutils.alloca_once_value(builder, cgutils.true_bit)
|
||||
bbend = builder.append_basic_block("cmp_end")
|
||||
for i, (ta, tb) in enumerate(zip(tu.types, tv.types)):
|
||||
a = builder.extract_value(u, i)
|
||||
b = builder.extract_value(v, i)
|
||||
not_equal = context.generic_compare(builder, operator.ne, (ta, tb), (a, b))
|
||||
with builder.if_then(not_equal):
|
||||
pred = context.generic_compare(builder, op, (ta, tb), (a, b))
|
||||
builder.store(pred, res)
|
||||
builder.branch(bbend)
|
||||
# Everything matched equal => compare lengths
|
||||
len_compare = op(len(tu.types), len(tv.types))
|
||||
pred = context.get_constant(types.boolean, len_compare)
|
||||
builder.store(pred, res)
|
||||
builder.branch(bbend)
|
||||
builder.position_at_end(bbend)
|
||||
return builder.load(res)
|
||||
|
||||
|
||||
@lower_builtin(operator.eq, types.BaseTuple, types.BaseTuple)
|
||||
def tuple_eq(context, builder, sig, args):
|
||||
tu, tv = sig.args
|
||||
u, v = args
|
||||
if len(tu.types) != len(tv.types):
|
||||
res = context.get_constant(types.boolean, False)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
res = context.get_constant(types.boolean, True)
|
||||
for i, (ta, tb) in enumerate(zip(tu.types, tv.types)):
|
||||
a = builder.extract_value(u, i)
|
||||
b = builder.extract_value(v, i)
|
||||
pred = context.generic_compare(builder, operator.eq, (ta, tb), (a, b))
|
||||
res = builder.and_(res, pred)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin(operator.ne, types.BaseTuple, types.BaseTuple)
|
||||
def tuple_ne(context, builder, sig, args):
|
||||
res = builder.not_(tuple_eq(context, builder, sig, args))
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin(operator.lt, types.BaseTuple, types.BaseTuple)
|
||||
def tuple_lt(context, builder, sig, args):
|
||||
res = tuple_cmp_ordered(context, builder, operator.lt, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin(operator.le, types.BaseTuple, types.BaseTuple)
|
||||
def tuple_le(context, builder, sig, args):
|
||||
res = tuple_cmp_ordered(context, builder, operator.le, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin(operator.gt, types.BaseTuple, types.BaseTuple)
|
||||
def tuple_gt(context, builder, sig, args):
|
||||
res = tuple_cmp_ordered(context, builder, operator.gt, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
@lower_builtin(operator.ge, types.BaseTuple, types.BaseTuple)
|
||||
def tuple_ge(context, builder, sig, args):
|
||||
res = tuple_cmp_ordered(context, builder, operator.ge, sig, args)
|
||||
return impl_ret_untracked(context, builder, sig.return_type, res)
|
||||
|
||||
# for hashing see hashing.py
|
||||
|
||||
@lower_getattr_generic(types.BaseNamedTuple)
|
||||
def namedtuple_getattr(context, builder, typ, value, attr):
|
||||
"""
|
||||
Fetch a namedtuple's field.
|
||||
"""
|
||||
index = typ.fields.index(attr)
|
||||
res = builder.extract_value(value, index)
|
||||
return impl_ret_borrowed(context, builder, typ[index], res)
|
||||
|
||||
|
||||
@lower_constant(types.UniTuple)
|
||||
@lower_constant(types.NamedUniTuple)
|
||||
def unituple_constant(context, builder, ty, pyval):
|
||||
"""
|
||||
Create a homogeneous tuple constant.
|
||||
"""
|
||||
consts = [context.get_constant_generic(builder, ty.dtype, v)
|
||||
for v in pyval]
|
||||
return impl_ret_borrowed(
|
||||
context, builder, ty, cgutils.pack_array(builder, consts),
|
||||
)
|
||||
|
||||
@lower_constant(types.Tuple)
|
||||
@lower_constant(types.NamedTuple)
|
||||
def unituple_constant(context, builder, ty, pyval):
|
||||
"""
|
||||
Create a heterogeneous tuple constant.
|
||||
"""
|
||||
consts = [context.get_constant_generic(builder, ty.types[i], v)
|
||||
for i, v in enumerate(pyval)]
|
||||
return impl_ret_borrowed(
|
||||
context, builder, ty, cgutils.pack_struct(builder, consts),
|
||||
)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Tuple iterators
|
||||
|
||||
@lower_builtin('getiter', types.UniTuple)
|
||||
@lower_builtin('getiter', types.NamedUniTuple)
|
||||
def getiter_unituple(context, builder, sig, args):
|
||||
[tupty] = sig.args
|
||||
[tup] = args
|
||||
|
||||
iterval = context.make_helper(builder, types.UniTupleIter(tupty))
|
||||
|
||||
index0 = context.get_constant(types.intp, 0)
|
||||
indexptr = cgutils.alloca_once(builder, index0.type)
|
||||
builder.store(index0, indexptr)
|
||||
|
||||
iterval.index = indexptr
|
||||
iterval.tuple = tup
|
||||
|
||||
res = iterval._getvalue()
|
||||
return impl_ret_borrowed(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin('iternext', types.UniTupleIter)
|
||||
@iternext_impl(RefType.BORROWED)
|
||||
def iternext_unituple(context, builder, sig, args, result):
|
||||
[tupiterty] = sig.args
|
||||
[tupiter] = args
|
||||
|
||||
iterval = context.make_helper(builder, tupiterty, value=tupiter)
|
||||
|
||||
tup = iterval.tuple
|
||||
idxptr = iterval.index
|
||||
idx = builder.load(idxptr)
|
||||
count = context.get_constant(types.intp, tupiterty.container.count)
|
||||
|
||||
is_valid = builder.icmp_signed('<', idx, count)
|
||||
result.set_valid(is_valid)
|
||||
|
||||
with builder.if_then(is_valid):
|
||||
getitem_sig = typing.signature(tupiterty.container.dtype,
|
||||
tupiterty.container,
|
||||
types.intp)
|
||||
getitem_out = getitem_unituple(context, builder, getitem_sig,
|
||||
[tup, idx])
|
||||
# As a iternext_impl function, this will incref the yieled value.
|
||||
# We need to release the new reference from getitem_unituple.
|
||||
if context.enable_nrt:
|
||||
context.nrt.decref(builder, tupiterty.container.dtype, getitem_out)
|
||||
result.yield_(getitem_out)
|
||||
nidx = builder.add(idx, context.get_constant(types.intp, 1))
|
||||
builder.store(nidx, iterval.index)
|
||||
|
||||
|
||||
@overload(operator.getitem)
|
||||
def getitem_literal_idx(tup, idx):
|
||||
"""
|
||||
Overloads BaseTuple getitem to cover cases where constant
|
||||
inference and RewriteConstGetitems cannot replace it
|
||||
with a static_getitem.
|
||||
"""
|
||||
if not (isinstance(tup, types.BaseTuple)
|
||||
and isinstance(idx, types.IntegerLiteral)):
|
||||
return None
|
||||
|
||||
idx_val = idx.literal_value
|
||||
def getitem_literal_idx_impl(tup, idx):
|
||||
return tup[idx_val]
|
||||
|
||||
return getitem_literal_idx_impl
|
||||
|
||||
|
||||
@lower_builtin('typed_getitem', types.BaseTuple, types.Any)
|
||||
def getitem_typed(context, builder, sig, args):
|
||||
tupty, _ = sig.args
|
||||
tup, idx = args
|
||||
errmsg_oob = ("tuple index out of range",)
|
||||
|
||||
if len(tupty) == 0:
|
||||
# Empty tuple.
|
||||
|
||||
# Always branch and raise IndexError
|
||||
with builder.if_then(cgutils.true_bit):
|
||||
context.call_conv.return_user_exc(builder, IndexError,
|
||||
errmsg_oob)
|
||||
# This is unreachable in runtime,
|
||||
# but it exists to not terminate the current basicblock.
|
||||
res = context.get_constant_null(sig.return_type)
|
||||
return impl_ret_untracked(context, builder,
|
||||
sig.return_type, res)
|
||||
else:
|
||||
# The tuple is not empty
|
||||
|
||||
bbelse = builder.append_basic_block("typed_switch.else")
|
||||
bbend = builder.append_basic_block("typed_switch.end")
|
||||
switch = builder.switch(idx, bbelse)
|
||||
|
||||
with builder.goto_block(bbelse):
|
||||
context.call_conv.return_user_exc(builder, IndexError,
|
||||
errmsg_oob)
|
||||
|
||||
lrtty = context.get_value_type(sig.return_type)
|
||||
voidptrty = context.get_value_type(types.voidptr)
|
||||
with builder.goto_block(bbend):
|
||||
phinode = builder.phi(voidptrty)
|
||||
|
||||
for i in range(tupty.count):
|
||||
ki = context.get_constant(types.intp, i)
|
||||
bbi = builder.append_basic_block("typed_switch.%d" % i)
|
||||
switch.add_case(ki, bbi)
|
||||
# handle negative indexing, create case (-tuple.count + i) to
|
||||
# reference same block as i
|
||||
kin = context.get_constant(types.intp, -tupty.count + i)
|
||||
switch.add_case(kin, bbi)
|
||||
with builder.goto_block(bbi):
|
||||
value = builder.extract_value(tup, i)
|
||||
# Dragon warning...
|
||||
# The fact the code has made it this far suggests that type
|
||||
# inference decided whatever was being done with the item pulled
|
||||
# from the tuple was legitimate, it is not the job of lowering
|
||||
# to argue about that. However, here lies a problem, the tuple
|
||||
# lowering is implemented as a switch table with each case
|
||||
# writing to a phi node slot that is returned. The type of this
|
||||
# phi node slot needs to be "correct" for the current type but
|
||||
# it also needs to survive stores being made to it from the
|
||||
# other cases that will in effect never run. To do this a stack
|
||||
# slot is made for each case for the specific type and then cast
|
||||
# to a void pointer type, this is then added as an incoming on
|
||||
# the phi node, at the end of the switch the phi node is then
|
||||
# cast back to the required return type for this typed_getitem.
|
||||
# The only further complication is that if the value is not a
|
||||
# pointer then the void* juggle won't work so a cast is made
|
||||
# prior to store, again, that type inference has permitted it
|
||||
# suggests this is safe.
|
||||
# End Dragon warning...
|
||||
DOCAST = context.typing_context.unify_types(sig.args[0][i],
|
||||
sig.return_type) == sig.return_type
|
||||
if DOCAST:
|
||||
value_slot = builder.alloca(lrtty,
|
||||
name="TYPED_VALUE_SLOT%s" % i)
|
||||
casted = context.cast(builder, value, sig.args[0][i],
|
||||
sig.return_type)
|
||||
builder.store(casted, value_slot)
|
||||
else:
|
||||
value_slot = builder.alloca(value.type,
|
||||
name="TYPED_VALUE_SLOT%s" % i)
|
||||
builder.store(value, value_slot)
|
||||
phinode.add_incoming(builder.bitcast(value_slot, voidptrty),
|
||||
bbi)
|
||||
builder.branch(bbend)
|
||||
|
||||
builder.position_at_end(bbend)
|
||||
res = builder.bitcast(phinode, lrtty.as_pointer())
|
||||
res = builder.load(res)
|
||||
return impl_ret_borrowed(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin(operator.getitem, types.UniTuple, types.intp)
|
||||
@lower_builtin(operator.getitem, types.UniTuple, types.uintp)
|
||||
@lower_builtin(operator.getitem, types.NamedUniTuple, types.intp)
|
||||
@lower_builtin(operator.getitem, types.NamedUniTuple, types.uintp)
|
||||
def getitem_unituple(context, builder, sig, args):
|
||||
tupty, _ = sig.args
|
||||
tup, idx = args
|
||||
|
||||
errmsg_oob = ("tuple index out of range",)
|
||||
|
||||
if len(tupty) == 0:
|
||||
# Empty tuple.
|
||||
|
||||
# Always branch and raise IndexError
|
||||
with builder.if_then(cgutils.true_bit):
|
||||
context.call_conv.return_user_exc(builder, IndexError,
|
||||
errmsg_oob)
|
||||
# This is unreachable in runtime,
|
||||
# but it exists to not terminate the current basicblock.
|
||||
res = context.get_constant_null(sig.return_type)
|
||||
return impl_ret_untracked(context, builder,
|
||||
sig.return_type, res)
|
||||
else:
|
||||
# The tuple is not empty
|
||||
bbelse = builder.append_basic_block("switch.else")
|
||||
bbend = builder.append_basic_block("switch.end")
|
||||
switch = builder.switch(idx, bbelse)
|
||||
|
||||
with builder.goto_block(bbelse):
|
||||
context.call_conv.return_user_exc(builder, IndexError,
|
||||
errmsg_oob)
|
||||
|
||||
lrtty = context.get_value_type(tupty.dtype)
|
||||
with builder.goto_block(bbend):
|
||||
phinode = builder.phi(lrtty)
|
||||
|
||||
for i in range(tupty.count):
|
||||
ki = context.get_constant(types.intp, i)
|
||||
bbi = builder.append_basic_block("switch.%d" % i)
|
||||
switch.add_case(ki, bbi)
|
||||
# handle negative indexing, create case (-tuple.count + i) to
|
||||
# reference same block as i
|
||||
kin = context.get_constant(types.intp, -tupty.count + i)
|
||||
switch.add_case(kin, bbi)
|
||||
with builder.goto_block(bbi):
|
||||
value = builder.extract_value(tup, i)
|
||||
builder.branch(bbend)
|
||||
phinode.add_incoming(value, bbi)
|
||||
|
||||
builder.position_at_end(bbend)
|
||||
res = phinode
|
||||
assert sig.return_type == tupty.dtype
|
||||
return impl_ret_borrowed(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
@lower_builtin('static_getitem', types.LiteralStrKeyDict, types.StringLiteral)
|
||||
@lower_builtin('static_getitem', types.LiteralList, types.IntegerLiteral)
|
||||
@lower_builtin('static_getitem', types.LiteralList, types.SliceLiteral)
|
||||
@lower_builtin('static_getitem', types.BaseTuple, types.IntegerLiteral)
|
||||
@lower_builtin('static_getitem', types.BaseTuple, types.SliceLiteral)
|
||||
def static_getitem_tuple(context, builder, sig, args):
|
||||
tupty, idxty = sig.args
|
||||
tup, idx = args
|
||||
if isinstance(idx, int):
|
||||
if idx < 0:
|
||||
idx += len(tupty)
|
||||
if not 0 <= idx < len(tupty):
|
||||
raise IndexError("cannot index at %d in %s" % (idx, tupty))
|
||||
res = builder.extract_value(tup, idx)
|
||||
elif isinstance(idx, slice):
|
||||
items = cgutils.unpack_tuple(builder, tup)[idx]
|
||||
res = context.make_tuple(builder, sig.return_type, items)
|
||||
elif isinstance(tupty, types.LiteralStrKeyDict):
|
||||
# pretend to be a dictionary
|
||||
idx_val = idxty.literal_value
|
||||
idx_offset = tupty.fields.index(idx_val)
|
||||
res = builder.extract_value(tup, idx_offset)
|
||||
else:
|
||||
raise NotImplementedError("unexpected index %r for %s"
|
||||
% (idx, sig.args[0]))
|
||||
return impl_ret_borrowed(context, builder, sig.return_type, res)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Implicit conversion
|
||||
|
||||
@lower_cast(types.BaseTuple, types.BaseTuple)
|
||||
def tuple_to_tuple(context, builder, fromty, toty, val):
|
||||
if (isinstance(fromty, types.BaseNamedTuple)
|
||||
or isinstance(toty, types.BaseNamedTuple)):
|
||||
# Disallowed by typing layer
|
||||
raise NotImplementedError
|
||||
|
||||
if len(fromty) != len(toty):
|
||||
# Disallowed by typing layer
|
||||
raise NotImplementedError
|
||||
|
||||
olditems = cgutils.unpack_tuple(builder, val, len(fromty))
|
||||
items = [context.cast(builder, v, f, t)
|
||||
for v, f, t in zip(olditems, fromty, toty)]
|
||||
return context.make_tuple(builder, toty, items)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Methods
|
||||
|
||||
@overload_method(types.BaseTuple, 'index')
|
||||
def tuple_index(tup, value):
|
||||
|
||||
def tuple_index_impl(tup, value):
|
||||
for i in range(len(tup)):
|
||||
if tup[i] == value:
|
||||
return i
|
||||
raise ValueError("tuple.index(x): x not in tuple")
|
||||
|
||||
return tuple_index_impl
|
||||
|
||||
|
||||
@overload(operator.contains)
|
||||
def in_seq_empty_tuple(x, y):
|
||||
if isinstance(x, types.Tuple) and not x.types:
|
||||
return lambda x, y: False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,765 @@
|
||||
"""
|
||||
This module contains support functions for more advanced unicode operations.
|
||||
This is not a public API and is for Numba internal use only. Most of the
|
||||
functions are relatively straightforward translations of the functions with the
|
||||
same name in CPython.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
from enum import IntEnum
|
||||
|
||||
import llvmlite.ir
|
||||
import numpy as np
|
||||
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.imputils import (impl_ret_untracked)
|
||||
|
||||
from numba.core.extending import overload, intrinsic, register_jitable
|
||||
from numba.core.errors import TypingError
|
||||
|
||||
# This is equivalent to the struct `_PyUnicode_TypeRecord defined in CPython's
|
||||
# Objects/unicodectype.c
|
||||
typerecord = namedtuple('typerecord',
|
||||
'upper lower title decimal digit flags')
|
||||
|
||||
# The Py_UCS4 type from CPython:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/unicodeobject.h#L112 # noqa: E501
|
||||
_Py_UCS4 = types.uint32
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Start code related to/from CPython's unicodectype impl
|
||||
#
|
||||
# NOTE: the original source at:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c # noqa: E501
|
||||
# contains this statement:
|
||||
#
|
||||
# /*
|
||||
# Unicode character type helpers.
|
||||
#
|
||||
# Written by Marc-Andre Lemburg (mal@lemburg.com).
|
||||
# Modified for Python 2.0 by Fredrik Lundh (fredrik@pythonware.com)
|
||||
#
|
||||
# Copyright (c) Corporation for National Research Initiatives.
|
||||
#
|
||||
# */
|
||||
|
||||
|
||||
# This enum contains the values defined in CPython's Objects/unicodectype.c that
|
||||
# provide masks for use against the various members of the typerecord
|
||||
#
|
||||
# See: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L13-L27 # noqa: E501
|
||||
#
|
||||
|
||||
|
||||
_Py_TAB = 0x9
|
||||
_Py_LINEFEED = 0xa
|
||||
_Py_CARRIAGE_RETURN = 0xd
|
||||
_Py_SPACE = 0x20
|
||||
|
||||
|
||||
class _PyUnicode_TyperecordMasks(IntEnum):
|
||||
ALPHA_MASK = 0x01
|
||||
DECIMAL_MASK = 0x02
|
||||
DIGIT_MASK = 0x04
|
||||
LOWER_MASK = 0x08
|
||||
LINEBREAK_MASK = 0x10
|
||||
SPACE_MASK = 0x20
|
||||
TITLE_MASK = 0x40
|
||||
UPPER_MASK = 0x80
|
||||
XID_START_MASK = 0x100
|
||||
XID_CONTINUE_MASK = 0x200
|
||||
PRINTABLE_MASK = 0x400
|
||||
NUMERIC_MASK = 0x800
|
||||
CASE_IGNORABLE_MASK = 0x1000
|
||||
CASED_MASK = 0x2000
|
||||
EXTENDED_CASE_MASK = 0x4000
|
||||
|
||||
|
||||
def _PyUnicode_gettyperecord(a):
|
||||
raise RuntimeError("Calling the Python definition is invalid")
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _gettyperecord_impl(typingctx, codepoint):
|
||||
"""
|
||||
Provides the binding to numba_gettyperecord, returns a `typerecord`
|
||||
namedtuple of properties from the codepoint.
|
||||
"""
|
||||
if not isinstance(codepoint, types.Integer):
|
||||
raise TypingError("codepoint must be an integer")
|
||||
|
||||
def details(context, builder, signature, args):
|
||||
ll_void = context.get_value_type(types.void)
|
||||
ll_Py_UCS4 = context.get_value_type(_Py_UCS4)
|
||||
ll_intc = context.get_value_type(types.intc)
|
||||
ll_intc_ptr = ll_intc.as_pointer()
|
||||
ll_uchar = context.get_value_type(types.uchar)
|
||||
ll_uchar_ptr = ll_uchar.as_pointer()
|
||||
ll_ushort = context.get_value_type(types.ushort)
|
||||
ll_ushort_ptr = ll_ushort.as_pointer()
|
||||
fnty = llvmlite.ir.FunctionType(ll_void, [
|
||||
ll_Py_UCS4, # code
|
||||
ll_intc_ptr, # upper
|
||||
ll_intc_ptr, # lower
|
||||
ll_intc_ptr, # title
|
||||
ll_uchar_ptr, # decimal
|
||||
ll_uchar_ptr, # digit
|
||||
ll_ushort_ptr, # flags
|
||||
])
|
||||
fn = cgutils.get_or_insert_function(
|
||||
builder.module,
|
||||
fnty, name="numba_gettyperecord")
|
||||
upper = cgutils.alloca_once(builder, ll_intc, name='upper')
|
||||
lower = cgutils.alloca_once(builder, ll_intc, name='lower')
|
||||
title = cgutils.alloca_once(builder, ll_intc, name='title')
|
||||
decimal = cgutils.alloca_once(builder, ll_uchar, name='decimal')
|
||||
digit = cgutils.alloca_once(builder, ll_uchar, name='digit')
|
||||
flags = cgutils.alloca_once(builder, ll_ushort, name='flags')
|
||||
|
||||
byref = [ upper, lower, title, decimal, digit, flags]
|
||||
builder.call(fn, [args[0]] + byref)
|
||||
buf = []
|
||||
for x in byref:
|
||||
buf.append(builder.load(x))
|
||||
|
||||
res = context.make_tuple(builder, signature.return_type, tuple(buf))
|
||||
return impl_ret_untracked(context, builder, signature.return_type, res)
|
||||
|
||||
tupty = types.NamedTuple([types.intc, types.intc, types.intc, types.uchar,
|
||||
types.uchar, types.ushort], typerecord)
|
||||
sig = tupty(_Py_UCS4)
|
||||
return sig, details
|
||||
|
||||
|
||||
@overload(_PyUnicode_gettyperecord)
|
||||
def gettyperecord_impl(a):
|
||||
"""
|
||||
Provides a _PyUnicode_gettyperecord binding, for convenience it will accept
|
||||
single character strings and code points.
|
||||
"""
|
||||
if isinstance(a, types.UnicodeType):
|
||||
from numba.cpython.unicode import _get_code_point
|
||||
|
||||
def impl(a):
|
||||
if len(a) > 1:
|
||||
msg = "gettyperecord takes a single unicode character"
|
||||
raise ValueError(msg)
|
||||
code_point = _get_code_point(a, 0)
|
||||
data = _gettyperecord_impl(_Py_UCS4(code_point))
|
||||
return data
|
||||
return impl
|
||||
if isinstance(a, types.Integer):
|
||||
return lambda a: _gettyperecord_impl(_Py_UCS4(a))
|
||||
|
||||
|
||||
# whilst it's possible to grab the _PyUnicode_ExtendedCase symbol as it's global
|
||||
# it is safer to use a defined api:
|
||||
@intrinsic
|
||||
def _PyUnicode_ExtendedCase(typingctx, index):
|
||||
"""
|
||||
Accessor function for the _PyUnicode_ExtendedCase array, binds to
|
||||
numba_get_PyUnicode_ExtendedCase which wraps the array and does the lookup
|
||||
"""
|
||||
if not isinstance(index, types.Integer):
|
||||
raise TypingError("Expected an index")
|
||||
|
||||
def details(context, builder, signature, args):
|
||||
ll_Py_UCS4 = context.get_value_type(_Py_UCS4)
|
||||
ll_intc = context.get_value_type(types.intc)
|
||||
fnty = llvmlite.ir.FunctionType(ll_Py_UCS4, [ll_intc])
|
||||
fn = cgutils.get_or_insert_function(
|
||||
builder.module,
|
||||
fnty, name="numba_get_PyUnicode_ExtendedCase")
|
||||
return builder.call(fn, [args[0]])
|
||||
|
||||
sig = _Py_UCS4(types.intc)
|
||||
return sig, details
|
||||
|
||||
# The following functions are replications of the functions with the same name
|
||||
# in CPython's Objects/unicodectype.c
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L64-L71 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_ToTitlecase(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
|
||||
return _PyUnicode_ExtendedCase(ctype.title & 0xFFFF)
|
||||
return ch + ctype.title
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L76-L81 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsTitlecase(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.TITLE_MASK != 0
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L86-L91 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsXidStart(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.XID_START_MASK != 0
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L96-L101 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsXidContinue(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.XID_CONTINUE_MASK != 0
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _PyUnicode_ToDecimalDigit(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
if ctype.flags & _PyUnicode_TyperecordMasks.DECIMAL_MASK:
|
||||
return ctype.decimal
|
||||
return -1
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L123-L1128 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_ToDigit(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
if ctype.flags & _PyUnicode_TyperecordMasks.DIGIT_MASK:
|
||||
return ctype.digit
|
||||
return -1
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L140-L145 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsNumeric(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.NUMERIC_MASK != 0
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L160-L165 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsPrintable(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.PRINTABLE_MASK != 0
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L170-L175 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsLowercase(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.LOWER_MASK != 0
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L180-L185 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsUppercase(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.UPPER_MASK != 0
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _PyUnicode_IsLineBreak(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.LINEBREAK_MASK != 0
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _PyUnicode_ToUppercase(ch):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _PyUnicode_ToLowercase(ch):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodectype.c#L211-L225 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_ToLowerFull(ch, res):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
|
||||
index = ctype.lower & 0xFFFF
|
||||
n = ctype.lower >> 24
|
||||
for i in range(n):
|
||||
res[i] = _PyUnicode_ExtendedCase(index + i)
|
||||
return n
|
||||
res[0] = ch + ctype.lower
|
||||
return 1
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodectype.c#L227-L241 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_ToTitleFull(ch, res):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
|
||||
index = ctype.title & 0xFFFF
|
||||
n = ctype.title >> 24
|
||||
for i in range(n):
|
||||
res[i] = _PyUnicode_ExtendedCase(index + i)
|
||||
return n
|
||||
res[0] = ch + ctype.title
|
||||
return 1
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L243-L257 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_ToUpperFull(ch, res):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
|
||||
index = ctype.upper & 0xFFFF
|
||||
n = ctype.upper >> 24
|
||||
for i in range(n):
|
||||
# Perhaps needed to use unicode._set_code_point() here
|
||||
res[i] = _PyUnicode_ExtendedCase(index + i)
|
||||
return n
|
||||
res[0] = ch + ctype.upper
|
||||
return 1
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L259-L272 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_ToFoldedFull(ch, res):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
extended_case_mask = _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK
|
||||
if ctype.flags & extended_case_mask and (ctype.lower >> 20) & 7:
|
||||
index = (ctype.lower & 0xFFFF) + (ctype.lower >> 24)
|
||||
n = (ctype.lower >> 20) & 7
|
||||
for i in range(n):
|
||||
res[i] = _PyUnicode_ExtendedCase(index + i)
|
||||
return n
|
||||
return _PyUnicode_ToLowerFull(ch, res)
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L274-L279 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsCased(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.CASED_MASK != 0
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L281-L286 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsCaseIgnorable(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.CASE_IGNORABLE_MASK != 0
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L123-L135 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsDigit(ch):
|
||||
if _PyUnicode_ToDigit(ch) < 0:
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L106-L118 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsDecimalDigit(ch):
|
||||
if _PyUnicode_ToDecimalDigit(ch) < 0:
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L291-L296 # noqa: E501
|
||||
@register_jitable
|
||||
def _PyUnicode_IsSpace(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.SPACE_MASK != 0
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _PyUnicode_IsAlpha(ch):
|
||||
ctype = _PyUnicode_gettyperecord(ch)
|
||||
return ctype.flags & _PyUnicode_TyperecordMasks.ALPHA_MASK != 0
|
||||
|
||||
|
||||
# End code related to/from CPython's unicodectype impl
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Start code related to/from CPython's pyctype
|
||||
|
||||
# From the definition in CPython's Include/pyctype.h
|
||||
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L5-L11 # noqa: E501
|
||||
class _PY_CTF(IntEnum):
|
||||
LOWER = 0x01
|
||||
UPPER = 0x02
|
||||
ALPHA = 0x01 | 0x02
|
||||
DIGIT = 0x04
|
||||
ALNUM = 0x01 | 0x02 | 0x04
|
||||
SPACE = 0x08
|
||||
XDIGIT = 0x10
|
||||
|
||||
|
||||
# From the definition in CPython's Python/pyctype.c
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L5 # noqa: E501
|
||||
_Py_ctype_table = np.array([
|
||||
0, # 0x0 '\x00'
|
||||
0, # 0x1 '\x01'
|
||||
0, # 0x2 '\x02'
|
||||
0, # 0x3 '\x03'
|
||||
0, # 0x4 '\x04'
|
||||
0, # 0x5 '\x05'
|
||||
0, # 0x6 '\x06'
|
||||
0, # 0x7 '\x07'
|
||||
0, # 0x8 '\x08'
|
||||
_PY_CTF.SPACE, # 0x9 '\t'
|
||||
_PY_CTF.SPACE, # 0xa '\n'
|
||||
_PY_CTF.SPACE, # 0xb '\v'
|
||||
_PY_CTF.SPACE, # 0xc '\f'
|
||||
_PY_CTF.SPACE, # 0xd '\r'
|
||||
0, # 0xe '\x0e'
|
||||
0, # 0xf '\x0f'
|
||||
0, # 0x10 '\x10'
|
||||
0, # 0x11 '\x11'
|
||||
0, # 0x12 '\x12'
|
||||
0, # 0x13 '\x13'
|
||||
0, # 0x14 '\x14'
|
||||
0, # 0x15 '\x15'
|
||||
0, # 0x16 '\x16'
|
||||
0, # 0x17 '\x17'
|
||||
0, # 0x18 '\x18'
|
||||
0, # 0x19 '\x19'
|
||||
0, # 0x1a '\x1a'
|
||||
0, # 0x1b '\x1b'
|
||||
0, # 0x1c '\x1c'
|
||||
0, # 0x1d '\x1d'
|
||||
0, # 0x1e '\x1e'
|
||||
0, # 0x1f '\x1f'
|
||||
_PY_CTF.SPACE, # 0x20 ' '
|
||||
0, # 0x21 '!'
|
||||
0, # 0x22 '"'
|
||||
0, # 0x23 '#'
|
||||
0, # 0x24 '$'
|
||||
0, # 0x25 '%'
|
||||
0, # 0x26 '&'
|
||||
0, # 0x27 "'"
|
||||
0, # 0x28 '('
|
||||
0, # 0x29 ')'
|
||||
0, # 0x2a '*'
|
||||
0, # 0x2b '+'
|
||||
0, # 0x2c ','
|
||||
0, # 0x2d '-'
|
||||
0, # 0x2e '.'
|
||||
0, # 0x2f '/'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x30 '0'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x31 '1'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x32 '2'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x33 '3'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x34 '4'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x35 '5'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x36 '6'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x37 '7'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x38 '8'
|
||||
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x39 '9'
|
||||
0, # 0x3a ':'
|
||||
0, # 0x3b ';'
|
||||
0, # 0x3c '<'
|
||||
0, # 0x3d '='
|
||||
0, # 0x3e '>'
|
||||
0, # 0x3f '?'
|
||||
0, # 0x40 '@'
|
||||
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x41 'A'
|
||||
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x42 'B'
|
||||
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x43 'C'
|
||||
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x44 'D'
|
||||
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x45 'E'
|
||||
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x46 'F'
|
||||
_PY_CTF.UPPER, # 0x47 'G'
|
||||
_PY_CTF.UPPER, # 0x48 'H'
|
||||
_PY_CTF.UPPER, # 0x49 'I'
|
||||
_PY_CTF.UPPER, # 0x4a 'J'
|
||||
_PY_CTF.UPPER, # 0x4b 'K'
|
||||
_PY_CTF.UPPER, # 0x4c 'L'
|
||||
_PY_CTF.UPPER, # 0x4d 'M'
|
||||
_PY_CTF.UPPER, # 0x4e 'N'
|
||||
_PY_CTF.UPPER, # 0x4f 'O'
|
||||
_PY_CTF.UPPER, # 0x50 'P'
|
||||
_PY_CTF.UPPER, # 0x51 'Q'
|
||||
_PY_CTF.UPPER, # 0x52 'R'
|
||||
_PY_CTF.UPPER, # 0x53 'S'
|
||||
_PY_CTF.UPPER, # 0x54 'T'
|
||||
_PY_CTF.UPPER, # 0x55 'U'
|
||||
_PY_CTF.UPPER, # 0x56 'V'
|
||||
_PY_CTF.UPPER, # 0x57 'W'
|
||||
_PY_CTF.UPPER, # 0x58 'X'
|
||||
_PY_CTF.UPPER, # 0x59 'Y'
|
||||
_PY_CTF.UPPER, # 0x5a 'Z'
|
||||
0, # 0x5b '['
|
||||
0, # 0x5c '\\'
|
||||
0, # 0x5d ']'
|
||||
0, # 0x5e '^'
|
||||
0, # 0x5f '_'
|
||||
0, # 0x60 '`'
|
||||
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x61 'a'
|
||||
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x62 'b'
|
||||
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x63 'c'
|
||||
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x64 'd'
|
||||
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x65 'e'
|
||||
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x66 'f'
|
||||
_PY_CTF.LOWER, # 0x67 'g'
|
||||
_PY_CTF.LOWER, # 0x68 'h'
|
||||
_PY_CTF.LOWER, # 0x69 'i'
|
||||
_PY_CTF.LOWER, # 0x6a 'j'
|
||||
_PY_CTF.LOWER, # 0x6b 'k'
|
||||
_PY_CTF.LOWER, # 0x6c 'l'
|
||||
_PY_CTF.LOWER, # 0x6d 'm'
|
||||
_PY_CTF.LOWER, # 0x6e 'n'
|
||||
_PY_CTF.LOWER, # 0x6f 'o'
|
||||
_PY_CTF.LOWER, # 0x70 'p'
|
||||
_PY_CTF.LOWER, # 0x71 'q'
|
||||
_PY_CTF.LOWER, # 0x72 'r'
|
||||
_PY_CTF.LOWER, # 0x73 's'
|
||||
_PY_CTF.LOWER, # 0x74 't'
|
||||
_PY_CTF.LOWER, # 0x75 'u'
|
||||
_PY_CTF.LOWER, # 0x76 'v'
|
||||
_PY_CTF.LOWER, # 0x77 'w'
|
||||
_PY_CTF.LOWER, # 0x78 'x'
|
||||
_PY_CTF.LOWER, # 0x79 'y'
|
||||
_PY_CTF.LOWER, # 0x7a 'z'
|
||||
0, # 0x7b '{'
|
||||
0, # 0x7c '|'
|
||||
0, # 0x7d '}'
|
||||
0, # 0x7e '~'
|
||||
0, # 0x7f '\x7f'
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
], dtype=np.intc)
|
||||
|
||||
|
||||
# From the definition in CPython's Python/pyctype.c
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L145 # noqa: E501
|
||||
_Py_ctype_tolower = np.array([
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
|
||||
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
|
||||
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
|
||||
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
|
||||
0x40, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
|
||||
0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
|
||||
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
|
||||
0x78, 0x79, 0x7a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
|
||||
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
|
||||
0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
|
||||
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
|
||||
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
|
||||
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
|
||||
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
|
||||
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
|
||||
0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
|
||||
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
|
||||
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
|
||||
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
|
||||
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
|
||||
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
|
||||
0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
|
||||
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7,
|
||||
0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
|
||||
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
|
||||
0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
|
||||
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
|
||||
0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
|
||||
], dtype=np.uint8)
|
||||
|
||||
|
||||
# From the definition in CPython's Python/pyctype.c
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L180
|
||||
_Py_ctype_toupper = np.array([
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
|
||||
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
|
||||
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
|
||||
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
|
||||
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
|
||||
0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
|
||||
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57,
|
||||
0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
|
||||
0x60, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
|
||||
0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
|
||||
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57,
|
||||
0x58, 0x59, 0x5a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
|
||||
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
|
||||
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
|
||||
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
|
||||
0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
|
||||
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
|
||||
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
|
||||
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
|
||||
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
|
||||
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
|
||||
0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
|
||||
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7,
|
||||
0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
|
||||
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
|
||||
0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
|
||||
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
|
||||
0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
|
||||
], dtype=np.uint8)
|
||||
|
||||
|
||||
class _PY_CTF_LB(IntEnum):
|
||||
LINE_BREAK = 0x01
|
||||
LINE_FEED = 0x02
|
||||
CARRIAGE_RETURN = 0x04
|
||||
|
||||
|
||||
_Py_ctype_islinebreak = np.array([
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
_PY_CTF_LB.LINE_BREAK | _PY_CTF_LB.LINE_FEED, # 0xa '\n'
|
||||
_PY_CTF_LB.LINE_BREAK, # 0xb '\v'
|
||||
_PY_CTF_LB.LINE_BREAK, # 0xc '\f'
|
||||
_PY_CTF_LB.LINE_BREAK | _PY_CTF_LB.CARRIAGE_RETURN, # 0xd '\r'
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
_PY_CTF_LB.LINE_BREAK, # 0x1c '\x1c'
|
||||
_PY_CTF_LB.LINE_BREAK, # 0x1d '\x1d'
|
||||
_PY_CTF_LB.LINE_BREAK, # 0x1e '\x1e'
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
_PY_CTF_LB.LINE_BREAK, # 0x85 '\x85'
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0,
|
||||
], dtype=np.intc)
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pymacro.h#L25 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_CHARMASK(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_CHARMASK()`, masks off all but the
|
||||
lowest 256 bits of ch.
|
||||
"""
|
||||
return types.uint8(ch) & types.uint8(0xff)
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L30 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_TOUPPER(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_TOUPPER()` converts an ASCII range
|
||||
code point to the upper equivalent
|
||||
"""
|
||||
return _Py_ctype_toupper[_Py_CHARMASK(ch)]
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L29 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_TOLOWER(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_TOLOWER()` converts an ASCII range
|
||||
code point to the lower equivalent
|
||||
"""
|
||||
return _Py_ctype_tolower[_Py_CHARMASK(ch)]
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L18 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_ISLOWER(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_ISLOWER()`
|
||||
"""
|
||||
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.LOWER
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L19 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_ISUPPER(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_ISUPPER()`
|
||||
"""
|
||||
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.UPPER
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L20 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_ISALPHA(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_ISALPHA()`
|
||||
"""
|
||||
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.ALPHA
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L21 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_ISDIGIT(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_ISDIGIT()`
|
||||
"""
|
||||
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.DIGIT
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L22 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_ISXDIGIT(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_ISXDIGIT()`
|
||||
"""
|
||||
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.XDIGIT
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L23 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_ISALNUM(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_ISALNUM()`
|
||||
"""
|
||||
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.ALNUM
|
||||
|
||||
|
||||
# Translation of:
|
||||
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L24 # noqa: E501
|
||||
@register_jitable
|
||||
def _Py_ISSPACE(ch):
|
||||
"""
|
||||
Equivalent to the CPython macro `Py_ISSPACE()`
|
||||
"""
|
||||
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.SPACE
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _Py_ISLINEBREAK(ch):
|
||||
"""Check if character is ASCII line break"""
|
||||
return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.LINE_BREAK
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _Py_ISLINEFEED(ch):
|
||||
"""Check if character is line feed `\n`"""
|
||||
return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.LINE_FEED
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _Py_ISCARRIAGERETURN(ch):
|
||||
"""Check if character is carriage return `\r`"""
|
||||
return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.CARRIAGE_RETURN
|
||||
|
||||
|
||||
# End code related to/from CPython's pyctype
|
||||
# ------------------------------------------------------------------------------
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,53 @@
|
||||
""" This module provides the unsafe things for targets/numbers.py
|
||||
"""
|
||||
from numba.core import types, errors
|
||||
from numba.core.extending import intrinsic
|
||||
|
||||
from llvmlite import ir
|
||||
|
||||
|
||||
@intrinsic
|
||||
def viewer(tyctx, val, viewty):
|
||||
""" Bitcast a scalar 'val' to the given type 'viewty'. """
|
||||
bits = val.bitwidth
|
||||
if isinstance(viewty.dtype, types.Integer):
|
||||
bitcastty = ir.IntType(bits)
|
||||
elif isinstance(viewty.dtype, types.Float):
|
||||
bitcastty = ir.FloatType() if bits == 32 else ir.DoubleType()
|
||||
else:
|
||||
assert 0, "unreachable"
|
||||
|
||||
def codegen(cgctx, builder, typ, args):
|
||||
flt = args[0]
|
||||
return builder.bitcast(flt, bitcastty)
|
||||
retty = viewty.dtype
|
||||
sig = retty(val, viewty)
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def trailing_zeros(typeingctx, src):
|
||||
"""Counts trailing zeros in the binary representation of an integer."""
|
||||
if not isinstance(src, types.Integer):
|
||||
msg = ("trailing_zeros is only defined for integers, but value passed "
|
||||
f"was '{src}'.")
|
||||
raise errors.NumbaTypeError(msg)
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
[src] = args
|
||||
return builder.cttz(src, ir.Constant(ir.IntType(1), 0))
|
||||
return src(src), codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def leading_zeros(typeingctx, src):
|
||||
"""Counts leading zeros in the binary representation of an integer."""
|
||||
if not isinstance(src, types.Integer):
|
||||
msg = ("leading_zeros is only defined for integers, but value passed "
|
||||
f"was '{src}'.")
|
||||
raise errors.NumbaTypeError(msg)
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
[src] = args
|
||||
return builder.ctlz(src, ir.Constant(ir.IntType(1), 0))
|
||||
return src(src), codegen
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
This file provides internal compiler utilities that support certain special
|
||||
operations with tuple and workarounds for limitations enforced in userland.
|
||||
"""
|
||||
|
||||
from numba.core import types, typing, errors
|
||||
from numba.core.cgutils import alloca_once
|
||||
from numba.core.extending import intrinsic
|
||||
|
||||
|
||||
@intrinsic
|
||||
def tuple_setitem(typingctx, tup, idx, val):
|
||||
"""Return a copy of the tuple with item at *idx* replaced with *val*.
|
||||
|
||||
Operation: ``out = tup[:idx] + (val,) + tup[idx + 1:]
|
||||
|
||||
**Warning**
|
||||
|
||||
- No boundchecking.
|
||||
- The dtype of the tuple cannot be changed.
|
||||
*val* is always cast to the existing dtype of the tuple.
|
||||
"""
|
||||
def codegen(context, builder, signature, args):
|
||||
tup, idx, val = args
|
||||
stack = alloca_once(builder, tup.type)
|
||||
builder.store(tup, stack)
|
||||
# Unsafe load on unchecked bounds. Poison value maybe returned.
|
||||
offptr = builder.gep(stack, [idx.type(0), idx], inbounds=True)
|
||||
builder.store(val, offptr)
|
||||
return builder.load(stack)
|
||||
|
||||
sig = tup(tup, idx, tup.dtype)
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def build_full_slice_tuple(tyctx, sz):
|
||||
"""Creates a sz-tuple of full slices."""
|
||||
if not isinstance(sz, types.IntegerLiteral):
|
||||
raise errors.RequireLiteralValue(sz)
|
||||
|
||||
size = int(sz.literal_value)
|
||||
tuple_type = types.UniTuple(dtype=types.slice2_type, count=size)
|
||||
sig = tuple_type(sz)
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
def impl(length, empty_tuple):
|
||||
out = empty_tuple
|
||||
for i in range(length):
|
||||
out = tuple_setitem(out, i, slice(None, None))
|
||||
return out
|
||||
|
||||
inner_argtypes = [types.intp, tuple_type]
|
||||
inner_sig = typing.signature(tuple_type, *inner_argtypes)
|
||||
ll_idx_type = context.get_value_type(types.intp)
|
||||
# Allocate an empty tuple
|
||||
empty_tuple = context.get_constant_undef(tuple_type)
|
||||
inner_args = [ll_idx_type(size), empty_tuple]
|
||||
|
||||
res = context.compile_internal(builder, impl, inner_sig, inner_args)
|
||||
return res
|
||||
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def unpack_single_tuple(tyctx, tup):
|
||||
"""This exists to handle the situation y = (*x,), the interpreter injects a
|
||||
call to it in the case of a single value unpack. It's not possible at
|
||||
interpreting time to differentiate between an unpack on a variable sized
|
||||
container e.g. list and a fixed one, e.g. tuple. This function handles the
|
||||
situation should it arise.
|
||||
"""
|
||||
# See issue #6534
|
||||
if not isinstance(tup, types.BaseTuple):
|
||||
msg = (f"Only tuples are supported when unpacking a single item, "
|
||||
f"got type: {tup}")
|
||||
raise errors.UnsupportedError(msg)
|
||||
|
||||
sig = tup(tup)
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
return args[0] # there's only one tuple and it's a simple pass through
|
||||
return sig, codegen
|
||||
Reference in New Issue
Block a user