This commit is contained in:
2026-04-10 15:06:59 +02:00
parent 3031b7153b
commit e5a4711004
7806 changed files with 1918528 additions and 335 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,542 @@
"""
Implement the cmath module functions.
"""
import cmath
import math
from numba.core.imputils import Registry, impl_ret_untracked
from numba.core import types, cgutils
from numba.core.typing import signature
from numba.cpython import builtins, mathimpl
from numba.core.extending import overload
registry = Registry('cmathimpl')
lower = registry.lower
def is_nan(builder, z):
return builder.fcmp_unordered('uno', z.real, z.imag)
def is_inf(builder, z):
return builder.or_(mathimpl.is_inf(builder, z.real),
mathimpl.is_inf(builder, z.imag))
def is_finite(builder, z):
return builder.and_(mathimpl.is_finite(builder, z.real),
mathimpl.is_finite(builder, z.imag))
@lower(cmath.isnan, types.Complex)
def isnan_float_impl(context, builder, sig, args):
[typ] = sig.args
[value] = args
z = context.make_complex(builder, typ, value=value)
res = is_nan(builder, z)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(cmath.isinf, types.Complex)
def isinf_float_impl(context, builder, sig, args):
[typ] = sig.args
[value] = args
z = context.make_complex(builder, typ, value=value)
res = is_inf(builder, z)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(cmath.isfinite, types.Complex)
def isfinite_float_impl(context, builder, sig, args):
[typ] = sig.args
[value] = args
z = context.make_complex(builder, typ, value=value)
res = is_finite(builder, z)
return impl_ret_untracked(context, builder, sig.return_type, res)
@overload(cmath.rect)
def impl_cmath_rect(r, phi):
if all([isinstance(typ, types.Float) for typ in [r, phi]]):
def impl(r, phi):
if not math.isfinite(phi):
if not r:
# cmath.rect(0, phi={inf, nan}) = 0
return abs(r)
if math.isinf(r):
# cmath.rect(inf, phi={inf, nan}) = inf + j phi
return complex(r, phi)
real = math.cos(phi)
imag = math.sin(phi)
if real == 0. and math.isinf(r):
# 0 * inf would return NaN, we want to keep 0 but xor the sign
real /= r
else:
real *= r
if imag == 0. and math.isinf(r):
# ditto
imag /= r
else:
imag *= r
return complex(real, imag)
return impl
def intrinsic_complex_unary(inner_func):
def wrapper(context, builder, sig, args):
[typ] = sig.args
[value] = args
z = context.make_complex(builder, typ, value=value)
x = z.real
y = z.imag
# Same as above: math.isfinite() is unavailable on 2.x so we precompute
# its value and pass it to the pure Python implementation.
x_is_finite = mathimpl.is_finite(builder, x)
y_is_finite = mathimpl.is_finite(builder, y)
inner_sig = signature(sig.return_type,
*(typ.underlying_float,) * 2 + (types.boolean,) * 2)
res = context.compile_internal(builder, inner_func, inner_sig,
(x, y, x_is_finite, y_is_finite))
return impl_ret_untracked(context, builder, sig, res)
return wrapper
NAN = float('nan')
INF = float('inf')
@lower(cmath.exp, types.Complex)
@intrinsic_complex_unary
def exp_impl(x, y, x_is_finite, y_is_finite):
"""cmath.exp(x + y j)"""
if x_is_finite:
if y_is_finite:
c = math.cos(y)
s = math.sin(y)
r = math.exp(x)
return complex(r * c, r * s)
else:
return complex(NAN, NAN)
elif math.isnan(x):
if y:
return complex(x, x) # nan + j nan
else:
return complex(x, y) # nan + 0j
elif x > 0.0:
# x == +inf
if y_is_finite:
real = math.cos(y)
imag = math.sin(y)
# Avoid NaNs if math.cos(y) or math.sin(y) == 0
# (e.g. cmath.exp(inf + 0j) == inf + 0j)
if real != 0:
real *= x
if imag != 0:
imag *= x
return complex(real, imag)
else:
return complex(x, NAN)
else:
# x == -inf
if y_is_finite:
r = math.exp(x)
c = math.cos(y)
s = math.sin(y)
return complex(r * c, r * s)
else:
r = 0
return complex(r, r)
@lower(cmath.log, types.Complex)
@intrinsic_complex_unary
def log_impl(x, y, x_is_finite, y_is_finite):
"""cmath.log(x + y j)"""
a = math.log(math.hypot(x, y))
b = math.atan2(y, x)
return complex(a, b)
@lower(cmath.log, types.Complex, types.Complex)
def log_base_impl(context, builder, sig, args):
"""cmath.log(z, base)"""
[z, base] = args
def log_base(z, base):
return cmath.log(z) / cmath.log(base)
res = context.compile_internal(builder, log_base, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@overload(cmath.log10)
def impl_cmath_log10(z):
if not isinstance(z, types.Complex):
return
LN_10 = 2.302585092994045684
def log10_impl(z):
"""cmath.log10(z)"""
z = cmath.log(z)
# This formula gives better results on +/-inf than cmath.log(z, 10)
# See http://bugs.python.org/issue22544
return complex(z.real / LN_10, z.imag / LN_10)
return log10_impl
@overload(cmath.phase)
def phase_impl(x):
"""cmath.phase(x + y j)"""
if not isinstance(x, types.Complex):
return
def impl(x):
return math.atan2(x.imag, x.real)
return impl
@overload(cmath.polar)
def polar_impl(x):
if not isinstance(x, types.Complex):
return
def impl(x):
r, i = x.real, x.imag
return math.hypot(r, i), math.atan2(i, r)
return impl
@lower(cmath.sqrt, types.Complex)
def sqrt_impl(context, builder, sig, args):
# We risk spurious overflow for components >= FLT_MAX / (1 + sqrt(2)).
SQRT2 = 1.414213562373095048801688724209698079E0
ONE_PLUS_SQRT2 = (1. + SQRT2)
theargflt = sig.args[0].underlying_float
# Get a type specific maximum value so scaling for overflow is based on that
MAX = mathimpl.DBL_MAX if theargflt.bitwidth == 64 else mathimpl.FLT_MAX
# THRES will be double precision, should not impact typing as it's just
# used for comparison, there *may* be a few values near THRES which
# deviate from e.g. NumPy due to rounding that occurs in the computation
# of this value in the case of a 32bit argument.
THRES = MAX / ONE_PLUS_SQRT2
def sqrt_impl(z):
"""cmath.sqrt(z)"""
# This is NumPy's algorithm, see npy_csqrt() in npy_math_complex.c.src
a = z.real
b = z.imag
if a == 0.0 and b == 0.0:
return complex(abs(b), b)
if math.isinf(b):
return complex(abs(b), b)
if math.isnan(a):
return complex(a, a)
if math.isinf(a):
if a < 0.0:
return complex(abs(b - b), math.copysign(a, b))
else:
return complex(a, math.copysign(b - b, b))
# The remaining special case (b is NaN) is handled just fine by
# the normal code path below.
# Scale to avoid overflow
if abs(a) >= THRES or abs(b) >= THRES:
a *= 0.25
b *= 0.25
scale = True
else:
scale = False
# Algorithm 312, CACM vol 10, Oct 1967
if a >= 0:
t = math.sqrt((a + math.hypot(a, b)) * 0.5)
real = t
imag = b / (2 * t)
else:
t = math.sqrt((-a + math.hypot(a, b)) * 0.5)
real = abs(b) / (2 * t)
imag = math.copysign(t, b)
# Rescale
if scale:
return complex(real * 2, imag)
else:
return complex(real, imag)
res = context.compile_internal(builder, sqrt_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@lower(cmath.cos, types.Complex)
def cos_impl(context, builder, sig, args):
def cos_impl(z):
"""cmath.cos(z) = cmath.cosh(z j)"""
return cmath.cosh(complex(-z.imag, z.real))
res = context.compile_internal(builder, cos_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@overload(cmath.cosh)
def impl_cmath_cosh(z):
if not isinstance(z, types.Complex):
return
def cosh_impl(z):
"""cmath.cosh(z)"""
x = z.real
y = z.imag
if math.isinf(x):
if math.isnan(y):
# x = +inf, y = NaN => cmath.cosh(x + y j) = inf + Nan * j
real = abs(x)
imag = y
elif y == 0.0:
# x = +inf, y = 0 => cmath.cosh(x + y j) = inf + 0j
real = abs(x)
imag = y
else:
real = math.copysign(x, math.cos(y))
imag = math.copysign(x, math.sin(y))
if x < 0.0:
# x = -inf => negate imaginary part of result
imag = -imag
return complex(real, imag)
return complex(math.cos(y) * math.cosh(x),
math.sin(y) * math.sinh(x))
return cosh_impl
@lower(cmath.sin, types.Complex)
def sin_impl(context, builder, sig, args):
def sin_impl(z):
"""cmath.sin(z) = -j * cmath.sinh(z j)"""
r = cmath.sinh(complex(-z.imag, z.real))
return complex(r.imag, -r.real)
res = context.compile_internal(builder, sin_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@overload(cmath.sinh)
def impl_cmath_sinh(z):
if not isinstance(z, types.Complex):
return
def sinh_impl(z):
"""cmath.sinh(z)"""
x = z.real
y = z.imag
if math.isinf(x):
if math.isnan(y):
# x = +/-inf, y = NaN => cmath.sinh(x + y j) = x + NaN * j
real = x
imag = y
else:
real = math.cos(y)
imag = math.sin(y)
if real != 0.:
real *= x
if imag != 0.:
imag *= abs(x)
return complex(real, imag)
return complex(math.cos(y) * math.sinh(x),
math.sin(y) * math.cosh(x))
return sinh_impl
@lower(cmath.tan, types.Complex)
def tan_impl(context, builder, sig, args):
def tan_impl(z):
"""cmath.tan(z) = -j * cmath.tanh(z j)"""
r = cmath.tanh(complex(-z.imag, z.real))
return complex(r.imag, -r.real)
res = context.compile_internal(builder, tan_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@overload(cmath.tanh)
def impl_cmath_tanh(z):
if not isinstance(z, types.Complex):
return
def tanh_impl(z):
"""cmath.tanh(z)"""
x = z.real
y = z.imag
if math.isinf(x):
real = math.copysign(1., x)
if math.isinf(y):
imag = 0.
else:
imag = math.copysign(0., math.sin(2. * y))
return complex(real, imag)
# This is CPython's algorithm (see c_tanh() in cmathmodule.c).
# XXX how to force float constants into single precision?
tx = math.tanh(x)
ty = math.tan(y)
cx = 1. / math.cosh(x)
txty = tx * ty
denom = 1. + txty * txty
return complex(
tx * (1. + ty * ty) / denom,
((ty / denom) * cx) * cx)
return tanh_impl
@lower(cmath.acos, types.Complex)
def acos_impl(context, builder, sig, args):
LN_4 = math.log(4)
THRES = mathimpl.FLT_MAX / 4
def acos_impl(z):
"""cmath.acos(z)"""
# CPython's algorithm (see c_acos() in cmathmodule.c)
if abs(z.real) > THRES or abs(z.imag) > THRES:
# Avoid unnecessary overflow for large arguments
# (also handles infinities gracefully)
real = math.atan2(abs(z.imag), z.real)
imag = math.copysign(
math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4,
-z.imag)
return complex(real, imag)
else:
s1 = cmath.sqrt(complex(1. - z.real, -z.imag))
s2 = cmath.sqrt(complex(1. + z.real, z.imag))
real = 2. * math.atan2(s1.real, s2.real)
imag = math.asinh(s2.real * s1.imag - s2.imag * s1.real)
return complex(real, imag)
res = context.compile_internal(builder, acos_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@overload(cmath.acosh)
def impl_cmath_acosh(z):
if not isinstance(z, types.Complex):
return
LN_4 = math.log(4)
THRES = mathimpl.FLT_MAX / 4
def acosh_impl(z):
"""cmath.acosh(z)"""
# CPython's algorithm (see c_acosh() in cmathmodule.c)
if abs(z.real) > THRES or abs(z.imag) > THRES:
# Avoid unnecessary overflow for large arguments
# (also handles infinities gracefully)
real = math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4
imag = math.atan2(z.imag, z.real)
return complex(real, imag)
else:
s1 = cmath.sqrt(complex(z.real - 1., z.imag))
s2 = cmath.sqrt(complex(z.real + 1., z.imag))
real = math.asinh(s1.real * s2.real + s1.imag * s2.imag)
imag = 2. * math.atan2(s1.imag, s2.real)
return complex(real, imag)
# Condensed formula (NumPy)
#return cmath.log(z + cmath.sqrt(z + 1.) * cmath.sqrt(z - 1.))
return acosh_impl
@lower(cmath.asinh, types.Complex)
def asinh_impl(context, builder, sig, args):
LN_4 = math.log(4)
THRES = mathimpl.FLT_MAX / 4
def asinh_impl(z):
"""cmath.asinh(z)"""
# CPython's algorithm (see c_asinh() in cmathmodule.c)
if abs(z.real) > THRES or abs(z.imag) > THRES:
real = math.copysign(
math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4,
z.real)
imag = math.atan2(z.imag, abs(z.real))
return complex(real, imag)
else:
s1 = cmath.sqrt(complex(1. + z.imag, -z.real))
s2 = cmath.sqrt(complex(1. - z.imag, z.real))
real = math.asinh(s1.real * s2.imag - s2.real * s1.imag)
imag = math.atan2(z.imag, s1.real * s2.real - s1.imag * s2.imag)
return complex(real, imag)
res = context.compile_internal(builder, asinh_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@lower(cmath.asin, types.Complex)
def asin_impl(context, builder, sig, args):
def asin_impl(z):
"""cmath.asin(z) = -j * cmath.asinh(z j)"""
r = cmath.asinh(complex(-z.imag, z.real))
return complex(r.imag, -r.real)
res = context.compile_internal(builder, asin_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@lower(cmath.atan, types.Complex)
def atan_impl(context, builder, sig, args):
def atan_impl(z):
"""cmath.atan(z) = -j * cmath.atanh(z j)"""
r = cmath.atanh(complex(-z.imag, z.real))
if math.isinf(z.real) and math.isnan(z.imag):
# XXX this is odd but necessary
return complex(r.imag, r.real)
else:
return complex(r.imag, -r.real)
res = context.compile_internal(builder, atan_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)
@lower(cmath.atanh, types.Complex)
def atanh_impl(context, builder, sig, args):
LN_4 = math.log(4)
THRES_LARGE = math.sqrt(mathimpl.FLT_MAX / 4)
THRES_SMALL = math.sqrt(mathimpl.FLT_MIN)
PI_12 = math.pi / 2
def atanh_impl(z):
"""cmath.atanh(z)"""
# CPython's algorithm (see c_atanh() in cmathmodule.c)
if z.real < 0.:
# Reduce to case where z.real >= 0., using atanh(z) = -atanh(-z).
negate = True
z = -z
else:
negate = False
ay = abs(z.imag)
if math.isnan(z.real) or z.real > THRES_LARGE or ay > THRES_LARGE:
if math.isinf(z.imag):
real = math.copysign(0., z.real)
elif math.isinf(z.real):
real = 0.
else:
# may be safe from overflow, depending on hypot's implementation...
h = math.hypot(z.real * 0.5, z.imag * 0.5)
real = z.real/4./h/h
imag = -math.copysign(PI_12, -z.imag)
elif z.real == 1. and ay < THRES_SMALL:
# C99 standard says: atanh(1+/-0.) should be inf +/- 0j
if ay == 0.:
real = INF
imag = z.imag
else:
real = -math.log(math.sqrt(ay) /
math.sqrt(math.hypot(ay, 2.)))
imag = math.copysign(math.atan2(2., -ay) / 2, z.imag)
else:
sqay = ay * ay
zr1 = 1 - z.real
real = math.log1p(4. * z.real / (zr1 * zr1 + sqay)) * 0.25
imag = -math.atan2(-2. * z.imag,
zr1 * (1 + z.real) - sqay) * 0.5
if math.isnan(z.imag):
imag = NAN
if negate:
return complex(-real, -imag)
else:
return complex(real, imag)
res = context.compile_internal(builder, atanh_impl, sig, args)
return impl_ret_untracked(context, builder, sig, res)

View File

@@ -0,0 +1,89 @@
"""
Implementation of enums.
"""
import operator
from numba.core.imputils import (lower_builtin, lower_getattr,
lower_getattr_generic, lower_cast,
lower_constant, impl_ret_untracked)
from numba.core import types
from numba.core.extending import overload_method
@lower_builtin(operator.eq, types.EnumMember, types.EnumMember)
def enum_eq(context, builder, sig, args):
tu, tv = sig.args
u, v = args
res = context.generic_compare(builder, operator.eq,
(tu.dtype, tv.dtype), (u, v))
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_builtin(operator.is_, types.EnumMember, types.EnumMember)
def enum_is(context, builder, sig, args):
tu, tv = sig.args
u, v = args
if tu == tv:
res = context.generic_compare(builder, operator.eq,
(tu.dtype, tv.dtype), (u, v))
else:
res = context.get_constant(sig.return_type, False)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_builtin(operator.ne, types.EnumMember, types.EnumMember)
def enum_ne(context, builder, sig, args):
tu, tv = sig.args
u, v = args
res = context.generic_compare(builder, operator.ne,
(tu.dtype, tv.dtype), (u, v))
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_getattr(types.EnumMember, 'value')
def enum_value(context, builder, ty, val):
return val
@lower_cast(types.IntEnumMember, types.Integer)
def int_enum_to_int(context, builder, fromty, toty, val):
"""
Convert an IntEnum member to its raw integer value.
"""
return context.cast(builder, val, fromty.dtype, toty)
@lower_constant(types.EnumMember)
def enum_constant(context, builder, ty, pyval):
"""
Return a LLVM constant representing enum member *pyval*.
"""
return context.get_constant_generic(builder, ty.dtype, pyval.value)
@lower_getattr_generic(types.EnumClass)
def enum_class_getattr(context, builder, ty, val, attr):
"""
Return an enum member by attribute name.
"""
member = getattr(ty.instance_class, attr)
return context.get_constant_generic(builder, ty.dtype, member.value)
@lower_builtin('static_getitem', types.EnumClass, types.StringLiteral)
def enum_class_getitem(context, builder, sig, args):
"""
Return an enum member by index name.
"""
enum_cls_typ, idx = sig.args
member = enum_cls_typ.instance_class[idx.literal_value]
return context.get_constant_generic(builder, enum_cls_typ.dtype,
member.value)
@overload_method(types.IntEnumMember, '__hash__')
def intenum_hash(val):
# uses the hash of the value, for IntEnums this will be int.__hash__
def hash_impl(val):
return hash(val.value)
return hash_impl

View File

@@ -0,0 +1,743 @@
"""
Hash implementations for Numba types
"""
import math
import numpy as np
import sys
import ctypes
import warnings
from collections import namedtuple
import llvmlite.binding as ll
from llvmlite import ir
from numba import literal_unroll
from numba.core.extending import (
overload, overload_method, intrinsic, register_jitable)
from numba.core import errors
from numba.core import types
from numba.core.unsafe.bytes import grab_byte, grab_uint64_t
from numba.cpython.randomimpl import (const_int, get_next_int, get_next_int32,
get_state_ptr)
# This is Py_hash_t, which is a Py_ssize_t, which has sizeof(size_t):
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Include/pyport.h#L91-L96 # noqa: E501
_hash_width = sys.hash_info.width
_Py_hash_t = getattr(types, 'int%s' % _hash_width)
_Py_uhash_t = getattr(types, 'uint%s' % _hash_width)
# Constants from CPython source, obtained by various means:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Include/pyhash.h # noqa: E501
_PyHASH_INF = sys.hash_info.inf
_PyHASH_NAN = sys.hash_info.nan
_PyHASH_MODULUS = _Py_uhash_t(sys.hash_info.modulus)
_PyHASH_BITS = 31 if types.intp.bitwidth == 32 else 61 # mersenne primes
_PyHASH_MULTIPLIER = 0xf4243 # 1000003UL
_PyHASH_IMAG = _PyHASH_MULTIPLIER
_PyLong_SHIFT = sys.int_info.bits_per_digit
_Py_HASH_CUTOFF = sys.hash_info.cutoff
_Py_hashfunc_name = sys.hash_info.algorithm
# This stub/overload pair are used to force branch pruning to remove the dead
# branch based on the potential `None` type of the hash_func which works better
# if the predicate for the prune in an ir.Arg. The obj is an arg to allow for
# a custom error message.
def _defer_hash(hash_func):
pass
@overload(_defer_hash)
def ol_defer_hash(obj, hash_func):
err_msg = f"unhashable type: '{obj}'"
def impl(obj, hash_func):
if hash_func is None:
raise TypeError(err_msg)
else:
return hash_func()
return impl
# hash(obj) is implemented by calling obj.__hash__()
@overload(hash)
def hash_overload(obj):
attempt_generic_msg = ("No __hash__ is defined for object of type "
f"'{obj}' and a generic hash() cannot be "
"performed as there is no suitable object "
"represention in Numba compiled code!")
def impl(obj):
if hasattr(obj, '__hash__'):
return _defer_hash(obj, getattr(obj, '__hash__'))
else:
raise TypeError(attempt_generic_msg)
return impl
@register_jitable
def process_return(val):
asint = _Py_hash_t(val)
if (asint == int(-1)):
asint = int(-2)
return asint
# This is a translation of CPython's _Py_HashDouble:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L34-L129 # noqa: E501
# NOTE: In Python 3.10 hash of nan is now hash of the pointer to the PyObject
# containing said nan. Numba cannot replicate this as there is no object, so it
# elects to replicate the behaviour i.e. hash of nan is something "unique" which
# satisfies https://bugs.python.org/issue43475.
@register_jitable(locals={'x': _Py_uhash_t,
'y': _Py_uhash_t,
'm': types.double,
'e': types.intc,
'sign': types.intc,
'_PyHASH_MODULUS': _Py_uhash_t,
'_PyHASH_BITS': types.intc})
def _Py_HashDouble(v):
if not np.isfinite(v):
if (np.isinf(v)):
if (v > 0):
return _PyHASH_INF
else:
return -_PyHASH_INF
else:
# Python 3.10 does not use `_PyHASH_NAN`.
# https://github.com/python/cpython/blob/2c4792264f9218692a1bd87398a60591f756b171/Python/pyhash.c#L102 # noqa: E501
# Numba returns a pseudo-random number to reflect the spirit of the
# change.
x = _prng_random_hash()
return process_return(x)
m, e = math.frexp(v)
sign = 1
if (m < 0):
sign = -1
m = -m
# process 28 bits at a time; this should work well both for binary
# and hexadecimal floating point.
x = 0
while (m):
x = ((x << 28) & _PyHASH_MODULUS) | x >> (_PyHASH_BITS - 28)
m *= 268435456.0 # /* 2**28 */
e -= 28
y = int(m) # /* pull out integer part */
m -= y
x += y
if x >= _PyHASH_MODULUS:
x -= _PyHASH_MODULUS
# /* adjust for the exponent; first reduce it modulo _PyHASH_BITS */
if e >= 0:
e = e % _PyHASH_BITS
else:
e = _PyHASH_BITS - 1 - ((-1 - e) % _PyHASH_BITS)
x = ((x << e) & _PyHASH_MODULUS) | x >> (_PyHASH_BITS - e)
x = x * sign
return process_return(x)
@intrinsic
def _fpext(tyctx, val):
def impl(cgctx, builder, signature, args):
val = args[0]
return builder.fpext(val, ir.DoubleType())
sig = types.float64(types.float32)
return sig, impl
@intrinsic
def _prng_random_hash(tyctx):
def impl(cgctx, builder, signature, args):
state_ptr = get_state_ptr(cgctx, builder, "internal")
bits = const_int(_hash_width)
# Why not just use get_next_int() with the correct bitwidth?
# get_next_int() always returns an i64, because the bitwidth it is
# passed may not be a compile-time constant, so it needs to allocate
# the largest unit of storage that may be required. Therefore, if the
# hash width is 32, then we need to use get_next_int32() to ensure we
# don't return a wider-than-expected hash, even if everything above
# the low 32 bits would have been zero.
if _hash_width == 32:
value = get_next_int32(cgctx, builder, state_ptr)
else:
value = get_next_int(cgctx, builder, state_ptr, bits, False)
return value
sig = _Py_hash_t()
return sig, impl
# This is a translation of CPython's long_hash, but restricted to the numerical
# domain reachable by int64/uint64 (i.e. no BigInt like support):
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/longobject.c#L2934-L2989 # noqa: E501
# obdigit is a uint32_t which is typedef'd to digit
# int32_t is typedef'd to sdigit
@register_jitable(locals={'x': _Py_uhash_t,
'p1': _Py_uhash_t,
'p2': _Py_uhash_t,
'p3': _Py_uhash_t,
'p4': _Py_uhash_t,
'_PyHASH_MODULUS': _Py_uhash_t,
'_PyHASH_BITS': types.int32,
'_PyLong_SHIFT': types.int32,})
def _long_impl(val):
# This function assumes val came from a long int repr with val being a
# uint64_t this means having to split the input into PyLong_SHIFT size
# chunks in an unsigned hash wide type, max numba can handle is a 64bit int
# mask to select low _PyLong_SHIFT bits
_tmp_shift = 32 - _PyLong_SHIFT
mask_shift = (~types.uint32(0x0)) >> _tmp_shift
# a 64bit wide max means Numba only needs 3 x 30 bit values max,
# or 5 x 15 bit values max on 32bit platforms
i = (64 // _PyLong_SHIFT) + 1
# alg as per hash_long
x = 0
p3 = (_PyHASH_BITS - _PyLong_SHIFT)
for idx in range(i - 1, -1, -1):
p1 = x << _PyLong_SHIFT
p2 = p1 & _PyHASH_MODULUS
p4 = x >> p3
x = p2 | p4
# the shift and mask splits out the `ob_digit` parts of a Long repr
x += types.uint32((val >> idx * _PyLong_SHIFT) & mask_shift)
if x >= _PyHASH_MODULUS:
x -= _PyHASH_MODULUS
return _Py_hash_t(x)
# This has no CPython equivalent, CPython uses long_hash.
@overload_method(types.Integer, '__hash__')
@overload_method(types.Boolean, '__hash__')
def int_hash(val):
_HASH_I64_MIN = -2 if sys.maxsize <= 2 ** 32 else -4
_SIGNED_MIN = types.int64(-0x8000000000000000)
# Find a suitable type to hold a "big" value, i.e. iinfo(ty).min/max
# this is to ensure e.g. int32.min is handled ok as it's abs() is its value
_BIG = types.int64 if getattr(val, 'signed', False) else types.uint64
# this is a bit involved due to the CPython repr of ints
def impl(val):
# If the magnitude is under PyHASH_MODULUS, just return the
# value val as the hash, couple of special cases if val == val:
# 1. it's 0, in which case return 0
# 2. it's signed int minimum value, return the value CPython computes
# but Numba cannot as there's no type wide enough to hold the shifts.
#
# If the magnitude is greater than PyHASH_MODULUS then... if the value
# is negative then negate it switch the sign on the hash once computed
# and use the standard wide unsigned hash implementation
val = _BIG(val)
mag = abs(val)
if mag < _PyHASH_MODULUS:
if val == 0:
ret = 0
elif val == _SIGNED_MIN: # e.g. int64 min, -0x8000000000000000
ret = _Py_hash_t(_HASH_I64_MIN)
else:
ret = _Py_hash_t(val)
else:
needs_negate = False
if val < 0:
val = -val
needs_negate = True
ret = _long_impl(val)
if needs_negate:
ret = -ret
return process_return(ret)
return impl
# This is a translation of CPython's float_hash:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/floatobject.c#L528-L532 # noqa: E501
@overload_method(types.Float, '__hash__')
def float_hash(val):
if val.bitwidth == 64:
def impl(val):
hashed = _Py_HashDouble(val)
return hashed
else:
def impl(val):
# widen the 32bit float to 64bit
fpextended = np.float64(_fpext(val))
hashed = _Py_HashDouble(fpextended)
return hashed
return impl
# This is a translation of CPython's complex_hash:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/complexobject.c#L408-L428 # noqa: E501
@overload_method(types.Complex, '__hash__')
def complex_hash(val):
def impl(val):
hashreal = hash(val.real)
hashimag = hash(val.imag)
# Note: if the imaginary part is 0, hashimag is 0 now,
# so the following returns hashreal unchanged. This is
# important because numbers of different types that
# compare equal must have the same hash value, so that
# hash(x + 0*j) must equal hash(x).
combined = hashreal + _PyHASH_IMAG * hashimag
return process_return(combined)
return impl
# Python 3.8 strengthened its hash alg for tuples.
# This is a translation of CPython's tuplehash for Python >=3.8
# https://github.com/python/cpython/blob/b738237d6792acba85b1f6e6c8993a812c7fd815/Objects/tupleobject.c#L338-L391 # noqa: E501
# These consts are needed for this alg variant, they are from:
# https://github.com/python/cpython/blob/b738237d6792acba85b1f6e6c8993a812c7fd815/Objects/tupleobject.c#L353-L363 # noqa: E501
if _Py_uhash_t.bitwidth // 8 > 4:
_PyHASH_XXPRIME_1 = _Py_uhash_t(11400714785074694791)
_PyHASH_XXPRIME_2 = _Py_uhash_t(14029467366897019727)
_PyHASH_XXPRIME_5 = _Py_uhash_t(2870177450012600261)
@register_jitable(locals={'x': types.uint64})
def _PyHASH_XXROTATE(x):
# Rotate left 31 bits
return ((x << types.uint64(31)) | (x >> types.uint64(33)))
else:
_PyHASH_XXPRIME_1 = _Py_uhash_t(2654435761)
_PyHASH_XXPRIME_2 = _Py_uhash_t(2246822519)
_PyHASH_XXPRIME_5 = _Py_uhash_t(374761393)
@register_jitable(locals={'x': types.uint64})
def _PyHASH_XXROTATE(x):
# Rotate left 13 bits
return ((x << types.uint64(13)) | (x >> types.uint64(19)))
@register_jitable(locals={'acc': _Py_uhash_t, 'lane': _Py_uhash_t,
'_PyHASH_XXPRIME_5': _Py_uhash_t,
'_PyHASH_XXPRIME_1': _Py_uhash_t,
'tl': _Py_uhash_t})
def _tuple_hash(tup):
tl = len(tup)
acc = _PyHASH_XXPRIME_5
for x in literal_unroll(tup):
lane = hash(x)
if lane == _Py_uhash_t(-1):
return -1
acc += lane * _PyHASH_XXPRIME_2
acc = _PyHASH_XXROTATE(acc)
acc *= _PyHASH_XXPRIME_1
acc += tl ^ (_PyHASH_XXPRIME_5 ^ _Py_uhash_t(3527539))
if acc == _Py_uhash_t(-1):
return process_return(1546275796)
return process_return(acc)
@overload_method(types.BaseTuple, '__hash__')
def tuple_hash(val):
def impl(val):
return _tuple_hash(val)
return impl
# ------------------------------------------------------------------------------
# String/bytes hashing needs hashseed info, this is from:
# https://stackoverflow.com/a/41088757
# with thanks to Martijn Pieters
#
# Developer note:
# CPython makes use of an internal "hashsecret" which is essentially a struct
# containing some state that is set on CPython initialization and contains magic
# numbers used particularly in unicode/string hashing. This code binds to the
# Python runtime libraries in use by the current process and reads the
# "hashsecret" state so that it can be used by Numba. As this is done at runtime
# the behaviour and influence of the PYTHONHASHSEED environment variable is
# accommodated.
from ctypes import ( # noqa
c_size_t,
c_ubyte,
c_uint64,
pythonapi,
Structure,
Union,
) # noqa
class FNV(Structure):
_fields_ = [
('prefix', c_size_t),
('suffix', c_size_t)
]
class SIPHASH(Structure):
_fields_ = [
('k0', c_uint64),
('k1', c_uint64),
]
class DJBX33A(Structure):
_fields_ = [
('padding', c_ubyte * 16),
('suffix', c_size_t),
]
class EXPAT(Structure):
_fields_ = [
('padding', c_ubyte * 16),
('hashsalt', c_size_t),
]
class _Py_HashSecret_t(Union):
_fields_ = [
# ensure 24 bytes
('uc', c_ubyte * 24),
# two Py_hash_t for FNV
('fnv', FNV),
# two uint64 for SipHash24
('siphash', SIPHASH),
# a different (!) Py_hash_t for small string optimization
('djbx33a', DJBX33A),
('expat', EXPAT),
]
_hashsecret_entry = namedtuple('_hashsecret_entry', ['symbol', 'value'])
# Only a few members are needed at present
def _build_hashsecret():
"""Read hash secret from the Python process
Returns
-------
info : dict
- keys are "djbx33a_suffix", "siphash_k0", siphash_k1".
- values are the namedtuple[symbol:str, value:int]
"""
# Read hashsecret and inject it into the LLVM symbol map under the
# prefix `_numba_hashsecret_`.
pyhashsecret = _Py_HashSecret_t.in_dll(pythonapi, '_Py_HashSecret')
info = {}
def inject(name, val):
symbol_name = "_numba_hashsecret_{}".format(name)
val = ctypes.c_uint64(val)
addr = ctypes.addressof(val)
ll.add_symbol(symbol_name, addr)
info[name] = _hashsecret_entry(symbol=symbol_name, value=val)
inject('djbx33a_suffix', pyhashsecret.djbx33a.suffix)
inject('siphash_k0', pyhashsecret.siphash.k0)
inject('siphash_k1', pyhashsecret.siphash.k1)
return info
_hashsecret = _build_hashsecret()
# ------------------------------------------------------------------------------
if _Py_hashfunc_name in ('siphash13', 'siphash24', 'fnv'):
# Check for use of the FNV hashing alg, warn users that it's not implemented
# and functionality relying of properties derived from hashing will be fine
# but hash values themselves are likely to be different.
if _Py_hashfunc_name == 'fnv':
msg = ("FNV hashing is not implemented in Numba. See PEP 456 "
"https://www.python.org/dev/peps/pep-0456/ "
"for rationale over not using FNV. Numba will continue to work, "
"but hashes for built in types will be computed using "
"siphash24. This will permit e.g. dictionaries to continue to "
"behave as expected, however anything relying on the value of "
"the hash opposed to hash as a derived property is likely to "
"not work as expected.")
warnings.warn(msg)
# This is a translation of CPython's siphash24 function:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L287-L413 # noqa: E501
# and also, since Py 3.11, a translation of CPython's siphash13 function:
# https://github.com/python/cpython/blob/9dda9020abcf0d51d59b283a89c58c8e1fb0f574/Python/pyhash.c#L376-L424
# the only differences are in the use of SINGLE_ROUND in siphash13 vs.
# DOUBLE_ROUND in siphash24, and that siphash13 has an extra "ROUND" applied
# just before the final XORing of components to create the return value.
# /* *********************************************************************
# <MIT License>
# Copyright (c) 2013 Marek Majkowski <marek@popcount.org>
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# </MIT License>
# Original location:
# https://github.com/majek/csiphash/
# Solution inspired by code from:
# Samuel Neves (supercop/crypto_auth/siphash24/little)
#djb (supercop/crypto_auth/siphash24/little2)
# Jean-Philippe Aumasson (https://131002.net/siphash/siphash24.c)
# Modified for Python by Christian Heimes:
# - C89 / MSVC compatibility
# - _rotl64() on Windows
# - letoh64() fallback
# */
@register_jitable(locals={'x': types.uint64,
'b': types.uint64, })
def _ROTATE(x, b):
return types.uint64(((x) << (b)) | ((x) >> (types.uint64(64) - (b))))
@register_jitable(locals={'a': types.uint64,
'b': types.uint64,
'c': types.uint64,
'd': types.uint64,
's': types.uint64,
't': types.uint64, })
def _HALF_ROUND(a, b, c, d, s, t):
a += b
c += d
b = _ROTATE(b, s) ^ a
d = _ROTATE(d, t) ^ c
a = _ROTATE(a, 32)
return a, b, c, d
@register_jitable(locals={'v0': types.uint64,
'v1': types.uint64,
'v2': types.uint64,
'v3': types.uint64, })
def _SINGLE_ROUND(v0, v1, v2, v3):
v0, v1, v2, v3 = _HALF_ROUND(v0, v1, v2, v3, 13, 16)
v2, v1, v0, v3 = _HALF_ROUND(v2, v1, v0, v3, 17, 21)
return v0, v1, v2, v3
@register_jitable(locals={'v0': types.uint64,
'v1': types.uint64,
'v2': types.uint64,
'v3': types.uint64, })
def _DOUBLE_ROUND(v0, v1, v2, v3):
v0, v1, v2, v3 = _SINGLE_ROUND(v0, v1, v2, v3)
v0, v1, v2, v3 = _SINGLE_ROUND(v0, v1, v2, v3)
return v0, v1, v2, v3
def _gen_siphash(alg):
if alg == 'siphash13':
_ROUNDER = _SINGLE_ROUND
_EXTRA_ROUND = True
elif alg == 'siphash24':
_ROUNDER = _DOUBLE_ROUND
_EXTRA_ROUND = False
else:
assert 0, 'unreachable'
@register_jitable(locals={'v0': types.uint64,
'v1': types.uint64,
'v2': types.uint64,
'v3': types.uint64,
'b': types.uint64,
'mi': types.uint64,
't': types.uint64,
'mask': types.uint64,
'jmp': types.uint64,
'ohexefef': types.uint64})
def _siphash(k0, k1, src, src_sz):
b = types.uint64(src_sz) << 56
v0 = k0 ^ types.uint64(0x736f6d6570736575)
v1 = k1 ^ types.uint64(0x646f72616e646f6d)
v2 = k0 ^ types.uint64(0x6c7967656e657261)
v3 = k1 ^ types.uint64(0x7465646279746573)
idx = 0
while (src_sz >= 8):
mi = grab_uint64_t(src, idx)
idx += 1
src_sz -= 8
v3 ^= mi
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
v0 ^= mi
# this is the switch fallthrough:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L390-L400 # noqa: E501
t = types.uint64(0x0)
boffset = idx * 8
ohexefef = types.uint64(0xff)
if src_sz >= 7:
jmp = (6 * 8)
mask = ~types.uint64(ohexefef << jmp)
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 6))
<< jmp)
if src_sz >= 6:
jmp = (5 * 8)
mask = ~types.uint64(ohexefef << jmp)
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 5))
<< jmp)
if src_sz >= 5:
jmp = (4 * 8)
mask = ~types.uint64(ohexefef << jmp)
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 4))
<< jmp)
if src_sz >= 4:
t &= types.uint64(0xffffffff00000000)
for i in range(4):
jmp = i * 8
mask = ~types.uint64(ohexefef << jmp)
t = (t & mask) | (types.uint64(grab_byte(src, boffset + i))
<< jmp)
if src_sz >= 3:
jmp = (2 * 8)
mask = ~types.uint64(ohexefef << jmp)
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 2))
<< jmp)
if src_sz >= 2:
jmp = (1 * 8)
mask = ~types.uint64(ohexefef << jmp)
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 1))
<< jmp)
if src_sz >= 1:
mask = ~(ohexefef)
t = (t & mask) | (types.uint64(grab_byte(src, boffset + 0)))
b |= t
v3 ^= b
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
v0 ^= b
v2 ^= ohexefef
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
if _EXTRA_ROUND:
v0, v1, v2, v3 = _ROUNDER(v0, v1, v2, v3)
t = (v0 ^ v1) ^ (v2 ^ v3)
return t
return _siphash
_siphash13 = _gen_siphash('siphash13')
_siphash24 = _gen_siphash('siphash24')
_siphasher = _siphash13 if _Py_hashfunc_name == 'siphash13' else _siphash24
else:
msg = "Unsupported hashing algorithm in use %s" % _Py_hashfunc_name
raise ValueError(msg)
@intrinsic
def _inject_hashsecret_read(tyctx, name):
"""Emit code to load the hashsecret.
"""
if not isinstance(name, types.StringLiteral):
raise errors.TypingError("requires literal string")
sym = _hashsecret[name.literal_value].symbol
resty = types.uint64
sig = resty(name)
def impl(cgctx, builder, sig, args):
mod = builder.module
try:
# Search for existing global
gv = mod.get_global(sym)
except KeyError:
# Inject the symbol if not already exist.
gv = ir.GlobalVariable(mod, ir.IntType(64), name=sym)
v = builder.load(gv)
return v
return sig, impl
def _load_hashsecret(name):
return _hashsecret[name].value
@overload(_load_hashsecret)
def _impl_load_hashsecret(name):
def imp(name):
return _inject_hashsecret_read(name)
return imp
# This is a translation of CPythons's _Py_HashBytes:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Python/pyhash.c#L145-L191 # noqa: E501
@register_jitable(locals={'_hash': _Py_uhash_t})
def _Py_HashBytes(val, _len):
if (_len == 0):
return process_return(0)
if (_len < _Py_HASH_CUTOFF):
# TODO: this branch needs testing, needs a CPython setup for it!
# /* Optimize hashing of very small strings with inline DJBX33A. */
_hash = _Py_uhash_t(5381) # /* DJBX33A starts with 5381 */
for idx in range(_len):
_hash = ((_hash << 5) + _hash) + np.uint8(grab_byte(val, idx))
_hash ^= _len
_hash ^= _load_hashsecret('djbx33a_suffix')
else:
tmp = _siphasher(types.uint64(_load_hashsecret('siphash_k0')),
types.uint64(_load_hashsecret('siphash_k1')),
val, _len)
_hash = process_return(tmp)
return process_return(_hash)
# This is an approximate translation of CPython's unicode_hash:
# https://github.com/python/cpython/blob/d1dd6be613381b996b9071443ef081de8e5f3aff/Objects/unicodeobject.c#L11635-L11663 # noqa: E501
@overload_method(types.UnicodeType, '__hash__')
def unicode_hash(val):
from numba.cpython.unicode import _kind_to_byte_width
def impl(val):
kindwidth = _kind_to_byte_width(val._kind)
_len = len(val)
# use the cache if possible
current_hash = val._hash
if current_hash != -1:
return current_hash
else:
# cannot write hash value to cache in the unicode struct due to
# pass by value on the struct making the struct member immutable
return _Py_HashBytes(val._data, kindwidth * _len)
return impl

View File

@@ -0,0 +1,266 @@
# A port of https://github.com/python/cpython/blob/e42b7051/Lib/heapq.py
import heapq as hq
from numba.core import types
from numba.core.errors import TypingError
from numba.core.extending import overload, register_jitable
@register_jitable
def _siftdown(heap, startpos, pos):
newitem = heap[pos]
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if newitem < parent:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
@register_jitable
def _siftup(heap, pos):
endpos = len(heap)
startpos = pos
newitem = heap[pos]
childpos = 2 * pos + 1
while childpos < endpos:
rightpos = childpos + 1
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
childpos = rightpos
heap[pos] = heap[childpos]
pos = childpos
childpos = 2 * pos + 1
heap[pos] = newitem
_siftdown(heap, startpos, pos)
@register_jitable
def _siftdown_max(heap, startpos, pos):
newitem = heap[pos]
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if parent < newitem:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
@register_jitable
def _siftup_max(heap, pos):
endpos = len(heap)
startpos = pos
newitem = heap[pos]
childpos = 2 * pos + 1
while childpos < endpos:
rightpos = childpos + 1
if rightpos < endpos and not heap[rightpos] < heap[childpos]:
childpos = rightpos
heap[pos] = heap[childpos]
pos = childpos
childpos = 2 * pos + 1
heap[pos] = newitem
_siftdown_max(heap, startpos, pos)
@register_jitable
def reversed_range(x):
# analogous to reversed(range(x))
return range(x - 1, -1, -1)
@register_jitable
def _heapify_max(x):
n = len(x)
for i in reversed_range(n // 2):
_siftup_max(x, i)
@register_jitable
def _heapreplace_max(heap, item):
returnitem = heap[0]
heap[0] = item
_siftup_max(heap, 0)
return returnitem
def assert_heap_type(heap):
if not isinstance(heap, (types.List, types.ListType)):
raise TypingError('heap argument must be a list')
dt = heap.dtype
if isinstance(dt, types.Complex):
msg = ("'<' not supported between instances "
"of 'complex' and 'complex'")
raise TypingError(msg)
def assert_item_type_consistent_with_heap_type(heap, item):
if not heap.dtype == item:
raise TypingError('heap type must be the same as item type')
@overload(hq.heapify)
def hq_heapify(x):
assert_heap_type(x)
def hq_heapify_impl(x):
n = len(x)
for i in reversed_range(n // 2):
_siftup(x, i)
return hq_heapify_impl
@overload(hq.heappop)
def hq_heappop(heap):
assert_heap_type(heap)
def hq_heappop_impl(heap):
lastelt = heap.pop()
if heap:
returnitem = heap[0]
heap[0] = lastelt
_siftup(heap, 0)
return returnitem
return lastelt
return hq_heappop_impl
@overload(hq.heappush)
def heappush(heap, item):
assert_heap_type(heap)
assert_item_type_consistent_with_heap_type(heap, item)
def hq_heappush_impl(heap, item):
heap.append(item)
_siftdown(heap, 0, len(heap) - 1)
return hq_heappush_impl
@overload(hq.heapreplace)
def heapreplace(heap, item):
assert_heap_type(heap)
assert_item_type_consistent_with_heap_type(heap, item)
def hq_heapreplace(heap, item):
returnitem = heap[0]
heap[0] = item
_siftup(heap, 0)
return returnitem
return hq_heapreplace
@overload(hq.heappushpop)
def heappushpop(heap, item):
assert_heap_type(heap)
assert_item_type_consistent_with_heap_type(heap, item)
def hq_heappushpop_impl(heap, item):
if heap and heap[0] < item:
item, heap[0] = heap[0], item
_siftup(heap, 0)
return item
return hq_heappushpop_impl
def check_input_types(n, iterable):
if not isinstance(n, (types.Integer, types.Boolean)):
raise TypingError("First argument 'n' must be an integer")
# heapq also accepts 1.0 (but not 0.0, 2.0, 3.0...) but
# this isn't replicated
if not isinstance(iterable, (types.Sequence, types.Array, types.ListType)):
raise TypingError("Second argument 'iterable' must be iterable")
@overload(hq.nsmallest)
def nsmallest(n, iterable):
check_input_types(n, iterable)
def hq_nsmallest_impl(n, iterable):
if n == 0:
return [iterable[0] for _ in range(0)]
elif n == 1:
out = min(iterable)
return [out]
size = len(iterable)
if n >= size:
return sorted(iterable)[:n]
it = iter(iterable)
result = [(elem, i) for i, elem in zip(range(n), it)]
_heapify_max(result)
top = result[0][0]
order = n
for elem in it:
if elem < top:
_heapreplace_max(result, (elem, order))
top, _order = result[0]
order += 1
result.sort()
return [elem for (elem, order) in result]
return hq_nsmallest_impl
@overload(hq.nlargest)
def nlargest(n, iterable):
check_input_types(n, iterable)
def hq_nlargest_impl(n, iterable):
if n == 0:
return [iterable[0] for _ in range(0)]
elif n == 1:
out = max(iterable)
return [out]
size = len(iterable)
if n >= size:
return sorted(iterable)[::-1][:n]
it = iter(iterable)
result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
hq.heapify(result)
top = result[0][0]
order = -n
for elem in it:
if top < elem:
hq.heapreplace(result, (elem, order))
top, _order = result[0]
order -= 1
result.sort(reverse=True)
return [elem for (elem, order) in result]
return hq_nlargest_impl

View File

@@ -0,0 +1,140 @@
"""
Implementation of various iterable and iterator types.
"""
from numba.core import types, cgutils
from numba.core.imputils import (
lower_builtin, iternext_impl, call_iternext, call_getiter,
impl_ret_borrowed, impl_ret_new_ref, RefType)
@lower_builtin('getiter', types.IteratorType)
def iterator_getiter(context, builder, sig, args):
[it] = args
return impl_ret_borrowed(context, builder, sig.return_type, it)
#-------------------------------------------------------------------------------
# builtin `enumerate` implementation
@lower_builtin(enumerate, types.IterableType)
@lower_builtin(enumerate, types.IterableType, types.Integer)
def make_enumerate_object(context, builder, sig, args):
assert len(args) == 1 or len(args) == 2 # enumerate(it) or enumerate(it, start)
srcty = sig.args[0]
if len(args) == 1:
src = args[0]
start_val = context.get_constant(types.intp, 0)
elif len(args) == 2:
src = args[0]
start_val = context.cast(builder, args[1], sig.args[1], types.intp)
iterobj = call_getiter(context, builder, srcty, src)
enum = context.make_helper(builder, sig.return_type)
countptr = cgutils.alloca_once(builder, start_val.type)
builder.store(start_val, countptr)
enum.count = countptr
enum.iter = iterobj
res = enum._getvalue()
return impl_ret_new_ref(context, builder, sig.return_type, res)
@lower_builtin('iternext', types.EnumerateType)
@iternext_impl(RefType.NEW)
def iternext_enumerate(context, builder, sig, args, result):
[enumty] = sig.args
[enum] = args
enum = context.make_helper(builder, enumty, value=enum)
count = builder.load(enum.count)
ncount = builder.add(count, context.get_constant(types.intp, 1))
builder.store(ncount, enum.count)
srcres = call_iternext(context, builder, enumty.source_type, enum.iter)
is_valid = srcres.is_valid()
result.set_valid(is_valid)
with builder.if_then(is_valid):
srcval = srcres.yielded_value()
result.yield_(context.make_tuple(builder, enumty.yield_type,
[count, srcval]))
#-------------------------------------------------------------------------------
# builtin `zip` implementation
@lower_builtin(zip, types.VarArg(types.Any))
def make_zip_object(context, builder, sig, args):
zip_type = sig.return_type
assert len(args) == len(zip_type.source_types)
zipobj = context.make_helper(builder, zip_type)
for i, (arg, srcty) in enumerate(zip(args, sig.args)):
zipobj[i] = call_getiter(context, builder, srcty, arg)
res = zipobj._getvalue()
return impl_ret_new_ref(context, builder, sig.return_type, res)
@lower_builtin('iternext', types.ZipType)
@iternext_impl(RefType.NEW)
def iternext_zip(context, builder, sig, args, result):
[zip_type] = sig.args
[zipobj] = args
zipobj = context.make_helper(builder, zip_type, value=zipobj)
if len(zipobj) == 0:
# zip() is an empty iterator
result.set_exhausted()
return
p_ret_tup = cgutils.alloca_once(builder,
context.get_value_type(zip_type.yield_type))
p_is_valid = cgutils.alloca_once_value(builder, value=cgutils.true_bit)
for i, (iterobj, srcty) in enumerate(zip(zipobj, zip_type.source_types)):
is_valid = builder.load(p_is_valid)
# Avoid calling the remaining iternext if a iterator has been exhausted
with builder.if_then(is_valid):
srcres = call_iternext(context, builder, srcty, iterobj)
is_valid = builder.and_(is_valid, srcres.is_valid())
builder.store(is_valid, p_is_valid)
val = srcres.yielded_value()
ptr = cgutils.gep_inbounds(builder, p_ret_tup, 0, i)
builder.store(val, ptr)
is_valid = builder.load(p_is_valid)
result.set_valid(is_valid)
with builder.if_then(is_valid):
result.yield_(builder.load(p_ret_tup))
#-------------------------------------------------------------------------------
# generator implementation
@lower_builtin('iternext', types.Generator)
@iternext_impl(RefType.BORROWED)
def iternext_zip(context, builder, sig, args, result):
genty, = sig.args
gen, = args
impl = context.get_generator_impl(genty)
status, retval = impl(context, builder, sig, args)
context.add_linking_libs(getattr(impl, 'libs', ()))
with cgutils.if_likely(builder, status.is_ok):
result.set_valid(True)
result.yield_(retval)
with cgutils.if_unlikely(builder, status.is_stop_iteration):
result.set_exhausted()
with cgutils.if_unlikely(builder,
builder.and_(status.is_error,
builder.not_(status.is_stop_iteration))):
context.call_conv.return_status_propagate(builder, status)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,471 @@
"""
Provide math calls that uses intrinsics or libc math functions.
"""
import math
import operator
import sys
import numpy as np
import llvmlite.ir
from llvmlite.ir import Constant
from numba.core.imputils import Registry, impl_ret_untracked
from numba import typeof
from numba.core import types, utils, config, cgutils
from numba.core.extending import overload
from numba.core.typing import signature
from numba.cpython.unsafe.numbers import trailing_zeros
registry = Registry('mathimpl')
lower = registry.lower
# Helpers, shared with cmathimpl.
_NP_FLT_FINFO = np.finfo(np.dtype('float32'))
FLT_MAX = _NP_FLT_FINFO.max
FLT_MIN = _NP_FLT_FINFO.tiny
_NP_DBL_FINFO = np.finfo(np.dtype('float64'))
DBL_MAX = _NP_DBL_FINFO.max
DBL_MIN = _NP_DBL_FINFO.tiny
FLOAT_ABS_MASK = 0x7fffffff
FLOAT_SIGN_MASK = 0x80000000
DOUBLE_ABS_MASK = 0x7fffffffffffffff
DOUBLE_SIGN_MASK = 0x8000000000000000
def is_nan(builder, val):
"""
Return a condition testing whether *val* is a NaN.
"""
return builder.fcmp_unordered('uno', val, val)
def is_inf(builder, val):
"""
Return a condition testing whether *val* is an infinite.
"""
pos_inf = Constant(val.type, float("+inf"))
neg_inf = Constant(val.type, float("-inf"))
isposinf = builder.fcmp_ordered('==', val, pos_inf)
isneginf = builder.fcmp_ordered('==', val, neg_inf)
return builder.or_(isposinf, isneginf)
def is_finite(builder, val):
"""
Return a condition testing whether *val* is a finite.
"""
# is_finite(x) <=> x - x != NaN
val_minus_val = builder.fsub(val, val)
return builder.fcmp_ordered('ord', val_minus_val, val_minus_val)
def f64_as_int64(builder, val):
"""
Bitcast a double into a 64-bit integer.
"""
assert val.type == llvmlite.ir.DoubleType()
return builder.bitcast(val, llvmlite.ir.IntType(64))
def int64_as_f64(builder, val):
"""
Bitcast a 64-bit integer into a double.
"""
assert val.type == llvmlite.ir.IntType(64)
return builder.bitcast(val, llvmlite.ir.DoubleType())
def f32_as_int32(builder, val):
"""
Bitcast a float into a 32-bit integer.
"""
assert val.type == llvmlite.ir.FloatType()
return builder.bitcast(val, llvmlite.ir.IntType(32))
def int32_as_f32(builder, val):
"""
Bitcast a 32-bit integer into a float.
"""
assert val.type == llvmlite.ir.IntType(32)
return builder.bitcast(val, llvmlite.ir.FloatType())
def negate_real(builder, val):
"""
Negate real number *val*, with proper handling of zeros.
"""
# The negative zero forces LLVM to handle signed zeros properly.
return builder.fsub(Constant(val.type, -0.0), val)
def call_fp_intrinsic(builder, name, args):
"""
Call a LLVM intrinsic floating-point operation.
"""
mod = builder.module
intr = mod.declare_intrinsic(name, [a.type for a in args])
return builder.call(intr, args)
def _unary_int_input_wrapper_impl(wrapped_impl):
"""
Return an implementation factory to convert the single integral input
argument to a float64, then defer to the *wrapped_impl*.
"""
def implementer(context, builder, sig, args):
val, = args
input_type = sig.args[0]
fpval = context.cast(builder, val, input_type, types.float64)
inner_sig = signature(types.float64, types.float64)
res = wrapped_impl(context, builder, inner_sig, (fpval,))
return context.cast(builder, res, types.float64, sig.return_type)
return implementer
def unary_math_int_impl(fn, float_impl):
impl = _unary_int_input_wrapper_impl(float_impl)
lower(fn, types.Integer)(impl)
def unary_math_intr(fn, intrcode):
"""
Implement the math function *fn* using the LLVM intrinsic *intrcode*.
"""
@lower(fn, types.Float)
def float_impl(context, builder, sig, args):
res = call_fp_intrinsic(builder, intrcode, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
unary_math_int_impl(fn, float_impl)
return float_impl
def unary_math_extern(fn, f32extern, f64extern, int_restype=False):
"""
Register implementations of Python function *fn* using the
external function named *f32extern* and *f64extern* (for float32
and float64 inputs, respectively).
If *int_restype* is true, then the function's return value should be
integral, otherwise floating-point.
"""
f_restype = types.int64 if int_restype else None
def float_impl(context, builder, sig, args):
"""
Implement *fn* for a types.Float input.
"""
[val] = args
mod = builder.module
input_type = sig.args[0]
lty = context.get_value_type(input_type)
func_name = {
types.float32: f32extern,
types.float64: f64extern,
}[input_type]
fnty = llvmlite.ir.FunctionType(lty, [lty])
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
res = builder.call(fn, (val,))
res = context.cast(builder, res, input_type, sig.return_type)
return impl_ret_untracked(context, builder, sig.return_type, res)
lower(fn, types.Float)(float_impl)
# Implement wrapper for integer inputs
unary_math_int_impl(fn, float_impl)
return float_impl
unary_math_intr(math.fabs, 'llvm.fabs')
exp_impl = unary_math_intr(math.exp, 'llvm.exp')
if sys.version_info >= (3, 11):
exp2_impl = unary_math_intr(math.exp2, 'llvm.exp2')
log_impl = unary_math_intr(math.log, 'llvm.log')
log10_impl = unary_math_intr(math.log10, 'llvm.log10')
log2_impl = unary_math_intr(math.log2, 'llvm.log2')
sin_impl = unary_math_intr(math.sin, 'llvm.sin')
cos_impl = unary_math_intr(math.cos, 'llvm.cos')
log1p_impl = unary_math_extern(math.log1p, "log1pf", "log1p")
expm1_impl = unary_math_extern(math.expm1, "expm1f", "expm1")
erf_impl = unary_math_extern(math.erf, "erff", "erf")
erfc_impl = unary_math_extern(math.erfc, "erfcf", "erfc")
tan_impl = unary_math_extern(math.tan, "tanf", "tan")
asin_impl = unary_math_extern(math.asin, "asinf", "asin")
acos_impl = unary_math_extern(math.acos, "acosf", "acos")
atan_impl = unary_math_extern(math.atan, "atanf", "atan")
asinh_impl = unary_math_extern(math.asinh, "asinhf", "asinh")
acosh_impl = unary_math_extern(math.acosh, "acoshf", "acosh")
atanh_impl = unary_math_extern(math.atanh, "atanhf", "atanh")
sinh_impl = unary_math_extern(math.sinh, "sinhf", "sinh")
cosh_impl = unary_math_extern(math.cosh, "coshf", "cosh")
tanh_impl = unary_math_extern(math.tanh, "tanhf", "tanh")
log2_impl = unary_math_extern(math.log2, "log2f", "log2")
ceil_impl = unary_math_extern(math.ceil, "ceilf", "ceil", True)
floor_impl = unary_math_extern(math.floor, "floorf", "floor", True)
gamma_impl = unary_math_extern(math.gamma, "numba_gammaf", "numba_gamma") # work-around
sqrt_impl = unary_math_extern(math.sqrt, "sqrtf", "sqrt")
trunc_impl = unary_math_extern(math.trunc, "truncf", "trunc", True)
lgamma_impl = unary_math_extern(math.lgamma, "lgammaf", "lgamma")
@lower(math.isnan, types.Float)
def isnan_float_impl(context, builder, sig, args):
[val] = args
res = is_nan(builder, val)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.isnan, types.Integer)
def isnan_int_impl(context, builder, sig, args):
res = cgutils.false_bit
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.isinf, types.Float)
def isinf_float_impl(context, builder, sig, args):
[val] = args
res = is_inf(builder, val)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.isinf, types.Integer)
def isinf_int_impl(context, builder, sig, args):
res = cgutils.false_bit
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.isfinite, types.Float)
def isfinite_float_impl(context, builder, sig, args):
[val] = args
res = is_finite(builder, val)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.isfinite, types.Integer)
def isfinite_int_impl(context, builder, sig, args):
res = cgutils.true_bit
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.copysign, types.Float, types.Float)
def copysign_float_impl(context, builder, sig, args):
lty = args[0].type
mod = builder.module
fn = cgutils.get_or_insert_function(mod, llvmlite.ir.FunctionType(lty, (lty, lty)),
'llvm.copysign.%s' % lty.intrinsic_name)
res = builder.call(fn, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
# -----------------------------------------------------------------------------
@lower(math.frexp, types.Float)
def frexp_impl(context, builder, sig, args):
val, = args
fltty = context.get_data_type(sig.args[0])
intty = context.get_data_type(sig.return_type[1])
expptr = cgutils.alloca_once(builder, intty, name='exp')
fnty = llvmlite.ir.FunctionType(fltty, (fltty, llvmlite.ir.PointerType(intty)))
fname = {
"float": "numba_frexpf",
"double": "numba_frexp",
}[str(fltty)]
fn = cgutils.get_or_insert_function(builder.module, fnty, fname)
res = builder.call(fn, (val, expptr))
res = cgutils.make_anonymous_struct(builder, (res, builder.load(expptr)))
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.ldexp, types.Float, types.intc)
def ldexp_impl(context, builder, sig, args):
val, exp = args
fltty, intty = map(context.get_data_type, sig.args)
fnty = llvmlite.ir.FunctionType(fltty, (fltty, intty))
fname = {
"float": "numba_ldexpf",
"double": "numba_ldexp",
}[str(fltty)]
fn = cgutils.insert_pure_function(builder.module, fnty, name=fname)
res = builder.call(fn, (val, exp))
return impl_ret_untracked(context, builder, sig.return_type, res)
# -----------------------------------------------------------------------------
@lower(math.atan2, types.int64, types.int64)
def atan2_s64_impl(context, builder, sig, args):
[y, x] = args
y = builder.sitofp(y, llvmlite.ir.DoubleType())
x = builder.sitofp(x, llvmlite.ir.DoubleType())
fsig = signature(types.float64, types.float64, types.float64)
return atan2_float_impl(context, builder, fsig, (y, x))
@lower(math.atan2, types.uint64, types.uint64)
def atan2_u64_impl(context, builder, sig, args):
[y, x] = args
y = builder.uitofp(y, llvmlite.ir.DoubleType())
x = builder.uitofp(x, llvmlite.ir.DoubleType())
fsig = signature(types.float64, types.float64, types.float64)
return atan2_float_impl(context, builder, fsig, (y, x))
@lower(math.atan2, types.Float, types.Float)
def atan2_float_impl(context, builder, sig, args):
assert len(args) == 2
mod = builder.module
ty = sig.args[0]
lty = context.get_value_type(ty)
func_name = {
types.float32: "atan2f",
types.float64: "atan2"
}[ty]
fnty = llvmlite.ir.FunctionType(lty, (lty, lty))
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
res = builder.call(fn, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
# -----------------------------------------------------------------------------
@lower(math.hypot, types.int64, types.int64)
def hypot_s64_impl(context, builder, sig, args):
[x, y] = args
y = builder.sitofp(y, llvmlite.ir.DoubleType())
x = builder.sitofp(x, llvmlite.ir.DoubleType())
fsig = signature(types.float64, types.float64, types.float64)
res = hypot_float_impl(context, builder, fsig, (x, y))
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.hypot, types.uint64, types.uint64)
def hypot_u64_impl(context, builder, sig, args):
[x, y] = args
y = builder.sitofp(y, llvmlite.ir.DoubleType())
x = builder.sitofp(x, llvmlite.ir.DoubleType())
fsig = signature(types.float64, types.float64, types.float64)
res = hypot_float_impl(context, builder, fsig, (x, y))
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(math.hypot, types.Float, types.Float)
def hypot_float_impl(context, builder, sig, args):
xty, yty = sig.args
assert xty == yty == sig.return_type
x, y = args
# Windows has alternate names for hypot/hypotf, see
# https://msdn.microsoft.com/fr-fr/library/a9yb3dbt%28v=vs.80%29.aspx
fname = {
types.float32: "_hypotf" if sys.platform == 'win32' else "hypotf",
types.float64: "_hypot" if sys.platform == 'win32' else "hypot",
}[xty]
plat_hypot = types.ExternalFunction(fname, sig)
if sys.platform == 'win32' and config.MACHINE_BITS == 32:
inf = xty(float('inf'))
def hypot_impl(x, y):
if math.isinf(x) or math.isinf(y):
return inf
return plat_hypot(x, y)
else:
def hypot_impl(x, y):
return plat_hypot(x, y)
res = context.compile_internal(builder, hypot_impl, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
# -----------------------------------------------------------------------------
@lower(math.radians, types.Float)
def radians_float_impl(context, builder, sig, args):
[x] = args
coef = context.get_constant(sig.return_type, math.pi / 180)
res = builder.fmul(x, coef)
return impl_ret_untracked(context, builder, sig.return_type, res)
unary_math_int_impl(math.radians, radians_float_impl)
# -----------------------------------------------------------------------------
@lower(math.degrees, types.Float)
def degrees_float_impl(context, builder, sig, args):
[x] = args
coef = context.get_constant(sig.return_type, 180 / math.pi)
res = builder.fmul(x, coef)
return impl_ret_untracked(context, builder, sig.return_type, res)
unary_math_int_impl(math.degrees, degrees_float_impl)
# -----------------------------------------------------------------------------
@lower(math.pow, types.Float, types.Float)
@lower(math.pow, types.Float, types.Integer)
def pow_impl(context, builder, sig, args):
impl = context.get_function(operator.pow, sig)
return impl(builder, args)
# -----------------------------------------------------------------------------
@lower(math.nextafter, types.Float, types.Float)
def nextafter_impl(context, builder, sig, args):
assert len(args) == 2
ty = sig.args[0]
lty = context.get_value_type(ty)
func_name = {
types.float32: "nextafterf",
types.float64: "nextafter"
}[ty]
fnty = llvmlite.ir.FunctionType(lty, (lty, lty))
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
res = builder.call(fn, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
# -----------------------------------------------------------------------------
def _unsigned(T):
"""Convert integer to unsigned integer of equivalent width."""
pass
@overload(_unsigned)
def _unsigned_impl(T):
if T in types.unsigned_domain:
return lambda T: T
elif T in types.signed_domain:
newT = getattr(types, 'uint{}'.format(T.bitwidth))
return lambda T: newT(T)
def gcd_impl(context, builder, sig, args):
xty, yty = sig.args
assert xty == yty == sig.return_type
x, y = args
def gcd(a, b):
"""
Stein's algorithm, heavily cribbed from Julia implementation.
"""
T = type(a)
if a == 0: return abs(b)
if b == 0: return abs(a)
za = trailing_zeros(a)
zb = trailing_zeros(b)
k = min(za, zb)
# Uses np.*_shift instead of operators due to return types
u = _unsigned(abs(np.right_shift(a, za)))
v = _unsigned(abs(np.right_shift(b, zb)))
while u != v:
if u > v:
u, v = v, u
v -= u
v = np.right_shift(v, trailing_zeros(v))
r = np.left_shift(T(u), k)
return r
res = context.compile_internal(builder, gcd, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
lower(math.gcd, types.Integer, types.Integer)(gcd_impl)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
"""
This file implements print functionality for the CPU.
"""
from numba.core import types, typing, cgutils
from numba.core.imputils import Registry, impl_ret_untracked
registry = Registry('printimpl')
lower = registry.lower
# NOTE: the current implementation relies on CPython API even in
# nopython mode.
@lower("print_item", types.Literal)
def print_item_impl(context, builder, sig, args):
"""
Print a single constant value.
"""
ty, = sig.args
val = ty.literal_value
pyapi = context.get_python_api(builder)
strobj = pyapi.unserialize(pyapi.serialize_object(val))
pyapi.print_object(strobj)
pyapi.decref(strobj)
res = context.get_dummy_value()
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower("print_item", types.Any)
def print_item_impl(context, builder, sig, args):
"""
Print a single native value by boxing it in a Python object and
invoking the Python interpreter's print routine.
"""
ty, = sig.args
val, = args
pyapi = context.get_python_api(builder)
env_manager = context.get_env_manager(builder)
if context.enable_nrt:
context.nrt.incref(builder, ty, val)
obj = pyapi.from_native_value(ty, val, env_manager)
with builder.if_else(cgutils.is_not_null(builder, obj), likely=True) as (if_ok, if_error):
with if_ok:
pyapi.print_object(obj)
pyapi.decref(obj)
with if_error:
cstr = context.insert_const_string(builder.module,
"the print() function")
strobj = pyapi.string_from_string(cstr)
pyapi.err_write_unraisable(strobj)
pyapi.decref(strobj)
res = context.get_dummy_value()
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower(print, types.VarArg(types.Any))
def print_varargs_impl(context, builder, sig, args):
"""
A entire print() call.
"""
pyapi = context.get_python_api(builder)
gil = pyapi.gil_ensure()
for i, (argtype, argval) in enumerate(zip(sig.args, args)):
signature = typing.signature(types.none, argtype)
imp = context.get_function("print_item", signature)
imp(builder, [argval])
if i < len(args) - 1:
pyapi.print_string(' ')
pyapi.print_string('\n')
pyapi.gil_release(gil)
res = context.get_dummy_value()
return impl_ret_untracked(context, builder, sig.return_type, res)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,241 @@
"""
Implementation of the range object for fixed-size integers.
"""
import operator
from numba import prange
from numba.core import types, cgutils, errors, config
from numba.core.imputils import (lower_builtin, lower_cast,
iterator_impl, impl_ret_untracked)
from numba.core.typing import signature
from numba.core.extending import intrinsic, overload, overload_attribute, register_jitable
from numba.parfors.parfor import internal_prange
def make_range_iterator(typ):
"""
Return the Structure representation of the given *typ* (an
instance of types.RangeIteratorType).
"""
return cgutils.create_struct_proxy(typ)
def make_range_impl(int_type, range_state_type, range_iter_type):
RangeState = cgutils.create_struct_proxy(range_state_type)
@lower_builtin(range, int_type)
@lower_builtin(prange, int_type)
@lower_builtin(internal_prange, int_type)
def range1_impl(context, builder, sig, args):
"""
range(stop: int) -> range object
"""
[stop] = args
state = RangeState(context, builder)
state.start = context.get_constant(int_type, 0)
state.stop = stop
state.step = context.get_constant(int_type, 1)
return impl_ret_untracked(context,
builder,
range_state_type,
state._getvalue())
@lower_builtin(range, int_type, int_type)
@lower_builtin(prange, int_type, int_type)
@lower_builtin(internal_prange, int_type, int_type)
def range2_impl(context, builder, sig, args):
"""
range(start: int, stop: int) -> range object
"""
start, stop = args
state = RangeState(context, builder)
state.start = start
state.stop = stop
state.step = context.get_constant(int_type, 1)
return impl_ret_untracked(context,
builder,
range_state_type,
state._getvalue())
@lower_builtin(range, int_type, int_type, int_type)
@lower_builtin(prange, int_type, int_type, int_type)
@lower_builtin(internal_prange, int_type, int_type, int_type)
def range3_impl(context, builder, sig, args):
"""
range(start: int, stop: int, step: int) -> range object
"""
[start, stop, step] = args
state = RangeState(context, builder)
state.start = start
state.stop = stop
state.step = step
return impl_ret_untracked(context,
builder,
range_state_type,
state._getvalue())
@lower_builtin(len, range_state_type)
def range_len(context, builder, sig, args):
"""
len(range)
"""
(value,) = args
state = RangeState(context, builder, value)
res = RangeIter.from_range_state(context, builder, state)
return impl_ret_untracked(context, builder, int_type, builder.load(res.count))
@lower_builtin('getiter', range_state_type)
def getiter_range32_impl(context, builder, sig, args):
"""
range.__iter__
"""
(value,) = args
state = RangeState(context, builder, value)
res = RangeIter.from_range_state(context, builder, state)._getvalue()
return impl_ret_untracked(context, builder, range_iter_type, res)
@iterator_impl(range_state_type, range_iter_type)
class RangeIter(make_range_iterator(range_iter_type)):
@classmethod
def from_range_state(cls, context, builder, state):
"""
Create a RangeIter initialized from the given RangeState *state*.
"""
self = cls(context, builder)
start = state.start
stop = state.stop
step = state.step
startptr = cgutils.alloca_once(builder, start.type)
builder.store(start, startptr)
countptr = cgutils.alloca_once(builder, start.type)
self.iter = startptr
self.stop = stop
self.step = step
self.count = countptr
diff = builder.sub(stop, start)
zero = context.get_constant(int_type, 0)
one = context.get_constant(int_type, 1)
pos_diff = builder.icmp_signed('>', diff, zero)
pos_step = builder.icmp_signed('>', step, zero)
sign_differs = builder.xor(pos_diff, pos_step)
zero_step = builder.icmp_unsigned('==', step, zero)
with cgutils.if_unlikely(builder, zero_step):
# step shouldn't be zero
context.call_conv.return_user_exc(builder, ValueError,
("range() arg 3 must not be zero",))
with builder.if_else(sign_differs) as (then, orelse):
with then:
builder.store(zero, self.count)
with orelse:
rem = builder.srem(diff, step)
rem = builder.select(pos_diff, rem, builder.neg(rem))
uneven = builder.icmp_signed('>', rem, zero)
newcount = builder.add(builder.sdiv(diff, step),
builder.select(uneven, one, zero))
builder.store(newcount, self.count)
return self
def iternext(self, context, builder, result):
zero = context.get_constant(int_type, 0)
countptr = self.count
count = builder.load(countptr)
is_valid = builder.icmp_signed('>', count, zero)
result.set_valid(is_valid)
with builder.if_then(is_valid):
value = builder.load(self.iter)
result.yield_(value)
one = context.get_constant(int_type, 1)
builder.store(builder.sub(count, one, flags=["nsw"]), countptr)
builder.store(builder.add(value, self.step), self.iter)
range_impl_map = {
types.int32 : (types.range_state32_type, types.range_iter32_type),
types.int64 : (types.range_state64_type, types.range_iter64_type),
types.uint64 : (types.unsigned_range_state64_type, types.unsigned_range_iter64_type)
}
for int_type, state_types in range_impl_map.items():
make_range_impl(int_type, *state_types)
@lower_cast(types.RangeType, types.RangeType)
def range_to_range(context, builder, fromty, toty, val):
olditems = cgutils.unpack_tuple(builder, val, 3)
items = [context.cast(builder, v, fromty.dtype, toty.dtype)
for v in olditems]
return cgutils.make_anonymous_struct(builder, items)
def make_range_attr(index, attribute):
@intrinsic
def rangetype_attr_getter(typingctx, a):
if isinstance(a, types.RangeType):
def codegen(context, builder, sig, args):
(val,) = args
items = cgutils.unpack_tuple(builder, val, 3)
return impl_ret_untracked(context, builder, sig.return_type,
items[index])
return signature(a.dtype, a), codegen
@overload_attribute(types.RangeType, attribute)
def range_attr(rnge):
def get(rnge):
return rangetype_attr_getter(rnge)
return get
@register_jitable
def impl_contains_helper(robj, val):
if robj.step > 0 and (val < robj.start or val >= robj.stop):
return False
elif robj.step < 0 and (val <= robj.stop or val > robj.start):
return False
return ((val - robj.start) % robj.step) == 0
@overload(operator.contains)
def impl_contains(robj, val):
def impl_false(robj, val):
return False
if not isinstance(robj, types.RangeType):
return
elif isinstance(val, (types.Integer, types.Boolean)):
return impl_contains_helper
elif isinstance(val, types.Float):
def impl(robj, val):
if val % 1 != 0:
return False
else:
return impl_contains_helper(robj, int(val))
return impl
elif isinstance(val, types.Complex):
def impl(robj, val):
if val.imag != 0:
return False
elif val.real % 1 != 0:
return False
else:
return impl_contains_helper(robj, int(val.real))
return impl
elif not isinstance(val, types.Number):
return impl_false
for ix, attr in enumerate(('start', 'stop', 'step')):
make_range_attr(index=ix, attribute=attr)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,302 @@
"""
Implement slices and various slice computations.
"""
from itertools import zip_longest
from llvmlite import ir
from numba.core import cgutils, types, typing, utils
from numba.core.imputils import (impl_ret_borrowed, impl_ret_new_ref,
impl_ret_untracked, iternext_impl,
lower_builtin, lower_cast, lower_constant,
lower_getattr)
def fix_index(builder, idx, size):
"""
Fix negative index by adding *size* to it. Positive
indices are left untouched.
"""
is_negative = builder.icmp_signed('<', idx, ir.Constant(size.type, 0))
wrapped_index = builder.add(idx, size)
return builder.select(is_negative, wrapped_index, idx)
def fix_slice(builder, slice, size):
"""
Fix *slice* start and stop to be valid (inclusive and exclusive, resp)
indexing bounds for a sequence of the given *size*.
"""
# See PySlice_GetIndicesEx()
zero = ir.Constant(size.type, 0)
minus_one = ir.Constant(size.type, -1)
def fix_bound(bound_name, lower_repl, upper_repl):
bound = getattr(slice, bound_name)
bound = fix_index(builder, bound, size)
# Store value
setattr(slice, bound_name, bound)
# Still negative? => clamp to lower_repl
underflow = builder.icmp_signed('<', bound, zero)
with builder.if_then(underflow, likely=False):
setattr(slice, bound_name, lower_repl)
# Greater than size? => clamp to upper_repl
overflow = builder.icmp_signed('>=', bound, size)
with builder.if_then(overflow, likely=False):
setattr(slice, bound_name, upper_repl)
with builder.if_else(cgutils.is_neg_int(builder, slice.step)) as (if_neg_step, if_pos_step):
with if_pos_step:
# < 0 => 0; >= size => size
fix_bound('start', zero, size)
fix_bound('stop', zero, size)
with if_neg_step:
# < 0 => -1; >= size => size - 1
lower = minus_one
upper = builder.add(size, minus_one)
fix_bound('start', lower, upper)
fix_bound('stop', lower, upper)
def get_slice_length(builder, slicestruct):
"""
Given a slice, compute the number of indices it spans, i.e. the
number of iterations that for_range_slice() will execute.
Pseudo-code:
assert step != 0
if step > 0:
if stop <= start:
return 0
else:
return (stop - start - 1) // step + 1
else:
if stop >= start:
return 0
else:
return (stop - start + 1) // step + 1
(see PySlice_GetIndicesEx() in CPython)
"""
start = slicestruct.start
stop = slicestruct.stop
step = slicestruct.step
one = ir.Constant(start.type, 1)
zero = ir.Constant(start.type, 0)
is_step_negative = cgutils.is_neg_int(builder, step)
delta = builder.sub(stop, start)
# Nominal case
pos_dividend = builder.sub(delta, one)
neg_dividend = builder.add(delta, one)
dividend = builder.select(is_step_negative, neg_dividend, pos_dividend)
nominal_length = builder.add(one, builder.sdiv(dividend, step))
# Catch zero length
is_zero_length = builder.select(is_step_negative,
builder.icmp_signed('>=', delta, zero),
builder.icmp_signed('<=', delta, zero))
# Clamp to 0 if is_zero_length
return builder.select(is_zero_length, zero, nominal_length)
def get_slice_bounds(builder, slicestruct):
"""
Return the [lower, upper) indexing bounds of a slice.
"""
start = slicestruct.start
stop = slicestruct.stop
zero = start.type(0)
one = start.type(1)
# This is a bit pessimal, e.g. it will return [1, 5) instead
# of [1, 4) for `1:5:2`
is_step_negative = builder.icmp_signed('<', slicestruct.step, zero)
lower = builder.select(is_step_negative,
builder.add(stop, one), start)
upper = builder.select(is_step_negative,
builder.add(start, one), stop)
return lower, upper
def fix_stride(builder, slice, stride):
"""
Fix the given stride for the slice's step.
"""
return builder.mul(slice.step, stride)
def guard_invalid_slice(context, builder, typ, slicestruct):
"""
Guard against *slicestruct* having a zero step (and raise ValueError).
"""
if typ.has_step:
cgutils.guard_null(context, builder, slicestruct.step,
(ValueError, "slice step cannot be zero"))
def get_defaults(context):
"""
Get the default values for a slice's members:
(start for positive step, start for negative step,
stop for positive step, stop for negative step, step)
"""
maxint = (1 << (context.address_size - 1)) - 1
return (0, maxint, maxint, - maxint - 1, 1)
#---------------------------------------------------------------------------
# The slice structure
@lower_builtin(slice, types.VarArg(types.Any))
def slice_constructor_impl(context, builder, sig, args):
(
default_start_pos,
default_start_neg,
default_stop_pos,
default_stop_neg,
default_step,
) = [context.get_constant(types.intp, x) for x in get_defaults(context)]
slice_args = [None] * 3
# Fetch non-None arguments
if len(args) == 1 and sig.args[0] is not types.none:
slice_args[1] = args[0]
else:
for i, (ty, val) in enumerate(zip(sig.args, args)):
if ty is not types.none:
slice_args[i] = val
# Fill omitted arguments
def get_arg_value(i, default):
val = slice_args[i]
if val is None:
return default
else:
return val
step = get_arg_value(2, default_step)
is_step_negative = builder.icmp_signed('<', step,
context.get_constant(types.intp, 0))
default_stop = builder.select(is_step_negative,
default_stop_neg, default_stop_pos)
default_start = builder.select(is_step_negative,
default_start_neg, default_start_pos)
stop = get_arg_value(1, default_stop)
start = get_arg_value(0, default_start)
ty = sig.return_type
sli = context.make_helper(builder, sig.return_type)
sli.start = start
sli.stop = stop
sli.step = step
res = sli._getvalue()
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_getattr(types.SliceType, "start")
def slice_start_impl(context, builder, typ, value):
sli = context.make_helper(builder, typ, value)
return sli.start
@lower_getattr(types.SliceType, "stop")
def slice_stop_impl(context, builder, typ, value):
sli = context.make_helper(builder, typ, value)
return sli.stop
@lower_getattr(types.SliceType, "step")
def slice_step_impl(context, builder, typ, value):
if typ.has_step:
sli = context.make_helper(builder, typ, value)
return sli.step
else:
return context.get_constant(types.intp, 1)
@lower_builtin("slice.indices", types.SliceType, types.Integer)
def slice_indices(context, builder, sig, args):
length = args[1]
sli = context.make_helper(builder, sig.args[0], args[0])
with builder.if_then(cgutils.is_neg_int(builder, length), likely=False):
context.call_conv.return_user_exc(
builder, ValueError,
("length should not be negative",)
)
with builder.if_then(cgutils.is_scalar_zero(builder, sli.step), likely=False):
context.call_conv.return_user_exc(
builder, ValueError,
("slice step cannot be zero",)
)
fix_slice(builder, sli, length)
return context.make_tuple(
builder,
sig.return_type,
(sli.start, sli.stop, sli.step)
)
def make_slice_from_constant(context, builder, ty, pyval):
sli = context.make_helper(builder, ty)
lty = context.get_value_type(types.intp)
(
default_start_pos,
default_start_neg,
default_stop_pos,
default_stop_neg,
default_step,
) = [context.get_constant(types.intp, x) for x in get_defaults(context)]
step = pyval.step
if step is None:
step_is_neg = False
step = default_step
else:
step_is_neg = step < 0
step = lty(step)
start = pyval.start
if start is None:
if step_is_neg:
start = default_start_neg
else:
start = default_start_pos
else:
start = lty(start)
stop = pyval.stop
if stop is None:
if step_is_neg:
stop = default_stop_neg
else:
stop = default_stop_pos
else:
stop = lty(stop)
sli.start = start
sli.stop = stop
sli.step = step
return sli._getvalue()
@lower_constant(types.SliceType)
def constant_slice(context, builder, ty, pyval):
if isinstance(ty, types.Literal):
typ = ty.literal_type
else:
typ = ty
return make_slice_from_constant(context, builder, typ, pyval)
@lower_cast(types.misc.SliceLiteral, types.SliceType)
def cast_from_literal(context, builder, fromty, toty, val):
return make_slice_from_constant(
context, builder, toty, fromty.literal_value,
)

View File

@@ -0,0 +1,412 @@
"""
Implementation of tuple objects
"""
import operator
from numba.core.imputils import (lower_builtin, lower_getattr_generic,
lower_cast, lower_constant, iternext_impl,
impl_ret_borrowed, impl_ret_untracked,
RefType)
from numba.core import typing, types, cgutils
from numba.core.extending import overload_method, overload, intrinsic
@lower_builtin(types.NamedTupleClass, types.VarArg(types.Any))
def namedtuple_constructor(context, builder, sig, args):
# A namedtuple has the same representation as a regular tuple
# the arguments need casting (lower_cast) from the types in the ctor args
# to those in the ctor return type, this is to handle cases such as a
# literal present in the args, but a type present in the return type.
newargs = []
for i, arg in enumerate(args):
casted = context.cast(builder, arg, sig.args[i], sig.return_type[i])
newargs.append(casted)
res = context.make_tuple(builder, sig.return_type, tuple(newargs))
# The tuple's contents are borrowed
return impl_ret_borrowed(context, builder, sig.return_type, res)
@lower_builtin(operator.add, types.BaseTuple, types.BaseTuple)
def tuple_add(context, builder, sig, args):
left, right = [cgutils.unpack_tuple(builder, x) for x in args]
res = context.make_tuple(builder, sig.return_type, left + right)
# The tuple's contents are borrowed
return impl_ret_borrowed(context, builder, sig.return_type, res)
def tuple_cmp_ordered(context, builder, op, sig, args):
tu, tv = sig.args
u, v = args
res = cgutils.alloca_once_value(builder, cgutils.true_bit)
bbend = builder.append_basic_block("cmp_end")
for i, (ta, tb) in enumerate(zip(tu.types, tv.types)):
a = builder.extract_value(u, i)
b = builder.extract_value(v, i)
not_equal = context.generic_compare(builder, operator.ne, (ta, tb), (a, b))
with builder.if_then(not_equal):
pred = context.generic_compare(builder, op, (ta, tb), (a, b))
builder.store(pred, res)
builder.branch(bbend)
# Everything matched equal => compare lengths
len_compare = op(len(tu.types), len(tv.types))
pred = context.get_constant(types.boolean, len_compare)
builder.store(pred, res)
builder.branch(bbend)
builder.position_at_end(bbend)
return builder.load(res)
@lower_builtin(operator.eq, types.BaseTuple, types.BaseTuple)
def tuple_eq(context, builder, sig, args):
tu, tv = sig.args
u, v = args
if len(tu.types) != len(tv.types):
res = context.get_constant(types.boolean, False)
return impl_ret_untracked(context, builder, sig.return_type, res)
res = context.get_constant(types.boolean, True)
for i, (ta, tb) in enumerate(zip(tu.types, tv.types)):
a = builder.extract_value(u, i)
b = builder.extract_value(v, i)
pred = context.generic_compare(builder, operator.eq, (ta, tb), (a, b))
res = builder.and_(res, pred)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_builtin(operator.ne, types.BaseTuple, types.BaseTuple)
def tuple_ne(context, builder, sig, args):
res = builder.not_(tuple_eq(context, builder, sig, args))
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_builtin(operator.lt, types.BaseTuple, types.BaseTuple)
def tuple_lt(context, builder, sig, args):
res = tuple_cmp_ordered(context, builder, operator.lt, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_builtin(operator.le, types.BaseTuple, types.BaseTuple)
def tuple_le(context, builder, sig, args):
res = tuple_cmp_ordered(context, builder, operator.le, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_builtin(operator.gt, types.BaseTuple, types.BaseTuple)
def tuple_gt(context, builder, sig, args):
res = tuple_cmp_ordered(context, builder, operator.gt, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
@lower_builtin(operator.ge, types.BaseTuple, types.BaseTuple)
def tuple_ge(context, builder, sig, args):
res = tuple_cmp_ordered(context, builder, operator.ge, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
# for hashing see hashing.py
@lower_getattr_generic(types.BaseNamedTuple)
def namedtuple_getattr(context, builder, typ, value, attr):
"""
Fetch a namedtuple's field.
"""
index = typ.fields.index(attr)
res = builder.extract_value(value, index)
return impl_ret_borrowed(context, builder, typ[index], res)
@lower_constant(types.UniTuple)
@lower_constant(types.NamedUniTuple)
def unituple_constant(context, builder, ty, pyval):
"""
Create a homogeneous tuple constant.
"""
consts = [context.get_constant_generic(builder, ty.dtype, v)
for v in pyval]
return impl_ret_borrowed(
context, builder, ty, cgutils.pack_array(builder, consts),
)
@lower_constant(types.Tuple)
@lower_constant(types.NamedTuple)
def unituple_constant(context, builder, ty, pyval):
"""
Create a heterogeneous tuple constant.
"""
consts = [context.get_constant_generic(builder, ty.types[i], v)
for i, v in enumerate(pyval)]
return impl_ret_borrowed(
context, builder, ty, cgutils.pack_struct(builder, consts),
)
#------------------------------------------------------------------------------
# Tuple iterators
@lower_builtin('getiter', types.UniTuple)
@lower_builtin('getiter', types.NamedUniTuple)
def getiter_unituple(context, builder, sig, args):
[tupty] = sig.args
[tup] = args
iterval = context.make_helper(builder, types.UniTupleIter(tupty))
index0 = context.get_constant(types.intp, 0)
indexptr = cgutils.alloca_once(builder, index0.type)
builder.store(index0, indexptr)
iterval.index = indexptr
iterval.tuple = tup
res = iterval._getvalue()
return impl_ret_borrowed(context, builder, sig.return_type, res)
@lower_builtin('iternext', types.UniTupleIter)
@iternext_impl(RefType.BORROWED)
def iternext_unituple(context, builder, sig, args, result):
[tupiterty] = sig.args
[tupiter] = args
iterval = context.make_helper(builder, tupiterty, value=tupiter)
tup = iterval.tuple
idxptr = iterval.index
idx = builder.load(idxptr)
count = context.get_constant(types.intp, tupiterty.container.count)
is_valid = builder.icmp_signed('<', idx, count)
result.set_valid(is_valid)
with builder.if_then(is_valid):
getitem_sig = typing.signature(tupiterty.container.dtype,
tupiterty.container,
types.intp)
getitem_out = getitem_unituple(context, builder, getitem_sig,
[tup, idx])
# As a iternext_impl function, this will incref the yieled value.
# We need to release the new reference from getitem_unituple.
if context.enable_nrt:
context.nrt.decref(builder, tupiterty.container.dtype, getitem_out)
result.yield_(getitem_out)
nidx = builder.add(idx, context.get_constant(types.intp, 1))
builder.store(nidx, iterval.index)
@overload(operator.getitem)
def getitem_literal_idx(tup, idx):
"""
Overloads BaseTuple getitem to cover cases where constant
inference and RewriteConstGetitems cannot replace it
with a static_getitem.
"""
if not (isinstance(tup, types.BaseTuple)
and isinstance(idx, types.IntegerLiteral)):
return None
idx_val = idx.literal_value
def getitem_literal_idx_impl(tup, idx):
return tup[idx_val]
return getitem_literal_idx_impl
@lower_builtin('typed_getitem', types.BaseTuple, types.Any)
def getitem_typed(context, builder, sig, args):
tupty, _ = sig.args
tup, idx = args
errmsg_oob = ("tuple index out of range",)
if len(tupty) == 0:
# Empty tuple.
# Always branch and raise IndexError
with builder.if_then(cgutils.true_bit):
context.call_conv.return_user_exc(builder, IndexError,
errmsg_oob)
# This is unreachable in runtime,
# but it exists to not terminate the current basicblock.
res = context.get_constant_null(sig.return_type)
return impl_ret_untracked(context, builder,
sig.return_type, res)
else:
# The tuple is not empty
bbelse = builder.append_basic_block("typed_switch.else")
bbend = builder.append_basic_block("typed_switch.end")
switch = builder.switch(idx, bbelse)
with builder.goto_block(bbelse):
context.call_conv.return_user_exc(builder, IndexError,
errmsg_oob)
lrtty = context.get_value_type(sig.return_type)
voidptrty = context.get_value_type(types.voidptr)
with builder.goto_block(bbend):
phinode = builder.phi(voidptrty)
for i in range(tupty.count):
ki = context.get_constant(types.intp, i)
bbi = builder.append_basic_block("typed_switch.%d" % i)
switch.add_case(ki, bbi)
# handle negative indexing, create case (-tuple.count + i) to
# reference same block as i
kin = context.get_constant(types.intp, -tupty.count + i)
switch.add_case(kin, bbi)
with builder.goto_block(bbi):
value = builder.extract_value(tup, i)
# Dragon warning...
# The fact the code has made it this far suggests that type
# inference decided whatever was being done with the item pulled
# from the tuple was legitimate, it is not the job of lowering
# to argue about that. However, here lies a problem, the tuple
# lowering is implemented as a switch table with each case
# writing to a phi node slot that is returned. The type of this
# phi node slot needs to be "correct" for the current type but
# it also needs to survive stores being made to it from the
# other cases that will in effect never run. To do this a stack
# slot is made for each case for the specific type and then cast
# to a void pointer type, this is then added as an incoming on
# the phi node, at the end of the switch the phi node is then
# cast back to the required return type for this typed_getitem.
# The only further complication is that if the value is not a
# pointer then the void* juggle won't work so a cast is made
# prior to store, again, that type inference has permitted it
# suggests this is safe.
# End Dragon warning...
DOCAST = context.typing_context.unify_types(sig.args[0][i],
sig.return_type) == sig.return_type
if DOCAST:
value_slot = builder.alloca(lrtty,
name="TYPED_VALUE_SLOT%s" % i)
casted = context.cast(builder, value, sig.args[0][i],
sig.return_type)
builder.store(casted, value_slot)
else:
value_slot = builder.alloca(value.type,
name="TYPED_VALUE_SLOT%s" % i)
builder.store(value, value_slot)
phinode.add_incoming(builder.bitcast(value_slot, voidptrty),
bbi)
builder.branch(bbend)
builder.position_at_end(bbend)
res = builder.bitcast(phinode, lrtty.as_pointer())
res = builder.load(res)
return impl_ret_borrowed(context, builder, sig.return_type, res)
@lower_builtin(operator.getitem, types.UniTuple, types.intp)
@lower_builtin(operator.getitem, types.UniTuple, types.uintp)
@lower_builtin(operator.getitem, types.NamedUniTuple, types.intp)
@lower_builtin(operator.getitem, types.NamedUniTuple, types.uintp)
def getitem_unituple(context, builder, sig, args):
tupty, _ = sig.args
tup, idx = args
errmsg_oob = ("tuple index out of range",)
if len(tupty) == 0:
# Empty tuple.
# Always branch and raise IndexError
with builder.if_then(cgutils.true_bit):
context.call_conv.return_user_exc(builder, IndexError,
errmsg_oob)
# This is unreachable in runtime,
# but it exists to not terminate the current basicblock.
res = context.get_constant_null(sig.return_type)
return impl_ret_untracked(context, builder,
sig.return_type, res)
else:
# The tuple is not empty
bbelse = builder.append_basic_block("switch.else")
bbend = builder.append_basic_block("switch.end")
switch = builder.switch(idx, bbelse)
with builder.goto_block(bbelse):
context.call_conv.return_user_exc(builder, IndexError,
errmsg_oob)
lrtty = context.get_value_type(tupty.dtype)
with builder.goto_block(bbend):
phinode = builder.phi(lrtty)
for i in range(tupty.count):
ki = context.get_constant(types.intp, i)
bbi = builder.append_basic_block("switch.%d" % i)
switch.add_case(ki, bbi)
# handle negative indexing, create case (-tuple.count + i) to
# reference same block as i
kin = context.get_constant(types.intp, -tupty.count + i)
switch.add_case(kin, bbi)
with builder.goto_block(bbi):
value = builder.extract_value(tup, i)
builder.branch(bbend)
phinode.add_incoming(value, bbi)
builder.position_at_end(bbend)
res = phinode
assert sig.return_type == tupty.dtype
return impl_ret_borrowed(context, builder, sig.return_type, res)
@lower_builtin('static_getitem', types.LiteralStrKeyDict, types.StringLiteral)
@lower_builtin('static_getitem', types.LiteralList, types.IntegerLiteral)
@lower_builtin('static_getitem', types.LiteralList, types.SliceLiteral)
@lower_builtin('static_getitem', types.BaseTuple, types.IntegerLiteral)
@lower_builtin('static_getitem', types.BaseTuple, types.SliceLiteral)
def static_getitem_tuple(context, builder, sig, args):
tupty, idxty = sig.args
tup, idx = args
if isinstance(idx, int):
if idx < 0:
idx += len(tupty)
if not 0 <= idx < len(tupty):
raise IndexError("cannot index at %d in %s" % (idx, tupty))
res = builder.extract_value(tup, idx)
elif isinstance(idx, slice):
items = cgutils.unpack_tuple(builder, tup)[idx]
res = context.make_tuple(builder, sig.return_type, items)
elif isinstance(tupty, types.LiteralStrKeyDict):
# pretend to be a dictionary
idx_val = idxty.literal_value
idx_offset = tupty.fields.index(idx_val)
res = builder.extract_value(tup, idx_offset)
else:
raise NotImplementedError("unexpected index %r for %s"
% (idx, sig.args[0]))
return impl_ret_borrowed(context, builder, sig.return_type, res)
#------------------------------------------------------------------------------
# Implicit conversion
@lower_cast(types.BaseTuple, types.BaseTuple)
def tuple_to_tuple(context, builder, fromty, toty, val):
if (isinstance(fromty, types.BaseNamedTuple)
or isinstance(toty, types.BaseNamedTuple)):
# Disallowed by typing layer
raise NotImplementedError
if len(fromty) != len(toty):
# Disallowed by typing layer
raise NotImplementedError
olditems = cgutils.unpack_tuple(builder, val, len(fromty))
items = [context.cast(builder, v, f, t)
for v, f, t in zip(olditems, fromty, toty)]
return context.make_tuple(builder, toty, items)
#------------------------------------------------------------------------------
# Methods
@overload_method(types.BaseTuple, 'index')
def tuple_index(tup, value):
def tuple_index_impl(tup, value):
for i in range(len(tup)):
if tup[i] == value:
return i
raise ValueError("tuple.index(x): x not in tuple")
return tuple_index_impl
@overload(operator.contains)
def in_seq_empty_tuple(x, y):
if isinstance(x, types.Tuple) and not x.types:
return lambda x, y: False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,765 @@
"""
This module contains support functions for more advanced unicode operations.
This is not a public API and is for Numba internal use only. Most of the
functions are relatively straightforward translations of the functions with the
same name in CPython.
"""
from collections import namedtuple
from enum import IntEnum
import llvmlite.ir
import numpy as np
from numba.core import types, cgutils
from numba.core.imputils import (impl_ret_untracked)
from numba.core.extending import overload, intrinsic, register_jitable
from numba.core.errors import TypingError
# This is equivalent to the struct `_PyUnicode_TypeRecord defined in CPython's
# Objects/unicodectype.c
typerecord = namedtuple('typerecord',
'upper lower title decimal digit flags')
# The Py_UCS4 type from CPython:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/unicodeobject.h#L112 # noqa: E501
_Py_UCS4 = types.uint32
# ------------------------------------------------------------------------------
# Start code related to/from CPython's unicodectype impl
#
# NOTE: the original source at:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c # noqa: E501
# contains this statement:
#
# /*
# Unicode character type helpers.
#
# Written by Marc-Andre Lemburg (mal@lemburg.com).
# Modified for Python 2.0 by Fredrik Lundh (fredrik@pythonware.com)
#
# Copyright (c) Corporation for National Research Initiatives.
#
# */
# This enum contains the values defined in CPython's Objects/unicodectype.c that
# provide masks for use against the various members of the typerecord
#
# See: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L13-L27 # noqa: E501
#
_Py_TAB = 0x9
_Py_LINEFEED = 0xa
_Py_CARRIAGE_RETURN = 0xd
_Py_SPACE = 0x20
class _PyUnicode_TyperecordMasks(IntEnum):
ALPHA_MASK = 0x01
DECIMAL_MASK = 0x02
DIGIT_MASK = 0x04
LOWER_MASK = 0x08
LINEBREAK_MASK = 0x10
SPACE_MASK = 0x20
TITLE_MASK = 0x40
UPPER_MASK = 0x80
XID_START_MASK = 0x100
XID_CONTINUE_MASK = 0x200
PRINTABLE_MASK = 0x400
NUMERIC_MASK = 0x800
CASE_IGNORABLE_MASK = 0x1000
CASED_MASK = 0x2000
EXTENDED_CASE_MASK = 0x4000
def _PyUnicode_gettyperecord(a):
raise RuntimeError("Calling the Python definition is invalid")
@intrinsic
def _gettyperecord_impl(typingctx, codepoint):
"""
Provides the binding to numba_gettyperecord, returns a `typerecord`
namedtuple of properties from the codepoint.
"""
if not isinstance(codepoint, types.Integer):
raise TypingError("codepoint must be an integer")
def details(context, builder, signature, args):
ll_void = context.get_value_type(types.void)
ll_Py_UCS4 = context.get_value_type(_Py_UCS4)
ll_intc = context.get_value_type(types.intc)
ll_intc_ptr = ll_intc.as_pointer()
ll_uchar = context.get_value_type(types.uchar)
ll_uchar_ptr = ll_uchar.as_pointer()
ll_ushort = context.get_value_type(types.ushort)
ll_ushort_ptr = ll_ushort.as_pointer()
fnty = llvmlite.ir.FunctionType(ll_void, [
ll_Py_UCS4, # code
ll_intc_ptr, # upper
ll_intc_ptr, # lower
ll_intc_ptr, # title
ll_uchar_ptr, # decimal
ll_uchar_ptr, # digit
ll_ushort_ptr, # flags
])
fn = cgutils.get_or_insert_function(
builder.module,
fnty, name="numba_gettyperecord")
upper = cgutils.alloca_once(builder, ll_intc, name='upper')
lower = cgutils.alloca_once(builder, ll_intc, name='lower')
title = cgutils.alloca_once(builder, ll_intc, name='title')
decimal = cgutils.alloca_once(builder, ll_uchar, name='decimal')
digit = cgutils.alloca_once(builder, ll_uchar, name='digit')
flags = cgutils.alloca_once(builder, ll_ushort, name='flags')
byref = [ upper, lower, title, decimal, digit, flags]
builder.call(fn, [args[0]] + byref)
buf = []
for x in byref:
buf.append(builder.load(x))
res = context.make_tuple(builder, signature.return_type, tuple(buf))
return impl_ret_untracked(context, builder, signature.return_type, res)
tupty = types.NamedTuple([types.intc, types.intc, types.intc, types.uchar,
types.uchar, types.ushort], typerecord)
sig = tupty(_Py_UCS4)
return sig, details
@overload(_PyUnicode_gettyperecord)
def gettyperecord_impl(a):
"""
Provides a _PyUnicode_gettyperecord binding, for convenience it will accept
single character strings and code points.
"""
if isinstance(a, types.UnicodeType):
from numba.cpython.unicode import _get_code_point
def impl(a):
if len(a) > 1:
msg = "gettyperecord takes a single unicode character"
raise ValueError(msg)
code_point = _get_code_point(a, 0)
data = _gettyperecord_impl(_Py_UCS4(code_point))
return data
return impl
if isinstance(a, types.Integer):
return lambda a: _gettyperecord_impl(_Py_UCS4(a))
# whilst it's possible to grab the _PyUnicode_ExtendedCase symbol as it's global
# it is safer to use a defined api:
@intrinsic
def _PyUnicode_ExtendedCase(typingctx, index):
"""
Accessor function for the _PyUnicode_ExtendedCase array, binds to
numba_get_PyUnicode_ExtendedCase which wraps the array and does the lookup
"""
if not isinstance(index, types.Integer):
raise TypingError("Expected an index")
def details(context, builder, signature, args):
ll_Py_UCS4 = context.get_value_type(_Py_UCS4)
ll_intc = context.get_value_type(types.intc)
fnty = llvmlite.ir.FunctionType(ll_Py_UCS4, [ll_intc])
fn = cgutils.get_or_insert_function(
builder.module,
fnty, name="numba_get_PyUnicode_ExtendedCase")
return builder.call(fn, [args[0]])
sig = _Py_UCS4(types.intc)
return sig, details
# The following functions are replications of the functions with the same name
# in CPython's Objects/unicodectype.c
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L64-L71 # noqa: E501
@register_jitable
def _PyUnicode_ToTitlecase(ch):
ctype = _PyUnicode_gettyperecord(ch)
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
return _PyUnicode_ExtendedCase(ctype.title & 0xFFFF)
return ch + ctype.title
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L76-L81 # noqa: E501
@register_jitable
def _PyUnicode_IsTitlecase(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.TITLE_MASK != 0
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L86-L91 # noqa: E501
@register_jitable
def _PyUnicode_IsXidStart(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.XID_START_MASK != 0
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L96-L101 # noqa: E501
@register_jitable
def _PyUnicode_IsXidContinue(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.XID_CONTINUE_MASK != 0
@register_jitable
def _PyUnicode_ToDecimalDigit(ch):
ctype = _PyUnicode_gettyperecord(ch)
if ctype.flags & _PyUnicode_TyperecordMasks.DECIMAL_MASK:
return ctype.decimal
return -1
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L123-L1128 # noqa: E501
@register_jitable
def _PyUnicode_ToDigit(ch):
ctype = _PyUnicode_gettyperecord(ch)
if ctype.flags & _PyUnicode_TyperecordMasks.DIGIT_MASK:
return ctype.digit
return -1
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L140-L145 # noqa: E501
@register_jitable
def _PyUnicode_IsNumeric(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.NUMERIC_MASK != 0
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L160-L165 # noqa: E501
@register_jitable
def _PyUnicode_IsPrintable(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.PRINTABLE_MASK != 0
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L170-L175 # noqa: E501
@register_jitable
def _PyUnicode_IsLowercase(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.LOWER_MASK != 0
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L180-L185 # noqa: E501
@register_jitable
def _PyUnicode_IsUppercase(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.UPPER_MASK != 0
@register_jitable
def _PyUnicode_IsLineBreak(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.LINEBREAK_MASK != 0
@register_jitable
def _PyUnicode_ToUppercase(ch):
raise NotImplementedError
@register_jitable
def _PyUnicode_ToLowercase(ch):
raise NotImplementedError
# From: https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodectype.c#L211-L225 # noqa: E501
@register_jitable
def _PyUnicode_ToLowerFull(ch, res):
ctype = _PyUnicode_gettyperecord(ch)
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
index = ctype.lower & 0xFFFF
n = ctype.lower >> 24
for i in range(n):
res[i] = _PyUnicode_ExtendedCase(index + i)
return n
res[0] = ch + ctype.lower
return 1
# From: https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodectype.c#L227-L241 # noqa: E501
@register_jitable
def _PyUnicode_ToTitleFull(ch, res):
ctype = _PyUnicode_gettyperecord(ch)
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
index = ctype.title & 0xFFFF
n = ctype.title >> 24
for i in range(n):
res[i] = _PyUnicode_ExtendedCase(index + i)
return n
res[0] = ch + ctype.title
return 1
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L243-L257 # noqa: E501
@register_jitable
def _PyUnicode_ToUpperFull(ch, res):
ctype = _PyUnicode_gettyperecord(ch)
if (ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK):
index = ctype.upper & 0xFFFF
n = ctype.upper >> 24
for i in range(n):
# Perhaps needed to use unicode._set_code_point() here
res[i] = _PyUnicode_ExtendedCase(index + i)
return n
res[0] = ch + ctype.upper
return 1
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L259-L272 # noqa: E501
@register_jitable
def _PyUnicode_ToFoldedFull(ch, res):
ctype = _PyUnicode_gettyperecord(ch)
extended_case_mask = _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK
if ctype.flags & extended_case_mask and (ctype.lower >> 20) & 7:
index = (ctype.lower & 0xFFFF) + (ctype.lower >> 24)
n = (ctype.lower >> 20) & 7
for i in range(n):
res[i] = _PyUnicode_ExtendedCase(index + i)
return n
return _PyUnicode_ToLowerFull(ch, res)
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L274-L279 # noqa: E501
@register_jitable
def _PyUnicode_IsCased(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.CASED_MASK != 0
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L281-L286 # noqa: E501
@register_jitable
def _PyUnicode_IsCaseIgnorable(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.CASE_IGNORABLE_MASK != 0
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L123-L135 # noqa: E501
@register_jitable
def _PyUnicode_IsDigit(ch):
if _PyUnicode_ToDigit(ch) < 0:
return 0
return 1
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L106-L118 # noqa: E501
@register_jitable
def _PyUnicode_IsDecimalDigit(ch):
if _PyUnicode_ToDecimalDigit(ch) < 0:
return 0
return 1
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L291-L296 # noqa: E501
@register_jitable
def _PyUnicode_IsSpace(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.SPACE_MASK != 0
@register_jitable
def _PyUnicode_IsAlpha(ch):
ctype = _PyUnicode_gettyperecord(ch)
return ctype.flags & _PyUnicode_TyperecordMasks.ALPHA_MASK != 0
# End code related to/from CPython's unicodectype impl
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# Start code related to/from CPython's pyctype
# From the definition in CPython's Include/pyctype.h
# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L5-L11 # noqa: E501
class _PY_CTF(IntEnum):
LOWER = 0x01
UPPER = 0x02
ALPHA = 0x01 | 0x02
DIGIT = 0x04
ALNUM = 0x01 | 0x02 | 0x04
SPACE = 0x08
XDIGIT = 0x10
# From the definition in CPython's Python/pyctype.c
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L5 # noqa: E501
_Py_ctype_table = np.array([
0, # 0x0 '\x00'
0, # 0x1 '\x01'
0, # 0x2 '\x02'
0, # 0x3 '\x03'
0, # 0x4 '\x04'
0, # 0x5 '\x05'
0, # 0x6 '\x06'
0, # 0x7 '\x07'
0, # 0x8 '\x08'
_PY_CTF.SPACE, # 0x9 '\t'
_PY_CTF.SPACE, # 0xa '\n'
_PY_CTF.SPACE, # 0xb '\v'
_PY_CTF.SPACE, # 0xc '\f'
_PY_CTF.SPACE, # 0xd '\r'
0, # 0xe '\x0e'
0, # 0xf '\x0f'
0, # 0x10 '\x10'
0, # 0x11 '\x11'
0, # 0x12 '\x12'
0, # 0x13 '\x13'
0, # 0x14 '\x14'
0, # 0x15 '\x15'
0, # 0x16 '\x16'
0, # 0x17 '\x17'
0, # 0x18 '\x18'
0, # 0x19 '\x19'
0, # 0x1a '\x1a'
0, # 0x1b '\x1b'
0, # 0x1c '\x1c'
0, # 0x1d '\x1d'
0, # 0x1e '\x1e'
0, # 0x1f '\x1f'
_PY_CTF.SPACE, # 0x20 ' '
0, # 0x21 '!'
0, # 0x22 '"'
0, # 0x23 '#'
0, # 0x24 '$'
0, # 0x25 '%'
0, # 0x26 '&'
0, # 0x27 "'"
0, # 0x28 '('
0, # 0x29 ')'
0, # 0x2a '*'
0, # 0x2b '+'
0, # 0x2c ','
0, # 0x2d '-'
0, # 0x2e '.'
0, # 0x2f '/'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x30 '0'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x31 '1'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x32 '2'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x33 '3'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x34 '4'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x35 '5'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x36 '6'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x37 '7'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x38 '8'
_PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x39 '9'
0, # 0x3a ':'
0, # 0x3b ';'
0, # 0x3c '<'
0, # 0x3d '='
0, # 0x3e '>'
0, # 0x3f '?'
0, # 0x40 '@'
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x41 'A'
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x42 'B'
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x43 'C'
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x44 'D'
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x45 'E'
_PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x46 'F'
_PY_CTF.UPPER, # 0x47 'G'
_PY_CTF.UPPER, # 0x48 'H'
_PY_CTF.UPPER, # 0x49 'I'
_PY_CTF.UPPER, # 0x4a 'J'
_PY_CTF.UPPER, # 0x4b 'K'
_PY_CTF.UPPER, # 0x4c 'L'
_PY_CTF.UPPER, # 0x4d 'M'
_PY_CTF.UPPER, # 0x4e 'N'
_PY_CTF.UPPER, # 0x4f 'O'
_PY_CTF.UPPER, # 0x50 'P'
_PY_CTF.UPPER, # 0x51 'Q'
_PY_CTF.UPPER, # 0x52 'R'
_PY_CTF.UPPER, # 0x53 'S'
_PY_CTF.UPPER, # 0x54 'T'
_PY_CTF.UPPER, # 0x55 'U'
_PY_CTF.UPPER, # 0x56 'V'
_PY_CTF.UPPER, # 0x57 'W'
_PY_CTF.UPPER, # 0x58 'X'
_PY_CTF.UPPER, # 0x59 'Y'
_PY_CTF.UPPER, # 0x5a 'Z'
0, # 0x5b '['
0, # 0x5c '\\'
0, # 0x5d ']'
0, # 0x5e '^'
0, # 0x5f '_'
0, # 0x60 '`'
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x61 'a'
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x62 'b'
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x63 'c'
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x64 'd'
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x65 'e'
_PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x66 'f'
_PY_CTF.LOWER, # 0x67 'g'
_PY_CTF.LOWER, # 0x68 'h'
_PY_CTF.LOWER, # 0x69 'i'
_PY_CTF.LOWER, # 0x6a 'j'
_PY_CTF.LOWER, # 0x6b 'k'
_PY_CTF.LOWER, # 0x6c 'l'
_PY_CTF.LOWER, # 0x6d 'm'
_PY_CTF.LOWER, # 0x6e 'n'
_PY_CTF.LOWER, # 0x6f 'o'
_PY_CTF.LOWER, # 0x70 'p'
_PY_CTF.LOWER, # 0x71 'q'
_PY_CTF.LOWER, # 0x72 'r'
_PY_CTF.LOWER, # 0x73 's'
_PY_CTF.LOWER, # 0x74 't'
_PY_CTF.LOWER, # 0x75 'u'
_PY_CTF.LOWER, # 0x76 'v'
_PY_CTF.LOWER, # 0x77 'w'
_PY_CTF.LOWER, # 0x78 'x'
_PY_CTF.LOWER, # 0x79 'y'
_PY_CTF.LOWER, # 0x7a 'z'
0, # 0x7b '{'
0, # 0x7c '|'
0, # 0x7d '}'
0, # 0x7e '~'
0, # 0x7f '\x7f'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
], dtype=np.intc)
# From the definition in CPython's Python/pyctype.c
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L145 # noqa: E501
_Py_ctype_tolower = np.array([
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
0x40, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
0x78, 0x79, 0x7a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7,
0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
], dtype=np.uint8)
# From the definition in CPython's Python/pyctype.c
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L180
_Py_ctype_toupper = np.array([
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57,
0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
0x60, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57,
0x58, 0x59, 0x5a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7,
0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
], dtype=np.uint8)
class _PY_CTF_LB(IntEnum):
LINE_BREAK = 0x01
LINE_FEED = 0x02
CARRIAGE_RETURN = 0x04
_Py_ctype_islinebreak = np.array([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
_PY_CTF_LB.LINE_BREAK | _PY_CTF_LB.LINE_FEED, # 0xa '\n'
_PY_CTF_LB.LINE_BREAK, # 0xb '\v'
_PY_CTF_LB.LINE_BREAK, # 0xc '\f'
_PY_CTF_LB.LINE_BREAK | _PY_CTF_LB.CARRIAGE_RETURN, # 0xd '\r'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
_PY_CTF_LB.LINE_BREAK, # 0x1c '\x1c'
_PY_CTF_LB.LINE_BREAK, # 0x1d '\x1d'
_PY_CTF_LB.LINE_BREAK, # 0x1e '\x1e'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
_PY_CTF_LB.LINE_BREAK, # 0x85 '\x85'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0,
], dtype=np.intc)
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pymacro.h#L25 # noqa: E501
@register_jitable
def _Py_CHARMASK(ch):
"""
Equivalent to the CPython macro `Py_CHARMASK()`, masks off all but the
lowest 256 bits of ch.
"""
return types.uint8(ch) & types.uint8(0xff)
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L30 # noqa: E501
@register_jitable
def _Py_TOUPPER(ch):
"""
Equivalent to the CPython macro `Py_TOUPPER()` converts an ASCII range
code point to the upper equivalent
"""
return _Py_ctype_toupper[_Py_CHARMASK(ch)]
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L29 # noqa: E501
@register_jitable
def _Py_TOLOWER(ch):
"""
Equivalent to the CPython macro `Py_TOLOWER()` converts an ASCII range
code point to the lower equivalent
"""
return _Py_ctype_tolower[_Py_CHARMASK(ch)]
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L18 # noqa: E501
@register_jitable
def _Py_ISLOWER(ch):
"""
Equivalent to the CPython macro `Py_ISLOWER()`
"""
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.LOWER
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L19 # noqa: E501
@register_jitable
def _Py_ISUPPER(ch):
"""
Equivalent to the CPython macro `Py_ISUPPER()`
"""
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.UPPER
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L20 # noqa: E501
@register_jitable
def _Py_ISALPHA(ch):
"""
Equivalent to the CPython macro `Py_ISALPHA()`
"""
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.ALPHA
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L21 # noqa: E501
@register_jitable
def _Py_ISDIGIT(ch):
"""
Equivalent to the CPython macro `Py_ISDIGIT()`
"""
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.DIGIT
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L22 # noqa: E501
@register_jitable
def _Py_ISXDIGIT(ch):
"""
Equivalent to the CPython macro `Py_ISXDIGIT()`
"""
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.XDIGIT
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L23 # noqa: E501
@register_jitable
def _Py_ISALNUM(ch):
"""
Equivalent to the CPython macro `Py_ISALNUM()`
"""
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.ALNUM
# Translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L24 # noqa: E501
@register_jitable
def _Py_ISSPACE(ch):
"""
Equivalent to the CPython macro `Py_ISSPACE()`
"""
return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.SPACE
@register_jitable
def _Py_ISLINEBREAK(ch):
"""Check if character is ASCII line break"""
return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.LINE_BREAK
@register_jitable
def _Py_ISLINEFEED(ch):
"""Check if character is line feed `\n`"""
return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.LINE_FEED
@register_jitable
def _Py_ISCARRIAGERETURN(ch):
"""Check if character is carriage return `\r`"""
return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.CARRIAGE_RETURN
# End code related to/from CPython's pyctype
# ------------------------------------------------------------------------------

View File

@@ -0,0 +1,53 @@
""" This module provides the unsafe things for targets/numbers.py
"""
from numba.core import types, errors
from numba.core.extending import intrinsic
from llvmlite import ir
@intrinsic
def viewer(tyctx, val, viewty):
""" Bitcast a scalar 'val' to the given type 'viewty'. """
bits = val.bitwidth
if isinstance(viewty.dtype, types.Integer):
bitcastty = ir.IntType(bits)
elif isinstance(viewty.dtype, types.Float):
bitcastty = ir.FloatType() if bits == 32 else ir.DoubleType()
else:
assert 0, "unreachable"
def codegen(cgctx, builder, typ, args):
flt = args[0]
return builder.bitcast(flt, bitcastty)
retty = viewty.dtype
sig = retty(val, viewty)
return sig, codegen
@intrinsic
def trailing_zeros(typeingctx, src):
"""Counts trailing zeros in the binary representation of an integer."""
if not isinstance(src, types.Integer):
msg = ("trailing_zeros is only defined for integers, but value passed "
f"was '{src}'.")
raise errors.NumbaTypeError(msg)
def codegen(context, builder, signature, args):
[src] = args
return builder.cttz(src, ir.Constant(ir.IntType(1), 0))
return src(src), codegen
@intrinsic
def leading_zeros(typeingctx, src):
"""Counts leading zeros in the binary representation of an integer."""
if not isinstance(src, types.Integer):
msg = ("leading_zeros is only defined for integers, but value passed "
f"was '{src}'.")
raise errors.NumbaTypeError(msg)
def codegen(context, builder, signature, args):
[src] = args
return builder.ctlz(src, ir.Constant(ir.IntType(1), 0))
return src(src), codegen

View File

@@ -0,0 +1,84 @@
"""
This file provides internal compiler utilities that support certain special
operations with tuple and workarounds for limitations enforced in userland.
"""
from numba.core import types, typing, errors
from numba.core.cgutils import alloca_once
from numba.core.extending import intrinsic
@intrinsic
def tuple_setitem(typingctx, tup, idx, val):
"""Return a copy of the tuple with item at *idx* replaced with *val*.
Operation: ``out = tup[:idx] + (val,) + tup[idx + 1:]
**Warning**
- No boundchecking.
- The dtype of the tuple cannot be changed.
*val* is always cast to the existing dtype of the tuple.
"""
def codegen(context, builder, signature, args):
tup, idx, val = args
stack = alloca_once(builder, tup.type)
builder.store(tup, stack)
# Unsafe load on unchecked bounds. Poison value maybe returned.
offptr = builder.gep(stack, [idx.type(0), idx], inbounds=True)
builder.store(val, offptr)
return builder.load(stack)
sig = tup(tup, idx, tup.dtype)
return sig, codegen
@intrinsic
def build_full_slice_tuple(tyctx, sz):
"""Creates a sz-tuple of full slices."""
if not isinstance(sz, types.IntegerLiteral):
raise errors.RequireLiteralValue(sz)
size = int(sz.literal_value)
tuple_type = types.UniTuple(dtype=types.slice2_type, count=size)
sig = tuple_type(sz)
def codegen(context, builder, signature, args):
def impl(length, empty_tuple):
out = empty_tuple
for i in range(length):
out = tuple_setitem(out, i, slice(None, None))
return out
inner_argtypes = [types.intp, tuple_type]
inner_sig = typing.signature(tuple_type, *inner_argtypes)
ll_idx_type = context.get_value_type(types.intp)
# Allocate an empty tuple
empty_tuple = context.get_constant_undef(tuple_type)
inner_args = [ll_idx_type(size), empty_tuple]
res = context.compile_internal(builder, impl, inner_sig, inner_args)
return res
return sig, codegen
@intrinsic
def unpack_single_tuple(tyctx, tup):
"""This exists to handle the situation y = (*x,), the interpreter injects a
call to it in the case of a single value unpack. It's not possible at
interpreting time to differentiate between an unpack on a variable sized
container e.g. list and a fixed one, e.g. tuple. This function handles the
situation should it arise.
"""
# See issue #6534
if not isinstance(tup, types.BaseTuple):
msg = (f"Only tuples are supported when unpacking a single item, "
f"got type: {tup}")
raise errors.UnsupportedError(msg)
sig = tup(tup)
def codegen(context, builder, signature, args):
return args[0] # there's only one tuple and it's a simple pass through
return sig, codegen