Videre
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from .jitclass import jitclass
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,332 @@
|
||||
"""Provides Numba type, FunctionType, that makes functions as
|
||||
instances of a first-class function type.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
from numba.extending import typeof_impl
|
||||
from numba.extending import models, register_model
|
||||
from numba.extending import unbox, NativeValue, box
|
||||
from numba.core.imputils import lower_constant, lower_cast
|
||||
from numba.core.ccallback import CFunc
|
||||
from numba.core import cgutils
|
||||
from llvmlite import ir
|
||||
from numba.core import types, errors
|
||||
from numba.core.types import (FunctionType, UndefinedFunctionType,
|
||||
FunctionPrototype, WrapperAddressProtocol)
|
||||
from numba.core.dispatcher import Dispatcher
|
||||
|
||||
|
||||
@typeof_impl.register(WrapperAddressProtocol)
|
||||
@typeof_impl.register(CFunc)
|
||||
def typeof_function_type(val, c):
|
||||
if isinstance(val, CFunc):
|
||||
sig = val._sig
|
||||
elif isinstance(val, WrapperAddressProtocol):
|
||||
sig = val.signature()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'function type from {type(val).__name__}')
|
||||
return FunctionType(sig)
|
||||
|
||||
|
||||
@register_model(FunctionPrototype)
|
||||
class FunctionProtoModel(models.PrimitiveModel):
|
||||
"""FunctionProtoModel describes the signatures of first-class functions
|
||||
"""
|
||||
def __init__(self, dmm, fe_type):
|
||||
if isinstance(fe_type, FunctionType):
|
||||
ftype = fe_type.ftype
|
||||
elif isinstance(fe_type, FunctionPrototype):
|
||||
ftype = fe_type
|
||||
else:
|
||||
raise NotImplementedError((type(fe_type)))
|
||||
retty = dmm.lookup(ftype.rtype).get_value_type()
|
||||
args = [dmm.lookup(t).get_value_type() for t in ftype.atypes]
|
||||
be_type = ir.PointerType(ir.FunctionType(retty, args))
|
||||
super(FunctionProtoModel, self).__init__(dmm, fe_type, be_type)
|
||||
|
||||
|
||||
@register_model(FunctionType)
|
||||
@register_model(UndefinedFunctionType)
|
||||
class FunctionModel(models.StructModel):
|
||||
"""FunctionModel holds addresses of function implementations
|
||||
"""
|
||||
def __init__(self, dmm, fe_type):
|
||||
members = [
|
||||
# Address of cfunc wrapper function.
|
||||
# This uses a C callconv and doesn't not support exceptions.
|
||||
('c_addr', types.voidptr),
|
||||
# Address of PyObject* referencing the Python function
|
||||
# object:
|
||||
('py_addr', types.voidptr),
|
||||
# Address of the underlying function object.
|
||||
# Calling through this function pointer supports all features of
|
||||
# regular numba function as it follows the same Numba callconv.
|
||||
('jit_addr', types.voidptr),
|
||||
]
|
||||
super(FunctionModel, self).__init__(dmm, fe_type, members)
|
||||
|
||||
|
||||
@lower_constant(types.Dispatcher)
|
||||
def lower_constant_dispatcher(context, builder, typ, pyval):
|
||||
return context.add_dynamic_addr(builder, id(pyval),
|
||||
info=type(pyval).__name__)
|
||||
|
||||
|
||||
@lower_constant(FunctionType)
|
||||
def lower_constant_function_type(context, builder, typ, pyval):
|
||||
typ = typ.get_precise()
|
||||
|
||||
if isinstance(pyval, CFunc):
|
||||
addr = pyval._wrapper_address
|
||||
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
|
||||
sfunc.c_addr = context.add_dynamic_addr(builder, addr,
|
||||
info=str(typ))
|
||||
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
|
||||
info=type(pyval).__name__)
|
||||
return sfunc._getvalue()
|
||||
|
||||
if isinstance(pyval, Dispatcher):
|
||||
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
|
||||
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
|
||||
info=type(pyval).__name__)
|
||||
return sfunc._getvalue()
|
||||
|
||||
if isinstance(pyval, WrapperAddressProtocol):
|
||||
addr = pyval.__wrapper_address__()
|
||||
assert typ.check_signature(pyval.signature())
|
||||
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
|
||||
sfunc.c_addr = context.add_dynamic_addr(builder, addr,
|
||||
info=str(typ))
|
||||
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
|
||||
info=type(pyval).__name__)
|
||||
return sfunc._getvalue()
|
||||
|
||||
# TODO: implement support for pytypes.FunctionType, ctypes.CFUNCTYPE
|
||||
raise NotImplementedError(
|
||||
'lower_constant_struct_function_type({}, {}, {}, {})'
|
||||
.format(context, builder, typ, pyval))
|
||||
|
||||
|
||||
def _get_wrapper_address(func, sig):
|
||||
"""Return the address of a compiled cfunc wrapper function of `func`.
|
||||
|
||||
Warning: The compiled function must be compatible with the given
|
||||
signature `sig`. If it is not, then result of calling the compiled
|
||||
function is undefined. The compatibility is ensured when passing
|
||||
in a first-class function to a Numba njit compiled function either
|
||||
as an argument or via namespace scoping.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : object
|
||||
A Numba cfunc or jit decoreated function or an object that
|
||||
implements the wrapper address protocol (see note below).
|
||||
sig : Signature
|
||||
The expected function signature.
|
||||
|
||||
Returns
|
||||
-------
|
||||
addr : int
|
||||
An address in memory (pointer value) of the compiled function
|
||||
corresponding to the specified signature.
|
||||
|
||||
Note: wrapper address protocol
|
||||
------------------------------
|
||||
|
||||
An object implements the wrapper address protocol iff the object
|
||||
provides a callable attribute named __wrapper_address__ that takes
|
||||
a Signature instance as the argument, and returns an integer
|
||||
representing the address or pointer value of a compiled function
|
||||
for the given signature.
|
||||
|
||||
"""
|
||||
if not sig.is_precise():
|
||||
# addr==-1 will indicate that no implementation is available
|
||||
# for cases where type-inference did not identified the
|
||||
# function type. For example, the type of an unused
|
||||
# jit-decorated function argument will be undefined but also
|
||||
# irrelevant.
|
||||
addr = -1
|
||||
elif hasattr(func, '__wrapper_address__'):
|
||||
# func can be any object that implements the
|
||||
# __wrapper_address__ protocol.
|
||||
addr = func.__wrapper_address__()
|
||||
elif isinstance(func, CFunc):
|
||||
assert sig == func._sig
|
||||
addr = func.address
|
||||
elif isinstance(func, Dispatcher):
|
||||
cres = func.get_compile_result(sig)
|
||||
wrapper_name = cres.fndesc.llvm_cfunc_wrapper_name
|
||||
addr = cres.library.get_pointer_to_function(wrapper_name)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'get wrapper address of {type(func)} instance with {sig!r}')
|
||||
if not isinstance(addr, int):
|
||||
raise TypeError(
|
||||
f'wrapper address must be integer, got {type(addr)} instance')
|
||||
if addr <= 0 and addr != -1:
|
||||
raise ValueError(f'wrapper address of {type(func)} instance must be'
|
||||
f' a positive integer but got {addr} [sig={sig}]')
|
||||
# print(f'_get_wrapper_address[{func}]({sig=}) -> {addr}')
|
||||
return addr
|
||||
|
||||
|
||||
def _get_jit_address(func, sig):
|
||||
"""Similar to ``_get_wrapper_address()`` but get the `.jit_addr` instead.
|
||||
"""
|
||||
if isinstance(func, Dispatcher):
|
||||
cres = func.get_compile_result(sig)
|
||||
jit_name = cres.fndesc.llvm_func_name
|
||||
addr = cres.library.get_pointer_to_function(jit_name)
|
||||
else:
|
||||
addr = 0
|
||||
if not isinstance(addr, int):
|
||||
raise TypeError(
|
||||
f'jit address must be integer, got {type(addr)} instance')
|
||||
return addr
|
||||
|
||||
|
||||
def _lower_get_address(context, builder, func, sig, failure_mode,
|
||||
*, function_name):
|
||||
"""Low-level call to <function_name>(func, sig).
|
||||
|
||||
When calling this function, GIL must be acquired.
|
||||
"""
|
||||
pyapi = context.get_python_api(builder)
|
||||
|
||||
# Get the cfunc wrapper address. The code below trusts that the
|
||||
# function numba.function._get_wrapper_address exists and can be
|
||||
# called with two arguments. However, if an exception is raised in
|
||||
# the function, then it will be caught and propagated to the
|
||||
# caller.
|
||||
|
||||
modname = context.insert_const_string(builder.module, __name__)
|
||||
numba_mod = pyapi.import_module(modname)
|
||||
numba_func = pyapi.object_getattr_string(numba_mod, function_name)
|
||||
pyapi.decref(numba_mod)
|
||||
sig_obj = pyapi.unserialize(pyapi.serialize_object(sig))
|
||||
|
||||
addr = pyapi.call_function_objargs(numba_func, (func, sig_obj))
|
||||
|
||||
if failure_mode != 'ignore':
|
||||
with builder.if_then(cgutils.is_null(builder, addr), likely=False):
|
||||
# *function_name* has raised an exception, propagate it
|
||||
# to the caller.
|
||||
if failure_mode == 'return_exc':
|
||||
context.call_conv.return_exc(builder)
|
||||
elif failure_mode == 'return_null':
|
||||
builder.ret(pyapi.get_null_object())
|
||||
else:
|
||||
raise NotImplementedError(failure_mode)
|
||||
# else the caller will handle addr == NULL
|
||||
return addr # new reference or NULL
|
||||
|
||||
|
||||
lower_get_wrapper_address = partial(
|
||||
_lower_get_address,
|
||||
function_name="_get_wrapper_address",
|
||||
)
|
||||
|
||||
|
||||
lower_get_jit_address = partial(
|
||||
_lower_get_address,
|
||||
function_name="_get_jit_address",
|
||||
)
|
||||
|
||||
|
||||
@unbox(FunctionType)
|
||||
def unbox_function_type(typ, obj, c):
|
||||
typ = typ.get_precise()
|
||||
|
||||
sfunc = cgutils.create_struct_proxy(typ)(c.context, c.builder)
|
||||
|
||||
addr = lower_get_wrapper_address(
|
||||
c.context, c.builder, obj, typ.signature, failure_mode='return_null')
|
||||
sfunc.c_addr = c.pyapi.long_as_voidptr(addr)
|
||||
c.pyapi.decref(addr)
|
||||
|
||||
llty = c.context.get_value_type(types.voidptr)
|
||||
sfunc.py_addr = c.builder.ptrtoint(obj, llty)
|
||||
|
||||
addr = lower_get_jit_address(
|
||||
c.context, c.builder, obj, typ.signature, failure_mode='return_null')
|
||||
sfunc.jit_addr = c.pyapi.long_as_voidptr(addr)
|
||||
c.pyapi.decref(addr)
|
||||
|
||||
return NativeValue(sfunc._getvalue())
|
||||
|
||||
|
||||
@box(FunctionType)
|
||||
def box_function_type(typ, val, c):
|
||||
typ = typ.get_precise()
|
||||
|
||||
sfunc = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
|
||||
pyaddr_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
|
||||
raw_ptr = c.builder.inttoptr(sfunc.py_addr, c.pyapi.pyobj)
|
||||
with c.builder.if_then(cgutils.is_null(c.builder, raw_ptr),
|
||||
likely=False):
|
||||
cstr = f"first-class function {typ} parent object not set"
|
||||
c.pyapi.err_set_string("PyExc_MemoryError", cstr)
|
||||
c.builder.ret(c.pyapi.get_null_object())
|
||||
c.builder.store(raw_ptr, pyaddr_ptr)
|
||||
cfunc = c.builder.load(pyaddr_ptr)
|
||||
c.pyapi.incref(cfunc)
|
||||
return cfunc
|
||||
|
||||
|
||||
@lower_cast(UndefinedFunctionType, FunctionType)
|
||||
def lower_cast_function_type_to_function_type(
|
||||
context, builder, fromty, toty, val):
|
||||
return val
|
||||
|
||||
|
||||
@lower_cast(types.Dispatcher, FunctionType)
|
||||
def lower_cast_dispatcher_to_function_type(context, builder, fromty, toty, val):
|
||||
toty = toty.get_precise()
|
||||
|
||||
sig = toty.signature
|
||||
dispatcher = fromty.dispatcher
|
||||
llvoidptr = context.get_value_type(types.voidptr)
|
||||
sfunc = cgutils.create_struct_proxy(toty)(context, builder)
|
||||
# Always store the python function
|
||||
sfunc.py_addr = builder.ptrtoint(val, llvoidptr)
|
||||
|
||||
# Attempt to compile the Dispatcher to the expected function type
|
||||
try:
|
||||
cres = dispatcher.get_compile_result(sig)
|
||||
except errors.NumbaError:
|
||||
cres = None
|
||||
|
||||
# If compilation is successful, we can by-pass using GIL to get the cfunc
|
||||
if cres is not None:
|
||||
# Declare cfunc in the current module
|
||||
wrapper_name = cres.fndesc.llvm_cfunc_wrapper_name
|
||||
llfnptr = context.get_value_type(toty.ftype)
|
||||
llfnty = llfnptr.pointee
|
||||
fn = cgutils.get_or_insert_function(
|
||||
builder.module, llfnty, wrapper_name,
|
||||
)
|
||||
addr = builder.bitcast(fn, llvoidptr)
|
||||
# Store the cfunc
|
||||
sfunc.c_addr = addr
|
||||
# Store the jit func
|
||||
fn = context.declare_function(builder.module, cres.fndesc)
|
||||
sfunc.jit_addr = builder.bitcast(fn, llvoidptr)
|
||||
# Link-in the dispatcher library
|
||||
context.active_code_library.add_linking_library(cres.library)
|
||||
|
||||
else:
|
||||
# Use lower_get_wrapper_address() to get the cfunc
|
||||
lower_get_wrapper_address
|
||||
pyapi = context.get_python_api(builder)
|
||||
|
||||
gil_state = pyapi.gil_ensure()
|
||||
addr = lower_get_wrapper_address(
|
||||
context, builder, val, toty.signature,
|
||||
failure_mode='return_exc')
|
||||
sfunc.c_addr = pyapi.long_as_voidptr(addr)
|
||||
pyapi.decref(addr)
|
||||
pyapi.gil_release(gil_state)
|
||||
|
||||
return sfunc._getvalue()
|
||||
@@ -0,0 +1,3 @@
|
||||
from numba.experimental.jitclass.decorators import jitclass
|
||||
from numba.experimental.jitclass import boxing # Has import-time side effect
|
||||
from numba.experimental.jitclass import overloads # Has import-time side effect
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,598 @@
|
||||
import inspect
|
||||
import operator
|
||||
import types as pytypes
|
||||
import typing as pt
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llvmlite import ir as llvmir
|
||||
from numba import njit
|
||||
from numba.core import cgutils, errors, imputils, types, utils
|
||||
from numba.core.datamodel import default_manager, models
|
||||
from numba.core.registry import cpu_target
|
||||
from numba.core.typing import templates
|
||||
from numba.core.typing.asnumbatype import as_numba_type
|
||||
from numba.core.serialize import disable_pickling
|
||||
from numba.experimental.jitclass import _box
|
||||
|
||||
##############################################################################
|
||||
# Data model
|
||||
|
||||
|
||||
class InstanceModel(models.StructModel):
|
||||
def __init__(self, dmm, fe_typ):
|
||||
cls_data_ty = types.ClassDataType(fe_typ)
|
||||
# MemInfoPointer uses the `dtype` attribute to traverse for nested
|
||||
# NRT MemInfo. Since we handle nested NRT MemInfo ourselves,
|
||||
# we will replace provide MemInfoPointer with an opaque type
|
||||
# so that it does not raise exception for nested meminfo.
|
||||
dtype = types.Opaque('Opaque.' + str(cls_data_ty))
|
||||
members = [
|
||||
('meminfo', types.MemInfoPointer(dtype)),
|
||||
('data', types.CPointer(cls_data_ty)),
|
||||
]
|
||||
super(InstanceModel, self).__init__(dmm, fe_typ, members)
|
||||
|
||||
|
||||
class InstanceDataModel(models.StructModel):
|
||||
def __init__(self, dmm, fe_typ):
|
||||
clsty = fe_typ.class_type
|
||||
members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
|
||||
super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
|
||||
|
||||
|
||||
default_manager.register(types.ClassInstanceType, InstanceModel)
|
||||
default_manager.register(types.ClassDataType, InstanceDataModel)
|
||||
default_manager.register(types.ClassType, models.OpaqueModel)
|
||||
|
||||
|
||||
def _mangle_attr(name):
|
||||
"""
|
||||
Mangle attributes.
|
||||
The resulting name does not startswith an underscore '_'.
|
||||
"""
|
||||
return 'm_' + name
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Class object
|
||||
|
||||
_ctor_template = """
|
||||
def ctor({args}):
|
||||
return __numba_cls_({args})
|
||||
"""
|
||||
|
||||
|
||||
def _getargs(fn_sig):
|
||||
"""
|
||||
Returns list of positional and keyword argument names in order.
|
||||
"""
|
||||
params = fn_sig.parameters
|
||||
args = []
|
||||
for k, v in params.items():
|
||||
if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD:
|
||||
args.append(k)
|
||||
else:
|
||||
msg = "%s argument type unsupported in jitclass" % v.kind
|
||||
raise errors.UnsupportedError(msg)
|
||||
return args
|
||||
|
||||
|
||||
@disable_pickling
|
||||
class JitClassType(type):
|
||||
"""
|
||||
The type of any jitclass.
|
||||
"""
|
||||
def __new__(cls, name, bases, dct):
|
||||
if len(bases) != 1:
|
||||
raise TypeError("must have exactly one base class")
|
||||
[base] = bases
|
||||
if isinstance(base, JitClassType):
|
||||
raise TypeError("cannot subclass from a jitclass")
|
||||
assert 'class_type' in dct, 'missing "class_type" attr'
|
||||
outcls = type.__new__(cls, name, bases, dct)
|
||||
outcls._set_init()
|
||||
return outcls
|
||||
|
||||
def _set_init(cls):
|
||||
"""
|
||||
Generate a wrapper for calling the constructor from pure Python.
|
||||
Note the wrapper will only accept positional arguments.
|
||||
"""
|
||||
init = cls.class_type.instance_type.methods['__init__']
|
||||
init_sig = utils.pysignature(init)
|
||||
# get postitional and keyword arguments
|
||||
# offset by one to exclude the `self` arg
|
||||
args = _getargs(init_sig)[1:]
|
||||
cls._ctor_sig = init_sig
|
||||
ctor_source = _ctor_template.format(args=', '.join(args))
|
||||
glbls = {"__numba_cls_": cls}
|
||||
exec(ctor_source, glbls)
|
||||
ctor = glbls['ctor']
|
||||
cls._ctor = njit(ctor)
|
||||
|
||||
def __instancecheck__(cls, instance):
|
||||
if isinstance(instance, _box.Box):
|
||||
return instance._numba_type_.class_type is cls.class_type
|
||||
return False
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
# The first argument of _ctor_sig is `cls`, which here
|
||||
# is bound to None and then skipped when invoking the constructor.
|
||||
bind = cls._ctor_sig.bind(None, *args, **kwargs)
|
||||
bind.apply_defaults()
|
||||
return cls._ctor(*bind.args[1:], **bind.kwargs)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Registration utils
|
||||
|
||||
def _validate_spec(spec):
|
||||
for k, v in spec.items():
|
||||
if not isinstance(k, str):
|
||||
raise TypeError("spec keys should be strings, got %r" % (k,))
|
||||
if not isinstance(v, types.Type):
|
||||
raise TypeError("spec values should be Numba type instances, got %r"
|
||||
% (v,))
|
||||
|
||||
|
||||
def _fix_up_private_attr(clsname, spec):
|
||||
"""
|
||||
Apply the same changes to dunder names as CPython would.
|
||||
"""
|
||||
out = OrderedDict()
|
||||
for k, v in spec.items():
|
||||
if k.startswith('__') and not k.endswith('__'):
|
||||
k = '_' + clsname + k
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def _add_linking_libs(context, call):
|
||||
"""
|
||||
Add the required libs for the callable to allow inlining.
|
||||
"""
|
||||
libs = getattr(call, "libs", ())
|
||||
if libs:
|
||||
context.add_linking_libs(libs)
|
||||
|
||||
|
||||
def register_class_type(cls, spec, class_ctor, builder):
|
||||
"""
|
||||
Internal function to create a jitclass.
|
||||
|
||||
Args
|
||||
----
|
||||
cls: the original class object (used as the prototype)
|
||||
spec: the structural specification contains the field types.
|
||||
class_ctor: the numba type to represent the jitclass
|
||||
builder: the internal jitclass builder
|
||||
"""
|
||||
# Normalize spec
|
||||
if spec is None:
|
||||
spec = OrderedDict()
|
||||
elif isinstance(spec, Sequence):
|
||||
spec = OrderedDict(spec)
|
||||
|
||||
# Extend spec with class annotations.
|
||||
for attr, py_type in pt.get_type_hints(cls).items():
|
||||
if attr not in spec:
|
||||
spec[attr] = as_numba_type(py_type)
|
||||
|
||||
_validate_spec(spec)
|
||||
|
||||
# Fix up private attribute names
|
||||
spec = _fix_up_private_attr(cls.__name__, spec)
|
||||
|
||||
# Copy methods from base classes
|
||||
clsdct = {}
|
||||
for basecls in reversed(inspect.getmro(cls)):
|
||||
clsdct.update(basecls.__dict__)
|
||||
|
||||
methods, props, static_methods, others = {}, {}, {}, {}
|
||||
for k, v in clsdct.items():
|
||||
if isinstance(v, pytypes.FunctionType):
|
||||
methods[k] = v
|
||||
elif isinstance(v, property):
|
||||
props[k] = v
|
||||
elif isinstance(v, staticmethod):
|
||||
static_methods[k] = v
|
||||
else:
|
||||
others[k] = v
|
||||
|
||||
# Check for name shadowing
|
||||
shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec)
|
||||
if shadowed:
|
||||
raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
|
||||
|
||||
docstring = others.pop('__doc__', "")
|
||||
_drop_ignored_attrs(others)
|
||||
if others:
|
||||
msg = "class members are not yet supported: {0}"
|
||||
members = ', '.join(others.keys())
|
||||
raise TypeError(msg.format(members))
|
||||
|
||||
for k, v in props.items():
|
||||
if v.fdel is not None:
|
||||
raise TypeError("deleter is not supported: {0}".format(k))
|
||||
|
||||
jit_methods = {k: njit(v) for k, v in methods.items()}
|
||||
|
||||
jit_props = {}
|
||||
for k, v in props.items():
|
||||
dct = {}
|
||||
if v.fget:
|
||||
dct['get'] = njit(v.fget)
|
||||
if v.fset:
|
||||
dct['set'] = njit(v.fset)
|
||||
jit_props[k] = dct
|
||||
|
||||
jit_static_methods = {
|
||||
k: njit(v.__func__) for k, v in static_methods.items()}
|
||||
|
||||
# Instantiate class type
|
||||
class_type = class_ctor(
|
||||
cls,
|
||||
ConstructorTemplate,
|
||||
spec,
|
||||
jit_methods,
|
||||
jit_props,
|
||||
jit_static_methods)
|
||||
|
||||
jit_class_dct = dict(class_type=class_type, __doc__=docstring)
|
||||
jit_class_dct.update(jit_static_methods)
|
||||
cls = JitClassType(cls.__name__, (cls,), jit_class_dct)
|
||||
|
||||
# Register resolution of the class object
|
||||
typingctx = cpu_target.typing_context
|
||||
typingctx.insert_global(cls, class_type)
|
||||
|
||||
# Register class
|
||||
targetctx = cpu_target.target_context
|
||||
builder(class_type, typingctx, targetctx).register()
|
||||
as_numba_type.register(cls, class_type.instance_type)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class ConstructorTemplate(templates.AbstractTemplate):
|
||||
"""
|
||||
Base class for jitclass constructor templates.
|
||||
"""
|
||||
|
||||
def generic(self, args, kws):
|
||||
# Redirect resolution to __init__
|
||||
instance_type = self.key.instance_type
|
||||
ctor = instance_type.jit_methods['__init__']
|
||||
boundargs = (instance_type.get_reference_type(),) + args
|
||||
disp_type = types.Dispatcher(ctor)
|
||||
sig = disp_type.get_call_type(self.context, boundargs, kws)
|
||||
|
||||
if not isinstance(sig.return_type, types.NoneType):
|
||||
raise errors.NumbaTypeError(
|
||||
f"__init__() should return None, not '{sig.return_type}'")
|
||||
|
||||
# Actual constructor returns an instance value (not None)
|
||||
out = templates.signature(instance_type, *sig.args[1:])
|
||||
return out
|
||||
|
||||
|
||||
def _drop_ignored_attrs(dct):
|
||||
# ignore anything defined by object
|
||||
drop = set(['__weakref__',
|
||||
'__module__',
|
||||
'__dict__'])
|
||||
if utils.PYVERSION in ((3, 13), (3, 14)):
|
||||
# new in python 3.13
|
||||
drop |= set(['__firstlineno__', '__static_attributes__'])
|
||||
|
||||
for att in ('__annotations__',
|
||||
'__annotate_func__',
|
||||
'__annotations_cache__'):
|
||||
if att in dct:
|
||||
drop.add(att)
|
||||
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, (pytypes.BuiltinFunctionType,
|
||||
pytypes.BuiltinMethodType)):
|
||||
drop.add(k)
|
||||
elif getattr(v, '__objclass__', None) is object:
|
||||
drop.add(k)
|
||||
|
||||
# If a class defines __eq__ but not __hash__, __hash__ is implicitly set to
|
||||
# None. This is a class member, and class members are not presently
|
||||
# supported.
|
||||
if '__hash__' in dct and dct['__hash__'] is None:
|
||||
drop.add('__hash__')
|
||||
|
||||
for k in drop:
|
||||
dct.pop(k)
|
||||
|
||||
|
||||
class ClassBuilder(object):
|
||||
"""
|
||||
A jitclass builder for a mutable jitclass. This will register
|
||||
typing and implementation hooks to the given typing and target contexts.
|
||||
"""
|
||||
class_impl_registry = imputils.Registry('jitclass builder')
|
||||
implemented_methods = set()
|
||||
|
||||
def __init__(self, class_type, typingctx, targetctx):
|
||||
self.class_type = class_type
|
||||
self.typingctx = typingctx
|
||||
self.targetctx = targetctx
|
||||
|
||||
def register(self):
|
||||
"""
|
||||
Register to the frontend and backend.
|
||||
"""
|
||||
# Register generic implementations for all jitclasses
|
||||
self._register_methods(self.class_impl_registry,
|
||||
self.class_type.instance_type)
|
||||
# NOTE other registrations are done at the top-level
|
||||
# (see ctor_impl and attr_impl below)
|
||||
self.targetctx.install_registry(self.class_impl_registry)
|
||||
|
||||
def _register_methods(self, registry, instance_type):
|
||||
"""
|
||||
Register method implementations.
|
||||
This simply registers that the method names are valid methods. Inside
|
||||
of imp() below we retrieve the actual method to run from the type of
|
||||
the receiver argument (i.e. self).
|
||||
"""
|
||||
to_register = list(instance_type.jit_methods) + \
|
||||
list(instance_type.jit_static_methods)
|
||||
for meth in to_register:
|
||||
|
||||
# There's no way to retrieve the particular method name
|
||||
# inside the implementation function, so we have to register a
|
||||
# specific closure for each different name
|
||||
if meth not in self.implemented_methods:
|
||||
self._implement_method(registry, meth)
|
||||
self.implemented_methods.add(meth)
|
||||
|
||||
def _implement_method(self, registry, attr):
|
||||
# create a separate instance of imp method to avoid closure clashing
|
||||
def get_imp():
|
||||
def imp(context, builder, sig, args):
|
||||
instance_type = sig.args[0]
|
||||
|
||||
if attr in instance_type.jit_methods:
|
||||
method = instance_type.jit_methods[attr]
|
||||
elif attr in instance_type.jit_static_methods:
|
||||
method = instance_type.jit_static_methods[attr]
|
||||
# imp gets called as a method, where the first argument is
|
||||
# self. We drop this for a static method.
|
||||
sig = sig.replace(args=sig.args[1:])
|
||||
args = args[1:]
|
||||
|
||||
disp_type = types.Dispatcher(method)
|
||||
call = context.get_function(disp_type, sig)
|
||||
out = call(builder, args)
|
||||
_add_linking_libs(context, call)
|
||||
return imputils.impl_ret_new_ref(context, builder,
|
||||
sig.return_type, out)
|
||||
return imp
|
||||
|
||||
def _getsetitem_gen(getset):
|
||||
_dunder_meth = "__%s__" % getset
|
||||
op = getattr(operator, getset)
|
||||
|
||||
@templates.infer_global(op)
|
||||
class GetSetItem(templates.AbstractTemplate):
|
||||
def generic(self, args, kws):
|
||||
instance = args[0]
|
||||
if isinstance(instance, types.ClassInstanceType) and \
|
||||
_dunder_meth in instance.jit_methods:
|
||||
meth = instance.jit_methods[_dunder_meth]
|
||||
disp_type = types.Dispatcher(meth)
|
||||
sig = disp_type.get_call_type(self.context, args, kws)
|
||||
return sig
|
||||
|
||||
# lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
|
||||
# from python and numba
|
||||
imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
|
||||
types.ClassInstanceType,
|
||||
types.VarArg(types.Any))(get_imp())
|
||||
imputils.lower_builtin(op,
|
||||
types.ClassInstanceType,
|
||||
types.VarArg(types.Any))(get_imp())
|
||||
|
||||
dunder_stripped = attr.strip('_')
|
||||
if dunder_stripped in ("getitem", "setitem"):
|
||||
_getsetitem_gen(dunder_stripped)
|
||||
else:
|
||||
registry.lower((types.ClassInstanceType, attr),
|
||||
types.ClassInstanceType,
|
||||
types.VarArg(types.Any))(get_imp())
|
||||
|
||||
|
||||
@templates.infer_getattr
|
||||
class ClassAttribute(templates.AttributeTemplate):
|
||||
key = types.ClassInstanceType
|
||||
|
||||
def generic_resolve(self, instance, attr):
|
||||
if attr in instance.struct:
|
||||
# It's a struct field => the type is well-known
|
||||
return instance.struct[attr]
|
||||
|
||||
elif attr in instance.jit_methods:
|
||||
# It's a jitted method => typeinfer it
|
||||
meth = instance.jit_methods[attr]
|
||||
disp_type = types.Dispatcher(meth)
|
||||
|
||||
class MethodTemplate(templates.AbstractTemplate):
|
||||
key = (self.key, attr)
|
||||
|
||||
def generic(self, args, kws):
|
||||
args = (instance,) + tuple(args)
|
||||
sig = disp_type.get_call_type(self.context, args, kws)
|
||||
return sig.as_method()
|
||||
|
||||
return types.BoundFunction(MethodTemplate, instance)
|
||||
|
||||
elif attr in instance.jit_static_methods:
|
||||
# It's a jitted method => typeinfer it
|
||||
meth = instance.jit_static_methods[attr]
|
||||
disp_type = types.Dispatcher(meth)
|
||||
|
||||
class StaticMethodTemplate(templates.AbstractTemplate):
|
||||
key = (self.key, attr)
|
||||
|
||||
def generic(self, args, kws):
|
||||
# Don't add instance as the first argument for a static
|
||||
# method.
|
||||
sig = disp_type.get_call_type(self.context, args, kws)
|
||||
# sig itself does not include ClassInstanceType as it's
|
||||
# first argument, so instead of calling sig.as_method()
|
||||
# we insert the recvr. This is equivalent to
|
||||
# sig.replace(args=(instance,) + sig.args).as_method().
|
||||
return sig.replace(recvr=instance)
|
||||
|
||||
return types.BoundFunction(StaticMethodTemplate, instance)
|
||||
|
||||
elif attr in instance.jit_props:
|
||||
# It's a jitted property => typeinfer its getter
|
||||
impdct = instance.jit_props[attr]
|
||||
getter = impdct['get']
|
||||
disp_type = types.Dispatcher(getter)
|
||||
sig = disp_type.get_call_type(self.context, (instance,), {})
|
||||
return sig.return_type
|
||||
|
||||
|
||||
@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
|
||||
def get_attr_impl(context, builder, typ, value, attr):
|
||||
"""
|
||||
Generic getattr() for @jitclass instances.
|
||||
"""
|
||||
if attr in typ.struct:
|
||||
# It's a struct field
|
||||
inst = context.make_helper(builder, typ, value=value)
|
||||
data_pointer = inst.data
|
||||
data = context.make_data_helper(builder, typ.get_data_type(),
|
||||
ref=data_pointer)
|
||||
return imputils.impl_ret_borrowed(context, builder,
|
||||
typ.struct[attr],
|
||||
getattr(data, _mangle_attr(attr)))
|
||||
elif attr in typ.jit_props:
|
||||
# It's a jitted property
|
||||
getter = typ.jit_props[attr]['get']
|
||||
sig = templates.signature(None, typ)
|
||||
dispatcher = types.Dispatcher(getter)
|
||||
sig = dispatcher.get_call_type(context.typing_context, [typ], {})
|
||||
call = context.get_function(dispatcher, sig)
|
||||
out = call(builder, [value])
|
||||
_add_linking_libs(context, call)
|
||||
return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
|
||||
|
||||
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
|
||||
|
||||
|
||||
@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
|
||||
def set_attr_impl(context, builder, sig, args, attr):
|
||||
"""
|
||||
Generic setattr() for @jitclass instances.
|
||||
"""
|
||||
typ, valty = sig.args
|
||||
target, val = args
|
||||
|
||||
if attr in typ.struct:
|
||||
# It's a struct member
|
||||
inst = context.make_helper(builder, typ, value=target)
|
||||
data_ptr = inst.data
|
||||
data = context.make_data_helper(builder, typ.get_data_type(),
|
||||
ref=data_ptr)
|
||||
|
||||
# Get old value
|
||||
attr_type = typ.struct[attr]
|
||||
oldvalue = getattr(data, _mangle_attr(attr))
|
||||
|
||||
# Store n
|
||||
setattr(data, _mangle_attr(attr), val)
|
||||
context.nrt.incref(builder, attr_type, val)
|
||||
|
||||
# Delete old value
|
||||
context.nrt.decref(builder, attr_type, oldvalue)
|
||||
|
||||
elif attr in typ.jit_props:
|
||||
# It's a jitted property
|
||||
setter = typ.jit_props[attr]['set']
|
||||
disp_type = types.Dispatcher(setter)
|
||||
sig = disp_type.get_call_type(context.typing_context,
|
||||
(typ, valty), {})
|
||||
call = context.get_function(disp_type, sig)
|
||||
call(builder, (target, val))
|
||||
_add_linking_libs(context, call)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'attribute {0!r} not implemented'.format(attr))
|
||||
|
||||
|
||||
def imp_dtor(context, module, instance_type):
|
||||
llvoidptr = context.get_value_type(types.voidptr)
|
||||
llsize = context.get_value_type(types.uintp)
|
||||
dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
|
||||
[llvoidptr, llsize, llvoidptr])
|
||||
|
||||
fname = "_Dtor.{0}".format(instance_type.name)
|
||||
dtor_fn = cgutils.get_or_insert_function(module, dtor_ftype, fname)
|
||||
if dtor_fn.is_declaration:
|
||||
# Define
|
||||
builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
|
||||
|
||||
alloc_fe_type = instance_type.get_data_type()
|
||||
alloc_type = context.get_value_type(alloc_fe_type)
|
||||
|
||||
ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
|
||||
data = context.make_helper(builder, alloc_fe_type, ref=ptr)
|
||||
|
||||
context.nrt.decref(builder, alloc_fe_type, data._getvalue())
|
||||
|
||||
builder.ret_void()
|
||||
|
||||
return dtor_fn
|
||||
|
||||
|
||||
@ClassBuilder.class_impl_registry.lower(types.ClassType,
|
||||
types.VarArg(types.Any))
|
||||
def ctor_impl(context, builder, sig, args):
|
||||
"""
|
||||
Generic constructor (__new__) for jitclasses.
|
||||
"""
|
||||
# Allocate the instance
|
||||
inst_typ = sig.return_type
|
||||
alloc_type = context.get_data_type(inst_typ.get_data_type())
|
||||
alloc_size = context.get_abi_sizeof(alloc_type)
|
||||
|
||||
meminfo = context.nrt.meminfo_alloc_dtor(
|
||||
builder,
|
||||
context.get_constant(types.uintp, alloc_size),
|
||||
imp_dtor(context, builder.module, inst_typ),
|
||||
)
|
||||
data_pointer = context.nrt.meminfo_data(builder, meminfo)
|
||||
data_pointer = builder.bitcast(data_pointer,
|
||||
alloc_type.as_pointer())
|
||||
|
||||
# Nullify all data
|
||||
builder.store(cgutils.get_null_value(alloc_type),
|
||||
data_pointer)
|
||||
|
||||
inst_struct = context.make_helper(builder, inst_typ)
|
||||
inst_struct.meminfo = meminfo
|
||||
inst_struct.data = data_pointer
|
||||
|
||||
# Call the jitted __init__
|
||||
# TODO: extract the following into a common util
|
||||
init_sig = (sig.return_type,) + sig.args
|
||||
|
||||
init = inst_typ.jit_methods['__init__']
|
||||
disp_type = types.Dispatcher(init)
|
||||
call = context.get_function(disp_type, types.void(*init_sig))
|
||||
_add_linking_libs(context, call)
|
||||
realargs = [inst_struct._getvalue()] + list(args)
|
||||
call(builder, realargs)
|
||||
|
||||
# Prepare return value
|
||||
ret = inst_struct._getvalue()
|
||||
|
||||
return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)
|
||||
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Implement logic relating to wrapping (box) and unwrapping (unbox) instances
|
||||
of jitclasses for use inside the python interpreter.
|
||||
"""
|
||||
|
||||
from functools import wraps, partial
|
||||
|
||||
from llvmlite import ir
|
||||
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.decorators import njit
|
||||
from numba.core.pythonapi import box, unbox, NativeValue
|
||||
from numba.core.typing.typeof import typeof_impl
|
||||
from numba.experimental.jitclass import _box
|
||||
|
||||
|
||||
_getter_code_template = """
|
||||
def accessor(__numba_self_):
|
||||
return __numba_self_.{0}
|
||||
"""
|
||||
|
||||
_setter_code_template = """
|
||||
def mutator(__numba_self_, __numba_val):
|
||||
__numba_self_.{0} = __numba_val
|
||||
"""
|
||||
|
||||
_method_code_template = """
|
||||
def method(__numba_self_, *args):
|
||||
return __numba_self_.{method}(*args)
|
||||
"""
|
||||
|
||||
|
||||
def _generate_property(field, template, fname):
|
||||
"""
|
||||
Generate simple function that get/set a field of the instance
|
||||
"""
|
||||
source = template.format(field)
|
||||
glbls = {}
|
||||
exec(source, glbls)
|
||||
return njit(glbls[fname])
|
||||
|
||||
|
||||
_generate_getter = partial(_generate_property, template=_getter_code_template,
|
||||
fname='accessor')
|
||||
_generate_setter = partial(_generate_property, template=_setter_code_template,
|
||||
fname='mutator')
|
||||
|
||||
|
||||
def _generate_method(name, func):
|
||||
"""
|
||||
Generate a wrapper for calling a method. Note the wrapper will only
|
||||
accept positional arguments.
|
||||
"""
|
||||
source = _method_code_template.format(method=name)
|
||||
glbls = {}
|
||||
exec(source, glbls)
|
||||
method = njit(glbls['method'])
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
_cache_specialized_box = {}
|
||||
|
||||
|
||||
def _specialize_box(typ):
|
||||
"""
|
||||
Create a subclass of Box that is specialized to the jitclass.
|
||||
|
||||
This function caches the result to avoid code bloat.
|
||||
"""
|
||||
# Check cache
|
||||
if typ in _cache_specialized_box:
|
||||
return _cache_specialized_box[typ]
|
||||
dct = {'__slots__': (),
|
||||
'_numba_type_': typ,
|
||||
'__doc__': typ.class_type.class_doc,
|
||||
}
|
||||
# Inject attributes as class properties
|
||||
for field in typ.struct:
|
||||
getter = _generate_getter(field)
|
||||
setter = _generate_setter(field)
|
||||
dct[field] = property(getter, setter)
|
||||
# Inject properties as class properties
|
||||
for field, impdct in typ.jit_props.items():
|
||||
getter = None
|
||||
setter = None
|
||||
if 'get' in impdct:
|
||||
getter = _generate_getter(field)
|
||||
if 'set' in impdct:
|
||||
setter = _generate_setter(field)
|
||||
# get docstring from either the fget or fset
|
||||
imp = impdct.get('get') or impdct.get('set') or None
|
||||
doc = getattr(imp, '__doc__', None)
|
||||
dct[field] = property(getter, setter, doc=doc)
|
||||
# Inject methods as class members
|
||||
supported_dunders = {
|
||||
"__abs__",
|
||||
"__annotate_func__",
|
||||
"__bool__",
|
||||
"__complex__",
|
||||
"__contains__",
|
||||
"__float__",
|
||||
"__getitem__",
|
||||
"__hash__",
|
||||
"__index__",
|
||||
"__invert__",
|
||||
"__int__",
|
||||
"__len__",
|
||||
"__setitem__",
|
||||
"__str__",
|
||||
"__eq__",
|
||||
"__ne__",
|
||||
"__ge__",
|
||||
"__gt__",
|
||||
"__le__",
|
||||
"__lt__",
|
||||
"__add__",
|
||||
"__floordiv__",
|
||||
"__lshift__",
|
||||
"__matmul__",
|
||||
"__mod__",
|
||||
"__mul__",
|
||||
"__neg__",
|
||||
"__pos__",
|
||||
"__pow__",
|
||||
"__rshift__",
|
||||
"__sub__",
|
||||
"__truediv__",
|
||||
"__and__",
|
||||
"__or__",
|
||||
"__xor__",
|
||||
"__iadd__",
|
||||
"__ifloordiv__",
|
||||
"__ilshift__",
|
||||
"__imatmul__",
|
||||
"__imod__",
|
||||
"__imul__",
|
||||
"__ipow__",
|
||||
"__irshift__",
|
||||
"__isub__",
|
||||
"__itruediv__",
|
||||
"__iand__",
|
||||
"__ior__",
|
||||
"__ixor__",
|
||||
"__radd__",
|
||||
"__rfloordiv__",
|
||||
"__rlshift__",
|
||||
"__rmatmul__",
|
||||
"__rmod__",
|
||||
"__rmul__",
|
||||
"__rpow__",
|
||||
"__rrshift__",
|
||||
"__rsub__",
|
||||
"__rtruediv__",
|
||||
"__rand__",
|
||||
"__ror__",
|
||||
"__rxor__",
|
||||
}
|
||||
for name, func in typ.methods.items():
|
||||
if name == "__init__":
|
||||
continue
|
||||
if (
|
||||
name.startswith("__")
|
||||
and name.endswith("__")
|
||||
and name not in supported_dunders
|
||||
):
|
||||
raise TypeError(f"Method '{name}' is not supported.")
|
||||
dct[name] = _generate_method(name, func)
|
||||
|
||||
# Inject static methods as class members
|
||||
for name, func in typ.static_methods.items():
|
||||
dct[name] = _generate_method(name, func)
|
||||
|
||||
# Create subclass
|
||||
subcls = type(typ.classname, (_box.Box,), dct)
|
||||
# Store to cache
|
||||
_cache_specialized_box[typ] = subcls
|
||||
|
||||
# Pre-compile attribute getter.
|
||||
# Note: This must be done after the "box" class is created because
|
||||
# compiling the getter requires the "box" class to be defined.
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, property):
|
||||
prop = getattr(subcls, k)
|
||||
if prop.fget is not None:
|
||||
fget = prop.fget
|
||||
fast_fget = fget.compile((typ,))
|
||||
fget.disable_compile()
|
||||
setattr(subcls, k,
|
||||
property(fast_fget, prop.fset, prop.fdel,
|
||||
doc=prop.__doc__))
|
||||
|
||||
return subcls
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Implement box/unbox for call wrapper
|
||||
|
||||
@box(types.ClassInstanceType)
|
||||
def _box_class_instance(typ, val, c):
|
||||
meminfo, dataptr = cgutils.unpack_tuple(c.builder, val)
|
||||
|
||||
# Create Box instance
|
||||
box_subclassed = _specialize_box(typ)
|
||||
# Note: the ``box_subclassed`` is kept alive by the cache
|
||||
voidptr_boxcls = c.context.add_dynamic_addr(
|
||||
c.builder,
|
||||
id(box_subclassed),
|
||||
info="box_class_instance",
|
||||
)
|
||||
box_cls = c.builder.bitcast(voidptr_boxcls, c.pyapi.pyobj)
|
||||
|
||||
box = c.pyapi.call_function_objargs(box_cls, ())
|
||||
|
||||
# Initialize Box instance
|
||||
llvoidptr = ir.IntType(8).as_pointer()
|
||||
addr_meminfo = c.builder.bitcast(meminfo, llvoidptr)
|
||||
addr_data = c.builder.bitcast(dataptr, llvoidptr)
|
||||
|
||||
def set_member(member_offset, value):
|
||||
# Access member by byte offset
|
||||
offset = c.context.get_constant(types.uintp, member_offset)
|
||||
ptr = cgutils.pointer_add(c.builder, box, offset)
|
||||
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
|
||||
c.builder.store(value, casted)
|
||||
|
||||
set_member(_box.box_meminfoptr_offset, addr_meminfo)
|
||||
set_member(_box.box_dataptr_offset, addr_data)
|
||||
return box
|
||||
|
||||
|
||||
@unbox(types.ClassInstanceType)
|
||||
def _unbox_class_instance(typ, val, c):
|
||||
def access_member(member_offset):
|
||||
# Access member by byte offset
|
||||
offset = c.context.get_constant(types.uintp, member_offset)
|
||||
llvoidptr = ir.IntType(8).as_pointer()
|
||||
ptr = cgutils.pointer_add(c.builder, val, offset)
|
||||
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
|
||||
return c.builder.load(casted)
|
||||
|
||||
struct_cls = cgutils.create_struct_proxy(typ)
|
||||
inst = struct_cls(c.context, c.builder)
|
||||
|
||||
# load from Python object
|
||||
ptr_meminfo = access_member(_box.box_meminfoptr_offset)
|
||||
ptr_dataptr = access_member(_box.box_dataptr_offset)
|
||||
|
||||
# store to native structure
|
||||
inst.meminfo = c.builder.bitcast(ptr_meminfo, inst.meminfo.type)
|
||||
inst.data = c.builder.bitcast(ptr_dataptr, inst.data.type)
|
||||
|
||||
ret = inst._getvalue()
|
||||
|
||||
c.context.nrt.incref(c.builder, typ, ret)
|
||||
|
||||
return NativeValue(ret, is_error=c.pyapi.c_api_error())
|
||||
|
||||
|
||||
# Add a typeof_impl implementation for boxed jitclasses to short-circut the
|
||||
# various tests in typeof. This is needed for jitclasses which implement a
|
||||
# custom hash method. Without this, typeof_impl will return None, and one of the
|
||||
# later attempts to determine the type of the jitclass (before checking for
|
||||
# _numba_type_) will look up the object in a dictionary, triggering the hash
|
||||
# method. This will cause the dispatcher to determine the call signature of the
|
||||
# jit decorated obj.__hash__ method, which will call typeof(obj), and thus
|
||||
# infinite loop.
|
||||
# This implementation is here instead of in typeof.py to avoid circular imports.
|
||||
@typeof_impl.register(_box.Box)
|
||||
def _typeof_jitclass_box(val, c):
|
||||
return getattr(type(val), "_numba_type_")
|
||||
@@ -0,0 +1,88 @@
|
||||
from numba.core import types, config
|
||||
|
||||
|
||||
def jitclass(cls_or_spec=None, spec=None):
|
||||
"""
|
||||
A function for creating a jitclass.
|
||||
Can be used as a decorator or function.
|
||||
|
||||
Different use cases will cause different arguments to be set.
|
||||
|
||||
If specified, ``spec`` gives the types of class fields.
|
||||
It must be a dictionary or sequence.
|
||||
With a dictionary, use collections.OrderedDict for stable ordering.
|
||||
With a sequence, it must contain 2-tuples of (fieldname, fieldtype).
|
||||
|
||||
Any class annotations for field names not listed in spec will be added.
|
||||
For class annotation `x: T` we will append ``("x", as_numba_type(T))`` to
|
||||
the spec if ``x`` is not already a key in spec.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
1) ``cls_or_spec = None``, ``spec = None``
|
||||
|
||||
>>> @jitclass()
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
2) ``cls_or_spec = None``, ``spec = spec``
|
||||
|
||||
>>> @jitclass(spec=spec)
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
3) ``cls_or_spec = Foo``, ``spec = None``
|
||||
|
||||
>>> @jitclass
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
4) ``cls_or_spec = spec``, ``spec = None``
|
||||
In this case we update ``cls_or_spec, spec = None, cls_or_spec``.
|
||||
|
||||
>>> @jitclass(spec)
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
5) ``cls_or_spec = Foo``, ``spec = spec``
|
||||
|
||||
>>> JitFoo = jitclass(Foo, spec)
|
||||
|
||||
Returns
|
||||
-------
|
||||
If used as a decorator, returns a callable that takes a class object and
|
||||
returns a compiled version.
|
||||
If used as a function, returns the compiled class (an instance of
|
||||
``JitClassType``).
|
||||
"""
|
||||
|
||||
if (cls_or_spec is not None and
|
||||
spec is None and
|
||||
not isinstance(cls_or_spec, type)):
|
||||
# Used like
|
||||
# @jitclass([("x", intp)])
|
||||
# class Foo:
|
||||
# ...
|
||||
spec = cls_or_spec
|
||||
cls_or_spec = None
|
||||
|
||||
def wrap(cls):
|
||||
if config.DISABLE_JIT:
|
||||
return cls
|
||||
else:
|
||||
from numba.experimental.jitclass.base import (register_class_type,
|
||||
ClassBuilder)
|
||||
cls_jitted = register_class_type(cls, spec, types.ClassType,
|
||||
ClassBuilder)
|
||||
|
||||
# Preserve the module name of the original class
|
||||
cls_jitted.__module__ = cls.__module__
|
||||
|
||||
return cls_jitted
|
||||
|
||||
if cls_or_spec is None:
|
||||
return wrap
|
||||
else:
|
||||
return wrap(cls_or_spec)
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Overloads for ClassInstanceType for built-in functions that call dunder methods
|
||||
on an object.
|
||||
"""
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import operator
|
||||
|
||||
from numba.core.extending import overload
|
||||
from numba.core.types import ClassInstanceType
|
||||
|
||||
|
||||
def _get_args(n_args):
|
||||
assert n_args in (1, 2)
|
||||
return list("xy")[:n_args]
|
||||
|
||||
|
||||
def class_instance_overload(target):
|
||||
"""
|
||||
Decorator to add an overload for target that applies when the first argument
|
||||
is a ClassInstanceType.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
if not isinstance(args[0], ClassInstanceType):
|
||||
return
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if target is not complex:
|
||||
# complex ctor needs special treatment as it uses kwargs
|
||||
params = list(inspect.signature(wrapped).parameters)
|
||||
assert params == _get_args(len(params))
|
||||
return overload(target)(wrapped)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def extract_template(template, name):
|
||||
"""
|
||||
Extract a code-generated function from a string template.
|
||||
"""
|
||||
namespace = {}
|
||||
exec(template, namespace)
|
||||
return namespace[name]
|
||||
|
||||
|
||||
def register_simple_overload(func, *attrs, n_args=1,):
|
||||
"""
|
||||
Register an overload for func that checks for methods __attr__ for each
|
||||
attr in attrs.
|
||||
"""
|
||||
# Use a template to set the signature correctly.
|
||||
arg_names = _get_args(n_args)
|
||||
template = f"""
|
||||
def func({','.join(arg_names)}):
|
||||
pass
|
||||
"""
|
||||
|
||||
@wraps(extract_template(template, "func"))
|
||||
def overload_func(*args, **kwargs):
|
||||
options = [
|
||||
try_call_method(args[0], f"__{attr}__", n_args)
|
||||
for attr in attrs
|
||||
]
|
||||
return take_first(*options)
|
||||
|
||||
return class_instance_overload(func)(overload_func)
|
||||
|
||||
|
||||
def try_call_method(cls_type, method, n_args=1):
|
||||
"""
|
||||
If method is defined for cls_type, return a callable that calls this method.
|
||||
If not, return None.
|
||||
"""
|
||||
if method in cls_type.jit_methods:
|
||||
arg_names = _get_args(n_args)
|
||||
template = f"""
|
||||
def func({','.join(arg_names)}):
|
||||
return {arg_names[0]}.{method}({','.join(arg_names[1:])})
|
||||
"""
|
||||
return extract_template(template, "func")
|
||||
|
||||
|
||||
def try_call_complex_method(cls_type, method):
|
||||
""" __complex__ needs special treatment as the argument names are kwargs
|
||||
and therefore specific in name and default value.
|
||||
"""
|
||||
if method in cls_type.jit_methods:
|
||||
template = f"""
|
||||
def func(real=0, imag=0):
|
||||
return real.{method}()
|
||||
"""
|
||||
return extract_template(template, "func")
|
||||
|
||||
|
||||
def take_first(*options):
|
||||
"""
|
||||
Take the first non-None option.
|
||||
"""
|
||||
assert all(o is None or inspect.isfunction(o) for o in options), options
|
||||
for o in options:
|
||||
if o is not None:
|
||||
return o
|
||||
|
||||
|
||||
@class_instance_overload(bool)
|
||||
def class_bool(x):
|
||||
using_bool_impl = try_call_method(x, "__bool__")
|
||||
|
||||
if '__len__' in x.jit_methods:
|
||||
def using_len_impl(x):
|
||||
return bool(len(x))
|
||||
else:
|
||||
using_len_impl = None
|
||||
|
||||
always_true_impl = lambda x: True
|
||||
|
||||
return take_first(using_bool_impl, using_len_impl, always_true_impl)
|
||||
|
||||
|
||||
@class_instance_overload(complex)
|
||||
def class_complex(real=0, imag=0):
|
||||
return take_first(
|
||||
try_call_complex_method(real, "__complex__"),
|
||||
lambda real=0, imag=0: complex(float(real))
|
||||
)
|
||||
|
||||
|
||||
@class_instance_overload(operator.contains)
|
||||
def class_contains(x, y):
|
||||
# https://docs.python.org/3/reference/expressions.html#membership-test-operations
|
||||
return try_call_method(x, "__contains__", 2)
|
||||
# TODO: use __iter__ if defined.
|
||||
|
||||
|
||||
@class_instance_overload(float)
|
||||
def class_float(x):
|
||||
options = [try_call_method(x, "__float__")]
|
||||
|
||||
if (
|
||||
"__index__" in x.jit_methods
|
||||
):
|
||||
options.append(lambda x: float(x.__index__()))
|
||||
|
||||
return take_first(*options)
|
||||
|
||||
|
||||
@class_instance_overload(int)
|
||||
def class_int(x):
|
||||
options = [try_call_method(x, "__int__")]
|
||||
|
||||
options.append(try_call_method(x, "__index__"))
|
||||
|
||||
return take_first(*options)
|
||||
|
||||
|
||||
@class_instance_overload(str)
|
||||
def class_str(x):
|
||||
return take_first(
|
||||
try_call_method(x, "__str__"),
|
||||
lambda x: repr(x),
|
||||
)
|
||||
|
||||
|
||||
@class_instance_overload(operator.ne)
|
||||
def class_ne(x, y):
|
||||
# This doesn't use register_reflected_overload like the other operators
|
||||
# because it falls back to inverting __eq__ rather than reflecting its
|
||||
# arguments (as per the definition of the Python data model).
|
||||
return take_first(
|
||||
try_call_method(x, "__ne__", 2),
|
||||
lambda x, y: not (x == y),
|
||||
)
|
||||
|
||||
|
||||
def register_reflected_overload(func, meth_forward, meth_reflected):
|
||||
def class_lt(x, y):
|
||||
normal_impl = try_call_method(x, f"__{meth_forward}__", 2)
|
||||
|
||||
if f"__{meth_reflected}__" in y.jit_methods:
|
||||
def reflected_impl(x, y):
|
||||
return y > x
|
||||
else:
|
||||
reflected_impl = None
|
||||
|
||||
return take_first(normal_impl, reflected_impl)
|
||||
|
||||
class_instance_overload(func)(class_lt)
|
||||
|
||||
|
||||
register_simple_overload(abs, "abs")
|
||||
register_simple_overload(len, "len")
|
||||
register_simple_overload(hash, "hash")
|
||||
|
||||
# Comparison operators.
|
||||
register_reflected_overload(operator.ge, "ge", "le")
|
||||
register_reflected_overload(operator.gt, "gt", "lt")
|
||||
register_reflected_overload(operator.le, "le", "ge")
|
||||
register_reflected_overload(operator.lt, "lt", "gt")
|
||||
|
||||
# Note that eq is missing support for fallback to `x is y`, but `is` and
|
||||
# `operator.is` are presently unsupported in general.
|
||||
register_reflected_overload(operator.eq, "eq", "eq")
|
||||
|
||||
# Arithmetic operators.
|
||||
register_simple_overload(operator.add, "add", n_args=2)
|
||||
register_simple_overload(operator.floordiv, "floordiv", n_args=2)
|
||||
register_simple_overload(operator.lshift, "lshift", n_args=2)
|
||||
register_simple_overload(operator.mul, "mul", n_args=2)
|
||||
register_simple_overload(operator.mod, "mod", n_args=2)
|
||||
register_simple_overload(operator.neg, "neg")
|
||||
register_simple_overload(operator.pos, "pos")
|
||||
register_simple_overload(operator.invert, "invert")
|
||||
register_simple_overload(operator.pow, "pow", n_args=2)
|
||||
register_simple_overload(operator.rshift, "rshift", n_args=2)
|
||||
register_simple_overload(operator.sub, "sub", n_args=2)
|
||||
register_simple_overload(operator.truediv, "truediv", n_args=2)
|
||||
|
||||
# Inplace arithmetic operators.
|
||||
register_simple_overload(operator.iadd, "iadd", "add", n_args=2)
|
||||
register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2)
|
||||
register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2)
|
||||
register_simple_overload(operator.imul, "imul", "mul", n_args=2)
|
||||
register_simple_overload(operator.imod, "imod", "mod", n_args=2)
|
||||
register_simple_overload(operator.ipow, "ipow", "pow", n_args=2)
|
||||
register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2)
|
||||
register_simple_overload(operator.isub, "isub", "sub", n_args=2)
|
||||
register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2)
|
||||
|
||||
# Logical operators.
|
||||
register_simple_overload(operator.and_, "and", n_args=2)
|
||||
register_simple_overload(operator.or_, "or", n_args=2)
|
||||
register_simple_overload(operator.xor, "xor", n_args=2)
|
||||
|
||||
# Inplace logical operators.
|
||||
register_simple_overload(operator.iand, "iand", "and", n_args=2)
|
||||
register_simple_overload(operator.ior, "ior", "or", n_args=2)
|
||||
register_simple_overload(operator.ixor, "ixor", "xor", n_args=2)
|
||||
@@ -0,0 +1,400 @@
|
||||
"""Utilities for defining a mutable struct.
|
||||
|
||||
A mutable struct is passed by reference;
|
||||
hence, structref (a reference to a struct).
|
||||
|
||||
"""
|
||||
import operator
|
||||
from numba.core.cgutils import create_struct_proxy
|
||||
from numba import njit
|
||||
from numba.core import types, imputils, cgutils
|
||||
from numba.core.datamodel import default_manager, models
|
||||
from numba.core.extending import (
|
||||
infer_getattr,
|
||||
lower_getattr_generic,
|
||||
lower_setattr_generic,
|
||||
lower_builtin,
|
||||
box,
|
||||
unbox,
|
||||
NativeValue,
|
||||
intrinsic,
|
||||
overload,
|
||||
)
|
||||
from numba.core.typing.templates import AttributeTemplate
|
||||
|
||||
|
||||
class _Utils:
|
||||
"""Internal builder-code utils for structref definitions.
|
||||
"""
|
||||
def __init__(self, context, builder, struct_type):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
context :
|
||||
a numba target context
|
||||
builder :
|
||||
a llvmlite IRBuilder
|
||||
struct_type : numba.core.types.StructRef
|
||||
"""
|
||||
self.context = context
|
||||
self.builder = builder
|
||||
self.struct_type = struct_type
|
||||
|
||||
def new_struct_ref(self, mi):
|
||||
"""Encapsulate the MemInfo from a `StructRefPayload` in a `StructRef`
|
||||
"""
|
||||
context = self.context
|
||||
builder = self.builder
|
||||
struct_type = self.struct_type
|
||||
|
||||
st = cgutils.create_struct_proxy(struct_type)(context, builder)
|
||||
st.meminfo = mi
|
||||
return st
|
||||
|
||||
def get_struct_ref(self, val):
|
||||
"""Return a helper for accessing a StructRefType
|
||||
"""
|
||||
context = self.context
|
||||
builder = self.builder
|
||||
struct_type = self.struct_type
|
||||
|
||||
return cgutils.create_struct_proxy(struct_type)(
|
||||
context, builder, value=val
|
||||
)
|
||||
|
||||
def get_data_pointer(self, val):
|
||||
"""Get the data pointer to the payload from a `StructRefType`.
|
||||
"""
|
||||
context = self.context
|
||||
builder = self.builder
|
||||
struct_type = self.struct_type
|
||||
|
||||
structval = self.get_struct_ref(val)
|
||||
meminfo = structval.meminfo
|
||||
data_ptr = context.nrt.meminfo_data(builder, meminfo)
|
||||
|
||||
valtype = struct_type.get_data_type()
|
||||
model = context.data_model_manager[valtype]
|
||||
alloc_type = model.get_value_type()
|
||||
data_ptr = builder.bitcast(data_ptr, alloc_type.as_pointer())
|
||||
return data_ptr
|
||||
|
||||
def get_data_struct(self, val):
|
||||
"""Get a getter/setter helper for accessing a `StructRefPayload`
|
||||
"""
|
||||
context = self.context
|
||||
builder = self.builder
|
||||
struct_type = self.struct_type
|
||||
|
||||
data_ptr = self.get_data_pointer(val)
|
||||
valtype = struct_type.get_data_type()
|
||||
dataval = cgutils.create_struct_proxy(valtype)(
|
||||
context, builder, ref=data_ptr
|
||||
)
|
||||
return dataval
|
||||
|
||||
|
||||
def define_attributes(struct_typeclass):
|
||||
"""Define attributes on `struct_typeclass`.
|
||||
|
||||
Defines both setters and getters in jit-code.
|
||||
|
||||
This is called directly in `register()`.
|
||||
"""
|
||||
@infer_getattr
|
||||
class StructAttribute(AttributeTemplate):
|
||||
key = struct_typeclass
|
||||
|
||||
def generic_resolve(self, typ, attr):
|
||||
if attr in typ.field_dict:
|
||||
attrty = typ.field_dict[attr]
|
||||
return attrty
|
||||
|
||||
@lower_getattr_generic(struct_typeclass)
|
||||
def struct_getattr_impl(context, builder, typ, val, attr):
|
||||
utils = _Utils(context, builder, typ)
|
||||
dataval = utils.get_data_struct(val)
|
||||
ret = getattr(dataval, attr)
|
||||
fieldtype = typ.field_dict[attr]
|
||||
return imputils.impl_ret_borrowed(context, builder, fieldtype, ret)
|
||||
|
||||
@lower_setattr_generic(struct_typeclass)
|
||||
def struct_setattr_impl(context, builder, sig, args, attr):
|
||||
[inst_type, val_type] = sig.args
|
||||
[instance, val] = args
|
||||
utils = _Utils(context, builder, inst_type)
|
||||
dataval = utils.get_data_struct(instance)
|
||||
# cast val to the correct type
|
||||
field_type = inst_type.field_dict[attr]
|
||||
casted = context.cast(builder, val, val_type, field_type)
|
||||
# read old
|
||||
old_value = getattr(dataval, attr)
|
||||
# incref new value
|
||||
context.nrt.incref(builder, val_type, casted)
|
||||
# decref old value (must be last in case new value is old value)
|
||||
context.nrt.decref(builder, val_type, old_value)
|
||||
# write new
|
||||
setattr(dataval, attr, casted)
|
||||
|
||||
|
||||
def define_boxing(struct_type, obj_class):
|
||||
"""Define the boxing & unboxing logic for `struct_type` to `obj_class`.
|
||||
|
||||
Defines both boxing and unboxing.
|
||||
|
||||
- boxing turns an instance of `struct_type` into a PyObject of `obj_class`
|
||||
- unboxing turns an instance of `obj_class` into an instance of
|
||||
`struct_type` in jit-code.
|
||||
|
||||
|
||||
Use this directly instead of `define_proxy()` when the user does not
|
||||
want any constructor to be defined.
|
||||
"""
|
||||
if struct_type is types.StructRef:
|
||||
raise ValueError(f"cannot register {types.StructRef}")
|
||||
|
||||
obj_ctor = obj_class._numba_box_
|
||||
|
||||
@box(struct_type)
|
||||
def box_struct_ref(typ, val, c):
|
||||
"""
|
||||
Convert a raw pointer to a Python int.
|
||||
"""
|
||||
utils = _Utils(c.context, c.builder, typ)
|
||||
struct_ref = utils.get_struct_ref(val)
|
||||
meminfo = struct_ref.meminfo
|
||||
|
||||
mip_type = types.MemInfoPointer(types.voidptr)
|
||||
boxed_meminfo = c.box(mip_type, meminfo)
|
||||
|
||||
ctor_pyfunc = c.pyapi.unserialize(c.pyapi.serialize_object(obj_ctor))
|
||||
ty_pyobj = c.pyapi.unserialize(c.pyapi.serialize_object(typ))
|
||||
|
||||
res = c.pyapi.call_function_objargs(
|
||||
ctor_pyfunc, [ty_pyobj, boxed_meminfo],
|
||||
)
|
||||
c.pyapi.decref(ctor_pyfunc)
|
||||
c.pyapi.decref(ty_pyobj)
|
||||
c.pyapi.decref(boxed_meminfo)
|
||||
return res
|
||||
|
||||
@unbox(struct_type)
|
||||
def unbox_struct_ref(typ, obj, c):
|
||||
mi_obj = c.pyapi.object_getattr_string(obj, "_meminfo")
|
||||
|
||||
mip_type = types.MemInfoPointer(types.voidptr)
|
||||
|
||||
mi = c.unbox(mip_type, mi_obj).value
|
||||
|
||||
utils = _Utils(c.context, c.builder, typ)
|
||||
struct_ref = utils.new_struct_ref(mi)
|
||||
out = struct_ref._getvalue()
|
||||
|
||||
c.pyapi.decref(mi_obj)
|
||||
return NativeValue(out)
|
||||
|
||||
|
||||
def define_constructor(py_class, struct_typeclass, fields):
|
||||
"""Define the jit-code constructor for `struct_typeclass` using the
|
||||
Python type `py_class` and the required `fields`.
|
||||
|
||||
Use this instead of `define_proxy()` if the user does not want boxing
|
||||
logic defined.
|
||||
"""
|
||||
# Build source code for the constructor
|
||||
params = ', '.join(fields)
|
||||
indent = ' ' * 8
|
||||
init_fields_buf = []
|
||||
for k in fields:
|
||||
init_fields_buf.append(f"st.{k} = {k}")
|
||||
init_fields = f'\n{indent}'.join(init_fields_buf)
|
||||
|
||||
source = f"""
|
||||
def ctor({params}):
|
||||
struct_type = struct_typeclass(list(zip({list(fields)}, [{params}])))
|
||||
def impl({params}):
|
||||
st = new(struct_type)
|
||||
{init_fields}
|
||||
return st
|
||||
return impl
|
||||
"""
|
||||
|
||||
glbs = dict(struct_typeclass=struct_typeclass, new=new)
|
||||
exec(source, glbs)
|
||||
ctor = glbs['ctor']
|
||||
# Make it an overload
|
||||
overload(py_class)(ctor)
|
||||
|
||||
|
||||
def define_proxy(py_class, struct_typeclass, fields):
|
||||
"""Defines a PyObject proxy for a structref.
|
||||
|
||||
This makes `py_class` a valid constructor for creating a instance of
|
||||
`struct_typeclass` that contains the members as defined by `fields`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
py_class : type
|
||||
The Python class for constructing an instance of `struct_typeclass`.
|
||||
struct_typeclass : numba.core.types.Type
|
||||
The structref type class to bind to.
|
||||
fields : Sequence[str]
|
||||
A sequence of field names.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
define_constructor(py_class, struct_typeclass, fields)
|
||||
define_boxing(struct_typeclass, py_class)
|
||||
|
||||
|
||||
def register(struct_type):
|
||||
"""Register a `numba.core.types.StructRef` for use in jit-code.
|
||||
|
||||
This defines the data-model for lowering an instance of `struct_type`.
|
||||
This defines attributes accessor and mutator for an instance of
|
||||
`struct_type`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
struct_type : type
|
||||
A subclass of `numba.core.types.StructRef`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
struct_type : type
|
||||
Returns the input argument so this can act like a decorator.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. code-block::
|
||||
|
||||
class MyStruct(numba.core.types.StructRef):
|
||||
... # the simplest subclass can be empty
|
||||
|
||||
numba.experimental.structref.register(MyStruct)
|
||||
|
||||
"""
|
||||
if struct_type is types.StructRef:
|
||||
raise ValueError(f"cannot register {types.StructRef}")
|
||||
default_manager.register(struct_type, models.StructRefModel)
|
||||
define_attributes(struct_type)
|
||||
return struct_type
|
||||
|
||||
|
||||
@intrinsic
|
||||
def new(typingctx, struct_type):
|
||||
"""new(struct_type)
|
||||
|
||||
A jit-code only intrinsic. Used to allocate an **empty** mutable struct.
|
||||
The fields are zero-initialized and must be set manually after calling
|
||||
the function.
|
||||
|
||||
Example:
|
||||
|
||||
instance = new(MyStruct)
|
||||
instance.field = field_value
|
||||
"""
|
||||
from numba.experimental.jitclass.base import imp_dtor
|
||||
|
||||
inst_type = struct_type.instance_type
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
# FIXME: mostly the same as jitclass ctor_impl()
|
||||
model = context.data_model_manager[inst_type.get_data_type()]
|
||||
alloc_type = model.get_value_type()
|
||||
alloc_size = context.get_abi_sizeof(alloc_type)
|
||||
|
||||
meminfo = context.nrt.meminfo_alloc_dtor(
|
||||
builder,
|
||||
context.get_constant(types.uintp, alloc_size),
|
||||
imp_dtor(context, builder.module, inst_type),
|
||||
)
|
||||
data_pointer = context.nrt.meminfo_data(builder, meminfo)
|
||||
data_pointer = builder.bitcast(data_pointer, alloc_type.as_pointer())
|
||||
|
||||
# Nullify all data
|
||||
builder.store(cgutils.get_null_value(alloc_type), data_pointer)
|
||||
|
||||
inst_struct = context.make_helper(builder, inst_type)
|
||||
inst_struct.meminfo = meminfo
|
||||
|
||||
return inst_struct._getvalue()
|
||||
|
||||
sig = inst_type(struct_type)
|
||||
return sig, codegen
|
||||
|
||||
|
||||
class StructRefProxy:
|
||||
"""A PyObject proxy to the Numba allocated structref data structure.
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
* Subclasses should not define ``__init__``.
|
||||
* Subclasses can override ``__new__``.
|
||||
"""
|
||||
__slots__ = ('_type', '_meminfo')
|
||||
|
||||
@classmethod
|
||||
def _numba_box_(cls, ty, mi):
|
||||
"""Called by boxing logic, the conversion of Numba internal
|
||||
representation into a PyObject.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ty :
|
||||
a Numba type instance.
|
||||
mi :
|
||||
a wrapped MemInfoPointer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
instance :
|
||||
a StructRefProxy instance.
|
||||
"""
|
||||
instance = super().__new__(cls)
|
||||
instance._type = ty
|
||||
instance._meminfo = mi
|
||||
return instance
|
||||
|
||||
def __new__(cls, *args):
|
||||
"""Construct a new instance of the structref.
|
||||
|
||||
This takes positional-arguments only due to limitation of the compiler.
|
||||
The arguments are mapped to ``cls(*args)`` in jit-code.
|
||||
"""
|
||||
try:
|
||||
# use cached ctor if available
|
||||
ctor = cls.__numba_ctor
|
||||
except AttributeError:
|
||||
# lazily define the ctor
|
||||
@njit
|
||||
def ctor(*args):
|
||||
return cls(*args)
|
||||
# cache it to attribute to avoid recompilation
|
||||
cls.__numba_ctor = ctor
|
||||
return ctor(*args)
|
||||
|
||||
@property
|
||||
def _numba_type_(self):
|
||||
"""Returns the Numba type instance for this structref instance.
|
||||
|
||||
Subclasses should NOT override.
|
||||
"""
|
||||
return self._type
|
||||
|
||||
|
||||
@lower_builtin(operator.is_, types.StructRef, types.StructRef)
|
||||
def structref_is(context, builder, sig, args):
|
||||
"""
|
||||
Define the 'is' operator for structrefs by comparing the memory addresses.
|
||||
This is the identity check for structref objects.
|
||||
"""
|
||||
a, b = args
|
||||
aty, bty = sig.args
|
||||
a_ptr = create_struct_proxy(aty)(context, builder, value=a).meminfo
|
||||
b_ptr = create_struct_proxy(bty)(context, builder, value=b).meminfo
|
||||
return builder.icmp_unsigned("==", a_ptr, b_ptr)
|
||||
Reference in New Issue
Block a user