Videre
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from numba.experimental.jitclass.decorators import jitclass
|
||||
from numba.experimental.jitclass import boxing # Has import-time side effect
|
||||
from numba.experimental.jitclass import overloads # Has import-time side effect
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,598 @@
|
||||
import inspect
|
||||
import operator
|
||||
import types as pytypes
|
||||
import typing as pt
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llvmlite import ir as llvmir
|
||||
from numba import njit
|
||||
from numba.core import cgutils, errors, imputils, types, utils
|
||||
from numba.core.datamodel import default_manager, models
|
||||
from numba.core.registry import cpu_target
|
||||
from numba.core.typing import templates
|
||||
from numba.core.typing.asnumbatype import as_numba_type
|
||||
from numba.core.serialize import disable_pickling
|
||||
from numba.experimental.jitclass import _box
|
||||
|
||||
##############################################################################
|
||||
# Data model
|
||||
|
||||
|
||||
class InstanceModel(models.StructModel):
|
||||
def __init__(self, dmm, fe_typ):
|
||||
cls_data_ty = types.ClassDataType(fe_typ)
|
||||
# MemInfoPointer uses the `dtype` attribute to traverse for nested
|
||||
# NRT MemInfo. Since we handle nested NRT MemInfo ourselves,
|
||||
# we will replace provide MemInfoPointer with an opaque type
|
||||
# so that it does not raise exception for nested meminfo.
|
||||
dtype = types.Opaque('Opaque.' + str(cls_data_ty))
|
||||
members = [
|
||||
('meminfo', types.MemInfoPointer(dtype)),
|
||||
('data', types.CPointer(cls_data_ty)),
|
||||
]
|
||||
super(InstanceModel, self).__init__(dmm, fe_typ, members)
|
||||
|
||||
|
||||
class InstanceDataModel(models.StructModel):
|
||||
def __init__(self, dmm, fe_typ):
|
||||
clsty = fe_typ.class_type
|
||||
members = [(_mangle_attr(k), v) for k, v in clsty.struct.items()]
|
||||
super(InstanceDataModel, self).__init__(dmm, fe_typ, members)
|
||||
|
||||
|
||||
default_manager.register(types.ClassInstanceType, InstanceModel)
|
||||
default_manager.register(types.ClassDataType, InstanceDataModel)
|
||||
default_manager.register(types.ClassType, models.OpaqueModel)
|
||||
|
||||
|
||||
def _mangle_attr(name):
|
||||
"""
|
||||
Mangle attributes.
|
||||
The resulting name does not startswith an underscore '_'.
|
||||
"""
|
||||
return 'm_' + name
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Class object
|
||||
|
||||
_ctor_template = """
|
||||
def ctor({args}):
|
||||
return __numba_cls_({args})
|
||||
"""
|
||||
|
||||
|
||||
def _getargs(fn_sig):
|
||||
"""
|
||||
Returns list of positional and keyword argument names in order.
|
||||
"""
|
||||
params = fn_sig.parameters
|
||||
args = []
|
||||
for k, v in params.items():
|
||||
if (v.kind & v.POSITIONAL_OR_KEYWORD) == v.POSITIONAL_OR_KEYWORD:
|
||||
args.append(k)
|
||||
else:
|
||||
msg = "%s argument type unsupported in jitclass" % v.kind
|
||||
raise errors.UnsupportedError(msg)
|
||||
return args
|
||||
|
||||
|
||||
@disable_pickling
|
||||
class JitClassType(type):
|
||||
"""
|
||||
The type of any jitclass.
|
||||
"""
|
||||
def __new__(cls, name, bases, dct):
|
||||
if len(bases) != 1:
|
||||
raise TypeError("must have exactly one base class")
|
||||
[base] = bases
|
||||
if isinstance(base, JitClassType):
|
||||
raise TypeError("cannot subclass from a jitclass")
|
||||
assert 'class_type' in dct, 'missing "class_type" attr'
|
||||
outcls = type.__new__(cls, name, bases, dct)
|
||||
outcls._set_init()
|
||||
return outcls
|
||||
|
||||
def _set_init(cls):
|
||||
"""
|
||||
Generate a wrapper for calling the constructor from pure Python.
|
||||
Note the wrapper will only accept positional arguments.
|
||||
"""
|
||||
init = cls.class_type.instance_type.methods['__init__']
|
||||
init_sig = utils.pysignature(init)
|
||||
# get postitional and keyword arguments
|
||||
# offset by one to exclude the `self` arg
|
||||
args = _getargs(init_sig)[1:]
|
||||
cls._ctor_sig = init_sig
|
||||
ctor_source = _ctor_template.format(args=', '.join(args))
|
||||
glbls = {"__numba_cls_": cls}
|
||||
exec(ctor_source, glbls)
|
||||
ctor = glbls['ctor']
|
||||
cls._ctor = njit(ctor)
|
||||
|
||||
def __instancecheck__(cls, instance):
|
||||
if isinstance(instance, _box.Box):
|
||||
return instance._numba_type_.class_type is cls.class_type
|
||||
return False
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
# The first argument of _ctor_sig is `cls`, which here
|
||||
# is bound to None and then skipped when invoking the constructor.
|
||||
bind = cls._ctor_sig.bind(None, *args, **kwargs)
|
||||
bind.apply_defaults()
|
||||
return cls._ctor(*bind.args[1:], **bind.kwargs)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Registration utils
|
||||
|
||||
def _validate_spec(spec):
|
||||
for k, v in spec.items():
|
||||
if not isinstance(k, str):
|
||||
raise TypeError("spec keys should be strings, got %r" % (k,))
|
||||
if not isinstance(v, types.Type):
|
||||
raise TypeError("spec values should be Numba type instances, got %r"
|
||||
% (v,))
|
||||
|
||||
|
||||
def _fix_up_private_attr(clsname, spec):
|
||||
"""
|
||||
Apply the same changes to dunder names as CPython would.
|
||||
"""
|
||||
out = OrderedDict()
|
||||
for k, v in spec.items():
|
||||
if k.startswith('__') and not k.endswith('__'):
|
||||
k = '_' + clsname + k
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def _add_linking_libs(context, call):
|
||||
"""
|
||||
Add the required libs for the callable to allow inlining.
|
||||
"""
|
||||
libs = getattr(call, "libs", ())
|
||||
if libs:
|
||||
context.add_linking_libs(libs)
|
||||
|
||||
|
||||
def register_class_type(cls, spec, class_ctor, builder):
|
||||
"""
|
||||
Internal function to create a jitclass.
|
||||
|
||||
Args
|
||||
----
|
||||
cls: the original class object (used as the prototype)
|
||||
spec: the structural specification contains the field types.
|
||||
class_ctor: the numba type to represent the jitclass
|
||||
builder: the internal jitclass builder
|
||||
"""
|
||||
# Normalize spec
|
||||
if spec is None:
|
||||
spec = OrderedDict()
|
||||
elif isinstance(spec, Sequence):
|
||||
spec = OrderedDict(spec)
|
||||
|
||||
# Extend spec with class annotations.
|
||||
for attr, py_type in pt.get_type_hints(cls).items():
|
||||
if attr not in spec:
|
||||
spec[attr] = as_numba_type(py_type)
|
||||
|
||||
_validate_spec(spec)
|
||||
|
||||
# Fix up private attribute names
|
||||
spec = _fix_up_private_attr(cls.__name__, spec)
|
||||
|
||||
# Copy methods from base classes
|
||||
clsdct = {}
|
||||
for basecls in reversed(inspect.getmro(cls)):
|
||||
clsdct.update(basecls.__dict__)
|
||||
|
||||
methods, props, static_methods, others = {}, {}, {}, {}
|
||||
for k, v in clsdct.items():
|
||||
if isinstance(v, pytypes.FunctionType):
|
||||
methods[k] = v
|
||||
elif isinstance(v, property):
|
||||
props[k] = v
|
||||
elif isinstance(v, staticmethod):
|
||||
static_methods[k] = v
|
||||
else:
|
||||
others[k] = v
|
||||
|
||||
# Check for name shadowing
|
||||
shadowed = (set(methods) | set(props) | set(static_methods)) & set(spec)
|
||||
if shadowed:
|
||||
raise NameError("name shadowing: {0}".format(', '.join(shadowed)))
|
||||
|
||||
docstring = others.pop('__doc__', "")
|
||||
_drop_ignored_attrs(others)
|
||||
if others:
|
||||
msg = "class members are not yet supported: {0}"
|
||||
members = ', '.join(others.keys())
|
||||
raise TypeError(msg.format(members))
|
||||
|
||||
for k, v in props.items():
|
||||
if v.fdel is not None:
|
||||
raise TypeError("deleter is not supported: {0}".format(k))
|
||||
|
||||
jit_methods = {k: njit(v) for k, v in methods.items()}
|
||||
|
||||
jit_props = {}
|
||||
for k, v in props.items():
|
||||
dct = {}
|
||||
if v.fget:
|
||||
dct['get'] = njit(v.fget)
|
||||
if v.fset:
|
||||
dct['set'] = njit(v.fset)
|
||||
jit_props[k] = dct
|
||||
|
||||
jit_static_methods = {
|
||||
k: njit(v.__func__) for k, v in static_methods.items()}
|
||||
|
||||
# Instantiate class type
|
||||
class_type = class_ctor(
|
||||
cls,
|
||||
ConstructorTemplate,
|
||||
spec,
|
||||
jit_methods,
|
||||
jit_props,
|
||||
jit_static_methods)
|
||||
|
||||
jit_class_dct = dict(class_type=class_type, __doc__=docstring)
|
||||
jit_class_dct.update(jit_static_methods)
|
||||
cls = JitClassType(cls.__name__, (cls,), jit_class_dct)
|
||||
|
||||
# Register resolution of the class object
|
||||
typingctx = cpu_target.typing_context
|
||||
typingctx.insert_global(cls, class_type)
|
||||
|
||||
# Register class
|
||||
targetctx = cpu_target.target_context
|
||||
builder(class_type, typingctx, targetctx).register()
|
||||
as_numba_type.register(cls, class_type.instance_type)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class ConstructorTemplate(templates.AbstractTemplate):
|
||||
"""
|
||||
Base class for jitclass constructor templates.
|
||||
"""
|
||||
|
||||
def generic(self, args, kws):
|
||||
# Redirect resolution to __init__
|
||||
instance_type = self.key.instance_type
|
||||
ctor = instance_type.jit_methods['__init__']
|
||||
boundargs = (instance_type.get_reference_type(),) + args
|
||||
disp_type = types.Dispatcher(ctor)
|
||||
sig = disp_type.get_call_type(self.context, boundargs, kws)
|
||||
|
||||
if not isinstance(sig.return_type, types.NoneType):
|
||||
raise errors.NumbaTypeError(
|
||||
f"__init__() should return None, not '{sig.return_type}'")
|
||||
|
||||
# Actual constructor returns an instance value (not None)
|
||||
out = templates.signature(instance_type, *sig.args[1:])
|
||||
return out
|
||||
|
||||
|
||||
def _drop_ignored_attrs(dct):
|
||||
# ignore anything defined by object
|
||||
drop = set(['__weakref__',
|
||||
'__module__',
|
||||
'__dict__'])
|
||||
if utils.PYVERSION in ((3, 13), (3, 14)):
|
||||
# new in python 3.13
|
||||
drop |= set(['__firstlineno__', '__static_attributes__'])
|
||||
|
||||
for att in ('__annotations__',
|
||||
'__annotate_func__',
|
||||
'__annotations_cache__'):
|
||||
if att in dct:
|
||||
drop.add(att)
|
||||
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, (pytypes.BuiltinFunctionType,
|
||||
pytypes.BuiltinMethodType)):
|
||||
drop.add(k)
|
||||
elif getattr(v, '__objclass__', None) is object:
|
||||
drop.add(k)
|
||||
|
||||
# If a class defines __eq__ but not __hash__, __hash__ is implicitly set to
|
||||
# None. This is a class member, and class members are not presently
|
||||
# supported.
|
||||
if '__hash__' in dct and dct['__hash__'] is None:
|
||||
drop.add('__hash__')
|
||||
|
||||
for k in drop:
|
||||
dct.pop(k)
|
||||
|
||||
|
||||
class ClassBuilder(object):
|
||||
"""
|
||||
A jitclass builder for a mutable jitclass. This will register
|
||||
typing and implementation hooks to the given typing and target contexts.
|
||||
"""
|
||||
class_impl_registry = imputils.Registry('jitclass builder')
|
||||
implemented_methods = set()
|
||||
|
||||
def __init__(self, class_type, typingctx, targetctx):
|
||||
self.class_type = class_type
|
||||
self.typingctx = typingctx
|
||||
self.targetctx = targetctx
|
||||
|
||||
def register(self):
|
||||
"""
|
||||
Register to the frontend and backend.
|
||||
"""
|
||||
# Register generic implementations for all jitclasses
|
||||
self._register_methods(self.class_impl_registry,
|
||||
self.class_type.instance_type)
|
||||
# NOTE other registrations are done at the top-level
|
||||
# (see ctor_impl and attr_impl below)
|
||||
self.targetctx.install_registry(self.class_impl_registry)
|
||||
|
||||
def _register_methods(self, registry, instance_type):
|
||||
"""
|
||||
Register method implementations.
|
||||
This simply registers that the method names are valid methods. Inside
|
||||
of imp() below we retrieve the actual method to run from the type of
|
||||
the receiver argument (i.e. self).
|
||||
"""
|
||||
to_register = list(instance_type.jit_methods) + \
|
||||
list(instance_type.jit_static_methods)
|
||||
for meth in to_register:
|
||||
|
||||
# There's no way to retrieve the particular method name
|
||||
# inside the implementation function, so we have to register a
|
||||
# specific closure for each different name
|
||||
if meth not in self.implemented_methods:
|
||||
self._implement_method(registry, meth)
|
||||
self.implemented_methods.add(meth)
|
||||
|
||||
def _implement_method(self, registry, attr):
|
||||
# create a separate instance of imp method to avoid closure clashing
|
||||
def get_imp():
|
||||
def imp(context, builder, sig, args):
|
||||
instance_type = sig.args[0]
|
||||
|
||||
if attr in instance_type.jit_methods:
|
||||
method = instance_type.jit_methods[attr]
|
||||
elif attr in instance_type.jit_static_methods:
|
||||
method = instance_type.jit_static_methods[attr]
|
||||
# imp gets called as a method, where the first argument is
|
||||
# self. We drop this for a static method.
|
||||
sig = sig.replace(args=sig.args[1:])
|
||||
args = args[1:]
|
||||
|
||||
disp_type = types.Dispatcher(method)
|
||||
call = context.get_function(disp_type, sig)
|
||||
out = call(builder, args)
|
||||
_add_linking_libs(context, call)
|
||||
return imputils.impl_ret_new_ref(context, builder,
|
||||
sig.return_type, out)
|
||||
return imp
|
||||
|
||||
def _getsetitem_gen(getset):
|
||||
_dunder_meth = "__%s__" % getset
|
||||
op = getattr(operator, getset)
|
||||
|
||||
@templates.infer_global(op)
|
||||
class GetSetItem(templates.AbstractTemplate):
|
||||
def generic(self, args, kws):
|
||||
instance = args[0]
|
||||
if isinstance(instance, types.ClassInstanceType) and \
|
||||
_dunder_meth in instance.jit_methods:
|
||||
meth = instance.jit_methods[_dunder_meth]
|
||||
disp_type = types.Dispatcher(meth)
|
||||
sig = disp_type.get_call_type(self.context, args, kws)
|
||||
return sig
|
||||
|
||||
# lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
|
||||
# from python and numba
|
||||
imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
|
||||
types.ClassInstanceType,
|
||||
types.VarArg(types.Any))(get_imp())
|
||||
imputils.lower_builtin(op,
|
||||
types.ClassInstanceType,
|
||||
types.VarArg(types.Any))(get_imp())
|
||||
|
||||
dunder_stripped = attr.strip('_')
|
||||
if dunder_stripped in ("getitem", "setitem"):
|
||||
_getsetitem_gen(dunder_stripped)
|
||||
else:
|
||||
registry.lower((types.ClassInstanceType, attr),
|
||||
types.ClassInstanceType,
|
||||
types.VarArg(types.Any))(get_imp())
|
||||
|
||||
|
||||
@templates.infer_getattr
|
||||
class ClassAttribute(templates.AttributeTemplate):
|
||||
key = types.ClassInstanceType
|
||||
|
||||
def generic_resolve(self, instance, attr):
|
||||
if attr in instance.struct:
|
||||
# It's a struct field => the type is well-known
|
||||
return instance.struct[attr]
|
||||
|
||||
elif attr in instance.jit_methods:
|
||||
# It's a jitted method => typeinfer it
|
||||
meth = instance.jit_methods[attr]
|
||||
disp_type = types.Dispatcher(meth)
|
||||
|
||||
class MethodTemplate(templates.AbstractTemplate):
|
||||
key = (self.key, attr)
|
||||
|
||||
def generic(self, args, kws):
|
||||
args = (instance,) + tuple(args)
|
||||
sig = disp_type.get_call_type(self.context, args, kws)
|
||||
return sig.as_method()
|
||||
|
||||
return types.BoundFunction(MethodTemplate, instance)
|
||||
|
||||
elif attr in instance.jit_static_methods:
|
||||
# It's a jitted method => typeinfer it
|
||||
meth = instance.jit_static_methods[attr]
|
||||
disp_type = types.Dispatcher(meth)
|
||||
|
||||
class StaticMethodTemplate(templates.AbstractTemplate):
|
||||
key = (self.key, attr)
|
||||
|
||||
def generic(self, args, kws):
|
||||
# Don't add instance as the first argument for a static
|
||||
# method.
|
||||
sig = disp_type.get_call_type(self.context, args, kws)
|
||||
# sig itself does not include ClassInstanceType as it's
|
||||
# first argument, so instead of calling sig.as_method()
|
||||
# we insert the recvr. This is equivalent to
|
||||
# sig.replace(args=(instance,) + sig.args).as_method().
|
||||
return sig.replace(recvr=instance)
|
||||
|
||||
return types.BoundFunction(StaticMethodTemplate, instance)
|
||||
|
||||
elif attr in instance.jit_props:
|
||||
# It's a jitted property => typeinfer its getter
|
||||
impdct = instance.jit_props[attr]
|
||||
getter = impdct['get']
|
||||
disp_type = types.Dispatcher(getter)
|
||||
sig = disp_type.get_call_type(self.context, (instance,), {})
|
||||
return sig.return_type
|
||||
|
||||
|
||||
@ClassBuilder.class_impl_registry.lower_getattr_generic(types.ClassInstanceType)
|
||||
def get_attr_impl(context, builder, typ, value, attr):
|
||||
"""
|
||||
Generic getattr() for @jitclass instances.
|
||||
"""
|
||||
if attr in typ.struct:
|
||||
# It's a struct field
|
||||
inst = context.make_helper(builder, typ, value=value)
|
||||
data_pointer = inst.data
|
||||
data = context.make_data_helper(builder, typ.get_data_type(),
|
||||
ref=data_pointer)
|
||||
return imputils.impl_ret_borrowed(context, builder,
|
||||
typ.struct[attr],
|
||||
getattr(data, _mangle_attr(attr)))
|
||||
elif attr in typ.jit_props:
|
||||
# It's a jitted property
|
||||
getter = typ.jit_props[attr]['get']
|
||||
sig = templates.signature(None, typ)
|
||||
dispatcher = types.Dispatcher(getter)
|
||||
sig = dispatcher.get_call_type(context.typing_context, [typ], {})
|
||||
call = context.get_function(dispatcher, sig)
|
||||
out = call(builder, [value])
|
||||
_add_linking_libs(context, call)
|
||||
return imputils.impl_ret_new_ref(context, builder, sig.return_type, out)
|
||||
|
||||
raise NotImplementedError('attribute {0!r} not implemented'.format(attr))
|
||||
|
||||
|
||||
@ClassBuilder.class_impl_registry.lower_setattr_generic(types.ClassInstanceType)
|
||||
def set_attr_impl(context, builder, sig, args, attr):
|
||||
"""
|
||||
Generic setattr() for @jitclass instances.
|
||||
"""
|
||||
typ, valty = sig.args
|
||||
target, val = args
|
||||
|
||||
if attr in typ.struct:
|
||||
# It's a struct member
|
||||
inst = context.make_helper(builder, typ, value=target)
|
||||
data_ptr = inst.data
|
||||
data = context.make_data_helper(builder, typ.get_data_type(),
|
||||
ref=data_ptr)
|
||||
|
||||
# Get old value
|
||||
attr_type = typ.struct[attr]
|
||||
oldvalue = getattr(data, _mangle_attr(attr))
|
||||
|
||||
# Store n
|
||||
setattr(data, _mangle_attr(attr), val)
|
||||
context.nrt.incref(builder, attr_type, val)
|
||||
|
||||
# Delete old value
|
||||
context.nrt.decref(builder, attr_type, oldvalue)
|
||||
|
||||
elif attr in typ.jit_props:
|
||||
# It's a jitted property
|
||||
setter = typ.jit_props[attr]['set']
|
||||
disp_type = types.Dispatcher(setter)
|
||||
sig = disp_type.get_call_type(context.typing_context,
|
||||
(typ, valty), {})
|
||||
call = context.get_function(disp_type, sig)
|
||||
call(builder, (target, val))
|
||||
_add_linking_libs(context, call)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'attribute {0!r} not implemented'.format(attr))
|
||||
|
||||
|
||||
def imp_dtor(context, module, instance_type):
|
||||
llvoidptr = context.get_value_type(types.voidptr)
|
||||
llsize = context.get_value_type(types.uintp)
|
||||
dtor_ftype = llvmir.FunctionType(llvmir.VoidType(),
|
||||
[llvoidptr, llsize, llvoidptr])
|
||||
|
||||
fname = "_Dtor.{0}".format(instance_type.name)
|
||||
dtor_fn = cgutils.get_or_insert_function(module, dtor_ftype, fname)
|
||||
if dtor_fn.is_declaration:
|
||||
# Define
|
||||
builder = llvmir.IRBuilder(dtor_fn.append_basic_block())
|
||||
|
||||
alloc_fe_type = instance_type.get_data_type()
|
||||
alloc_type = context.get_value_type(alloc_fe_type)
|
||||
|
||||
ptr = builder.bitcast(dtor_fn.args[0], alloc_type.as_pointer())
|
||||
data = context.make_helper(builder, alloc_fe_type, ref=ptr)
|
||||
|
||||
context.nrt.decref(builder, alloc_fe_type, data._getvalue())
|
||||
|
||||
builder.ret_void()
|
||||
|
||||
return dtor_fn
|
||||
|
||||
|
||||
@ClassBuilder.class_impl_registry.lower(types.ClassType,
|
||||
types.VarArg(types.Any))
|
||||
def ctor_impl(context, builder, sig, args):
|
||||
"""
|
||||
Generic constructor (__new__) for jitclasses.
|
||||
"""
|
||||
# Allocate the instance
|
||||
inst_typ = sig.return_type
|
||||
alloc_type = context.get_data_type(inst_typ.get_data_type())
|
||||
alloc_size = context.get_abi_sizeof(alloc_type)
|
||||
|
||||
meminfo = context.nrt.meminfo_alloc_dtor(
|
||||
builder,
|
||||
context.get_constant(types.uintp, alloc_size),
|
||||
imp_dtor(context, builder.module, inst_typ),
|
||||
)
|
||||
data_pointer = context.nrt.meminfo_data(builder, meminfo)
|
||||
data_pointer = builder.bitcast(data_pointer,
|
||||
alloc_type.as_pointer())
|
||||
|
||||
# Nullify all data
|
||||
builder.store(cgutils.get_null_value(alloc_type),
|
||||
data_pointer)
|
||||
|
||||
inst_struct = context.make_helper(builder, inst_typ)
|
||||
inst_struct.meminfo = meminfo
|
||||
inst_struct.data = data_pointer
|
||||
|
||||
# Call the jitted __init__
|
||||
# TODO: extract the following into a common util
|
||||
init_sig = (sig.return_type,) + sig.args
|
||||
|
||||
init = inst_typ.jit_methods['__init__']
|
||||
disp_type = types.Dispatcher(init)
|
||||
call = context.get_function(disp_type, types.void(*init_sig))
|
||||
_add_linking_libs(context, call)
|
||||
realargs = [inst_struct._getvalue()] + list(args)
|
||||
call(builder, realargs)
|
||||
|
||||
# Prepare return value
|
||||
ret = inst_struct._getvalue()
|
||||
|
||||
return imputils.impl_ret_new_ref(context, builder, inst_typ, ret)
|
||||
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Implement logic relating to wrapping (box) and unwrapping (unbox) instances
|
||||
of jitclasses for use inside the python interpreter.
|
||||
"""
|
||||
|
||||
from functools import wraps, partial
|
||||
|
||||
from llvmlite import ir
|
||||
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.decorators import njit
|
||||
from numba.core.pythonapi import box, unbox, NativeValue
|
||||
from numba.core.typing.typeof import typeof_impl
|
||||
from numba.experimental.jitclass import _box
|
||||
|
||||
|
||||
_getter_code_template = """
|
||||
def accessor(__numba_self_):
|
||||
return __numba_self_.{0}
|
||||
"""
|
||||
|
||||
_setter_code_template = """
|
||||
def mutator(__numba_self_, __numba_val):
|
||||
__numba_self_.{0} = __numba_val
|
||||
"""
|
||||
|
||||
_method_code_template = """
|
||||
def method(__numba_self_, *args):
|
||||
return __numba_self_.{method}(*args)
|
||||
"""
|
||||
|
||||
|
||||
def _generate_property(field, template, fname):
|
||||
"""
|
||||
Generate simple function that get/set a field of the instance
|
||||
"""
|
||||
source = template.format(field)
|
||||
glbls = {}
|
||||
exec(source, glbls)
|
||||
return njit(glbls[fname])
|
||||
|
||||
|
||||
_generate_getter = partial(_generate_property, template=_getter_code_template,
|
||||
fname='accessor')
|
||||
_generate_setter = partial(_generate_property, template=_setter_code_template,
|
||||
fname='mutator')
|
||||
|
||||
|
||||
def _generate_method(name, func):
|
||||
"""
|
||||
Generate a wrapper for calling a method. Note the wrapper will only
|
||||
accept positional arguments.
|
||||
"""
|
||||
source = _method_code_template.format(method=name)
|
||||
glbls = {}
|
||||
exec(source, glbls)
|
||||
method = njit(glbls['method'])
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
_cache_specialized_box = {}
|
||||
|
||||
|
||||
def _specialize_box(typ):
|
||||
"""
|
||||
Create a subclass of Box that is specialized to the jitclass.
|
||||
|
||||
This function caches the result to avoid code bloat.
|
||||
"""
|
||||
# Check cache
|
||||
if typ in _cache_specialized_box:
|
||||
return _cache_specialized_box[typ]
|
||||
dct = {'__slots__': (),
|
||||
'_numba_type_': typ,
|
||||
'__doc__': typ.class_type.class_doc,
|
||||
}
|
||||
# Inject attributes as class properties
|
||||
for field in typ.struct:
|
||||
getter = _generate_getter(field)
|
||||
setter = _generate_setter(field)
|
||||
dct[field] = property(getter, setter)
|
||||
# Inject properties as class properties
|
||||
for field, impdct in typ.jit_props.items():
|
||||
getter = None
|
||||
setter = None
|
||||
if 'get' in impdct:
|
||||
getter = _generate_getter(field)
|
||||
if 'set' in impdct:
|
||||
setter = _generate_setter(field)
|
||||
# get docstring from either the fget or fset
|
||||
imp = impdct.get('get') or impdct.get('set') or None
|
||||
doc = getattr(imp, '__doc__', None)
|
||||
dct[field] = property(getter, setter, doc=doc)
|
||||
# Inject methods as class members
|
||||
supported_dunders = {
|
||||
"__abs__",
|
||||
"__annotate_func__",
|
||||
"__bool__",
|
||||
"__complex__",
|
||||
"__contains__",
|
||||
"__float__",
|
||||
"__getitem__",
|
||||
"__hash__",
|
||||
"__index__",
|
||||
"__invert__",
|
||||
"__int__",
|
||||
"__len__",
|
||||
"__setitem__",
|
||||
"__str__",
|
||||
"__eq__",
|
||||
"__ne__",
|
||||
"__ge__",
|
||||
"__gt__",
|
||||
"__le__",
|
||||
"__lt__",
|
||||
"__add__",
|
||||
"__floordiv__",
|
||||
"__lshift__",
|
||||
"__matmul__",
|
||||
"__mod__",
|
||||
"__mul__",
|
||||
"__neg__",
|
||||
"__pos__",
|
||||
"__pow__",
|
||||
"__rshift__",
|
||||
"__sub__",
|
||||
"__truediv__",
|
||||
"__and__",
|
||||
"__or__",
|
||||
"__xor__",
|
||||
"__iadd__",
|
||||
"__ifloordiv__",
|
||||
"__ilshift__",
|
||||
"__imatmul__",
|
||||
"__imod__",
|
||||
"__imul__",
|
||||
"__ipow__",
|
||||
"__irshift__",
|
||||
"__isub__",
|
||||
"__itruediv__",
|
||||
"__iand__",
|
||||
"__ior__",
|
||||
"__ixor__",
|
||||
"__radd__",
|
||||
"__rfloordiv__",
|
||||
"__rlshift__",
|
||||
"__rmatmul__",
|
||||
"__rmod__",
|
||||
"__rmul__",
|
||||
"__rpow__",
|
||||
"__rrshift__",
|
||||
"__rsub__",
|
||||
"__rtruediv__",
|
||||
"__rand__",
|
||||
"__ror__",
|
||||
"__rxor__",
|
||||
}
|
||||
for name, func in typ.methods.items():
|
||||
if name == "__init__":
|
||||
continue
|
||||
if (
|
||||
name.startswith("__")
|
||||
and name.endswith("__")
|
||||
and name not in supported_dunders
|
||||
):
|
||||
raise TypeError(f"Method '{name}' is not supported.")
|
||||
dct[name] = _generate_method(name, func)
|
||||
|
||||
# Inject static methods as class members
|
||||
for name, func in typ.static_methods.items():
|
||||
dct[name] = _generate_method(name, func)
|
||||
|
||||
# Create subclass
|
||||
subcls = type(typ.classname, (_box.Box,), dct)
|
||||
# Store to cache
|
||||
_cache_specialized_box[typ] = subcls
|
||||
|
||||
# Pre-compile attribute getter.
|
||||
# Note: This must be done after the "box" class is created because
|
||||
# compiling the getter requires the "box" class to be defined.
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, property):
|
||||
prop = getattr(subcls, k)
|
||||
if prop.fget is not None:
|
||||
fget = prop.fget
|
||||
fast_fget = fget.compile((typ,))
|
||||
fget.disable_compile()
|
||||
setattr(subcls, k,
|
||||
property(fast_fget, prop.fset, prop.fdel,
|
||||
doc=prop.__doc__))
|
||||
|
||||
return subcls
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Implement box/unbox for call wrapper
|
||||
|
||||
@box(types.ClassInstanceType)
|
||||
def _box_class_instance(typ, val, c):
|
||||
meminfo, dataptr = cgutils.unpack_tuple(c.builder, val)
|
||||
|
||||
# Create Box instance
|
||||
box_subclassed = _specialize_box(typ)
|
||||
# Note: the ``box_subclassed`` is kept alive by the cache
|
||||
voidptr_boxcls = c.context.add_dynamic_addr(
|
||||
c.builder,
|
||||
id(box_subclassed),
|
||||
info="box_class_instance",
|
||||
)
|
||||
box_cls = c.builder.bitcast(voidptr_boxcls, c.pyapi.pyobj)
|
||||
|
||||
box = c.pyapi.call_function_objargs(box_cls, ())
|
||||
|
||||
# Initialize Box instance
|
||||
llvoidptr = ir.IntType(8).as_pointer()
|
||||
addr_meminfo = c.builder.bitcast(meminfo, llvoidptr)
|
||||
addr_data = c.builder.bitcast(dataptr, llvoidptr)
|
||||
|
||||
def set_member(member_offset, value):
|
||||
# Access member by byte offset
|
||||
offset = c.context.get_constant(types.uintp, member_offset)
|
||||
ptr = cgutils.pointer_add(c.builder, box, offset)
|
||||
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
|
||||
c.builder.store(value, casted)
|
||||
|
||||
set_member(_box.box_meminfoptr_offset, addr_meminfo)
|
||||
set_member(_box.box_dataptr_offset, addr_data)
|
||||
return box
|
||||
|
||||
|
||||
@unbox(types.ClassInstanceType)
|
||||
def _unbox_class_instance(typ, val, c):
|
||||
def access_member(member_offset):
|
||||
# Access member by byte offset
|
||||
offset = c.context.get_constant(types.uintp, member_offset)
|
||||
llvoidptr = ir.IntType(8).as_pointer()
|
||||
ptr = cgutils.pointer_add(c.builder, val, offset)
|
||||
casted = c.builder.bitcast(ptr, llvoidptr.as_pointer())
|
||||
return c.builder.load(casted)
|
||||
|
||||
struct_cls = cgutils.create_struct_proxy(typ)
|
||||
inst = struct_cls(c.context, c.builder)
|
||||
|
||||
# load from Python object
|
||||
ptr_meminfo = access_member(_box.box_meminfoptr_offset)
|
||||
ptr_dataptr = access_member(_box.box_dataptr_offset)
|
||||
|
||||
# store to native structure
|
||||
inst.meminfo = c.builder.bitcast(ptr_meminfo, inst.meminfo.type)
|
||||
inst.data = c.builder.bitcast(ptr_dataptr, inst.data.type)
|
||||
|
||||
ret = inst._getvalue()
|
||||
|
||||
c.context.nrt.incref(c.builder, typ, ret)
|
||||
|
||||
return NativeValue(ret, is_error=c.pyapi.c_api_error())
|
||||
|
||||
|
||||
# Add a typeof_impl implementation for boxed jitclasses to short-circut the
|
||||
# various tests in typeof. This is needed for jitclasses which implement a
|
||||
# custom hash method. Without this, typeof_impl will return None, and one of the
|
||||
# later attempts to determine the type of the jitclass (before checking for
|
||||
# _numba_type_) will look up the object in a dictionary, triggering the hash
|
||||
# method. This will cause the dispatcher to determine the call signature of the
|
||||
# jit decorated obj.__hash__ method, which will call typeof(obj), and thus
|
||||
# infinite loop.
|
||||
# This implementation is here instead of in typeof.py to avoid circular imports.
|
||||
@typeof_impl.register(_box.Box)
|
||||
def _typeof_jitclass_box(val, c):
|
||||
return getattr(type(val), "_numba_type_")
|
||||
@@ -0,0 +1,88 @@
|
||||
from numba.core import types, config
|
||||
|
||||
|
||||
def jitclass(cls_or_spec=None, spec=None):
|
||||
"""
|
||||
A function for creating a jitclass.
|
||||
Can be used as a decorator or function.
|
||||
|
||||
Different use cases will cause different arguments to be set.
|
||||
|
||||
If specified, ``spec`` gives the types of class fields.
|
||||
It must be a dictionary or sequence.
|
||||
With a dictionary, use collections.OrderedDict for stable ordering.
|
||||
With a sequence, it must contain 2-tuples of (fieldname, fieldtype).
|
||||
|
||||
Any class annotations for field names not listed in spec will be added.
|
||||
For class annotation `x: T` we will append ``("x", as_numba_type(T))`` to
|
||||
the spec if ``x`` is not already a key in spec.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
1) ``cls_or_spec = None``, ``spec = None``
|
||||
|
||||
>>> @jitclass()
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
2) ``cls_or_spec = None``, ``spec = spec``
|
||||
|
||||
>>> @jitclass(spec=spec)
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
3) ``cls_or_spec = Foo``, ``spec = None``
|
||||
|
||||
>>> @jitclass
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
4) ``cls_or_spec = spec``, ``spec = None``
|
||||
In this case we update ``cls_or_spec, spec = None, cls_or_spec``.
|
||||
|
||||
>>> @jitclass(spec)
|
||||
... class Foo:
|
||||
... ...
|
||||
|
||||
5) ``cls_or_spec = Foo``, ``spec = spec``
|
||||
|
||||
>>> JitFoo = jitclass(Foo, spec)
|
||||
|
||||
Returns
|
||||
-------
|
||||
If used as a decorator, returns a callable that takes a class object and
|
||||
returns a compiled version.
|
||||
If used as a function, returns the compiled class (an instance of
|
||||
``JitClassType``).
|
||||
"""
|
||||
|
||||
if (cls_or_spec is not None and
|
||||
spec is None and
|
||||
not isinstance(cls_or_spec, type)):
|
||||
# Used like
|
||||
# @jitclass([("x", intp)])
|
||||
# class Foo:
|
||||
# ...
|
||||
spec = cls_or_spec
|
||||
cls_or_spec = None
|
||||
|
||||
def wrap(cls):
|
||||
if config.DISABLE_JIT:
|
||||
return cls
|
||||
else:
|
||||
from numba.experimental.jitclass.base import (register_class_type,
|
||||
ClassBuilder)
|
||||
cls_jitted = register_class_type(cls, spec, types.ClassType,
|
||||
ClassBuilder)
|
||||
|
||||
# Preserve the module name of the original class
|
||||
cls_jitted.__module__ = cls.__module__
|
||||
|
||||
return cls_jitted
|
||||
|
||||
if cls_or_spec is None:
|
||||
return wrap
|
||||
else:
|
||||
return wrap(cls_or_spec)
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Overloads for ClassInstanceType for built-in functions that call dunder methods
|
||||
on an object.
|
||||
"""
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import operator
|
||||
|
||||
from numba.core.extending import overload
|
||||
from numba.core.types import ClassInstanceType
|
||||
|
||||
|
||||
def _get_args(n_args):
|
||||
assert n_args in (1, 2)
|
||||
return list("xy")[:n_args]
|
||||
|
||||
|
||||
def class_instance_overload(target):
|
||||
"""
|
||||
Decorator to add an overload for target that applies when the first argument
|
||||
is a ClassInstanceType.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
if not isinstance(args[0], ClassInstanceType):
|
||||
return
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if target is not complex:
|
||||
# complex ctor needs special treatment as it uses kwargs
|
||||
params = list(inspect.signature(wrapped).parameters)
|
||||
assert params == _get_args(len(params))
|
||||
return overload(target)(wrapped)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def extract_template(template, name):
|
||||
"""
|
||||
Extract a code-generated function from a string template.
|
||||
"""
|
||||
namespace = {}
|
||||
exec(template, namespace)
|
||||
return namespace[name]
|
||||
|
||||
|
||||
def register_simple_overload(func, *attrs, n_args=1,):
|
||||
"""
|
||||
Register an overload for func that checks for methods __attr__ for each
|
||||
attr in attrs.
|
||||
"""
|
||||
# Use a template to set the signature correctly.
|
||||
arg_names = _get_args(n_args)
|
||||
template = f"""
|
||||
def func({','.join(arg_names)}):
|
||||
pass
|
||||
"""
|
||||
|
||||
@wraps(extract_template(template, "func"))
|
||||
def overload_func(*args, **kwargs):
|
||||
options = [
|
||||
try_call_method(args[0], f"__{attr}__", n_args)
|
||||
for attr in attrs
|
||||
]
|
||||
return take_first(*options)
|
||||
|
||||
return class_instance_overload(func)(overload_func)
|
||||
|
||||
|
||||
def try_call_method(cls_type, method, n_args=1):
|
||||
"""
|
||||
If method is defined for cls_type, return a callable that calls this method.
|
||||
If not, return None.
|
||||
"""
|
||||
if method in cls_type.jit_methods:
|
||||
arg_names = _get_args(n_args)
|
||||
template = f"""
|
||||
def func({','.join(arg_names)}):
|
||||
return {arg_names[0]}.{method}({','.join(arg_names[1:])})
|
||||
"""
|
||||
return extract_template(template, "func")
|
||||
|
||||
|
||||
def try_call_complex_method(cls_type, method):
|
||||
""" __complex__ needs special treatment as the argument names are kwargs
|
||||
and therefore specific in name and default value.
|
||||
"""
|
||||
if method in cls_type.jit_methods:
|
||||
template = f"""
|
||||
def func(real=0, imag=0):
|
||||
return real.{method}()
|
||||
"""
|
||||
return extract_template(template, "func")
|
||||
|
||||
|
||||
def take_first(*options):
|
||||
"""
|
||||
Take the first non-None option.
|
||||
"""
|
||||
assert all(o is None or inspect.isfunction(o) for o in options), options
|
||||
for o in options:
|
||||
if o is not None:
|
||||
return o
|
||||
|
||||
|
||||
@class_instance_overload(bool)
|
||||
def class_bool(x):
|
||||
using_bool_impl = try_call_method(x, "__bool__")
|
||||
|
||||
if '__len__' in x.jit_methods:
|
||||
def using_len_impl(x):
|
||||
return bool(len(x))
|
||||
else:
|
||||
using_len_impl = None
|
||||
|
||||
always_true_impl = lambda x: True
|
||||
|
||||
return take_first(using_bool_impl, using_len_impl, always_true_impl)
|
||||
|
||||
|
||||
@class_instance_overload(complex)
|
||||
def class_complex(real=0, imag=0):
|
||||
return take_first(
|
||||
try_call_complex_method(real, "__complex__"),
|
||||
lambda real=0, imag=0: complex(float(real))
|
||||
)
|
||||
|
||||
|
||||
@class_instance_overload(operator.contains)
|
||||
def class_contains(x, y):
|
||||
# https://docs.python.org/3/reference/expressions.html#membership-test-operations
|
||||
return try_call_method(x, "__contains__", 2)
|
||||
# TODO: use __iter__ if defined.
|
||||
|
||||
|
||||
@class_instance_overload(float)
|
||||
def class_float(x):
|
||||
options = [try_call_method(x, "__float__")]
|
||||
|
||||
if (
|
||||
"__index__" in x.jit_methods
|
||||
):
|
||||
options.append(lambda x: float(x.__index__()))
|
||||
|
||||
return take_first(*options)
|
||||
|
||||
|
||||
@class_instance_overload(int)
|
||||
def class_int(x):
|
||||
options = [try_call_method(x, "__int__")]
|
||||
|
||||
options.append(try_call_method(x, "__index__"))
|
||||
|
||||
return take_first(*options)
|
||||
|
||||
|
||||
@class_instance_overload(str)
|
||||
def class_str(x):
|
||||
return take_first(
|
||||
try_call_method(x, "__str__"),
|
||||
lambda x: repr(x),
|
||||
)
|
||||
|
||||
|
||||
@class_instance_overload(operator.ne)
|
||||
def class_ne(x, y):
|
||||
# This doesn't use register_reflected_overload like the other operators
|
||||
# because it falls back to inverting __eq__ rather than reflecting its
|
||||
# arguments (as per the definition of the Python data model).
|
||||
return take_first(
|
||||
try_call_method(x, "__ne__", 2),
|
||||
lambda x, y: not (x == y),
|
||||
)
|
||||
|
||||
|
||||
def register_reflected_overload(func, meth_forward, meth_reflected):
|
||||
def class_lt(x, y):
|
||||
normal_impl = try_call_method(x, f"__{meth_forward}__", 2)
|
||||
|
||||
if f"__{meth_reflected}__" in y.jit_methods:
|
||||
def reflected_impl(x, y):
|
||||
return y > x
|
||||
else:
|
||||
reflected_impl = None
|
||||
|
||||
return take_first(normal_impl, reflected_impl)
|
||||
|
||||
class_instance_overload(func)(class_lt)
|
||||
|
||||
|
||||
register_simple_overload(abs, "abs")
|
||||
register_simple_overload(len, "len")
|
||||
register_simple_overload(hash, "hash")
|
||||
|
||||
# Comparison operators.
|
||||
register_reflected_overload(operator.ge, "ge", "le")
|
||||
register_reflected_overload(operator.gt, "gt", "lt")
|
||||
register_reflected_overload(operator.le, "le", "ge")
|
||||
register_reflected_overload(operator.lt, "lt", "gt")
|
||||
|
||||
# Note that eq is missing support for fallback to `x is y`, but `is` and
|
||||
# `operator.is` are presently unsupported in general.
|
||||
register_reflected_overload(operator.eq, "eq", "eq")
|
||||
|
||||
# Arithmetic operators.
|
||||
register_simple_overload(operator.add, "add", n_args=2)
|
||||
register_simple_overload(operator.floordiv, "floordiv", n_args=2)
|
||||
register_simple_overload(operator.lshift, "lshift", n_args=2)
|
||||
register_simple_overload(operator.mul, "mul", n_args=2)
|
||||
register_simple_overload(operator.mod, "mod", n_args=2)
|
||||
register_simple_overload(operator.neg, "neg")
|
||||
register_simple_overload(operator.pos, "pos")
|
||||
register_simple_overload(operator.invert, "invert")
|
||||
register_simple_overload(operator.pow, "pow", n_args=2)
|
||||
register_simple_overload(operator.rshift, "rshift", n_args=2)
|
||||
register_simple_overload(operator.sub, "sub", n_args=2)
|
||||
register_simple_overload(operator.truediv, "truediv", n_args=2)
|
||||
|
||||
# Inplace arithmetic operators.
|
||||
register_simple_overload(operator.iadd, "iadd", "add", n_args=2)
|
||||
register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2)
|
||||
register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2)
|
||||
register_simple_overload(operator.imul, "imul", "mul", n_args=2)
|
||||
register_simple_overload(operator.imod, "imod", "mod", n_args=2)
|
||||
register_simple_overload(operator.ipow, "ipow", "pow", n_args=2)
|
||||
register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2)
|
||||
register_simple_overload(operator.isub, "isub", "sub", n_args=2)
|
||||
register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2)
|
||||
|
||||
# Logical operators.
|
||||
register_simple_overload(operator.and_, "and", n_args=2)
|
||||
register_simple_overload(operator.or_, "or", n_args=2)
|
||||
register_simple_overload(operator.xor, "xor", n_args=2)
|
||||
|
||||
# Inplace logical operators.
|
||||
register_simple_overload(operator.iand, "iand", "and", n_args=2)
|
||||
register_simple_overload(operator.ior, "ior", "or", n_args=2)
|
||||
register_simple_overload(operator.ixor, "ixor", "xor", n_args=2)
|
||||
Reference in New Issue
Block a user