This commit is contained in:
2026-04-10 15:06:59 +02:00
parent 3031b7153b
commit e5a4711004
7806 changed files with 1918528 additions and 335 deletions

View File

@@ -0,0 +1,3 @@
from numba.experimental.jitclass.decorators import jitclass
from numba.experimental.jitclass import boxing # Has import-time side effect
from numba.experimental.jitclass import overloads # Has import-time side effect

View File

@@ -0,0 +1,598 @@
import inspect
import operator
import types as pytypes
import typing as pt
from collections import OrderedDict
from collections.abc import Sequence
from llvmlite import ir as llvmir
from numba import njit
from numba.core import cgutils, errors, imputils, types, utils
from numba.core.datamodel import default_manager, models
from numba.core.registry import cpu_target
from numba.core.typing import templates
from numba.core.typing.asnumbatype import as_numba_type
from numba.core.serialize import disable_pickling
from numba.experimental.jitclass import _box
##############################################################################
# Data model
class InstanceModel(models.StructModel):
def __init__(self, dmm, fe_typ):
cls_data_ty = types.ClassDataType(fe_typ)
# MemInfoPointer uses the `dtype` attribute to traverse for nested
# NRT MemInfo. Since we handle nested NRT MemInfo ourselves,
# we will replace provide MemInfoPointer with an opaque type
# so that it does not raise exception for nested meminfo.
dtype = types.Opaque('Opaque.' + str(cls_data_ty))
members = [
('meminfo', types.MemInfoPointer(dtype)),
('data', types.CPointer(cls_data_ty)),
]
super(InstanceModel, self).__init__(dmm, fe_typ, members)
class InstanceDataModel(models.StructModel):
def __init__(self, dmm, fe_typ):
clsty = fe_typ.class_type
members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
default_manager.register(types.ClassInstanceType, InstanceModel)
default_manager.register(types.ClassDataType, InstanceDataModel)
default_manager.register(types.ClassType, models.OpaqueModel)
def _mangle_attr(name):
"""
Mangle attributes.
The resulting name does not startswith an underscore '_'.
"""
return 'm_' + name
##############################################################################
# Class object
_ctor_template = """
def ctor({args}):
return __numba_cls_({args})
"""
def _getargs(fn_sig):
"""
Returns list of positional and keyword argument names in order.
"""
params = fn_sig.parameters
args = []
for k, v in params.items():
if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD:
args.append(k)
else:
msg = "%s argument type unsupported in jitclass" % v.kind
raise errors.UnsupportedError(msg)
return args
@disable_pickling
class JitClassType(type):
"""
The type of any jitclass.
"""
def __new__(cls, name, bases, dct):
if len(bases) != 1:
raise TypeError("must have exactly one base class")
[base] = bases
if isinstance(base, JitClassType):
raise TypeError("cannot subclass from a jitclass")
assert 'class_type' in dct, 'missing "class_type" attr'
outcls = type.__new__(cls, name, bases, dct)
outcls._set_init()
return outcls
def _set_init(cls):
"""
Generate a wrapper for calling the constructor from pure Python.
Note the wrapper will only accept positional arguments.
"""
init = cls.class_type.instance_type.methods['__init__']
init_sig = utils.pysignature(init)
# get postitional and keyword arguments
# offset by one to exclude the `self` arg
args = _getargs(init_sig)[1:]
cls._ctor_sig = init_sig
ctor_source = _ctor_template.format(args=', '.join(args))
glbls = {"__numba_cls_": cls}
exec(ctor_source, glbls)
ctor = glbls['ctor']
cls._ctor = njit(ctor)
def __instancecheck__(cls, instance):
if isinstance(instance, _box.Box):
return instance._numba_type_.class_type is cls.class_type
return False
def __call__(cls, *args, **kwargs):
# The first argument of _ctor_sig is `cls`, which here
# is bound to None and then skipped when invoking the constructor.
bind = cls._ctor_sig.bind(None, *args, **kwargs)
bind.apply_defaults()
return cls._ctor(*bind.args[1:], **bind.kwargs)
##############################################################################
# Registration utils
def _validate_spec(spec):
for k, v in spec.items():
if not isinstance(k, str):
raise TypeError("spec keys should be strings, got %r" % (k,))
if not isinstance(v, types.Type):
raise TypeError("spec values should be Numba type instances, got %r"
% (v,))
def _fix_up_private_attr(clsname, spec):
"""
Apply the same changes to dunder names as CPython would.
"""
out = OrderedDict()
for k, v in spec.items():
if k.startswith('__') and not k.endswith('__'):
k = '_' + clsname + k
out[k] = v
return out
def _add_linking_libs(context, call):
"""
Add the required libs for the callable to allow inlining.
"""
libs = getattr(call, "libs", ())
if libs:
context.add_linking_libs(libs)
def register_class_type(cls, spec, class_ctor, builder):
"""
Internal function to create a jitclass.
Args
----
cls: the original class object (used as the prototype)
spec: the structural specification contains the field types.
class_ctor: the numba type to represent the jitclass
builder: the internal jitclass builder
"""
# Normalize spec
if spec is None:
spec = OrderedDict()
elif isinstance(spec, Sequence):
spec = OrderedDict(spec)
# Extend spec with class annotations.
for attr, py_type in pt.get_type_hints(cls).items():
if attr not in spec:
spec[attr] = as_numba_type(py_type)
_validate_spec(spec)
# Fix up private attribute names
spec = _fix_up_private_attr(cls.__name__, spec)
# Copy methods from base classes
clsdct = {}
for basecls in reversed(inspect.getmro(cls)):
clsdct.update(basecls.__dict__)
methods, props, static_methods, others = {}, {}, {}, {}
for k, v in clsdct.items():
if isinstance(v, pytypes.FunctionType):
methods[k] = v
elif isinstance(v, property):
props[k] = v
elif isinstance(v, staticmethod):
static_methods[k] = v
else:
others[k] = v
# Check for name shadowing
shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec)
if shadowed:
raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
docstring = others.pop('__doc__', "")
_drop_ignored_attrs(others)
if others:
msg = "class members are not yet supported: {0}"
members = ', '.join(others.keys())
raise TypeError(msg.format(members))
for k, v in props.items():
if v.fdel is not None:
raise TypeError("deleter is not supported: {0}".format(k))
jit_methods = {k: njit(v) for k, v in methods.items()}
jit_props = {}
for k, v in props.items():
dct = {}
if v.fget:
dct['get'] = njit(v.fget)
if v.fset:
dct['set'] = njit(v.fset)
jit_props[k] = dct
jit_static_methods = {
k: njit(v.__func__) for k, v in static_methods.items()}
# Instantiate class type
class_type = class_ctor(
cls,
ConstructorTemplate,
spec,
jit_methods,
jit_props,
jit_static_methods)
jit_class_dct = dict(class_type=class_type, __doc__=docstring)
jit_class_dct.update(jit_static_methods)
cls = JitClassType(cls.__name__, (cls,), jit_class_dct)
# Register resolution of the class object
typingctx = cpu_target.typing_context
typingctx.insert_global(cls, class_type)
# Register class
targetctx = cpu_target.target_context
builder(class_type, typingctx, targetctx).register()
as_numba_type.register(cls, class_type.instance_type)
return cls
class ConstructorTemplate(templates.AbstractTemplate):
"""
Base class for jitclass constructor templates.
"""
def generic(self, args, kws):
# Redirect resolution to __init__
instance_type = self.key.instance_type
ctor = instance_type.jit_methods['__init__']
boundargs = (instance_type.get_reference_type(),) + args
disp_type = types.Dispatcher(ctor)
sig = disp_type.get_call_type(self.context, boundargs, kws)
if not isinstance(sig.return_type, types.NoneType):
raise errors.NumbaTypeError(
f"__init__() should return None, not '{sig.return_type}'")
# Actual constructor returns an instance value (not None)
out = templates.signature(instance_type, *sig.args[1:])
return out
def _drop_ignored_attrs(dct):
# ignore anything defined by object
drop = set(['__weakref__',
'__module__',
'__dict__'])
if utils.PYVERSION in ((3, 13), (3, 14)):
# new in python 3.13
drop |= set(['__firstlineno__', '__static_attributes__'])
for att in ('__annotations__',
'__annotate_func__',
'__annotations_cache__'):
if att in dct:
drop.add(att)
for k, v in dct.items():
if isinstance(v, (pytypes.BuiltinFunctionType,
pytypes.BuiltinMethodType)):
drop.add(k)
elif getattr(v, '__objclass__', None) is object:
drop.add(k)
# If a class defines __eq__ but not __hash__, __hash__ is implicitly set to
# None. This is a class member, and class members are not presently
# supported.
if '__hash__' in dct and dct['__hash__'] is None:
drop.add('__hash__')
for k in drop:
dct.pop(k)
class ClassBuilder(object):
"""
A jitclass builder for a mutable jitclass. This will register
typing and implementation hooks to the given typing and target contexts.
"""
class_impl_registry = imputils.Registry('jitclass builder')
implemented_methods = set()
def __init__(self, class_type, typingctx, targetctx):
self.class_type = class_type
self.typingctx = typingctx
self.targetctx = targetctx
def register(self):
"""
Register to the frontend and backend.
"""
# Register generic implementations for all jitclasses
self._register_methods(self.class_impl_registry,
self.class_type.instance_type)
# NOTE other registrations are done at the top-level
# (see ctor_impl and attr_impl below)
self.targetctx.install_registry(self.class_impl_registry)
def _register_methods(self, registry, instance_type):
"""
Register method implementations.
This simply registers that the method names are valid methods. Inside
of imp() below we retrieve the actual method to run from the type of
the receiver argument (i.e. self).
"""
to_register = list(instance_type.jit_methods) + \
list(instance_type.jit_static_methods)
for meth in to_register:
# There's no way to retrieve the particular method name
# inside the implementation function, so we have to register a
# specific closure for each different name
if meth not in self.implemented_methods:
self._implement_method(registry, meth)
self.implemented_methods.add(meth)
def _implement_method(self, registry, attr):
# create a separate instance of imp method to avoid closure clashing
def get_imp():
def imp(context, builder, sig, args):
instance_type = sig.args[0]
if attr in instance_type.jit_methods:
method = instance_type.jit_methods[attr]
elif attr in instance_type.jit_static_methods:
method = instance_type.jit_static_methods[attr]
# imp gets called as a method, where the first argument is
# self. We drop this for a static method.
sig = sig.replace(args=sig.args[1:])
args = args[1:]
disp_type = types.Dispatcher(method)
call = context.get_function(disp_type, sig)
out = call(builder, args)
_add_linking_libs(context, call)
return imputils.impl_ret_new_ref(context, builder,
sig.return_type, out)
return imp
def _getsetitem_gen(getset):
_dunder_meth = "__%s__" % getset
op = getattr(operator, getset)
@templates.infer_global(op)
class GetSetItem(templates.AbstractTemplate):
def generic(self, args, kws):
instance = args[0]
if isinstance(instance, types.ClassInstanceType) and \
_dunder_meth in instance.jit_methods:
meth = instance.jit_methods[_dunder_meth]
disp_type = types.Dispatcher(meth)
sig = disp_type.get_call_type(self.context, args, kws)
return sig
# lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
# from python and numba
imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
imputils.lower_builtin(op,
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
dunder_stripped = attr.strip('_')
if dunder_stripped in ("getitem", "setitem"):
_getsetitem_gen(dunder_stripped)
else:
registry.lower((types.ClassInstanceType, attr),
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
@templates.infer_getattr
class ClassAttribute(templates.AttributeTemplate):
key = types.ClassInstanceType
def generic_resolve(self, instance, attr):
if attr in instance.struct:
# It's a struct field => the type is well-known
return instance.struct[attr]
elif attr in instance.jit_methods:
# It's a jitted method => typeinfer it
meth = instance.jit_methods[attr]
disp_type = types.Dispatcher(meth)
class MethodTemplate(templates.AbstractTemplate):
key = (self.key, attr)
def generic(self, args, kws):
args = (instance,) + tuple(args)
sig = disp_type.get_call_type(self.context, args, kws)
return sig.as_method()
return types.BoundFunction(MethodTemplate, instance)
elif attr in instance.jit_static_methods:
# It's a jitted method => typeinfer it
meth = instance.jit_static_methods[attr]
disp_type = types.Dispatcher(meth)
class StaticMethodTemplate(templates.AbstractTemplate):
key = (self.key, attr)
def generic(self, args, kws):
# Don't add instance as the first argument for a static
# method.
sig = disp_type.get_call_type(self.context, args, kws)
# sig itself does not include ClassInstanceType as it's
# first argument, so instead of calling sig.as_method()
# we insert the recvr. This is equivalent to
# sig.replace(args=(instance,) + sig.args).as_method().
return sig.replace(recvr=instance)
return types.BoundFunction(StaticMethodTemplate, instance)
elif attr in instance.jit_props:
# It's a jitted property => typeinfer its getter
impdct = instance.jit_props[attr]
getter = impdct['get']
disp_type = types.Dispatcher(getter)
sig = disp_type.get_call_type(self.context, (instance,), {})
return sig.return_type
@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
def get_attr_impl(context, builder, typ, value, attr):
"""
Generic getattr() for @jitclass instances.
"""
if attr in typ.struct:
# It's a struct field
inst = context.make_helper(builder, typ, value=value)
data_pointer = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_pointer)
return imputils.impl_ret_borrowed(context, builder,
typ.struct[attr],
getattr(data, _mangle_attr(attr)))
elif attr in typ.jit_props:
# It's a jitted property
getter = typ.jit_props[attr]['get']
sig = templates.signature(None, typ)
dispatcher = types.Dispatcher(getter)
sig = dispatcher.get_call_type(context.typing_context, [typ], {})
call = context.get_function(dispatcher, sig)
out = call(builder, [value])
_add_linking_libs(context, call)
return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
def set_attr_impl(context, builder, sig, args, attr):
"""
Generic setattr() for @jitclass instances.
"""
typ, valty = sig.args
target, val = args
if attr in typ.struct:
# It's a struct member
inst = context.make_helper(builder, typ, value=target)
data_ptr = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_ptr)
# Get old value
attr_type = typ.struct[attr]
oldvalue = getattr(data, _mangle_attr(attr))
# Store n
setattr(data, _mangle_attr(attr), val)
context.nrt.incref(builder, attr_type, val)
# Delete old value
context.nrt.decref(builder, attr_type, oldvalue)
elif attr in typ.jit_props:
# It's a jitted property
setter = typ.jit_props[attr]['set']
disp_type = types.Dispatcher(setter)
sig = disp_type.get_call_type(context.typing_context,
(typ, valty), {})
call = context.get_function(disp_type, sig)
call(builder, (target, val))
_add_linking_libs(context, call)
else:
raise NotImplementedError(
'attribute {0!r} not implemented'.format(attr))
def imp_dtor(context, module, instance_type):
llvoidptr = context.get_value_type(types.voidptr)
llsize = context.get_value_type(types.uintp)
dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
[llvoidptr, llsize, llvoidptr])
fname = "_Dtor.{0}".format(instance_type.name)
dtor_fn = cgutils.get_or_insert_function(module, dtor_ftype, fname)
if dtor_fn.is_declaration:
# Define
builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
alloc_fe_type = instance_type.get_data_type()
alloc_type = context.get_value_type(alloc_fe_type)
ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
data = context.make_helper(builder, alloc_fe_type, ref=ptr)
context.nrt.decref(builder, alloc_fe_type, data._getvalue())
builder.ret_void()
return dtor_fn
@ClassBuilder.class_impl_registry.lower(types.ClassType,
types.VarArg(types.Any))
def ctor_impl(context, builder, sig, args):
"""
Generic constructor (__new__) for jitclasses.
"""
# Allocate the instance
inst_typ = sig.return_type
alloc_type = context.get_data_type(inst_typ.get_data_type())
alloc_size = context.get_abi_sizeof(alloc_type)
meminfo = context.nrt.meminfo_alloc_dtor(
builder,
context.get_constant(types.uintp, alloc_size),
imp_dtor(context, builder.module, inst_typ),
)
data_pointer = context.nrt.meminfo_data(builder, meminfo)
data_pointer = builder.bitcast(data_pointer,
alloc_type.as_pointer())
# Nullify all data
builder.store(cgutils.get_null_value(alloc_type),
data_pointer)
inst_struct = context.make_helper(builder, inst_typ)
inst_struct.meminfo = meminfo
inst_struct.data = data_pointer
# Call the jitted __init__
# TODO: extract the following into a common util
init_sig = (sig.return_type,) + sig.args
init = inst_typ.jit_methods['__init__']
disp_type = types.Dispatcher(init)
call = context.get_function(disp_type, types.void(*init_sig))
_add_linking_libs(context, call)
realargs = [inst_struct._getvalue()] + list(args)
call(builder, realargs)
# Prepare return value
ret = inst_struct._getvalue()
return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)

View File

@@ -0,0 +1,275 @@
"""
Implement logic relating to wrapping (box) and unwrapping (unbox) instances
of jitclasses for use inside the python interpreter.
"""
from functools import wraps, partial
from llvmlite import ir
from numba.core import types, cgutils
from numba.core.decorators import njit
from numba.core.pythonapi import box, unbox, NativeValue
from numba.core.typing.typeof import typeof_impl
from numba.experimental.jitclass import _box
_getter_code_template = """
def accessor(__numba_self_):
return __numba_self_.{0}
"""
_setter_code_template = """
def mutator(__numba_self_, __numba_val):
__numba_self_.{0} = __numba_val
"""
_method_code_template = """
def method(__numba_self_, *args):
return __numba_self_.{method}(*args)
"""
def _generate_property(field, template, fname):
"""
Generate simple function that get/set a field of the instance
"""
source = template.format(field)
glbls = {}
exec(source, glbls)
return njit(glbls[fname])
_generate_getter = partial(_generate_property, template=_getter_code_template,
fname='accessor')
_generate_setter = partial(_generate_property, template=_setter_code_template,
fname='mutator')
def _generate_method(name, func):
"""
Generate a wrapper for calling a method. Note the wrapper will only
accept positional arguments.
"""
source = _method_code_template.format(method=name)
glbls = {}
exec(source, glbls)
method = njit(glbls['method'])
@wraps(func)
def wrapper(*args, **kwargs):
return method(*args, **kwargs)
return wrapper
_cache_specialized_box = {}
def _specialize_box(typ):
"""
Create a subclass of Box that is specialized to the jitclass.
This function caches the result to avoid code bloat.
"""
# Check cache
if typ in _cache_specialized_box:
return _cache_specialized_box[typ]
dct = {'__slots__': (),
'_numba_type_': typ,
'__doc__': typ.class_type.class_doc,
}
# Inject attributes as class properties
for field in typ.struct:
getter = _generate_getter(field)
setter = _generate_setter(field)
dct[field] = property(getter, setter)
# Inject properties as class properties
for field, impdct in typ.jit_props.items():
getter = None
setter = None
if 'get' in impdct:
getter = _generate_getter(field)
if 'set' in impdct:
setter = _generate_setter(field)
# get docstring from either the fget or fset
imp = impdct.get('get') or impdct.get('set') or None
doc = getattr(imp, '__doc__', None)
dct[field] = property(getter, setter, doc=doc)
# Inject methods as class members
supported_dunders = {
"__abs__",
"__annotate_func__",
"__bool__",
"__complex__",
"__contains__",
"__float__",
"__getitem__",
"__hash__",
"__index__",
"__invert__",
"__int__",
"__len__",
"__setitem__",
"__str__",
"__eq__",
"__ne__",
"__ge__",
"__gt__",
"__le__",
"__lt__",
"__add__",
"__floordiv__",
"__lshift__",
"__matmul__",
"__mod__",
"__mul__",
"__neg__",
"__pos__",
"__pow__",
"__rshift__",
"__sub__",
"__truediv__",
"__and__",
"__or__",
"__xor__",
"__iadd__",
"__ifloordiv__",
"__ilshift__",
"__imatmul__",
"__imod__",
"__imul__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__iand__",
"__ior__",
"__ixor__",
"__radd__",
"__rfloordiv__",
"__rlshift__",
"__rmatmul__",
"__rmod__",
"__rmul__",
"__rpow__",
"__rrshift__",
"__rsub__",
"__rtruediv__",
"__rand__",
"__ror__",
"__rxor__",
}
for name, func in typ.methods.items():
if name == "__init__":
continue
if (
name.startswith("__")
and name.endswith("__")
and name not in supported_dunders
):
raise TypeError(f"Method '{name}' is not supported.")
dct[name] = _generate_method(name, func)
# Inject static methods as class members
for name, func in typ.static_methods.items():
dct[name] = _generate_method(name, func)
# Create subclass
subcls = type(typ.classname, (_box.Box,), dct)
# Store to cache
_cache_specialized_box[typ] = subcls
# Pre-compile attribute getter.
# Note: This must be done after the "box" class is created because
# compiling the getter requires the "box" class to be defined.
for k, v in dct.items():
if isinstance(v, property):
prop = getattr(subcls, k)
if prop.fget is not None:
fget = prop.fget
fast_fget = fget.compile((typ,))
fget.disable_compile()
setattr(subcls, k,
property(fast_fget, prop.fset, prop.fdel,
doc=prop.__doc__))
return subcls
###############################################################################
# Implement box/unbox for call wrapper
@box(types.ClassInstanceType)
def _box_class_instance(typ, val, c):
meminfo, dataptr = cgutils.unpack_tuple(c.builder, val)
# Create Box instance
box_subclassed = _specialize_box(typ)
# Note: the ``box_subclassed`` is kept alive by the cache
voidptr_boxcls = c.context.add_dynamic_addr(
c.builder,
id(box_subclassed),
info="box_class_instance",
)
box_cls = c.builder.bitcast(voidptr_boxcls, c.pyapi.pyobj)
box = c.pyapi.call_function_objargs(box_cls, ())
# Initialize Box instance
llvoidptr = ir.IntType(8).as_pointer()
addr_meminfo = c.builder.bitcast(meminfo, llvoidptr)
addr_data = c.builder.bitcast(dataptr, llvoidptr)
def set_member(member_offset, value):
# Access member by byte offset
offset = c.context.get_constant(types.uintp, member_offset)
ptr = cgutils.pointer_add(c.builder, box, offset)
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
c.builder.store(value, casted)
set_member(_box.box_meminfoptr_offset, addr_meminfo)
set_member(_box.box_dataptr_offset, addr_data)
return box
@unbox(types.ClassInstanceType)
def _unbox_class_instance(typ, val, c):
def access_member(member_offset):
# Access member by byte offset
offset = c.context.get_constant(types.uintp, member_offset)
llvoidptr = ir.IntType(8).as_pointer()
ptr = cgutils.pointer_add(c.builder, val, offset)
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
return c.builder.load(casted)
struct_cls = cgutils.create_struct_proxy(typ)
inst = struct_cls(c.context, c.builder)
# load from Python object
ptr_meminfo = access_member(_box.box_meminfoptr_offset)
ptr_dataptr = access_member(_box.box_dataptr_offset)
# store to native structure
inst.meminfo = c.builder.bitcast(ptr_meminfo, inst.meminfo.type)
inst.data = c.builder.bitcast(ptr_dataptr, inst.data.type)
ret = inst._getvalue()
c.context.nrt.incref(c.builder, typ, ret)
return NativeValue(ret, is_error=c.pyapi.c_api_error())
# Add a typeof_impl implementation for boxed jitclasses to short-circut the
# various tests in typeof. This is needed for jitclasses which implement a
# custom hash method. Without this, typeof_impl will return None, and one of the
# later attempts to determine the type of the jitclass (before checking for
# _numba_type_) will look up the object in a dictionary, triggering the hash
# method. This will cause the dispatcher to determine the call signature of the
# jit decorated obj.__hash__ method, which will call typeof(obj), and thus
# infinite loop.
# This implementation is here instead of in typeof.py to avoid circular imports.
@typeof_impl.register(_box.Box)
def _typeof_jitclass_box(val, c):
return getattr(type(val), "_numba_type_")

View File

@@ -0,0 +1,88 @@
from numba.core import types, config
def jitclass(cls_or_spec=None, spec=None):
"""
A function for creating a jitclass.
Can be used as a decorator or function.
Different use cases will cause different arguments to be set.
If specified, ``spec`` gives the types of class fields.
It must be a dictionary or sequence.
With a dictionary, use collections.OrderedDict for stable ordering.
With a sequence, it must contain 2-tuples of (fieldname, fieldtype).
Any class annotations for field names not listed in spec will be added.
For class annotation `x: T` we will append ``("x", as_numba_type(T))`` to
the spec if ``x`` is not already a key in spec.
Examples
--------
1) ``cls_or_spec = None``, ``spec = None``
>>> @jitclass()
... class Foo:
... ...
2) ``cls_or_spec = None``, ``spec = spec``
>>> @jitclass(spec=spec)
... class Foo:
... ...
3) ``cls_or_spec = Foo``, ``spec = None``
>>> @jitclass
... class Foo:
... ...
4) ``cls_or_spec = spec``, ``spec = None``
In this case we update ``cls_or_spec, spec = None, cls_or_spec``.
>>> @jitclass(spec)
... class Foo:
... ...
5) ``cls_or_spec = Foo``, ``spec = spec``
>>> JitFoo = jitclass(Foo, spec)
Returns
-------
If used as a decorator, returns a callable that takes a class object and
returns a compiled version.
If used as a function, returns the compiled class (an instance of
``JitClassType``).
"""
if (cls_or_spec is not None and
spec is None and
not isinstance(cls_or_spec, type)):
# Used like
# @jitclass([("x", intp)])
# class Foo:
# ...
spec = cls_or_spec
cls_or_spec = None
def wrap(cls):
if config.DISABLE_JIT:
return cls
else:
from numba.experimental.jitclass.base import (register_class_type,
ClassBuilder)
cls_jitted = register_class_type(cls, spec, types.ClassType,
ClassBuilder)
# Preserve the module name of the original class
cls_jitted.__module__ = cls.__module__
return cls_jitted
if cls_or_spec is None:
return wrap
else:
return wrap(cls_or_spec)

View File

@@ -0,0 +1,239 @@
"""
Overloads for ClassInstanceType for built-in functions that call dunder methods
on an object.
"""
from functools import wraps
import inspect
import operator
from numba.core.extending import overload
from numba.core.types import ClassInstanceType
def _get_args(n_args):
assert n_args in (1, 2)
return list("xy")[:n_args]
def class_instance_overload(target):
"""
Decorator to add an overload for target that applies when the first argument
is a ClassInstanceType.
"""
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
if not isinstance(args[0], ClassInstanceType):
return
return func(*args, **kwargs)
if target is not complex:
# complex ctor needs special treatment as it uses kwargs
params = list(inspect.signature(wrapped).parameters)
assert params == _get_args(len(params))
return overload(target)(wrapped)
return decorator
def extract_template(template, name):
"""
Extract a code-generated function from a string template.
"""
namespace = {}
exec(template, namespace)
return namespace[name]
def register_simple_overload(func, *attrs, n_args=1,):
"""
Register an overload for func that checks for methods __attr__ for each
attr in attrs.
"""
# Use a template to set the signature correctly.
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
pass
"""
@wraps(extract_template(template, "func"))
def overload_func(*args, **kwargs):
options = [
try_call_method(args[0], f"__{attr}__", n_args)
for attr in attrs
]
return take_first(*options)
return class_instance_overload(func)(overload_func)
def try_call_method(cls_type, method, n_args=1):
"""
If method is defined for cls_type, return a callable that calls this method.
If not, return None.
"""
if method in cls_type.jit_methods:
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
return {arg_names[0]}.{method}({','.join(arg_names[1:])})
"""
return extract_template(template, "func")
def try_call_complex_method(cls_type, method):
""" __complex__ needs special treatment as the argument names are kwargs
and therefore specific in name and default value.
"""
if method in cls_type.jit_methods:
template = f"""
def func(real=0, imag=0):
return real.{method}()
"""
return extract_template(template, "func")
def take_first(*options):
"""
Take the first non-None option.
"""
assert all(o is None or inspect.isfunction(o) for o in options), options
for o in options:
if o is not None:
return o
@class_instance_overload(bool)
def class_bool(x):
using_bool_impl = try_call_method(x, "__bool__")
if '__len__' in x.jit_methods:
def using_len_impl(x):
return bool(len(x))
else:
using_len_impl = None
always_true_impl = lambda x: True
return take_first(using_bool_impl, using_len_impl, always_true_impl)
@class_instance_overload(complex)
def class_complex(real=0, imag=0):
return take_first(
try_call_complex_method(real, "__complex__"),
lambda real=0, imag=0: complex(float(real))
)
@class_instance_overload(operator.contains)
def class_contains(x, y):
# https://docs.python.org/3/reference/expressions.html#membership-test-operations
return try_call_method(x, "__contains__", 2)
# TODO: use __iter__ if defined.
@class_instance_overload(float)
def class_float(x):
options = [try_call_method(x, "__float__")]
if (
"__index__" in x.jit_methods
):
options.append(lambda x: float(x.__index__()))
return take_first(*options)
@class_instance_overload(int)
def class_int(x):
options = [try_call_method(x, "__int__")]
options.append(try_call_method(x, "__index__"))
return take_first(*options)
@class_instance_overload(str)
def class_str(x):
return take_first(
try_call_method(x, "__str__"),
lambda x: repr(x),
)
@class_instance_overload(operator.ne)
def class_ne(x, y):
# This doesn't use register_reflected_overload like the other operators
# because it falls back to inverting __eq__ rather than reflecting its
# arguments (as per the definition of the Python data model).
return take_first(
try_call_method(x, "__ne__", 2),
lambda x, y: not (x == y),
)
def register_reflected_overload(func, meth_forward, meth_reflected):
def class_lt(x, y):
normal_impl = try_call_method(x, f"__{meth_forward}__", 2)
if f"__{meth_reflected}__" in y.jit_methods:
def reflected_impl(x, y):
return y > x
else:
reflected_impl = None
return take_first(normal_impl, reflected_impl)
class_instance_overload(func)(class_lt)
register_simple_overload(abs, "abs")
register_simple_overload(len, "len")
register_simple_overload(hash, "hash")
# Comparison operators.
register_reflected_overload(operator.ge, "ge", "le")
register_reflected_overload(operator.gt, "gt", "lt")
register_reflected_overload(operator.le, "le", "ge")
register_reflected_overload(operator.lt, "lt", "gt")
# Note that eq is missing support for fallback to `x is y`, but `is` and
# `operator.is` are presently unsupported in general.
register_reflected_overload(operator.eq, "eq", "eq")
# Arithmetic operators.
register_simple_overload(operator.add, "add", n_args=2)
register_simple_overload(operator.floordiv, "floordiv", n_args=2)
register_simple_overload(operator.lshift, "lshift", n_args=2)
register_simple_overload(operator.mul, "mul", n_args=2)
register_simple_overload(operator.mod, "mod", n_args=2)
register_simple_overload(operator.neg, "neg")
register_simple_overload(operator.pos, "pos")
register_simple_overload(operator.invert, "invert")
register_simple_overload(operator.pow, "pow", n_args=2)
register_simple_overload(operator.rshift, "rshift", n_args=2)
register_simple_overload(operator.sub, "sub", n_args=2)
register_simple_overload(operator.truediv, "truediv", n_args=2)
# Inplace arithmetic operators.
register_simple_overload(operator.iadd, "iadd", "add", n_args=2)
register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2)
register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2)
register_simple_overload(operator.imul, "imul", "mul", n_args=2)
register_simple_overload(operator.imod, "imod", "mod", n_args=2)
register_simple_overload(operator.ipow, "ipow", "pow", n_args=2)
register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2)
register_simple_overload(operator.isub, "isub", "sub", n_args=2)
register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2)
# Logical operators.
register_simple_overload(operator.and_, "and", n_args=2)
register_simple_overload(operator.or_, "or", n_args=2)
register_simple_overload(operator.xor, "xor", n_args=2)
# Inplace logical operators.
register_simple_overload(operator.iand, "iand", "and", n_args=2)
register_simple_overload(operator.ior, "ior", "or", n_args=2)
register_simple_overload(operator.ixor, "ixor", "xor", n_args=2)