This commit is contained in:
2026-04-10 15:06:59 +02:00
parent 3031b7153b
commit e5a4711004
7806 changed files with 1918528 additions and 335 deletions

View File

@@ -0,0 +1 @@
from .jitclass import jitclass

View File

@@ -0,0 +1,332 @@
"""Provides Numba type, FunctionType, that makes functions as
instances of a first-class function type.
"""
from functools import partial
from numba.extending import typeof_impl
from numba.extending import models, register_model
from numba.extending import unbox, NativeValue, box
from numba.core.imputils import lower_constant, lower_cast
from numba.core.ccallback import CFunc
from numba.core import cgutils
from llvmlite import ir
from numba.core import types, errors
from numba.core.types import (FunctionType, UndefinedFunctionType,
FunctionPrototype, WrapperAddressProtocol)
from numba.core.dispatcher import Dispatcher
@typeof_impl.register(WrapperAddressProtocol)
@typeof_impl.register(CFunc)
def typeof_function_type(val, c):
if isinstance(val, CFunc):
sig = val._sig
elif isinstance(val, WrapperAddressProtocol):
sig = val.signature()
else:
raise NotImplementedError(
f'function type from {type(val).__name__}')
return FunctionType(sig)
@register_model(FunctionPrototype)
class FunctionProtoModel(models.PrimitiveModel):
"""FunctionProtoModel describes the signatures of first-class functions
"""
def __init__(self, dmm, fe_type):
if isinstance(fe_type, FunctionType):
ftype = fe_type.ftype
elif isinstance(fe_type, FunctionPrototype):
ftype = fe_type
else:
raise NotImplementedError((type(fe_type)))
retty = dmm.lookup(ftype.rtype).get_value_type()
args = [dmm.lookup(t).get_value_type() for t in ftype.atypes]
be_type = ir.PointerType(ir.FunctionType(retty, args))
super(FunctionProtoModel, self).__init__(dmm, fe_type, be_type)
@register_model(FunctionType)
@register_model(UndefinedFunctionType)
class FunctionModel(models.StructModel):
"""FunctionModel holds addresses of function implementations
"""
def __init__(self, dmm, fe_type):
members = [
# Address of cfunc wrapper function.
# This uses a C callconv and doesn't not support exceptions.
('c_addr', types.voidptr),
# Address of PyObject* referencing the Python function
# object:
('py_addr', types.voidptr),
# Address of the underlying function object.
# Calling through this function pointer supports all features of
# regular numba function as it follows the same Numba callconv.
('jit_addr', types.voidptr),
]
super(FunctionModel, self).__init__(dmm, fe_type, members)
@lower_constant(types.Dispatcher)
def lower_constant_dispatcher(context, builder, typ, pyval):
return context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
@lower_constant(FunctionType)
def lower_constant_function_type(context, builder, typ, pyval):
typ = typ.get_precise()
if isinstance(pyval, CFunc):
addr = pyval._wrapper_address
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
sfunc.c_addr = context.add_dynamic_addr(builder, addr,
info=str(typ))
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
return sfunc._getvalue()
if isinstance(pyval, Dispatcher):
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
return sfunc._getvalue()
if isinstance(pyval, WrapperAddressProtocol):
addr = pyval.__wrapper_address__()
assert typ.check_signature(pyval.signature())
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
sfunc.c_addr = context.add_dynamic_addr(builder, addr,
info=str(typ))
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
return sfunc._getvalue()
# TODO: implement support for pytypes.FunctionType, ctypes.CFUNCTYPE
raise NotImplementedError(
'lower_constant_struct_function_type({}, {}, {}, {})'
.format(context, builder, typ, pyval))
def _get_wrapper_address(func, sig):
"""Return the address of a compiled cfunc wrapper function of `func`.
Warning: The compiled function must be compatible with the given
signature `sig`. If it is not, then result of calling the compiled
function is undefined. The compatibility is ensured when passing
in a first-class function to a Numba njit compiled function either
as an argument or via namespace scoping.
Parameters
----------
func : object
A Numba cfunc or jit decoreated function or an object that
implements the wrapper address protocol (see note below).
sig : Signature
The expected function signature.
Returns
-------
addr : int
An address in memory (pointer value) of the compiled function
corresponding to the specified signature.
Note: wrapper address protocol
------------------------------
An object implements the wrapper address protocol iff the object
provides a callable attribute named __wrapper_address__ that takes
a Signature instance as the argument, and returns an integer
representing the address or pointer value of a compiled function
for the given signature.
"""
if not sig.is_precise():
# addr==-1 will indicate that no implementation is available
# for cases where type-inference did not identified the
# function type. For example, the type of an unused
# jit-decorated function argument will be undefined but also
# irrelevant.
addr = -1
elif hasattr(func, '__wrapper_address__'):
# func can be any object that implements the
# __wrapper_address__ protocol.
addr = func.__wrapper_address__()
elif isinstance(func, CFunc):
assert sig == func._sig
addr = func.address
elif isinstance(func, Dispatcher):
cres = func.get_compile_result(sig)
wrapper_name = cres.fndesc.llvm_cfunc_wrapper_name
addr = cres.library.get_pointer_to_function(wrapper_name)
else:
raise NotImplementedError(
f'get wrapper address of {type(func)} instance with {sig!r}')
if not isinstance(addr, int):
raise TypeError(
f'wrapper address must be integer, got {type(addr)} instance')
if addr <= 0 and addr != -1:
raise ValueError(f'wrapper address of {type(func)} instance must be'
f' a positive integer but got {addr} [sig={sig}]')
# print(f'_get_wrapper_address[{func}]({sig=}) -> {addr}')
return addr
def _get_jit_address(func, sig):
"""Similar to ``_get_wrapper_address()`` but get the `.jit_addr` instead.
"""
if isinstance(func, Dispatcher):
cres = func.get_compile_result(sig)
jit_name = cres.fndesc.llvm_func_name
addr = cres.library.get_pointer_to_function(jit_name)
else:
addr = 0
if not isinstance(addr, int):
raise TypeError(
f'jit address must be integer, got {type(addr)} instance')
return addr
def _lower_get_address(context, builder, func, sig, failure_mode,
*, function_name):
"""Low-level call to <function_name>(func, sig).
When calling this function, GIL must be acquired.
"""
pyapi = context.get_python_api(builder)
# Get the cfunc wrapper address. The code below trusts that the
# function numba.function._get_wrapper_address exists and can be
# called with two arguments. However, if an exception is raised in
# the function, then it will be caught and propagated to the
# caller.
modname = context.insert_const_string(builder.module, __name__)
numba_mod = pyapi.import_module(modname)
numba_func = pyapi.object_getattr_string(numba_mod, function_name)
pyapi.decref(numba_mod)
sig_obj = pyapi.unserialize(pyapi.serialize_object(sig))
addr = pyapi.call_function_objargs(numba_func, (func, sig_obj))
if failure_mode != 'ignore':
with builder.if_then(cgutils.is_null(builder, addr), likely=False):
# *function_name* has raised an exception, propagate it
# to the caller.
if failure_mode == 'return_exc':
context.call_conv.return_exc(builder)
elif failure_mode == 'return_null':
builder.ret(pyapi.get_null_object())
else:
raise NotImplementedError(failure_mode)
# else the caller will handle addr == NULL
return addr # new reference or NULL
lower_get_wrapper_address = partial(
_lower_get_address,
function_name="_get_wrapper_address",
)
lower_get_jit_address = partial(
_lower_get_address,
function_name="_get_jit_address",
)
@unbox(FunctionType)
def unbox_function_type(typ, obj, c):
typ = typ.get_precise()
sfunc = cgutils.create_struct_proxy(typ)(c.context, c.builder)
addr = lower_get_wrapper_address(
c.context, c.builder, obj, typ.signature, failure_mode='return_null')
sfunc.c_addr = c.pyapi.long_as_voidptr(addr)
c.pyapi.decref(addr)
llty = c.context.get_value_type(types.voidptr)
sfunc.py_addr = c.builder.ptrtoint(obj, llty)
addr = lower_get_jit_address(
c.context, c.builder, obj, typ.signature, failure_mode='return_null')
sfunc.jit_addr = c.pyapi.long_as_voidptr(addr)
c.pyapi.decref(addr)
return NativeValue(sfunc._getvalue())
@box(FunctionType)
def box_function_type(typ, val, c):
typ = typ.get_precise()
sfunc = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
pyaddr_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
raw_ptr = c.builder.inttoptr(sfunc.py_addr, c.pyapi.pyobj)
with c.builder.if_then(cgutils.is_null(c.builder, raw_ptr),
likely=False):
cstr = f"first-class function {typ} parent object not set"
c.pyapi.err_set_string("PyExc_MemoryError", cstr)
c.builder.ret(c.pyapi.get_null_object())
c.builder.store(raw_ptr, pyaddr_ptr)
cfunc = c.builder.load(pyaddr_ptr)
c.pyapi.incref(cfunc)
return cfunc
@lower_cast(UndefinedFunctionType, FunctionType)
def lower_cast_function_type_to_function_type(
context, builder, fromty, toty, val):
return val
@lower_cast(types.Dispatcher, FunctionType)
def lower_cast_dispatcher_to_function_type(context, builder, fromty, toty, val):
toty = toty.get_precise()
sig = toty.signature
dispatcher = fromty.dispatcher
llvoidptr = context.get_value_type(types.voidptr)
sfunc = cgutils.create_struct_proxy(toty)(context, builder)
# Always store the python function
sfunc.py_addr = builder.ptrtoint(val, llvoidptr)
# Attempt to compile the Dispatcher to the expected function type
try:
cres = dispatcher.get_compile_result(sig)
except errors.NumbaError:
cres = None
# If compilation is successful, we can by-pass using GIL to get the cfunc
if cres is not None:
# Declare cfunc in the current module
wrapper_name = cres.fndesc.llvm_cfunc_wrapper_name
llfnptr = context.get_value_type(toty.ftype)
llfnty = llfnptr.pointee
fn = cgutils.get_or_insert_function(
builder.module, llfnty, wrapper_name,
)
addr = builder.bitcast(fn, llvoidptr)
# Store the cfunc
sfunc.c_addr = addr
# Store the jit func
fn = context.declare_function(builder.module, cres.fndesc)
sfunc.jit_addr = builder.bitcast(fn, llvoidptr)
# Link-in the dispatcher library
context.active_code_library.add_linking_library(cres.library)
else:
# Use lower_get_wrapper_address() to get the cfunc
lower_get_wrapper_address
pyapi = context.get_python_api(builder)
gil_state = pyapi.gil_ensure()
addr = lower_get_wrapper_address(
context, builder, val, toty.signature,
failure_mode='return_exc')
sfunc.c_addr = pyapi.long_as_voidptr(addr)
pyapi.decref(addr)
pyapi.gil_release(gil_state)
return sfunc._getvalue()

View File

@@ -0,0 +1,3 @@
from numba.experimental.jitclass.decorators import jitclass
from numba.experimental.jitclass import boxing # Has import-time side effect
from numba.experimental.jitclass import overloads # Has import-time side effect

View File

@@ -0,0 +1,598 @@
import inspect
import operator
import types as pytypes
import typing as pt
from collections import OrderedDict
from collections.abc import Sequence
from llvmlite import ir as llvmir
from numba import njit
from numba.core import cgutils, errors, imputils, types, utils
from numba.core.datamodel import default_manager, models
from numba.core.registry import cpu_target
from numba.core.typing import templates
from numba.core.typing.asnumbatype import as_numba_type
from numba.core.serialize import disable_pickling
from numba.experimental.jitclass import _box
##############################################################################
# Data model
class InstanceModel(models.StructModel):
def __init__(self, dmm, fe_typ):
cls_data_ty = types.ClassDataType(fe_typ)
# MemInfoPointer uses the `dtype` attribute to traverse for nested
# NRT MemInfo. Since we handle nested NRT MemInfo ourselves,
# we will replace provide MemInfoPointer with an opaque type
# so that it does not raise exception for nested meminfo.
dtype = types.Opaque('Opaque.' + str(cls_data_ty))
members = [
('meminfo', types.MemInfoPointer(dtype)),
('data', types.CPointer(cls_data_ty)),
]
super(InstanceModel, self).__init__(dmm, fe_typ, members)
class InstanceDataModel(models.StructModel):
def __init__(self, dmm, fe_typ):
clsty = fe_typ.class_type
members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
default_manager.register(types.ClassInstanceType, InstanceModel)
default_manager.register(types.ClassDataType, InstanceDataModel)
default_manager.register(types.ClassType, models.OpaqueModel)
def _mangle_attr(name):
"""
Mangle attributes.
The resulting name does not startswith an underscore '_'.
"""
return 'm_' + name
##############################################################################
# Class object
_ctor_template = """
def ctor({args}):
return __numba_cls_({args})
"""
def _getargs(fn_sig):
"""
Returns list of positional and keyword argument names in order.
"""
params = fn_sig.parameters
args = []
for k, v in params.items():
if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD:
args.append(k)
else:
msg = "%s argument type unsupported in jitclass" % v.kind
raise errors.UnsupportedError(msg)
return args
@disable_pickling
class JitClassType(type):
"""
The type of any jitclass.
"""
def __new__(cls, name, bases, dct):
if len(bases) != 1:
raise TypeError("must have exactly one base class")
[base] = bases
if isinstance(base, JitClassType):
raise TypeError("cannot subclass from a jitclass")
assert 'class_type' in dct, 'missing "class_type" attr'
outcls = type.__new__(cls, name, bases, dct)
outcls._set_init()
return outcls
def _set_init(cls):
"""
Generate a wrapper for calling the constructor from pure Python.
Note the wrapper will only accept positional arguments.
"""
init = cls.class_type.instance_type.methods['__init__']
init_sig = utils.pysignature(init)
# get postitional and keyword arguments
# offset by one to exclude the `self` arg
args = _getargs(init_sig)[1:]
cls._ctor_sig = init_sig
ctor_source = _ctor_template.format(args=', '.join(args))
glbls = {"__numba_cls_": cls}
exec(ctor_source, glbls)
ctor = glbls['ctor']
cls._ctor = njit(ctor)
def __instancecheck__(cls, instance):
if isinstance(instance, _box.Box):
return instance._numba_type_.class_type is cls.class_type
return False
def __call__(cls, *args, **kwargs):
# The first argument of _ctor_sig is `cls`, which here
# is bound to None and then skipped when invoking the constructor.
bind = cls._ctor_sig.bind(None, *args, **kwargs)
bind.apply_defaults()
return cls._ctor(*bind.args[1:], **bind.kwargs)
##############################################################################
# Registration utils
def _validate_spec(spec):
for k, v in spec.items():
if not isinstance(k, str):
raise TypeError("spec keys should be strings, got %r" % (k,))
if not isinstance(v, types.Type):
raise TypeError("spec values should be Numba type instances, got %r"
% (v,))
def _fix_up_private_attr(clsname, spec):
"""
Apply the same changes to dunder names as CPython would.
"""
out = OrderedDict()
for k, v in spec.items():
if k.startswith('__') and not k.endswith('__'):
k = '_' + clsname + k
out[k] = v
return out
def _add_linking_libs(context, call):
"""
Add the required libs for the callable to allow inlining.
"""
libs = getattr(call, "libs", ())
if libs:
context.add_linking_libs(libs)
def register_class_type(cls, spec, class_ctor, builder):
"""
Internal function to create a jitclass.
Args
----
cls: the original class object (used as the prototype)
spec: the structural specification contains the field types.
class_ctor: the numba type to represent the jitclass
builder: the internal jitclass builder
"""
# Normalize spec
if spec is None:
spec = OrderedDict()
elif isinstance(spec, Sequence):
spec = OrderedDict(spec)
# Extend spec with class annotations.
for attr, py_type in pt.get_type_hints(cls).items():
if attr not in spec:
spec[attr] = as_numba_type(py_type)
_validate_spec(spec)
# Fix up private attribute names
spec = _fix_up_private_attr(cls.__name__, spec)
# Copy methods from base classes
clsdct = {}
for basecls in reversed(inspect.getmro(cls)):
clsdct.update(basecls.__dict__)
methods, props, static_methods, others = {}, {}, {}, {}
for k, v in clsdct.items():
if isinstance(v, pytypes.FunctionType):
methods[k] = v
elif isinstance(v, property):
props[k] = v
elif isinstance(v, staticmethod):
static_methods[k] = v
else:
others[k] = v
# Check for name shadowing
shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec)
if shadowed:
raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
docstring = others.pop('__doc__', "")
_drop_ignored_attrs(others)
if others:
msg = "class members are not yet supported: {0}"
members = ', '.join(others.keys())
raise TypeError(msg.format(members))
for k, v in props.items():
if v.fdel is not None:
raise TypeError("deleter is not supported: {0}".format(k))
jit_methods = {k: njit(v) for k, v in methods.items()}
jit_props = {}
for k, v in props.items():
dct = {}
if v.fget:
dct['get'] = njit(v.fget)
if v.fset:
dct['set'] = njit(v.fset)
jit_props[k] = dct
jit_static_methods = {
k: njit(v.__func__) for k, v in static_methods.items()}
# Instantiate class type
class_type = class_ctor(
cls,
ConstructorTemplate,
spec,
jit_methods,
jit_props,
jit_static_methods)
jit_class_dct = dict(class_type=class_type, __doc__=docstring)
jit_class_dct.update(jit_static_methods)
cls = JitClassType(cls.__name__, (cls,), jit_class_dct)
# Register resolution of the class object
typingctx = cpu_target.typing_context
typingctx.insert_global(cls, class_type)
# Register class
targetctx = cpu_target.target_context
builder(class_type, typingctx, targetctx).register()
as_numba_type.register(cls, class_type.instance_type)
return cls
class ConstructorTemplate(templates.AbstractTemplate):
"""
Base class for jitclass constructor templates.
"""
def generic(self, args, kws):
# Redirect resolution to __init__
instance_type = self.key.instance_type
ctor = instance_type.jit_methods['__init__']
boundargs = (instance_type.get_reference_type(),) + args
disp_type = types.Dispatcher(ctor)
sig = disp_type.get_call_type(self.context, boundargs, kws)
if not isinstance(sig.return_type, types.NoneType):
raise errors.NumbaTypeError(
f"__init__() should return None, not '{sig.return_type}'")
# Actual constructor returns an instance value (not None)
out = templates.signature(instance_type, *sig.args[1:])
return out
def _drop_ignored_attrs(dct):
# ignore anything defined by object
drop = set(['__weakref__',
'__module__',
'__dict__'])
if utils.PYVERSION in ((3, 13), (3, 14)):
# new in python 3.13
drop |= set(['__firstlineno__', '__static_attributes__'])
for att in ('__annotations__',
'__annotate_func__',
'__annotations_cache__'):
if att in dct:
drop.add(att)
for k, v in dct.items():
if isinstance(v, (pytypes.BuiltinFunctionType,
pytypes.BuiltinMethodType)):
drop.add(k)
elif getattr(v, '__objclass__', None) is object:
drop.add(k)
# If a class defines __eq__ but not __hash__, __hash__ is implicitly set to
# None. This is a class member, and class members are not presently
# supported.
if '__hash__' in dct and dct['__hash__'] is None:
drop.add('__hash__')
for k in drop:
dct.pop(k)
class ClassBuilder(object):
"""
A jitclass builder for a mutable jitclass. This will register
typing and implementation hooks to the given typing and target contexts.
"""
class_impl_registry = imputils.Registry('jitclass builder')
implemented_methods = set()
def __init__(self, class_type, typingctx, targetctx):
self.class_type = class_type
self.typingctx = typingctx
self.targetctx = targetctx
def register(self):
"""
Register to the frontend and backend.
"""
# Register generic implementations for all jitclasses
self._register_methods(self.class_impl_registry,
self.class_type.instance_type)
# NOTE other registrations are done at the top-level
# (see ctor_impl and attr_impl below)
self.targetctx.install_registry(self.class_impl_registry)
def _register_methods(self, registry, instance_type):
"""
Register method implementations.
This simply registers that the method names are valid methods. Inside
of imp() below we retrieve the actual method to run from the type of
the receiver argument (i.e. self).
"""
to_register = list(instance_type.jit_methods) + \
list(instance_type.jit_static_methods)
for meth in to_register:
# There's no way to retrieve the particular method name
# inside the implementation function, so we have to register a
# specific closure for each different name
if meth not in self.implemented_methods:
self._implement_method(registry, meth)
self.implemented_methods.add(meth)
def _implement_method(self, registry, attr):
# create a separate instance of imp method to avoid closure clashing
def get_imp():
def imp(context, builder, sig, args):
instance_type = sig.args[0]
if attr in instance_type.jit_methods:
method = instance_type.jit_methods[attr]
elif attr in instance_type.jit_static_methods:
method = instance_type.jit_static_methods[attr]
# imp gets called as a method, where the first argument is
# self. We drop this for a static method.
sig = sig.replace(args=sig.args[1:])
args = args[1:]
disp_type = types.Dispatcher(method)
call = context.get_function(disp_type, sig)
out = call(builder, args)
_add_linking_libs(context, call)
return imputils.impl_ret_new_ref(context, builder,
sig.return_type, out)
return imp
def _getsetitem_gen(getset):
_dunder_meth = "__%s__" % getset
op = getattr(operator, getset)
@templates.infer_global(op)
class GetSetItem(templates.AbstractTemplate):
def generic(self, args, kws):
instance = args[0]
if isinstance(instance, types.ClassInstanceType) and \
_dunder_meth in instance.jit_methods:
meth = instance.jit_methods[_dunder_meth]
disp_type = types.Dispatcher(meth)
sig = disp_type.get_call_type(self.context, args, kws)
return sig
# lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
# from python and numba
imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
imputils.lower_builtin(op,
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
dunder_stripped = attr.strip('_')
if dunder_stripped in ("getitem", "setitem"):
_getsetitem_gen(dunder_stripped)
else:
registry.lower((types.ClassInstanceType, attr),
types.ClassInstanceType,
types.VarArg(types.Any))(get_imp())
@templates.infer_getattr
class ClassAttribute(templates.AttributeTemplate):
key = types.ClassInstanceType
def generic_resolve(self, instance, attr):
if attr in instance.struct:
# It's a struct field => the type is well-known
return instance.struct[attr]
elif attr in instance.jit_methods:
# It's a jitted method => typeinfer it
meth = instance.jit_methods[attr]
disp_type = types.Dispatcher(meth)
class MethodTemplate(templates.AbstractTemplate):
key = (self.key, attr)
def generic(self, args, kws):
args = (instance,) + tuple(args)
sig = disp_type.get_call_type(self.context, args, kws)
return sig.as_method()
return types.BoundFunction(MethodTemplate, instance)
elif attr in instance.jit_static_methods:
# It's a jitted method => typeinfer it
meth = instance.jit_static_methods[attr]
disp_type = types.Dispatcher(meth)
class StaticMethodTemplate(templates.AbstractTemplate):
key = (self.key, attr)
def generic(self, args, kws):
# Don't add instance as the first argument for a static
# method.
sig = disp_type.get_call_type(self.context, args, kws)
# sig itself does not include ClassInstanceType as it's
# first argument, so instead of calling sig.as_method()
# we insert the recvr. This is equivalent to
# sig.replace(args=(instance,) + sig.args).as_method().
return sig.replace(recvr=instance)
return types.BoundFunction(StaticMethodTemplate, instance)
elif attr in instance.jit_props:
# It's a jitted property => typeinfer its getter
impdct = instance.jit_props[attr]
getter = impdct['get']
disp_type = types.Dispatcher(getter)
sig = disp_type.get_call_type(self.context, (instance,), {})
return sig.return_type
@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
def get_attr_impl(context, builder, typ, value, attr):
"""
Generic getattr() for @jitclass instances.
"""
if attr in typ.struct:
# It's a struct field
inst = context.make_helper(builder, typ, value=value)
data_pointer = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_pointer)
return imputils.impl_ret_borrowed(context, builder,
typ.struct[attr],
getattr(data, _mangle_attr(attr)))
elif attr in typ.jit_props:
# It's a jitted property
getter = typ.jit_props[attr]['get']
sig = templates.signature(None, typ)
dispatcher = types.Dispatcher(getter)
sig = dispatcher.get_call_type(context.typing_context, [typ], {})
call = context.get_function(dispatcher, sig)
out = call(builder, [value])
_add_linking_libs(context, call)
return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
def set_attr_impl(context, builder, sig, args, attr):
"""
Generic setattr() for @jitclass instances.
"""
typ, valty = sig.args
target, val = args
if attr in typ.struct:
# It's a struct member
inst = context.make_helper(builder, typ, value=target)
data_ptr = inst.data
data = context.make_data_helper(builder, typ.get_data_type(),
ref=data_ptr)
# Get old value
attr_type = typ.struct[attr]
oldvalue = getattr(data, _mangle_attr(attr))
# Store n
setattr(data, _mangle_attr(attr), val)
context.nrt.incref(builder, attr_type, val)
# Delete old value
context.nrt.decref(builder, attr_type, oldvalue)
elif attr in typ.jit_props:
# It's a jitted property
setter = typ.jit_props[attr]['set']
disp_type = types.Dispatcher(setter)
sig = disp_type.get_call_type(context.typing_context,
(typ, valty), {})
call = context.get_function(disp_type, sig)
call(builder, (target, val))
_add_linking_libs(context, call)
else:
raise NotImplementedError(
'attribute {0!r} not implemented'.format(attr))
def imp_dtor(context, module, instance_type):
llvoidptr = context.get_value_type(types.voidptr)
llsize = context.get_value_type(types.uintp)
dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
[llvoidptr, llsize, llvoidptr])
fname = "_Dtor.{0}".format(instance_type.name)
dtor_fn = cgutils.get_or_insert_function(module, dtor_ftype, fname)
if dtor_fn.is_declaration:
# Define
builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
alloc_fe_type = instance_type.get_data_type()
alloc_type = context.get_value_type(alloc_fe_type)
ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
data = context.make_helper(builder, alloc_fe_type, ref=ptr)
context.nrt.decref(builder, alloc_fe_type, data._getvalue())
builder.ret_void()
return dtor_fn
@ClassBuilder.class_impl_registry.lower(types.ClassType,
types.VarArg(types.Any))
def ctor_impl(context, builder, sig, args):
"""
Generic constructor (__new__) for jitclasses.
"""
# Allocate the instance
inst_typ = sig.return_type
alloc_type = context.get_data_type(inst_typ.get_data_type())
alloc_size = context.get_abi_sizeof(alloc_type)
meminfo = context.nrt.meminfo_alloc_dtor(
builder,
context.get_constant(types.uintp, alloc_size),
imp_dtor(context, builder.module, inst_typ),
)
data_pointer = context.nrt.meminfo_data(builder, meminfo)
data_pointer = builder.bitcast(data_pointer,
alloc_type.as_pointer())
# Nullify all data
builder.store(cgutils.get_null_value(alloc_type),
data_pointer)
inst_struct = context.make_helper(builder, inst_typ)
inst_struct.meminfo = meminfo
inst_struct.data = data_pointer
# Call the jitted __init__
# TODO: extract the following into a common util
init_sig = (sig.return_type,) + sig.args
init = inst_typ.jit_methods['__init__']
disp_type = types.Dispatcher(init)
call = context.get_function(disp_type, types.void(*init_sig))
_add_linking_libs(context, call)
realargs = [inst_struct._getvalue()] + list(args)
call(builder, realargs)
# Prepare return value
ret = inst_struct._getvalue()
return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)

View File

@@ -0,0 +1,275 @@
"""
Implement logic relating to wrapping (box) and unwrapping (unbox) instances
of jitclasses for use inside the python interpreter.
"""
from functools import wraps, partial
from llvmlite import ir
from numba.core import types, cgutils
from numba.core.decorators import njit
from numba.core.pythonapi import box, unbox, NativeValue
from numba.core.typing.typeof import typeof_impl
from numba.experimental.jitclass import _box
_getter_code_template = """
def accessor(__numba_self_):
return __numba_self_.{0}
"""
_setter_code_template = """
def mutator(__numba_self_, __numba_val):
__numba_self_.{0} = __numba_val
"""
_method_code_template = """
def method(__numba_self_, *args):
return __numba_self_.{method}(*args)
"""
def _generate_property(field, template, fname):
"""
Generate simple function that get/set a field of the instance
"""
source = template.format(field)
glbls = {}
exec(source, glbls)
return njit(glbls[fname])
_generate_getter = partial(_generate_property, template=_getter_code_template,
fname='accessor')
_generate_setter = partial(_generate_property, template=_setter_code_template,
fname='mutator')
def _generate_method(name, func):
"""
Generate a wrapper for calling a method. Note the wrapper will only
accept positional arguments.
"""
source = _method_code_template.format(method=name)
glbls = {}
exec(source, glbls)
method = njit(glbls['method'])
@wraps(func)
def wrapper(*args, **kwargs):
return method(*args, **kwargs)
return wrapper
_cache_specialized_box = {}
def _specialize_box(typ):
"""
Create a subclass of Box that is specialized to the jitclass.
This function caches the result to avoid code bloat.
"""
# Check cache
if typ in _cache_specialized_box:
return _cache_specialized_box[typ]
dct = {'__slots__': (),
'_numba_type_': typ,
'__doc__': typ.class_type.class_doc,
}
# Inject attributes as class properties
for field in typ.struct:
getter = _generate_getter(field)
setter = _generate_setter(field)
dct[field] = property(getter, setter)
# Inject properties as class properties
for field, impdct in typ.jit_props.items():
getter = None
setter = None
if 'get' in impdct:
getter = _generate_getter(field)
if 'set' in impdct:
setter = _generate_setter(field)
# get docstring from either the fget or fset
imp = impdct.get('get') or impdct.get('set') or None
doc = getattr(imp, '__doc__', None)
dct[field] = property(getter, setter, doc=doc)
# Inject methods as class members
supported_dunders = {
"__abs__",
"__annotate_func__",
"__bool__",
"__complex__",
"__contains__",
"__float__",
"__getitem__",
"__hash__",
"__index__",
"__invert__",
"__int__",
"__len__",
"__setitem__",
"__str__",
"__eq__",
"__ne__",
"__ge__",
"__gt__",
"__le__",
"__lt__",
"__add__",
"__floordiv__",
"__lshift__",
"__matmul__",
"__mod__",
"__mul__",
"__neg__",
"__pos__",
"__pow__",
"__rshift__",
"__sub__",
"__truediv__",
"__and__",
"__or__",
"__xor__",
"__iadd__",
"__ifloordiv__",
"__ilshift__",
"__imatmul__",
"__imod__",
"__imul__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__iand__",
"__ior__",
"__ixor__",
"__radd__",
"__rfloordiv__",
"__rlshift__",
"__rmatmul__",
"__rmod__",
"__rmul__",
"__rpow__",
"__rrshift__",
"__rsub__",
"__rtruediv__",
"__rand__",
"__ror__",
"__rxor__",
}
for name, func in typ.methods.items():
if name == "__init__":
continue
if (
name.startswith("__")
and name.endswith("__")
and name not in supported_dunders
):
raise TypeError(f"Method '{name}' is not supported.")
dct[name] = _generate_method(name, func)
# Inject static methods as class members
for name, func in typ.static_methods.items():
dct[name] = _generate_method(name, func)
# Create subclass
subcls = type(typ.classname, (_box.Box,), dct)
# Store to cache
_cache_specialized_box[typ] = subcls
# Pre-compile attribute getter.
# Note: This must be done after the "box" class is created because
# compiling the getter requires the "box" class to be defined.
for k, v in dct.items():
if isinstance(v, property):
prop = getattr(subcls, k)
if prop.fget is not None:
fget = prop.fget
fast_fget = fget.compile((typ,))
fget.disable_compile()
setattr(subcls, k,
property(fast_fget, prop.fset, prop.fdel,
doc=prop.__doc__))
return subcls
###############################################################################
# Implement box/unbox for call wrapper
@box(types.ClassInstanceType)
def _box_class_instance(typ, val, c):
meminfo, dataptr = cgutils.unpack_tuple(c.builder, val)
# Create Box instance
box_subclassed = _specialize_box(typ)
# Note: the ``box_subclassed`` is kept alive by the cache
voidptr_boxcls = c.context.add_dynamic_addr(
c.builder,
id(box_subclassed),
info="box_class_instance",
)
box_cls = c.builder.bitcast(voidptr_boxcls, c.pyapi.pyobj)
box = c.pyapi.call_function_objargs(box_cls, ())
# Initialize Box instance
llvoidptr = ir.IntType(8).as_pointer()
addr_meminfo = c.builder.bitcast(meminfo, llvoidptr)
addr_data = c.builder.bitcast(dataptr, llvoidptr)
def set_member(member_offset, value):
# Access member by byte offset
offset = c.context.get_constant(types.uintp, member_offset)
ptr = cgutils.pointer_add(c.builder, box, offset)
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
c.builder.store(value, casted)
set_member(_box.box_meminfoptr_offset, addr_meminfo)
set_member(_box.box_dataptr_offset, addr_data)
return box
@unbox(types.ClassInstanceType)
def _unbox_class_instance(typ, val, c):
def access_member(member_offset):
# Access member by byte offset
offset = c.context.get_constant(types.uintp, member_offset)
llvoidptr = ir.IntType(8).as_pointer()
ptr = cgutils.pointer_add(c.builder, val, offset)
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
return c.builder.load(casted)
struct_cls = cgutils.create_struct_proxy(typ)
inst = struct_cls(c.context, c.builder)
# load from Python object
ptr_meminfo = access_member(_box.box_meminfoptr_offset)
ptr_dataptr = access_member(_box.box_dataptr_offset)
# store to native structure
inst.meminfo = c.builder.bitcast(ptr_meminfo, inst.meminfo.type)
inst.data = c.builder.bitcast(ptr_dataptr, inst.data.type)
ret = inst._getvalue()
c.context.nrt.incref(c.builder, typ, ret)
return NativeValue(ret, is_error=c.pyapi.c_api_error())
# Add a typeof_impl implementation for boxed jitclasses to short-circut the
# various tests in typeof. This is needed for jitclasses which implement a
# custom hash method. Without this, typeof_impl will return None, and one of the
# later attempts to determine the type of the jitclass (before checking for
# _numba_type_) will look up the object in a dictionary, triggering the hash
# method. This will cause the dispatcher to determine the call signature of the
# jit decorated obj.__hash__ method, which will call typeof(obj), and thus
# infinite loop.
# This implementation is here instead of in typeof.py to avoid circular imports.
@typeof_impl.register(_box.Box)
def _typeof_jitclass_box(val, c):
return getattr(type(val), "_numba_type_")

View File

@@ -0,0 +1,88 @@
from numba.core import types, config
def jitclass(cls_or_spec=None, spec=None):
"""
A function for creating a jitclass.
Can be used as a decorator or function.
Different use cases will cause different arguments to be set.
If specified, ``spec`` gives the types of class fields.
It must be a dictionary or sequence.
With a dictionary, use collections.OrderedDict for stable ordering.
With a sequence, it must contain 2-tuples of (fieldname, fieldtype).
Any class annotations for field names not listed in spec will be added.
For class annotation `x: T` we will append ``("x", as_numba_type(T))`` to
the spec if ``x`` is not already a key in spec.
Examples
--------
1) ``cls_or_spec = None``, ``spec = None``
>>> @jitclass()
... class Foo:
... ...
2) ``cls_or_spec = None``, ``spec = spec``
>>> @jitclass(spec=spec)
... class Foo:
... ...
3) ``cls_or_spec = Foo``, ``spec = None``
>>> @jitclass
... class Foo:
... ...
4) ``cls_or_spec = spec``, ``spec = None``
In this case we update ``cls_or_spec, spec = None, cls_or_spec``.
>>> @jitclass(spec)
... class Foo:
... ...
5) ``cls_or_spec = Foo``, ``spec = spec``
>>> JitFoo = jitclass(Foo, spec)
Returns
-------
If used as a decorator, returns a callable that takes a class object and
returns a compiled version.
If used as a function, returns the compiled class (an instance of
``JitClassType``).
"""
if (cls_or_spec is not None and
spec is None and
not isinstance(cls_or_spec, type)):
# Used like
# @jitclass([("x", intp)])
# class Foo:
# ...
spec = cls_or_spec
cls_or_spec = None
def wrap(cls):
if config.DISABLE_JIT:
return cls
else:
from numba.experimental.jitclass.base import (register_class_type,
ClassBuilder)
cls_jitted = register_class_type(cls, spec, types.ClassType,
ClassBuilder)
# Preserve the module name of the original class
cls_jitted.__module__ = cls.__module__
return cls_jitted
if cls_or_spec is None:
return wrap
else:
return wrap(cls_or_spec)

View File

@@ -0,0 +1,239 @@
"""
Overloads for ClassInstanceType for built-in functions that call dunder methods
on an object.
"""
from functools import wraps
import inspect
import operator
from numba.core.extending import overload
from numba.core.types import ClassInstanceType
def _get_args(n_args):
assert n_args in (1, 2)
return list("xy")[:n_args]
def class_instance_overload(target):
"""
Decorator to add an overload for target that applies when the first argument
is a ClassInstanceType.
"""
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
if not isinstance(args[0], ClassInstanceType):
return
return func(*args, **kwargs)
if target is not complex:
# complex ctor needs special treatment as it uses kwargs
params = list(inspect.signature(wrapped).parameters)
assert params == _get_args(len(params))
return overload(target)(wrapped)
return decorator
def extract_template(template, name):
"""
Extract a code-generated function from a string template.
"""
namespace = {}
exec(template, namespace)
return namespace[name]
def register_simple_overload(func, *attrs, n_args=1,):
"""
Register an overload for func that checks for methods __attr__ for each
attr in attrs.
"""
# Use a template to set the signature correctly.
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
pass
"""
@wraps(extract_template(template, "func"))
def overload_func(*args, **kwargs):
options = [
try_call_method(args[0], f"__{attr}__", n_args)
for attr in attrs
]
return take_first(*options)
return class_instance_overload(func)(overload_func)
def try_call_method(cls_type, method, n_args=1):
"""
If method is defined for cls_type, return a callable that calls this method.
If not, return None.
"""
if method in cls_type.jit_methods:
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
return {arg_names[0]}.{method}({','.join(arg_names[1:])})
"""
return extract_template(template, "func")
def try_call_complex_method(cls_type, method):
""" __complex__ needs special treatment as the argument names are kwargs
and therefore specific in name and default value.
"""
if method in cls_type.jit_methods:
template = f"""
def func(real=0, imag=0):
return real.{method}()
"""
return extract_template(template, "func")
def take_first(*options):
"""
Take the first non-None option.
"""
assert all(o is None or inspect.isfunction(o) for o in options), options
for o in options:
if o is not None:
return o
@class_instance_overload(bool)
def class_bool(x):
using_bool_impl = try_call_method(x, "__bool__")
if '__len__' in x.jit_methods:
def using_len_impl(x):
return bool(len(x))
else:
using_len_impl = None
always_true_impl = lambda x: True
return take_first(using_bool_impl, using_len_impl, always_true_impl)
@class_instance_overload(complex)
def class_complex(real=0, imag=0):
return take_first(
try_call_complex_method(real, "__complex__"),
lambda real=0, imag=0: complex(float(real))
)
@class_instance_overload(operator.contains)
def class_contains(x, y):
# https://docs.python.org/3/reference/expressions.html#membership-test-operations
return try_call_method(x, "__contains__", 2)
# TODO: use __iter__ if defined.
@class_instance_overload(float)
def class_float(x):
options = [try_call_method(x, "__float__")]
if (
"__index__" in x.jit_methods
):
options.append(lambda x: float(x.__index__()))
return take_first(*options)
@class_instance_overload(int)
def class_int(x):
options = [try_call_method(x, "__int__")]
options.append(try_call_method(x, "__index__"))
return take_first(*options)
@class_instance_overload(str)
def class_str(x):
return take_first(
try_call_method(x, "__str__"),
lambda x: repr(x),
)
@class_instance_overload(operator.ne)
def class_ne(x, y):
# This doesn't use register_reflected_overload like the other operators
# because it falls back to inverting __eq__ rather than reflecting its
# arguments (as per the definition of the Python data model).
return take_first(
try_call_method(x, "__ne__", 2),
lambda x, y: not (x == y),
)
def register_reflected_overload(func, meth_forward, meth_reflected):
def class_lt(x, y):
normal_impl = try_call_method(x, f"__{meth_forward}__", 2)
if f"__{meth_reflected}__" in y.jit_methods:
def reflected_impl(x, y):
return y > x
else:
reflected_impl = None
return take_first(normal_impl, reflected_impl)
class_instance_overload(func)(class_lt)
register_simple_overload(abs, "abs")
register_simple_overload(len, "len")
register_simple_overload(hash, "hash")
# Comparison operators.
register_reflected_overload(operator.ge, "ge", "le")
register_reflected_overload(operator.gt, "gt", "lt")
register_reflected_overload(operator.le, "le", "ge")
register_reflected_overload(operator.lt, "lt", "gt")
# Note that eq is missing support for fallback to `x is y`, but `is` and
# `operator.is` are presently unsupported in general.
register_reflected_overload(operator.eq, "eq", "eq")
# Arithmetic operators.
register_simple_overload(operator.add, "add", n_args=2)
register_simple_overload(operator.floordiv, "floordiv", n_args=2)
register_simple_overload(operator.lshift, "lshift", n_args=2)
register_simple_overload(operator.mul, "mul", n_args=2)
register_simple_overload(operator.mod, "mod", n_args=2)
register_simple_overload(operator.neg, "neg")
register_simple_overload(operator.pos, "pos")
register_simple_overload(operator.invert, "invert")
register_simple_overload(operator.pow, "pow", n_args=2)
register_simple_overload(operator.rshift, "rshift", n_args=2)
register_simple_overload(operator.sub, "sub", n_args=2)
register_simple_overload(operator.truediv, "truediv", n_args=2)
# Inplace arithmetic operators.
register_simple_overload(operator.iadd, "iadd", "add", n_args=2)
register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2)
register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2)
register_simple_overload(operator.imul, "imul", "mul", n_args=2)
register_simple_overload(operator.imod, "imod", "mod", n_args=2)
register_simple_overload(operator.ipow, "ipow", "pow", n_args=2)
register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2)
register_simple_overload(operator.isub, "isub", "sub", n_args=2)
register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2)
# Logical operators.
register_simple_overload(operator.and_, "and", n_args=2)
register_simple_overload(operator.or_, "or", n_args=2)
register_simple_overload(operator.xor, "xor", n_args=2)
# Inplace logical operators.
register_simple_overload(operator.iand, "iand", "and", n_args=2)
register_simple_overload(operator.ior, "ior", "or", n_args=2)
register_simple_overload(operator.ixor, "ixor", "xor", n_args=2)

View File

@@ -0,0 +1,400 @@
"""Utilities for defining a mutable struct.
A mutable struct is passed by reference;
hence, structref (a reference to a struct).
"""
import operator
from numba.core.cgutils import create_struct_proxy
from numba import njit
from numba.core import types, imputils, cgutils
from numba.core.datamodel import default_manager, models
from numba.core.extending import (
infer_getattr,
lower_getattr_generic,
lower_setattr_generic,
lower_builtin,
box,
unbox,
NativeValue,
intrinsic,
overload,
)
from numba.core.typing.templates import AttributeTemplate
class _Utils:
"""Internal builder-code utils for structref definitions.
"""
def __init__(self, context, builder, struct_type):
"""
Parameters
----------
context :
a numba target context
builder :
a llvmlite IRBuilder
struct_type : numba.core.types.StructRef
"""
self.context = context
self.builder = builder
self.struct_type = struct_type
def new_struct_ref(self, mi):
"""Encapsulate the MemInfo from a `StructRefPayload` in a `StructRef`
"""
context = self.context
builder = self.builder
struct_type = self.struct_type
st = cgutils.create_struct_proxy(struct_type)(context, builder)
st.meminfo = mi
return st
def get_struct_ref(self, val):
"""Return a helper for accessing a StructRefType
"""
context = self.context
builder = self.builder
struct_type = self.struct_type
return cgutils.create_struct_proxy(struct_type)(
context, builder, value=val
)
def get_data_pointer(self, val):
"""Get the data pointer to the payload from a `StructRefType`.
"""
context = self.context
builder = self.builder
struct_type = self.struct_type
structval = self.get_struct_ref(val)
meminfo = structval.meminfo
data_ptr = context.nrt.meminfo_data(builder, meminfo)
valtype = struct_type.get_data_type()
model = context.data_model_manager[valtype]
alloc_type = model.get_value_type()
data_ptr = builder.bitcast(data_ptr, alloc_type.as_pointer())
return data_ptr
def get_data_struct(self, val):
"""Get a getter/setter helper for accessing a `StructRefPayload`
"""
context = self.context
builder = self.builder
struct_type = self.struct_type
data_ptr = self.get_data_pointer(val)
valtype = struct_type.get_data_type()
dataval = cgutils.create_struct_proxy(valtype)(
context, builder, ref=data_ptr
)
return dataval
def define_attributes(struct_typeclass):
"""Define attributes on `struct_typeclass`.
Defines both setters and getters in jit-code.
This is called directly in `register()`.
"""
@infer_getattr
class StructAttribute(AttributeTemplate):
key = struct_typeclass
def generic_resolve(self, typ, attr):
if attr in typ.field_dict:
attrty = typ.field_dict[attr]
return attrty
@lower_getattr_generic(struct_typeclass)
def struct_getattr_impl(context, builder, typ, val, attr):
utils = _Utils(context, builder, typ)
dataval = utils.get_data_struct(val)
ret = getattr(dataval, attr)
fieldtype = typ.field_dict[attr]
return imputils.impl_ret_borrowed(context, builder, fieldtype, ret)
@lower_setattr_generic(struct_typeclass)
def struct_setattr_impl(context, builder, sig, args, attr):
[inst_type, val_type] = sig.args
[instance, val] = args
utils = _Utils(context, builder, inst_type)
dataval = utils.get_data_struct(instance)
# cast val to the correct type
field_type = inst_type.field_dict[attr]
casted = context.cast(builder, val, val_type, field_type)
# read old
old_value = getattr(dataval, attr)
# incref new value
context.nrt.incref(builder, val_type, casted)
# decref old value (must be last in case new value is old value)
context.nrt.decref(builder, val_type, old_value)
# write new
setattr(dataval, attr, casted)
def define_boxing(struct_type, obj_class):
"""Define the boxing & unboxing logic for `struct_type` to `obj_class`.
Defines both boxing and unboxing.
- boxing turns an instance of `struct_type` into a PyObject of `obj_class`
- unboxing turns an instance of `obj_class` into an instance of
`struct_type` in jit-code.
Use this directly instead of `define_proxy()` when the user does not
want any constructor to be defined.
"""
if struct_type is types.StructRef:
raise ValueError(f"cannot register {types.StructRef}")
obj_ctor = obj_class._numba_box_
@box(struct_type)
def box_struct_ref(typ, val, c):
"""
Convert a raw pointer to a Python int.
"""
utils = _Utils(c.context, c.builder, typ)
struct_ref = utils.get_struct_ref(val)
meminfo = struct_ref.meminfo
mip_type = types.MemInfoPointer(types.voidptr)
boxed_meminfo = c.box(mip_type, meminfo)
ctor_pyfunc = c.pyapi.unserialize(c.pyapi.serialize_object(obj_ctor))
ty_pyobj = c.pyapi.unserialize(c.pyapi.serialize_object(typ))
res = c.pyapi.call_function_objargs(
ctor_pyfunc, [ty_pyobj, boxed_meminfo],
)
c.pyapi.decref(ctor_pyfunc)
c.pyapi.decref(ty_pyobj)
c.pyapi.decref(boxed_meminfo)
return res
@unbox(struct_type)
def unbox_struct_ref(typ, obj, c):
mi_obj = c.pyapi.object_getattr_string(obj, "_meminfo")
mip_type = types.MemInfoPointer(types.voidptr)
mi = c.unbox(mip_type, mi_obj).value
utils = _Utils(c.context, c.builder, typ)
struct_ref = utils.new_struct_ref(mi)
out = struct_ref._getvalue()
c.pyapi.decref(mi_obj)
return NativeValue(out)
def define_constructor(py_class, struct_typeclass, fields):
"""Define the jit-code constructor for `struct_typeclass` using the
Python type `py_class` and the required `fields`.
Use this instead of `define_proxy()` if the user does not want boxing
logic defined.
"""
# Build source code for the constructor
params = ', '.join(fields)
indent = ' ' * 8
init_fields_buf = []
for k in fields:
init_fields_buf.append(f"st.{k} = {k}")
init_fields = f'\n{indent}'.join(init_fields_buf)
source = f"""
def ctor({params}):
struct_type = struct_typeclass(list(zip({list(fields)}, [{params}])))
def impl({params}):
st = new(struct_type)
{init_fields}
return st
return impl
"""
glbs = dict(struct_typeclass=struct_typeclass, new=new)
exec(source, glbs)
ctor = glbs['ctor']
# Make it an overload
overload(py_class)(ctor)
def define_proxy(py_class, struct_typeclass, fields):
"""Defines a PyObject proxy for a structref.
This makes `py_class` a valid constructor for creating a instance of
`struct_typeclass` that contains the members as defined by `fields`.
Parameters
----------
py_class : type
The Python class for constructing an instance of `struct_typeclass`.
struct_typeclass : numba.core.types.Type
The structref type class to bind to.
fields : Sequence[str]
A sequence of field names.
Returns
-------
None
"""
define_constructor(py_class, struct_typeclass, fields)
define_boxing(struct_typeclass, py_class)
def register(struct_type):
"""Register a `numba.core.types.StructRef` for use in jit-code.
This defines the data-model for lowering an instance of `struct_type`.
This defines attributes accessor and mutator for an instance of
`struct_type`.
Parameters
----------
struct_type : type
A subclass of `numba.core.types.StructRef`.
Returns
-------
struct_type : type
Returns the input argument so this can act like a decorator.
Examples
--------
.. code-block::
class MyStruct(numba.core.types.StructRef):
... # the simplest subclass can be empty
numba.experimental.structref.register(MyStruct)
"""
if struct_type is types.StructRef:
raise ValueError(f"cannot register {types.StructRef}")
default_manager.register(struct_type, models.StructRefModel)
define_attributes(struct_type)
return struct_type
@intrinsic
def new(typingctx, struct_type):
"""new(struct_type)
A jit-code only intrinsic. Used to allocate an **empty** mutable struct.
The fields are zero-initialized and must be set manually after calling
the function.
Example:
instance = new(MyStruct)
instance.field = field_value
"""
from numba.experimental.jitclass.base import imp_dtor
inst_type = struct_type.instance_type
def codegen(context, builder, signature, args):
# FIXME: mostly the same as jitclass ctor_impl()
model = context.data_model_manager[inst_type.get_data_type()]
alloc_type = model.get_value_type()
alloc_size = context.get_abi_sizeof(alloc_type)
meminfo = context.nrt.meminfo_alloc_dtor(
builder,
context.get_constant(types.uintp, alloc_size),
imp_dtor(context, builder.module, inst_type),
)
data_pointer = context.nrt.meminfo_data(builder, meminfo)
data_pointer = builder.bitcast(data_pointer, alloc_type.as_pointer())
# Nullify all data
builder.store(cgutils.get_null_value(alloc_type), data_pointer)
inst_struct = context.make_helper(builder, inst_type)
inst_struct.meminfo = meminfo
return inst_struct._getvalue()
sig = inst_type(struct_type)
return sig, codegen
class StructRefProxy:
"""A PyObject proxy to the Numba allocated structref data structure.
Notes
-----
* Subclasses should not define ``__init__``.
* Subclasses can override ``__new__``.
"""
__slots__ = ('_type', '_meminfo')
@classmethod
def _numba_box_(cls, ty, mi):
"""Called by boxing logic, the conversion of Numba internal
representation into a PyObject.
Parameters
----------
ty :
a Numba type instance.
mi :
a wrapped MemInfoPointer.
Returns
-------
instance :
a StructRefProxy instance.
"""
instance = super().__new__(cls)
instance._type = ty
instance._meminfo = mi
return instance
def __new__(cls, *args):
"""Construct a new instance of the structref.
This takes positional-arguments only due to limitation of the compiler.
The arguments are mapped to ``cls(*args)`` in jit-code.
"""
try:
# use cached ctor if available
ctor = cls.__numba_ctor
except AttributeError:
# lazily define the ctor
@njit
def ctor(*args):
return cls(*args)
# cache it to attribute to avoid recompilation
cls.__numba_ctor = ctor
return ctor(*args)
@property
def _numba_type_(self):
"""Returns the Numba type instance for this structref instance.
Subclasses should NOT override.
"""
return self._type
@lower_builtin(operator.is_, types.StructRef, types.StructRef)
def structref_is(context, builder, sig, args):
"""
Define the 'is' operator for structrefs by comparing the memory addresses.
This is the identity check for structref objects.
"""
a, b = args
aty, bty = sig.args
a_ptr = create_struct_proxy(aty)(context, builder, value=a).meminfo
b_ptr = create_struct_proxy(bty)(context, builder, value=b).meminfo
return builder.icmp_unsigned("==", a_ptr, b_ptr)