This commit is contained in:
2026-04-10 15:06:59 +02:00
parent 3031b7153b
commit e5a4711004
7806 changed files with 1918528 additions and 335 deletions

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
from numba.np.ufunc.decorators import Vectorize, GUVectorize, vectorize, guvectorize
from numba.np.ufunc._internal import PyUFunc_None, PyUFunc_Zero, PyUFunc_One
from numba.np.ufunc import _internal, array_exprs
from numba.np.ufunc.parallel import (threading_layer, get_num_threads,
set_num_threads, get_thread_id,
set_parallel_chunksize,
get_parallel_chunksize)
if hasattr(_internal, 'PyUFunc_ReorderableNone'):
PyUFunc_ReorderableNone = _internal.PyUFunc_ReorderableNone
del _internal, array_exprs
def _init():
def init_cuda_vectorize():
from numba.cuda.vectorizers import CUDAVectorize
return CUDAVectorize
def init_cuda_guvectorize():
from numba.cuda.vectorizers import CUDAGUFuncVectorize
return CUDAGUFuncVectorize
Vectorize.target_registry.ondemand['cuda'] = init_cuda_vectorize
GUVectorize.target_registry.ondemand['cuda'] = init_cuda_guvectorize
_init()
del _init

View File

@@ -0,0 +1,428 @@
import ast
from collections import defaultdict, OrderedDict
import contextlib
import sys
from types import SimpleNamespace
import numpy as np
import operator
from numba.core import types, targetconfig, ir, rewrites, compiler
from numba.core.typing import npydecl
from numba.np.ufunc.dufunc import DUFunc
def _is_ufunc(func):
return isinstance(func, (np.ufunc, DUFunc))
@rewrites.register_rewrite('after-inference')
class RewriteArrayExprs(rewrites.Rewrite):
'''The RewriteArrayExprs class is responsible for finding array
expressions in Numba intermediate representation code, and
rewriting those expressions to a single operation that will expand
into something similar to a ufunc call.
'''
def __init__(self, state, *args, **kws):
super(RewriteArrayExprs, self).__init__(state, *args, **kws)
# Install a lowering hook if we are using this rewrite.
special_ops = state.targetctx.special_ops
if 'arrayexpr' not in special_ops:
special_ops['arrayexpr'] = _lower_array_expr
def match(self, func_ir, block, typemap, calltypes):
"""
Using typing and a basic block, search the basic block for array
expressions.
Return True when one or more matches were found, False otherwise.
"""
# We can trivially reject everything if there are no
# calls in the type results.
if len(calltypes) == 0:
return False
self.crnt_block = block
self.typemap = typemap
# { variable name: IR assignment (of a function call or operator) }
self.array_assigns = OrderedDict()
# { variable name: IR assignment (of a constant) }
self.const_assigns = {}
assignments = block.find_insts(ir.Assign)
for instr in assignments:
target_name = instr.target.name
expr = instr.value
# Does it assign an expression to an array variable?
if (isinstance(expr, ir.Expr) and
isinstance(typemap.get(target_name, None), types.Array)):
self._match_array_expr(instr, expr, target_name)
elif isinstance(expr, ir.Const):
# Track constants since we might need them for an
# array expression.
self.const_assigns[target_name] = expr
return len(self.array_assigns) > 0
def _match_array_expr(self, instr, expr, target_name):
"""
Find whether the given assignment (*instr*) of an expression (*expr*)
to variable *target_name* is an array expression.
"""
# We've matched a subexpression assignment to an
# array variable. Now see if the expression is an
# array expression.
expr_op = expr.op
array_assigns = self.array_assigns
if ((expr_op in ('unary', 'binop')) and (
expr.fn in npydecl.supported_array_operators)):
# It is an array operator that maps to a ufunc.
# check that all args have internal types
if all(self.typemap[var.name].is_internal
for var in expr.list_vars()):
array_assigns[target_name] = instr
elif ((expr_op == 'call') and (expr.func.name in self.typemap)):
# It could be a match for a known ufunc call.
func_type = self.typemap[expr.func.name]
if isinstance(func_type, types.Function):
func_key = func_type.typing_key
if _is_ufunc(func_key):
# If so, check whether an explicit output is passed.
if not self._has_explicit_output(expr, func_key):
# If not, match it as a (sub)expression.
array_assigns[target_name] = instr
def _has_explicit_output(self, expr, func):
"""
Return whether the *expr* call to *func* (a ufunc) features an
explicit output argument.
"""
nargs = len(expr.args) + len(expr.kws)
if expr.vararg is not None:
# XXX *args unsupported here, assume there may be an explicit
# output
return True
return nargs > func.nin
def _get_array_operator(self, ir_expr):
ir_op = ir_expr.op
if ir_op in ('unary', 'binop'):
return ir_expr.fn
elif ir_op == 'call':
return self.typemap[ir_expr.func.name].typing_key
raise NotImplementedError(
"Don't know how to find the operator for '{0}' expressions.".format(
ir_op))
def _get_operands(self, ir_expr):
'''Given a Numba IR expression, return the operands to the expression
in order they appear in the expression.
'''
ir_op = ir_expr.op
if ir_op == 'binop':
return ir_expr.lhs, ir_expr.rhs
elif ir_op == 'unary':
return ir_expr.list_vars()
elif ir_op == 'call':
return ir_expr.args
raise NotImplementedError(
"Don't know how to find the operands for '{0}' expressions.".format(
ir_op))
def _translate_expr(self, ir_expr):
'''Translate the given expression from Numba IR to an array expression
tree.
'''
ir_op = ir_expr.op
if ir_op == 'arrayexpr':
return ir_expr.expr
operands_or_args = [self.const_assigns.get(op_var.name, op_var)
for op_var in self._get_operands(ir_expr)]
return self._get_array_operator(ir_expr), operands_or_args
def _handle_matches(self):
'''Iterate over the matches, trying to find which instructions should
be rewritten, deleted, or moved.
'''
replace_map = {}
dead_vars = set()
used_vars = defaultdict(int)
for instr in self.array_assigns.values():
expr = instr.value
arr_inps = []
arr_expr = self._get_array_operator(expr), arr_inps
new_expr = ir.Expr(op='arrayexpr',
loc=expr.loc,
expr=arr_expr,
ty=self.typemap[instr.target.name])
new_instr = ir.Assign(new_expr, instr.target, instr.loc)
replace_map[instr] = new_instr
self.array_assigns[instr.target.name] = new_instr
for operand in self._get_operands(expr):
operand_name = operand.name
if operand.is_temp and operand_name in self.array_assigns:
child_assign = self.array_assigns[operand_name]
child_expr = child_assign.value
child_operands = child_expr.list_vars()
for operand in child_operands:
used_vars[operand.name] += 1
arr_inps.append(self._translate_expr(child_expr))
if child_assign.target.is_temp:
dead_vars.add(child_assign.target.name)
replace_map[child_assign] = None
elif operand_name in self.const_assigns:
arr_inps.append(self.const_assigns[operand_name])
else:
used_vars[operand.name] += 1
arr_inps.append(operand)
return replace_map, dead_vars, used_vars
def _get_final_replacement(self, replacement_map, instr):
'''Find the final replacement instruction for a given initial
instruction by chasing instructions in a map from instructions
to replacement instructions.
'''
replacement = replacement_map[instr]
while replacement in replacement_map:
replacement = replacement_map[replacement]
return replacement
def apply(self):
'''When we've found array expressions in a basic block, rewrite that
block, returning a new, transformed block.
'''
# Part 1: Figure out what instructions should be rewritten
# based on the matches found.
replace_map, dead_vars, used_vars = self._handle_matches()
# Part 2: Using the information above, rewrite the target
# basic block.
result = self.crnt_block.copy()
result.clear()
delete_map = {}
for instr in self.crnt_block.body:
if isinstance(instr, ir.Assign):
if instr in replace_map:
replacement = self._get_final_replacement(
replace_map, instr)
if replacement:
result.append(replacement)
for var in replacement.value.list_vars():
var_name = var.name
if var_name in delete_map:
result.append(delete_map.pop(var_name))
if used_vars[var_name] > 0:
used_vars[var_name] -= 1
else:
result.append(instr)
elif isinstance(instr, ir.Del):
instr_value = instr.value
if used_vars[instr_value] > 0:
used_vars[instr_value] -= 1
delete_map[instr_value] = instr
elif instr_value not in dead_vars:
result.append(instr)
else:
result.append(instr)
if delete_map:
for instr in delete_map.values():
result.insert_before_terminator(instr)
return result
_unaryops = {
operator.pos: ast.UAdd,
operator.neg: ast.USub,
operator.invert: ast.Invert,
}
_binops = {
operator.add: ast.Add,
operator.sub: ast.Sub,
operator.mul: ast.Mult,
operator.truediv: ast.Div,
operator.mod: ast.Mod,
operator.or_: ast.BitOr,
operator.rshift: ast.RShift,
operator.xor: ast.BitXor,
operator.lshift: ast.LShift,
operator.and_: ast.BitAnd,
operator.pow: ast.Pow,
operator.floordiv: ast.FloorDiv,
}
_cmpops = {
operator.eq: ast.Eq,
operator.ne: ast.NotEq,
operator.lt: ast.Lt,
operator.le: ast.LtE,
operator.gt: ast.Gt,
operator.ge: ast.GtE,
}
def _arr_expr_to_ast(expr):
'''Build a Python expression AST from an array expression built by
RewriteArrayExprs.
'''
if isinstance(expr, tuple):
op, arr_expr_args = expr
ast_args = []
env = {}
for arg in arr_expr_args:
ast_arg, child_env = _arr_expr_to_ast(arg)
ast_args.append(ast_arg)
env.update(child_env)
if op in npydecl.supported_array_operators:
if len(ast_args) == 2:
if op in _binops:
return ast.BinOp(
ast_args[0], _binops[op](), ast_args[1]), env
if op in _cmpops:
return ast.Compare(
ast_args[0], [_cmpops[op]()], [ast_args[1]]), env
else:
assert op in _unaryops
return ast.UnaryOp(_unaryops[op](), ast_args[0]), env
elif _is_ufunc(op):
fn_name = "__ufunc_or_dufunc_{0}".format(
hex(hash(op)).replace("-", "_"))
fn_ast_name = ast.Name(fn_name, ast.Load())
env[fn_name] = op # Stash the ufunc or DUFunc in the environment
ast_call = ast.Call(fn_ast_name, ast_args, [])
return ast_call, env
elif isinstance(expr, ir.Var):
return ast.Name(expr.name, ast.Load(),
lineno=expr.loc.line,
col_offset=expr.loc.col if expr.loc.col else 0), {}
elif isinstance(expr, ir.Const):
return ast.Constant(expr.value), {}
raise NotImplementedError(
"Don't know how to translate array expression '%r'" % (expr,))
@contextlib.contextmanager
def _legalize_parameter_names(var_list):
"""
Legalize names in the variable list for use as a Python function's
parameter names.
"""
var_map = OrderedDict()
for var in var_list:
old_name = var.name
new_name = var.scope.redefine(old_name, loc=var.loc).name
new_name = new_name.replace("$", "_").replace(".", "_")
# Caller should ensure the names are unique
if new_name in var_map:
raise AssertionError(f"{new_name!r} not unique")
var_map[new_name] = var, old_name
var.name = new_name
param_names = list(var_map)
try:
yield param_names
finally:
# Make sure the old names are restored, to avoid confusing
# other parts of Numba (see issue #1466)
for var, old_name in var_map.values():
var.name = old_name
class _EraseInvalidLineRanges(ast.NodeTransformer):
def generic_visit(self, node: ast.AST) -> ast.AST:
node = super().generic_visit(node)
if hasattr(node, "lineno"):
if getattr(node, "end_lineno", None) is not None:
if node.lineno > node.end_lineno:
del node.lineno
del node.end_lineno
return node
def _fix_invalid_lineno_ranges(astree: ast.AST):
"""Inplace fixes invalid lineno ranges.
"""
# Make sure lineno and end_lineno are present
ast.fix_missing_locations(astree)
# Delete invalid lineno ranges
_EraseInvalidLineRanges().visit(astree)
# Make sure lineno and end_lineno are present
ast.fix_missing_locations(astree)
def _lower_array_expr(lowerer, expr):
'''Lower an array expression built by RewriteArrayExprs.
'''
expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
expr_filename = expr.loc.filename
expr_var_list = expr.list_vars()
# The expression may use a given variable several times, but we
# should only create one parameter for it.
expr_var_unique = sorted(set(expr_var_list), key=lambda var: var.name)
# Arguments are the names external to the new closure
expr_args = [var.name for var in expr_var_unique]
# 1. Create an AST tree from the array expression.
with _legalize_parameter_names(expr_var_unique) as expr_params:
ast_args = [ast.arg(param_name, None)
for param_name in expr_params]
# Parse a stub function to ensure the AST is populated with
# reasonable defaults for the Python version.
ast_module = ast.parse('def {0}(): return'.format(expr_name),
expr_filename, 'exec')
assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
ast_fn = ast_module.body[0]
ast_fn.args.args = ast_args
ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
_fix_invalid_lineno_ranges(ast_module)
# 2. Compile the AST module and extract the Python function.
code_obj = compile(ast_module, expr_filename, 'exec')
exec(code_obj, namespace)
impl = namespace[expr_name]
# 3. Now compile a ufunc using the Python function as kernel.
context = lowerer.context
builder = lowerer.builder
outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_args))
inner_sig_args = []
for argty in outer_sig.args:
if isinstance(argty, types.Optional):
argty = argty.type
if isinstance(argty, types.Array):
inner_sig_args.append(argty.dtype)
else:
inner_sig_args.append(argty)
inner_sig = outer_sig.return_type.dtype(*inner_sig_args)
flags = targetconfig.ConfigStack().top_or_none()
flags = compiler.Flags() if flags is None else flags.copy() # make sure it's a clone or a fresh instance
# Follow the Numpy error model. Note this also allows e.g. vectorizing
# division (issue #1223).
flags.error_model = 'numpy'
cres = context.compile_subroutine(builder, impl, inner_sig, flags=flags,
caching=False)
# Create kernel subclass calling our native function
from numba.np import npyimpl
class ExprKernel(npyimpl._Kernel):
def generate(self, *args):
arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
cast_args = [self.cast(val, inty, outty)
for val, inty, outty in arg_zip]
result = self.context.call_internal(
builder, cres.fndesc, inner_sig, cast_args)
return self.cast(result, inner_sig.return_type,
self.outer_sig.return_type)
# create a fake ufunc object which is enough to trick numpy_ufunc_kernel
ufunc = SimpleNamespace(nin=len(expr_args), nout=1, __name__=expr_name)
ufunc.nargs = ufunc.nin + ufunc.nout
args = [lowerer.loadvar(name) for name in expr_args]
return npyimpl.numpy_ufunc_kernel(
context, builder, outer_sig, args, ufunc, ExprKernel)

View File

@@ -0,0 +1,208 @@
import inspect
from numba.np.ufunc import _internal
from numba.np.ufunc.parallel import ParallelUFuncBuilder, ParallelGUFuncBuilder
from numba.core.registry import DelayedRegistry
from numba.np.ufunc import dufunc
from numba.np.ufunc import gufunc
class _BaseVectorize(object):
@classmethod
def get_identity(cls, kwargs):
return kwargs.pop('identity', None)
@classmethod
def get_cache(cls, kwargs):
return kwargs.pop('cache', False)
@classmethod
def get_writable_args(cls, kwargs):
return kwargs.pop('writable_args', ())
@classmethod
def get_target_implementation(cls, kwargs):
target = kwargs.pop('target', 'cpu')
try:
return cls.target_registry[target]
except KeyError:
raise ValueError("Unsupported target: %s" % target)
class Vectorize(_BaseVectorize):
target_registry = DelayedRegistry({'cpu': dufunc.DUFunc,
'parallel': ParallelUFuncBuilder,})
def __new__(cls, func, **kws):
identity = cls.get_identity(kws)
cache = cls.get_cache(kws)
imp = cls.get_target_implementation(kws)
return imp(func, identity=identity, cache=cache, targetoptions=kws)
class GUVectorize(_BaseVectorize):
target_registry = DelayedRegistry({'cpu': gufunc.GUFunc,
'parallel': ParallelGUFuncBuilder,})
def __new__(cls, func, signature, **kws):
identity = cls.get_identity(kws)
cache = cls.get_cache(kws)
imp = cls.get_target_implementation(kws)
writable_args = cls.get_writable_args(kws)
if imp is gufunc.GUFunc:
is_dyn = kws.pop('is_dynamic', False)
return imp(func, signature, identity=identity, cache=cache,
is_dynamic=is_dyn, targetoptions=kws,
writable_args=writable_args)
else:
return imp(func, signature, identity=identity, cache=cache,
targetoptions=kws, writable_args=writable_args)
def vectorize(ftylist_or_function=(), **kws):
"""vectorize(ftylist_or_function=(), target='cpu', identity=None, **kws)
A decorator that creates a NumPy ufunc object using Numba compiled
code. When no arguments or only keyword arguments are given,
vectorize will return a Numba dynamic ufunc (DUFunc) object, where
compilation/specialization may occur at call-time.
Args
-----
ftylist_or_function: function or iterable
When the first argument is a function, signatures are dealt
with at call-time.
When the first argument is an iterable of type signatures,
which are either function type object or a string describing
the function type, signatures are finalized at decoration
time.
Keyword Args
------------
target: str
A string for code generation target. Default to "cpu".
identity: int, str, or None
The identity (or unit) value for the element-wise function
being implemented. Allowed values are None (the default), 0, 1,
and "reorderable".
cache: bool
Turns on caching.
Returns
--------
A NumPy universal function
Examples
-------
@vectorize(['float32(float32, float32)',
'float64(float64, float64)'], identity=0)
def sum(a, b):
return a + b
@vectorize
def sum(a, b):
return a + b
@vectorize(identity=1)
def mul(a, b):
return a * b
"""
if isinstance(ftylist_or_function, str):
# Common user mistake
ftylist = [ftylist_or_function]
elif inspect.isfunction(ftylist_or_function):
return dufunc.DUFunc(ftylist_or_function, **kws)
elif ftylist_or_function is not None:
ftylist = ftylist_or_function
def wrap(func):
vec = Vectorize(func, **kws)
for sig in ftylist:
vec.add(sig)
if len(ftylist) > 0:
vec.disable_compile()
return vec.build_ufunc()
return wrap
def guvectorize(*args, **kwargs):
"""guvectorize(ftylist, signature, target='cpu', identity=None, **kws)
A decorator to create NumPy generalized-ufunc object from Numba compiled
code.
Args
-----
ftylist: iterable
An iterable of type signatures, which are either
function type object or a string describing the
function type.
signature: str
A NumPy generalized-ufunc signature.
e.g. "(m, n), (n, p)->(m, p)"
identity: int, str, or None
The identity (or unit) value for the element-wise function
being implemented. Allowed values are None (the default), 0, 1,
and "reorderable".
cache: bool
Turns on caching.
writable_args: tuple
a tuple of indices of input variables that are writable.
target: str
A string for code generation target. Defaults to "cpu".
Returns
--------
A NumPy generalized universal-function
Example
-------
@guvectorize(['void(int32[:,:], int32[:,:], int32[:,:])',
'void(float32[:,:], float32[:,:], float32[:,:])'],
'(x, y),(x, y)->(x, y)')
def add_2d_array(a, b, c):
for i in range(c.shape[0]):
for j in range(c.shape[1]):
c[i, j] = a[i, j] + b[i, j]
"""
if len(args) == 1:
ftylist = []
signature = args[0]
kwargs.setdefault('is_dynamic', True)
elif len(args) == 2:
ftylist = args[0]
signature = args[1]
else:
raise TypeError('guvectorize() takes one or two positional arguments')
if isinstance(ftylist, str):
# Common user mistake
ftylist = [ftylist]
def wrap(func):
guvec = GUVectorize(func, signature, **kwargs)
for fty in ftylist:
guvec.add(fty)
if len(ftylist) > 0:
guvec.disable_compile()
return guvec.build_ufunc()
return wrap

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,325 @@
from numba import typeof
from numba.core import types
from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
from numba.np.ufunc.sigparse import parse_signature
from numba.np.ufunc.ufunc_base import UfuncBase, UfuncLowererBase
from numba.np.numpy_support import ufunc_find_matching_loop
from numba.core import serialize, errors
from numba.core.typing import npydecl
from numba.core.typing.templates import signature, AbstractTemplate
import functools
def make_gufunc_kernel(_dufunc):
from numba.np import npyimpl
class GUFuncKernel(npyimpl._Kernel):
"""
npyimpl._Kernel subclass responsible for lowering a gufunc kernel
(element-wise function) inside a broadcast loop (which is
generated by npyimpl.numpy_gufunc_kernel()).
"""
dufunc = _dufunc
def __init__(self, context, builder, outer_sig):
super().__init__(context, builder, outer_sig)
ewise_types = self.dufunc._get_ewise_dtypes(outer_sig.args)
self.inner_sig, self.cres = self.dufunc.find_ewise_function(
ewise_types)
def cast(self, val, fromty, toty):
# Handle the case where "fromty" is an array and "toty" a scalar
if isinstance(fromty, types.Array) and not \
isinstance(toty, types.Array):
return super().cast(val, fromty.dtype, toty)
return super().cast(val, fromty, toty)
def generate(self, *args):
if self.cres.objectmode:
msg = ('Calling a guvectorize function in object mode is not '
'supported yet.')
raise errors.NumbaRuntimeError(msg)
self.context.add_linking_libs((self.cres.library,))
return super().generate(*args)
GUFuncKernel.__name__ += _dufunc.__name__
return GUFuncKernel
class GUFuncLowerer(UfuncLowererBase):
'''Callable class responsible for lowering calls to a specific gufunc.
'''
def __init__(self, gufunc):
from numba.np import npyimpl
super().__init__(gufunc,
make_gufunc_kernel,
npyimpl.numpy_gufunc_kernel)
class GUFunc(serialize.ReduceMixin, UfuncBase):
"""
Dynamic generalized universal function (GUFunc)
intended to act like a normal Numpy gufunc, but capable
of call-time (just-in-time) compilation of fast loops
specialized to inputs.
"""
def __init__(self, py_func, signature, identity=None, cache=None,
is_dynamic=False, targetoptions=None, writable_args=()):
if targetoptions is None:
targetoptions = {}
self.ufunc = None
self._frozen = False
self._is_dynamic = is_dynamic
self._identity = identity
# GUFunc cannot inherit from GUFuncBuilder because "identity"
# is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder
# object here
self.gufunc_builder = GUFuncBuilder(
py_func, signature, identity, cache, targetoptions, writable_args)
self.__name__ = self.gufunc_builder.py_func.__name__
self.__doc__ = self.gufunc_builder.py_func.__doc__
self._dispatcher = self.gufunc_builder.nb_func
self._initialize(self._dispatcher)
functools.update_wrapper(self, py_func)
def _initialize(self, dispatcher):
self.build_ufunc()
self._install_type()
self._lower_me = GUFuncLowerer(self)
self._install_cg()
def _reduce_states(self):
gb = self.gufunc_builder
dct = dict(
py_func=gb.py_func,
signature=gb.signature,
identity=self._identity,
cache=gb.cache,
is_dynamic=self._is_dynamic,
targetoptions=gb.targetoptions,
writable_args=gb.writable_args,
typesigs=gb._sigs,
frozen=self._frozen,
)
return dct
@classmethod
def _rebuild(cls, py_func, signature, identity, cache, is_dynamic,
targetoptions, writable_args, typesigs, frozen):
self = cls(py_func=py_func, signature=signature, identity=identity,
cache=cache, is_dynamic=is_dynamic,
targetoptions=targetoptions, writable_args=writable_args)
for sig in typesigs:
self.add(sig)
self.build_ufunc()
self._frozen = frozen
return self
def __repr__(self):
return f"<numba._GUFunc '{self.__name__}'>"
def _install_type(self, typingctx=None):
"""Constructs and installs a typing class for a gufunc object in the
input typing context. If no typing context is given, then
_install_type() installs into the typing context of the
dispatcher object (should be same default context used by
jit() and njit()).
"""
if typingctx is None:
typingctx = self._dispatcher.targetdescr.typing_context
_ty_cls = type('GUFuncTyping_' + self.__name__,
(AbstractTemplate,),
dict(key=self, generic=self._type_me))
typingctx.insert_user_function(self, _ty_cls)
def add(self, fty):
self.gufunc_builder.add(fty)
def build_ufunc(self):
self.ufunc = self.gufunc_builder.build_ufunc()
return self
def expected_ndims(self):
parsed_sig = parse_signature(self.gufunc_builder.signature)
return (tuple(map(len, parsed_sig[0])), tuple(map(len, parsed_sig[1])))
def _type_me(self, argtys, kws):
"""
Implement AbstractTemplate.generic() for the typing class
built by gufunc._install_type().
Return the call-site signature after either validating the
element-wise signature or compiling for it.
"""
assert not kws
ufunc = self.ufunc
sig = self.gufunc_builder.signature
inp_ndims, out_ndims = self.expected_ndims()
ndims = inp_ndims + out_ndims
assert len(argtys), len(ndims)
for idx, arg in enumerate(argtys):
if isinstance(arg, types.Array) and arg.ndim < ndims[idx]:
kind = "Input" if idx < len(inp_ndims) else "Output"
i = idx if idx < len(inp_ndims) else idx - len(inp_ndims)
msg = (
f"{self.__name__}: {kind} operand {i} does not have "
f"enough dimensions (has {arg.ndim}, gufunc core with "
f"signature {sig} requires {ndims[idx]})")
raise errors.TypingError(msg)
_handle_inputs_result = npydecl.Numpy_rules_ufunc._handle_inputs(
ufunc, argtys, kws)
ewise_types, _, _, _ = _handle_inputs_result
sig, _ = self.find_ewise_function(ewise_types)
if sig is None:
# Matching element-wise signature was not found; must
# compile.
if self._frozen:
msg = f"cannot call {self} with types {argtys}"
raise errors.TypingError(msg)
self._compile_for_argtys(ewise_types)
# double check to ensure there is a match
sig, _ = self.find_ewise_function(ewise_types)
if sig == (None, None):
msg = f"Fail to compile {self.__name__} with types {argtys}"
raise errors.TypingError(msg)
assert sig is not None
return signature(types.none, *argtys)
def _compile_for_argtys(self, argtys, return_type=None):
# Compile a new guvectorize function! Use the gufunc signature
# i.e. (n,m),(m)->(n)
# plus ewise_types to build a numba function type
fnty = self._get_function_type(*argtys)
self.gufunc_builder.add(fnty)
def match_signature(self, ewise_types, sig):
dtypes = self._get_ewise_dtypes(sig.args)
return tuple(dtypes) == tuple(ewise_types)
@property
def is_dynamic(self):
return self._is_dynamic
def _get_ewise_dtypes(self, args):
argtys = map(lambda arg: arg if isinstance(arg, types.Type) else
typeof(arg), args)
tys = []
for argty in argtys:
if isinstance(argty, types.Array):
tys.append(argty.dtype)
else:
tys.append(argty)
return tys
def _num_args_match(self, *args):
parsed_sig = parse_signature(self.gufunc_builder.signature)
return len(args) == len(parsed_sig[0]) + len(parsed_sig[1])
def _get_function_type(self, *args):
parsed_sig = parse_signature(self.gufunc_builder.signature)
# ewise_types is a list of [int32, int32, int32, ...]
ewise_types = self._get_ewise_dtypes(args)
# first time calling the gufunc
# generate a signature based on input arguments
l = []
for idx, sig_dim in enumerate(parsed_sig[0]):
ndim = len(sig_dim)
if ndim == 0: # append scalar
l.append(ewise_types[idx])
else:
l.append(types.Array(ewise_types[idx], ndim, 'A'))
offset = len(parsed_sig[0])
# add return type to signature
for idx, sig_dim in enumerate(parsed_sig[1]):
retty = ewise_types[idx + offset]
ret_ndim = len(sig_dim) or 1 # small hack to return scalars
l.append(types.Array(retty, ret_ndim, 'A'))
return types.none(*l)
def __call__(self, *args, **kwargs):
# If compilation is disabled OR it is NOT a dynamic gufunc
# call the underlying gufunc
if self._frozen or not self.is_dynamic:
# Do not unwrap the ufunc if the argument is a wrapper that will
# potentially pickle the ufunc after it receives it in
# __array_ufunc__. The same logic in theory should be replicated
# for reduce(), outer(), etc., but they're not implemented in dask.
if args and _is_array_wrapper(args[0]):
return args[0].__array_ufunc__(
self, "__call__", *args, **kwargs
)
else:
return self.ufunc(*args, **kwargs)
elif "out" in kwargs:
# If "out" argument is supplied
args += (kwargs.pop("out"),)
if self._num_args_match(*args) is False:
# It is not allowed to call a dynamic gufunc without
# providing all the arguments
# see: https://github.com/numba/numba/pull/5938#discussion_r506429392 # noqa: E501
msg = (
f"Too few arguments for function '{self.__name__}'. "
"Note that the pattern `out = gufunc(Arg1, Arg2, ..., ArgN)` "
"is not allowed. Use `gufunc(Arg1, Arg2, ..., ArgN, out) "
"instead.")
raise TypeError(msg)
# at this point we know the gufunc is a dynamic one
ewise = self._get_ewise_dtypes(args)
if not (self.ufunc and ufunc_find_matching_loop(self.ufunc, ewise)):
# A previous call (@njit -> @guvectorize) may have compiled a
# version for the element-wise dtypes. In this case, we don't need
# to compile it again, just build the (g)ufunc
if not self.find_ewise_function(ewise) != (None, None):
sig = self._get_function_type(*args)
self.add(sig)
self.build_ufunc()
return self.ufunc(*args, **kwargs)
def _is_array_wrapper(obj):
"""Return True if obj wraps around numpy or another numpy-like library
and is likely going to apply the ufunc to the wrapped array; False
otherwise.
At the moment, this returns True for
- dask.array.Array
- dask.dataframe.DataFrame
- dask.dataframe.Series
- xarray.DataArray
- xarray.Dataset
- xarray.Variable
- pint.Quantity
- other potential wrappers around dask array or dask dataframe
We may need to add other libraries that pickle ufuncs from their
__array_ufunc__ method in the future.
Note that the below test is a lot more naive than
`dask.base.is_dask_collection`
(https://github.com/dask/dask/blob/5949e54bc04158d215814586a44d51e0eb4a964d/dask/base.py#L209-L249), # noqa: E501
because it doesn't need to find out if we're actually dealing with
a dask collection, only that we're dealing with a wrapper.
Namely, it will return True for a pint.Quantity wrapping around a plain float, a
numpy.ndarray, or a dask.array.Array, and it's OK because in all cases
Quantity.__array_ufunc__ is going to forward the ufunc call inwards.
"""
return (
not isinstance(obj, type)
and hasattr(obj, "__dask_graph__")
and hasattr(obj, "__array_ufunc__")
)

View File

@@ -0,0 +1,761 @@
"""
This file implements the code-generator for parallel-vectorize.
ParallelUFunc is the platform independent base class for generating
the thread dispatcher. This thread dispatcher launches threads
that execute the generated function of UFuncCore.
UFuncCore is subclassed to specialize for the input/output types.
The actual workload is invoked inside the function generated by UFuncCore.
UFuncCore also defines a work-stealing mechanism that allows idle threads
to steal works from other threads.
"""
import os
import sys
import warnings
from threading import RLock as threadRLock
from ctypes import CFUNCTYPE, c_int, CDLL, POINTER, c_uint
import numpy as np
import llvmlite.binding as ll
from llvmlite import ir
from numba.np.numpy_support import as_dtype
from numba.core import types, cgutils, config, errors
from numba.core.typing import signature
from numba.np.ufunc.wrappers import _wrapper_info
from numba.np.ufunc import ufuncbuilder
from numba.extending import overload, intrinsic
_IS_OSX = sys.platform.startswith('darwin')
_IS_LINUX = sys.platform.startswith('linux')
_IS_WINDOWS = sys.platform.startswith('win32')
def get_thread_count():
"""
Gets the available thread count.
"""
t = config.NUMBA_NUM_THREADS
if t < 1:
raise ValueError("Number of threads specified must be > 0.")
return t
NUM_THREADS = get_thread_count()
def build_gufunc_kernel(library, ctx, info, sig, inner_ndim):
"""Wrap the original CPU ufunc/gufunc with a parallel dispatcher.
This function will wrap gufuncs and ufuncs something like.
Args
----
ctx
numba's codegen context
info: (library, env, name)
inner function info
sig
type signature of the gufunc
inner_ndim
inner dimension of the gufunc (this is len(sig.args) in the case of a
ufunc)
Returns
-------
wrapper_info : (library, env, name)
The info for the gufunc wrapper.
Details
-------
The kernel signature looks like this:
void kernel(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
args - the input arrays + output arrays
dimensions - the dimensions of the arrays
steps - the step size for the array (this is like sizeof(type))
data - any additional data
The parallel backend then stages multiple calls to this kernel concurrently
across a number of threads. Practically, for each item of work, the backend
duplicates `dimensions` and adjusts the first entry to reflect the size of
the item of work, it also forms up an array of pointers into the args for
offsets to read/write from/to with respect to its position in the items of
work. This allows the same kernel to be used for each item of work, with
simply adjusted reads/writes/domain sizes and is safe by virtue of the
domain partitioning.
NOTE: The execution backend is passed the requested thread count, but it can
choose to ignore it (TBB)!
"""
assert isinstance(info, tuple) # guard against old usage
# Declare types and function
byte_t = ir.IntType(8)
byte_ptr_t = ir.PointerType(byte_t)
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
intp_t = ctx.get_value_type(types.intp)
intp_ptr_t = ir.PointerType(intp_t)
fnty = ir.FunctionType(ir.VoidType(), [ir.PointerType(byte_ptr_t),
ir.PointerType(intp_t),
ir.PointerType(intp_t),
byte_ptr_t])
wrapperlib = ctx.codegen().create_library('parallelgufuncwrapper')
mod = wrapperlib.create_ir_module('parallel.gufunc.wrapper')
kernel_name = ".kernel.{}_{}".format(id(info.env), info.name)
lfunc = ir.Function(mod, fnty, name=kernel_name)
bb_entry = lfunc.append_basic_block('')
# Function body starts
builder = ir.IRBuilder(bb_entry)
args, dimensions, steps, data = lfunc.args
# Release the GIL (and ensure we have the GIL)
# Note: numpy ufunc may not always release the GIL; thus,
# we need to ensure we have the GIL.
pyapi = ctx.get_python_api(builder)
gil_state = pyapi.gil_ensure()
thread_state = pyapi.save_thread()
def as_void_ptr(arg):
return builder.bitcast(arg, byte_ptr_t)
# Array count depends on whether an "output" array is needed. In the case
# of a void return type cf. gufunc it is the number of args, in the case of
# a non-void return type cf. ufunc it is the number of args + 1 so as to
# account for the output array.
array_count = len(sig.args)
if not isinstance(sig.return_type, types.NoneType):
array_count += 1
parallel_for_ty = ir.FunctionType(ir.VoidType(),
[byte_ptr_t] * 5 + [intp_t, ] * 3)
parallel_for = cgutils.get_or_insert_function(mod, parallel_for_ty,
'numba_parallel_for')
# Reference inner-function and link
innerfunc_fnty = ir.FunctionType(
ir.VoidType(),
[byte_ptr_ptr_t, intp_ptr_t, intp_ptr_t, byte_ptr_t],
)
tmp_voidptr = cgutils.get_or_insert_function(mod, innerfunc_fnty,
info.name,)
wrapperlib.add_linking_library(info.library)
get_num_threads = cgutils.get_or_insert_function(
builder.module,
ir.FunctionType(ir.IntType(types.intp.bitwidth), []),
"get_num_threads")
num_threads = builder.call(get_num_threads, [])
# Prepare call
fnptr = builder.bitcast(tmp_voidptr, byte_ptr_t)
innerargs = [as_void_ptr(x) for x
in [args, dimensions, steps, data]]
builder.call(parallel_for, [fnptr] + innerargs +
[intp_t(x) for x in (inner_ndim, array_count)] + [num_threads])
# Release the GIL
pyapi.restore_thread(thread_state)
pyapi.gil_release(gil_state)
builder.ret_void()
wrapperlib.add_ir_module(mod)
wrapperlib.add_linking_library(library)
return _wrapper_info(library=wrapperlib, name=lfunc.name, env=info.env)
# ------------------------------------------------------------------------------
class ParallelUFuncBuilder(ufuncbuilder.UFuncBuilder):
def build(self, cres, sig):
_launch_threads()
# Builder wrapper for ufunc entry point
ctx = cres.target_context
signature = cres.signature
library = cres.library
fname = cres.fndesc.llvm_func_name
info = build_ufunc_wrapper(library, ctx, fname, signature, cres)
ptr = info.library.get_pointer_to_function(info.name)
# Get dtypes
dtypenums = [np.dtype(a.name).num for a in signature.args]
dtypenums.append(np.dtype(signature.return_type.name).num)
keepalive = ()
return dtypenums, ptr, keepalive
def build_ufunc_wrapper(library, ctx, fname, signature, cres):
innerfunc = ufuncbuilder.build_ufunc_wrapper(library, ctx, fname,
signature, objmode=False,
cres=cres)
info = build_gufunc_kernel(library, ctx, innerfunc, signature,
len(signature.args))
return info
# ---------------------------------------------------------------------------
class ParallelGUFuncBuilder(ufuncbuilder.GUFuncBuilder):
def __init__(self, py_func, signature, identity=None, cache=False,
targetoptions=None, writable_args=()):
# Force nopython mode
if targetoptions is None:
targetoptions = {}
targetoptions.update(dict(nopython=True))
super(
ParallelGUFuncBuilder,
self).__init__(
py_func=py_func,
signature=signature,
identity=identity,
cache=cache,
targetoptions=targetoptions,
writable_args=writable_args)
def build(self, cres):
"""
Returns (dtype numbers, function ptr, EnvironmentObject)
"""
_launch_threads()
# Build wrapper for ufunc entry point
info = build_gufunc_wrapper(
self.py_func, cres, self.sin, self.sout, cache=self.cache,
is_parfors=False,
)
ptr = info.library.get_pointer_to_function(info.name)
env = info.env
# Get dtypes
dtypenums = []
for a in cres.signature.args:
if isinstance(a, types.Array):
ty = a.dtype
else:
ty = a
dtypenums.append(as_dtype(ty).num)
return dtypenums, ptr, env
# This is not a member of the ParallelGUFuncBuilder function because it is
# called without an enclosing instance from parfors
def build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors):
"""Build gufunc wrapper for the given arguments.
The *is_parfors* is a boolean indicating whether the gufunc is being
built for use as a ParFors kernel. This changes codegen and caching
behavior.
"""
library = cres.library
ctx = cres.target_context
signature = cres.signature
innerinfo = ufuncbuilder.build_gufunc_wrapper(
py_func, cres, sin, sout, cache=cache, is_parfors=is_parfors,
)
sym_in = set(sym for term in sin for sym in term)
sym_out = set(sym for term in sout for sym in term)
inner_ndim = len(sym_in | sym_out)
info = build_gufunc_kernel(
library, ctx, innerinfo, signature, inner_ndim,
)
return info
# ---------------------------------------------------------------------------
_backend_init_thread_lock = threadRLock()
_windows = sys.platform.startswith('win32')
class _nop(object):
"""A no-op contextmanager
"""
def __enter__(self):
pass
def __exit__(self, *args):
pass
_backend_init_process_lock = None
def _set_init_process_lock():
global _backend_init_process_lock
try:
# Force the use of an RLock in the case a fork was used to start the
# process and thereby the init sequence, some of the threading backend
# init sequences are not fork safe. Also, windows global mp locks seem
# to be fine.
with _backend_init_thread_lock: # protect part-initialized module access
import multiprocessing
if "fork" in multiprocessing.get_start_method() or _windows:
ctx = multiprocessing.get_context()
_backend_init_process_lock = ctx.RLock()
else:
_backend_init_process_lock = _nop()
except OSError as e:
# probably lack of /dev/shm for semaphore writes, warn the user
msg = (
"Could not obtain multiprocessing lock due to OS level error: %s\n"
"A likely cause of this problem is '/dev/shm' is missing or "
"read-only such that necessary semaphores cannot be written.\n"
"*** The responsibility of ensuring multiprocessing safe access to "
"this initialization sequence/module import is deferred to the "
"user! ***\n"
)
warnings.warn(msg % str(e), errors.NumbaSystemWarning)
_backend_init_process_lock = _nop()
_is_initialized = False
# this is set by _launch_threads
_threading_layer = None
def threading_layer():
"""
Get the name of the threading layer in use for parallel CPU targets
"""
if _threading_layer is None:
raise ValueError("Threading layer is not initialized.")
else:
return _threading_layer
def _check_tbb_version_compatible():
"""
Checks that if TBB is present it is of a compatible version.
"""
try:
# first check that the TBB version is new enough
if _IS_WINDOWS:
libtbb_name = 'tbb12.dll'
elif _IS_OSX:
libtbb_name = 'libtbb.12.dylib'
elif _IS_LINUX:
libtbb_name = 'libtbb.so.12'
else:
raise ValueError("Unknown operating system")
libtbb = CDLL(libtbb_name)
version_func = libtbb.TBB_runtime_interface_version
version_func.argtypes = []
version_func.restype = c_int
tbb_iface_ver = version_func()
if tbb_iface_ver < 12060: # magic number from TBB
msg = ("The TBB threading layer requires TBB "
"version 2021 update 6 or later i.e., "
"TBB_INTERFACE_VERSION >= 12060. Found "
"TBB_INTERFACE_VERSION = %s. The TBB "
"threading layer is disabled.") % tbb_iface_ver
problem = errors.NumbaWarning(msg)
warnings.warn(problem)
raise ImportError("Problem with TBB. Reason: %s" % msg)
except (ValueError, OSError) as e:
# Translate as an ImportError for consistent error class use, this error
# will never materialise
raise ImportError("Problem with TBB. Reason: %s" % e)
def _launch_threads():
if not _backend_init_process_lock:
_set_init_process_lock()
with _backend_init_process_lock:
with _backend_init_thread_lock:
global _is_initialized
if _is_initialized:
return
def select_known_backend(backend):
"""
Loads a specific threading layer backend based on string
"""
lib = None
if backend.startswith("tbb"):
try:
# check if TBB is present and compatible
_check_tbb_version_compatible()
# now try and load the backend
from numba.np.ufunc import tbbpool as lib
except ImportError:
pass
elif backend.startswith("omp"):
# TODO: Check that if MKL is present that it is a version
# that understands GNU OMP might be present
try:
from numba.np.ufunc import omppool as lib
except ImportError:
pass
elif backend.startswith("workqueue"):
from numba.np.ufunc import workqueue as lib
else:
msg = "Unknown value specified for threading layer: %s"
raise ValueError(msg % backend)
return lib
def select_from_backends(backends):
"""
Selects from presented backends and returns the first working
"""
lib = None
for backend in backends:
lib = select_known_backend(backend)
if lib is not None:
break
else:
backend = ''
return lib, backend
t = str(config.THREADING_LAYER).lower()
namedbackends = config.THREADING_LAYER_PRIORITY
if not (len(namedbackends) == 3 and
set(namedbackends) == {'tbb', 'omp', 'workqueue'}):
raise ValueError(
"THREADING_LAYER_PRIORITY invalid: %s. "
"It must be a permutation of "
"{'tbb', 'omp', 'workqueue'}"
% namedbackends
)
lib = None
err_helpers = dict()
err_helpers['TBB'] = ("Intel TBB is required, try:\n"
"$ conda/pip install tbb")
err_helpers['OSX_OMP'] = ("Intel OpenMP is required, try:\n"
"$ conda/pip install intel-openmp")
requirements = []
def raise_with_hint(required):
errmsg = "No threading layer could be loaded.\n%s"
hintmsg = "HINT:\n%s"
if len(required) == 0:
hint = ''
if len(required) == 1:
hint = hintmsg % err_helpers[required[0]]
if len(required) > 1:
options = '\nOR\n'.join([err_helpers[x] for x in required])
hint = hintmsg % ("One of:\n%s" % options)
raise ValueError(errmsg % hint)
if t in namedbackends:
# Try and load the specific named backend
lib = select_known_backend(t)
if not lib:
# something is missing preventing a valid backend from
# loading, set requirements for hinting
if t == 'tbb':
requirements.append('TBB')
elif t == 'omp' and _IS_OSX:
requirements.append('OSX_OMP')
libname = t
elif t in ['threadsafe', 'forksafe', 'safe']:
# User wants a specific behaviour...
available = ['tbb']
requirements.append('TBB')
if t == "safe":
# "safe" is TBB, which is fork and threadsafe everywhere
pass
elif t == "threadsafe":
if _IS_OSX:
requirements.append('OSX_OMP')
# omp is threadsafe everywhere
available.append('omp')
elif t == "forksafe":
# everywhere apart from linux (GNU OpenMP) has a guaranteed
# forksafe OpenMP, as OpenMP has better performance, prefer
# this to workqueue
if not _IS_LINUX:
available.append('omp')
if _IS_OSX:
requirements.append('OSX_OMP')
# workqueue is forksafe everywhere
available.append('workqueue')
else: # unreachable
msg = "No threading layer available for purpose %s"
raise ValueError(msg % t)
# select amongst available
lib, libname = select_from_backends(available)
elif t == 'default':
# If default is supplied, try them in order, tbb, omp,
# workqueue
lib, libname = select_from_backends(namedbackends)
if not lib:
# set requirements for hinting
requirements.append('TBB')
if _IS_OSX:
requirements.append('OSX_OMP')
else:
msg = "The threading layer requested '%s' is unknown to Numba."
raise ValueError(msg % t)
# No lib found, raise and hint
if not lib:
raise_with_hint(requirements)
ll.add_symbol('numba_parallel_for', lib.parallel_for)
ll.add_symbol('do_scheduling_signed', lib.do_scheduling_signed)
ll.add_symbol('do_scheduling_unsigned', lib.do_scheduling_unsigned)
ll.add_symbol('allocate_sched', lib.allocate_sched)
ll.add_symbol('deallocate_sched', lib.deallocate_sched)
launch_threads = CFUNCTYPE(None, c_int)(lib.launch_threads)
launch_threads(NUM_THREADS)
_load_threading_functions(lib) # load late
# set library name so it can be queried
global _threading_layer
_threading_layer = libname
_is_initialized = True
def _load_threading_functions(lib):
ll.add_symbol('get_num_threads', lib.get_num_threads)
ll.add_symbol('set_num_threads', lib.set_num_threads)
ll.add_symbol('get_thread_id', lib.get_thread_id)
global _set_num_threads
_set_num_threads = CFUNCTYPE(None, c_int)(lib.set_num_threads)
_set_num_threads(NUM_THREADS)
global _get_num_threads
_get_num_threads = CFUNCTYPE(c_int)(lib.get_num_threads)
global _get_thread_id
_get_thread_id = CFUNCTYPE(c_int)(lib.get_thread_id)
ll.add_symbol('set_parallel_chunksize', lib.set_parallel_chunksize)
ll.add_symbol('get_parallel_chunksize', lib.get_parallel_chunksize)
ll.add_symbol('get_sched_size', lib.get_sched_size)
global _set_parallel_chunksize
_set_parallel_chunksize = CFUNCTYPE(c_uint,
c_uint)(lib.set_parallel_chunksize)
global _get_parallel_chunksize
_get_parallel_chunksize = CFUNCTYPE(c_uint)(lib.get_parallel_chunksize)
global _get_sched_size
_get_sched_size = CFUNCTYPE(c_uint,
c_uint,
c_uint,
POINTER(c_int),
POINTER(c_int))(lib.get_sched_size)
# Some helpers to make set_num_threads jittable
def gen_snt_check():
from numba.core.config import NUMBA_NUM_THREADS
msg = "The number of threads must be between 1 and %s" % NUMBA_NUM_THREADS
def snt_check(n):
if n > NUMBA_NUM_THREADS or n < 1:
raise ValueError(msg)
return snt_check
snt_check = gen_snt_check()
@overload(snt_check)
def ol_snt_check(n):
return snt_check
def set_num_threads(n):
"""
Set the number of threads to use for parallel execution.
By default, all :obj:`numba.config.NUMBA_NUM_THREADS` threads are used.
This functionality works by masking out threads that are not used.
Therefore, the number of threads *n* must be less than or equal to
:obj:`~.NUMBA_NUM_THREADS`, the total number of threads that are launched.
See its documentation for more details.
This function can be used inside of a jitted function.
Parameters
----------
n: The number of threads. Must be between 1 and NUMBA_NUM_THREADS.
See Also
--------
get_num_threads, numba.config.NUMBA_NUM_THREADS,
numba.config.NUMBA_DEFAULT_NUM_THREADS, :envvar:`NUMBA_NUM_THREADS`
"""
_launch_threads()
if not isinstance(n, (int, np.integer)):
raise TypeError("The number of threads specified must be an integer")
snt_check(n)
_set_num_threads(n)
@overload(set_num_threads)
def ol_set_num_threads(n):
_launch_threads()
if not isinstance(n, types.Integer):
msg = "The number of threads specified must be an integer"
raise errors.TypingError(msg)
def impl(n):
snt_check(n)
_set_num_threads(n)
return impl
def get_num_threads():
"""
Get the number of threads used for parallel execution.
By default (if :func:`~.set_num_threads` is never called), all
:obj:`numba.config.NUMBA_NUM_THREADS` threads are used.
This number is less than or equal to the total number of threads that are
launched, :obj:`numba.config.NUMBA_NUM_THREADS`.
This function can be used inside of a jitted function.
Returns
-------
The number of threads.
See Also
--------
set_num_threads, numba.config.NUMBA_NUM_THREADS,
numba.config.NUMBA_DEFAULT_NUM_THREADS, :envvar:`NUMBA_NUM_THREADS`
"""
_launch_threads()
num_threads = _get_num_threads()
if num_threads <= 0:
raise RuntimeError("Invalid number of threads. "
"This likely indicates a bug in Numba. "
"(thread_id=%s, num_threads=%s)" %
(get_thread_id(), num_threads))
return num_threads
@overload(get_num_threads)
def ol_get_num_threads():
_launch_threads()
def impl():
num_threads = _get_num_threads()
if num_threads <= 0:
print("Broken thread_id: ", get_thread_id())
print("num_threads: ", num_threads)
raise RuntimeError("Invalid number of threads. "
"This likely indicates a bug in Numba.")
return num_threads
return impl
@intrinsic
def _iget_num_threads(typingctx):
_launch_threads()
def codegen(context, builder, signature, args):
mod = builder.module
fnty = ir.FunctionType(cgutils.intp_t, [])
fn = cgutils.get_or_insert_function(mod, fnty, "get_num_threads")
return builder.call(fn, [])
return signature(types.intp), codegen
def get_thread_id():
"""
Returns a unique ID for each thread in the range 0 (inclusive)
to :func:`~.get_num_threads` (exclusive).
"""
# Called from the interpreter directly, this should return 0
# Called from a sequential JIT region, this should return 0
# Called from a parallel JIT region, this should return 0..N
# Called from objmode in a parallel JIT region, this should return 0..N
_launch_threads()
return _get_thread_id()
@overload(get_thread_id)
def ol_get_thread_id():
_launch_threads()
def impl():
return _iget_thread_id()
return impl
@intrinsic
def _iget_thread_id(typingctx):
def codegen(context, builder, signature, args):
mod = builder.module
fnty = ir.FunctionType(cgutils.intp_t, [])
fn = cgutils.get_or_insert_function(mod, fnty, "get_thread_id")
return builder.call(fn, [])
return signature(types.intp), codegen
_DYLD_WORKAROUND_SET = 'NUMBA_DYLD_WORKAROUND' in os.environ
_DYLD_WORKAROUND_VAL = int(os.environ.get('NUMBA_DYLD_WORKAROUND', 0))
if _DYLD_WORKAROUND_SET and _DYLD_WORKAROUND_VAL:
_launch_threads()
def set_parallel_chunksize(n):
_launch_threads()
if not isinstance(n, (int, np.integer)):
raise TypeError("The parallel chunksize must be an integer")
if n < 0:
raise ValueError("chunksize must be greater than or equal to zero")
return _set_parallel_chunksize(n)
def get_parallel_chunksize():
_launch_threads()
return _get_parallel_chunksize()
@overload(set_parallel_chunksize)
def ol_set_parallel_chunksize(n):
_launch_threads()
if not isinstance(n, types.Integer):
msg = "The parallel chunksize must be an integer"
raise errors.TypingError(msg)
def impl(n):
if n < 0:
raise ValueError("chunksize must be greater than or equal to zero")
return _set_parallel_chunksize(n)
return impl
@overload(get_parallel_chunksize)
def ol_get_parallel_chunksize():
_launch_threads()
def impl():
return _get_parallel_chunksize()
return impl

View File

@@ -0,0 +1,63 @@
import tokenize
import string
def parse_signature(sig):
'''Parse generalized ufunc signature.
NOTE: ',' (COMMA) is a delimiter; not separator.
This means trailing comma is legal.
'''
def stripws(s):
return ''.join(c for c in s if c not in string.whitespace)
def tokenizer(src):
def readline():
yield src
gen = readline()
return tokenize.generate_tokens(lambda: next(gen))
def parse(src):
tokgen = tokenizer(src)
while True:
tok = next(tokgen)
if tok[1] == '(':
symbols = []
while True:
tok = next(tokgen)
if tok[1] == ')':
break
elif tok[0] == tokenize.NAME:
symbols.append(tok[1])
elif tok[1] == ',':
continue
else:
raise ValueError('bad token in signature "%s"' % tok[1])
yield tuple(symbols)
tok = next(tokgen)
if tok[1] == ',':
continue
elif tokenize.ISEOF(tok[0]):
break
elif tokenize.ISEOF(tok[0]):
break
else:
raise ValueError('bad token in signature "%s"' % tok[1])
ins, _, outs = stripws(sig).partition('->')
inputs = list(parse(ins))
outputs = list(parse(outs))
# check that all output symbols are defined in the inputs
isym = set()
osym = set()
for grp in inputs:
isym |= set(grp)
for grp in outputs:
osym |= set(grp)
diff = osym.difference(isym)
if diff:
raise NameError('undefined output symbols: %s' % ','.join(sorted(diff)))
return inputs, outputs

View File

@@ -0,0 +1,113 @@
from numba.np import numpy_support
from numba.core import types
class UfuncLowererBase:
'''Callable class responsible for lowering calls to a specific gufunc.
'''
def __init__(self, ufunc, make_kernel_fn, make_ufunc_kernel_fn):
self.ufunc = ufunc
self.make_ufunc_kernel_fn = make_ufunc_kernel_fn
self.kernel = make_kernel_fn(ufunc)
self.libs = []
def __call__(self, context, builder, sig, args):
return self.make_ufunc_kernel_fn(context, builder, sig, args,
self.ufunc, self.kernel)
class UfuncBase:
@property
def nin(self):
return self.ufunc.nin
@property
def nout(self):
return self.ufunc.nout
@property
def nargs(self):
return self.ufunc.nargs
@property
def ntypes(self):
return self.ufunc.ntypes
@property
def types(self):
return self.ufunc.types
@property
def identity(self):
return self.ufunc.identity
@property
def signature(self):
return self.ufunc.signature
@property
def accumulate(self):
return self.ufunc.accumulate
@property
def at(self):
return self.ufunc.at
@property
def outer(self):
return self.ufunc.outer
@property
def reduce(self):
return self.ufunc.reduce
@property
def reduceat(self):
return self.ufunc.reduceat
def disable_compile(self):
"""
Disable the compilation of new signatures at call time.
"""
# If disabling compilation then there must be at least one signature
assert len(self._dispatcher.overloads) > 0
self._frozen = True
def _install_cg(self, targetctx=None):
"""
Install an implementation function for a GUFunc/DUFunc object in the
given target context. If no target context is given, then
_install_cg() installs into the target context of the
dispatcher object (should be same default context used by
jit() and njit()).
"""
if targetctx is None:
targetctx = self._dispatcher.targetdescr.target_context
_any = types.Any
_arr = types.Array
# Either all outputs are explicit or none of them are
sig0 = (_any,) * self.ufunc.nin + (_arr,) * self.ufunc.nout
sig1 = (_any,) * self.ufunc.nin
targetctx.insert_func_defn(
[(self._lower_me, self, sig) for sig in (sig0, sig1)])
def find_ewise_function(self, ewise_types):
"""
Given a tuple of element-wise argument types, find a matching
signature in the dispatcher.
Return a 2-tuple containing the matching signature, and
compilation result. Will return two None's if no matching
signature was found.
"""
if self._frozen:
# If we cannot compile, coerce to the best matching loop
loop = numpy_support.ufunc_find_matching_loop(self, ewise_types)
if loop is None:
return None, None
ewise_types = tuple(loop.inputs + loop.outputs)[:len(ewise_types)]
for sig, cres in self._dispatcher.overloads.items():
if self.match_signature(ewise_types, sig):
return sig, cres
return None, None

View File

@@ -0,0 +1,444 @@
# -*- coding: utf-8 -*-
import inspect
import warnings
from contextlib import contextmanager
from numba.core import config, targetconfig
from numba.core.decorators import jit
from numba.core.descriptors import TargetDescriptor
from numba.core.extending import is_jitted
from numba.core.errors import NumbaDeprecationWarning
from numba.core.options import TargetOptions, include_default_options
from numba.core.registry import cpu_target
from numba.core.target_extension import dispatcher_registry, target_registry
from numba.core import utils, types, serialize, compiler, sigutils
from numba.np.numpy_support import as_dtype
from numba.np.ufunc import _internal
from numba.np.ufunc.sigparse import parse_signature
from numba.np.ufunc.wrappers import build_ufunc_wrapper, build_gufunc_wrapper
from numba.core.caching import FunctionCache, NullCache
from numba.core.compiler_lock import global_compiler_lock
_options_mixin = include_default_options(
"nopython",
"forceobj",
"boundscheck",
"fastmath",
"writable_args"
)
class UFuncTargetOptions(_options_mixin, TargetOptions):
def finalize(self, flags, options):
if not flags.is_set("enable_pyobject"):
flags.enable_pyobject = True
if not flags.is_set("enable_looplift"):
flags.enable_looplift = True
flags.inherit_if_not_set("nrt", default=True)
if not flags.is_set("debuginfo"):
flags.debuginfo = config.DEBUGINFO_DEFAULT
if not flags.is_set("boundscheck"):
flags.boundscheck = flags.debuginfo
flags.enable_pyobject_looplift = True
flags.inherit_if_not_set("fastmath")
class UFuncTarget(TargetDescriptor):
options = UFuncTargetOptions
def __init__(self):
super().__init__('ufunc')
@property
def typing_context(self):
return cpu_target.typing_context
@property
def target_context(self):
return cpu_target.target_context
ufunc_target = UFuncTarget()
class UFuncDispatcher(serialize.ReduceMixin):
"""
An object handling compilation of various signatures for a ufunc.
"""
targetdescr = ufunc_target
def __init__(self, py_func, locals=None, targetoptions=None):
if locals is None:
locals = {}
if targetoptions is None:
targetoptions = {}
self.py_func = py_func
self.overloads = utils.UniqueDict()
self.targetoptions = targetoptions
self.locals = locals
self.cache = NullCache()
def _reduce_states(self):
"""
NOTE: part of ReduceMixin protocol
"""
return dict(
pyfunc=self.py_func,
locals=self.locals,
targetoptions=self.targetoptions,
)
@classmethod
def _rebuild(cls, pyfunc, locals, targetoptions):
"""
NOTE: part of ReduceMixin protocol
"""
return cls(py_func=pyfunc, locals=locals, targetoptions=targetoptions)
def enable_caching(self):
self.cache = FunctionCache(self.py_func)
def compile(self, sig, locals=None, **targetoptions):
if locals is None:
locals = {}
locs = self.locals.copy()
locs.update(locals)
topt = self.targetoptions.copy()
topt.update(targetoptions)
flags = compiler.Flags()
self.targetdescr.options.parse_as_flags(flags, topt)
flags.no_cpython_wrapper = True
flags.error_model = "numpy"
# Disable loop lifting
# The feature requires a real
# python function
flags.enable_looplift = False
return self._compile_core(sig, flags, locals)
def _compile_core(self, sig, flags, locals):
"""
Trigger the compiler on the core function or load a previously
compiled version from the cache. Returns the CompileResult.
"""
typingctx = self.targetdescr.typing_context
targetctx = self.targetdescr.target_context
@contextmanager
def store_overloads_on_success():
# use to ensure overloads are stored on success
try:
yield
except Exception:
raise
else:
exists = self.overloads.get(cres.signature)
if exists is None:
self.overloads[cres.signature] = cres
# Use cache and compiler in a critical section
with global_compiler_lock:
with targetconfig.ConfigStack().enter(flags.copy()):
with store_overloads_on_success():
# attempt look up of existing
cres = self.cache.load_overload(sig, targetctx)
if cres is not None:
return cres
# Compile
args, return_type = sigutils.normalize_signature(sig)
cres = compiler.compile_extra(typingctx, targetctx,
self.py_func, args=args,
return_type=return_type,
flags=flags, locals=locals)
# cache lookup failed before so safe to save
self.cache.save_overload(sig, cres)
return cres
dispatcher_registry[target_registry['npyufunc']] = UFuncDispatcher
# Utility functions
def _compile_element_wise_function(nb_func, targetoptions, sig):
# Do compilation
# Return CompileResult to test
cres = nb_func.compile(sig, **targetoptions)
args, return_type = sigutils.normalize_signature(sig)
return cres, args, return_type
def _finalize_ufunc_signature(cres, args, return_type):
'''Given a compilation result, argument types, and a return type,
build a valid Numba signature after validating that it doesn't
violate the constraints for the compilation mode.
'''
if return_type is None:
if cres.objectmode:
# Object mode is used and return type is not specified
raise TypeError("return type must be specified for object mode")
else:
return_type = cres.signature.return_type
assert return_type != types.pyobject
return return_type(*args)
def _build_element_wise_ufunc_wrapper(cres, signature):
'''Build a wrapper for the ufunc loop entry point given by the
compilation result object, using the element-wise signature.
'''
ctx = cres.target_context
library = cres.library
fname = cres.fndesc.llvm_func_name
with global_compiler_lock:
info = build_ufunc_wrapper(library, ctx, fname, signature,
cres.objectmode, cres)
ptr = info.library.get_pointer_to_function(info.name)
# Get dtypes
dtypenums = [as_dtype(a).num for a in signature.args]
dtypenums.append(as_dtype(signature.return_type).num)
return dtypenums, ptr, cres.environment
_identities = {
0: _internal.PyUFunc_Zero,
1: _internal.PyUFunc_One,
None: _internal.PyUFunc_None,
"reorderable": _internal.PyUFunc_ReorderableNone,
}
def parse_identity(identity):
"""
Parse an identity value and return the corresponding low-level value
for Numpy.
"""
try:
identity = _identities[identity]
except KeyError:
raise ValueError("Invalid identity value %r" % (identity,))
return identity
@contextmanager
def _suppress_deprecation_warning_nopython_not_supplied():
"""This suppresses the NumbaDeprecationWarning that occurs through the use
of `jit` without the `nopython` kwarg. This use of `jit` occurs in a few
places in the `{g,}ufunc` mechanism in Numba, predominantly to wrap the
"kernel" function."""
with warnings.catch_warnings():
warnings.filterwarnings('ignore',
category=NumbaDeprecationWarning,
message=(".*The 'nopython' keyword argument "
"was not supplied*"),)
yield
# Class definitions
class _BaseUFuncBuilder(object):
def add(self, sig=None):
if hasattr(self, 'targetoptions'):
targetoptions = self.targetoptions
else:
targetoptions = self.nb_func.targetoptions
cres, args, return_type = _compile_element_wise_function(
self.nb_func, targetoptions, sig)
sig = self._finalize_signature(cres, args, return_type)
self._sigs.append(sig)
self._cres[sig] = cres
return cres
def disable_compile(self):
"""
Disable the compilation of new signatures at call time.
"""
# Override this for implementations that support lazy compilation
class UFuncBuilder(_BaseUFuncBuilder):
def __init__(self, py_func, identity=None, cache=False, targetoptions=None):
if targetoptions is None:
targetoptions = {}
if is_jitted(py_func):
py_func = py_func.py_func
self.py_func = py_func
self.identity = parse_identity(identity)
with _suppress_deprecation_warning_nopython_not_supplied():
self.nb_func = jit(_target='npyufunc',
cache=cache,
**targetoptions)(py_func)
self._sigs = []
self._cres = {}
def _finalize_signature(self, cres, args, return_type):
'''Slated for deprecation, use ufuncbuilder._finalize_ufunc_signature()
instead.
'''
return _finalize_ufunc_signature(cres, args, return_type)
def build_ufunc(self):
with global_compiler_lock:
dtypelist = []
ptrlist = []
if not self.nb_func:
raise TypeError("No definition")
# Get signature in the order they are added
keepalive = []
cres = None
for sig in self._sigs:
cres = self._cres[sig]
dtypenums, ptr, env = self.build(cres, sig)
dtypelist.append(dtypenums)
ptrlist.append(int(ptr))
keepalive.append((cres.library, env))
datlist = [None] * len(ptrlist)
if cres is None:
argspec = inspect.getfullargspec(self.py_func)
inct = len(argspec.args)
else:
inct = len(cres.signature.args)
outct = 1
# Becareful that fromfunc does not provide full error checking yet.
# If typenum is out-of-bound, we have nasty memory corruptions.
# For instance, -1 for typenum will cause segfault.
# If elements of type-list (2nd arg) is tuple instead,
# there will also memory corruption. (Seems like code rewrite.)
ufunc = _internal.fromfunc(
self.py_func.__name__, self.py_func.__doc__,
ptrlist, dtypelist, inct, outct, datlist,
keepalive, self.identity,
)
return ufunc
def build(self, cres, signature):
'''Slated for deprecation, use
ufuncbuilder._build_element_wise_ufunc_wrapper().
'''
return _build_element_wise_ufunc_wrapper(cres, signature)
class GUFuncBuilder(_BaseUFuncBuilder):
# TODO handle scalar
def __init__(self, py_func, signature, identity=None, cache=False,
targetoptions=None, writable_args=()):
if targetoptions is None:
targetoptions = {}
self.py_func = py_func
self.identity = parse_identity(identity)
with _suppress_deprecation_warning_nopython_not_supplied():
self.nb_func = jit(_target='npyufunc', cache=cache)(py_func)
self.signature = signature
self.sin, self.sout = parse_signature(signature)
self.targetoptions = targetoptions
self.cache = cache
self._sigs = []
self._cres = {}
transform_arg = _get_transform_arg(py_func)
self.writable_args = tuple([transform_arg(a) for a in writable_args])
def _finalize_signature(self, cres, args, return_type):
if not cres.objectmode and cres.signature.return_type != types.void:
raise TypeError("gufunc kernel must have void return type")
if return_type is None:
return_type = types.void
return return_type(*args)
@global_compiler_lock
def build_ufunc(self):
type_list = []
func_list = []
if not self.nb_func:
raise TypeError("No definition")
# Get signature in the order they are added
keepalive = []
for sig in self._sigs:
cres = self._cres[sig]
dtypenums, ptr, env = self.build(cres)
type_list.append(dtypenums)
func_list.append(int(ptr))
keepalive.append((cres.library, env))
datalist = [None] * len(func_list)
nin = len(self.sin)
nout = len(self.sout)
# Pass envs to fromfuncsig to bind to the lifetime of the ufunc object
ufunc = _internal.fromfunc(
self.py_func.__name__, self.py_func.__doc__,
func_list, type_list, nin, nout, datalist,
keepalive, self.identity, self.signature, self.writable_args
)
return ufunc
def build(self, cres):
"""
Returns (dtype numbers, function ptr, EnvironmentObject)
"""
# Builder wrapper for ufunc entry point
signature = cres.signature
info = build_gufunc_wrapper(
self.py_func, cres, self.sin, self.sout,
cache=self.cache, is_parfors=False,
)
env = info.env
ptr = info.library.get_pointer_to_function(info.name)
# Get dtypes
dtypenums = []
for a in signature.args:
if isinstance(a, types.Array):
ty = a.dtype
else:
ty = a
dtypenums.append(as_dtype(ty).num)
return dtypenums, ptr, env
def _get_transform_arg(py_func):
"""Return function that transform arg into index"""
args = inspect.getfullargspec(py_func).args
pos_by_arg = {arg: i for i, arg in enumerate(args)}
def transform_arg(arg):
if isinstance(arg, int):
return arg
try:
return pos_by_arg[arg]
except KeyError:
msg = (f"Specified writable arg {arg} not found in arg list "
f"{args} for function {py_func.__qualname__}")
raise RuntimeError(msg)
return transform_arg

View File

@@ -0,0 +1,743 @@
from collections import namedtuple
import numpy as np
from llvmlite.ir import Constant, IRBuilder
from llvmlite import ir
from numba.core import types, cgutils
from numba.core.compiler_lock import global_compiler_lock
from numba.core.caching import make_library_cache, NullCache
_wrapper_info = namedtuple('_wrapper_info', ['library', 'env', 'name'])
def _build_ufunc_loop_body(load, store, context, func, builder, arrays, out,
offsets, store_offset, signature, pyapi, env):
elems = load()
# Compute
status, retval = context.call_conv.call_function(builder, func,
signature.return_type,
signature.args, elems)
# Store
with builder.if_else(status.is_ok, likely=True) as (if_ok, if_error):
with if_ok:
store(retval)
with if_error:
gil = pyapi.gil_ensure()
context.call_conv.raise_error(builder, pyapi, status)
pyapi.gil_release(gil)
# increment indices
for off, ary in zip(offsets, arrays):
builder.store(builder.add(builder.load(off), ary.step), off)
builder.store(builder.add(builder.load(store_offset), out.step),
store_offset)
return status.code
def _build_ufunc_loop_body_objmode(load, store, context, func, builder,
arrays, out, offsets, store_offset,
signature, env, pyapi):
elems = load()
# Compute
_objargs = [types.pyobject] * len(signature.args)
# We need to push the error indicator to avoid it messing with
# the ufunc's execution. We restore it unless the ufunc raised
# a new error.
with pyapi.err_push(keep_new=True):
status, retval = context.call_conv.call_function(builder, func,
types.pyobject,
_objargs, elems)
# Release owned reference to arguments
for elem in elems:
pyapi.decref(elem)
# NOTE: if an error occurred, it will be caught by the Numpy machinery
# Store
store(retval)
# increment indices
for off, ary in zip(offsets, arrays):
builder.store(builder.add(builder.load(off), ary.step), off)
builder.store(builder.add(builder.load(store_offset), out.step),
store_offset)
return status.code
def build_slow_loop_body(context, func, builder, arrays, out, offsets,
store_offset, signature, pyapi, env):
def load():
elems = [ary.load_direct(builder.load(off))
for off, ary in zip(offsets, arrays)]
return elems
def store(retval):
out.store_direct(retval, builder.load(store_offset))
return _build_ufunc_loop_body(load, store, context, func, builder, arrays,
out, offsets, store_offset, signature, pyapi,
env=env)
def build_obj_loop_body(context, func, builder, arrays, out, offsets,
store_offset, signature, pyapi, envptr, env):
env_body = context.get_env_body(builder, envptr)
env_manager = pyapi.get_env_manager(env, env_body, envptr)
def load():
# Load
elems = [ary.load_direct(builder.load(off))
for off, ary in zip(offsets, arrays)]
# Box
elems = [pyapi.from_native_value(t, v, env_manager)
for v, t in zip(elems, signature.args)]
return elems
def store(retval):
is_ok = cgutils.is_not_null(builder, retval)
# If an error is raised by the object mode ufunc, it will
# simply get caught by the Numpy ufunc machinery.
with builder.if_then(is_ok, likely=True):
# Unbox
native = pyapi.to_native_value(signature.return_type, retval)
assert native.cleanup is None
# Store
out.store_direct(native.value, builder.load(store_offset))
# Release owned reference
pyapi.decref(retval)
return _build_ufunc_loop_body_objmode(load, store, context, func, builder,
arrays, out, offsets, store_offset,
signature, envptr, pyapi)
def build_fast_loop_body(context, func, builder, arrays, out, offsets,
store_offset, signature, ind, pyapi, env):
def load():
elems = [ary.load_aligned(ind)
for ary in arrays]
return elems
def store(retval):
out.store_aligned(retval, ind)
return _build_ufunc_loop_body(load, store, context, func, builder, arrays,
out, offsets, store_offset, signature, pyapi,
env=env)
def build_ufunc_wrapper(library, context, fname, signature, objmode, cres):
"""
Wrap the scalar function with a loop that iterates over the arguments
Returns
-------
(library, env, name)
"""
assert isinstance(fname, str)
byte_t = ir.IntType(8)
byte_ptr_t = ir.PointerType(byte_t)
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
intp_t = context.get_value_type(types.intp)
intp_ptr_t = ir.PointerType(intp_t)
fnty = ir.FunctionType(ir.VoidType(), [byte_ptr_ptr_t, intp_ptr_t,
intp_ptr_t, byte_ptr_t])
wrapperlib = context.codegen().create_library('ufunc_wrapper')
wrapper_module = wrapperlib.create_ir_module('')
if objmode:
func_type = context.call_conv.get_function_type(
types.pyobject, [types.pyobject] * len(signature.args))
else:
func_type = context.call_conv.get_function_type(
signature.return_type, signature.args)
func = ir.Function(wrapper_module, func_type, name=fname)
func.attributes.add("alwaysinline")
wrapper = ir.Function(wrapper_module, fnty, "__ufunc__." + func.name)
arg_args, arg_dims, arg_steps, arg_data = wrapper.args
arg_args.name = "args"
arg_dims.name = "dims"
arg_steps.name = "steps"
arg_data.name = "data"
builder = IRBuilder(wrapper.append_basic_block("entry"))
# Prepare Environment
envname = context.get_env_name(cres.fndesc)
env = cres.environment
envptr = builder.load(context.declare_env_global(builder.module, envname))
# Emit loop
loopcount = builder.load(arg_dims, name="loopcount")
# Prepare inputs
arrays = []
for i, typ in enumerate(signature.args):
arrays.append(UArrayArg(context, builder, arg_args, arg_steps, i, typ))
# Prepare output
out = UArrayArg(context, builder, arg_args, arg_steps, len(arrays),
signature.return_type)
# Setup indices
offsets = []
zero = context.get_constant(types.intp, 0)
for _ in arrays:
p = cgutils.alloca_once(builder, intp_t)
offsets.append(p)
builder.store(zero, p)
store_offset = cgutils.alloca_once(builder, intp_t)
builder.store(zero, store_offset)
unit_strided = cgutils.true_bit
for ary in arrays:
unit_strided = builder.and_(unit_strided, ary.is_unit_strided)
pyapi = context.get_python_api(builder)
if objmode:
# General loop
gil = pyapi.gil_ensure()
with cgutils.for_range(builder, loopcount, intp=intp_t):
build_obj_loop_body(
context, func, builder, arrays, out, offsets,
store_offset, signature, pyapi, envptr, env,
)
pyapi.gil_release(gil)
builder.ret_void()
else:
with builder.if_else(unit_strided) as (is_unit_strided, is_strided):
with is_unit_strided:
with cgutils.for_range(builder, loopcount, intp=intp_t) as loop:
build_fast_loop_body(
context, func, builder, arrays, out, offsets,
store_offset, signature, loop.index, pyapi,
env=envptr,
)
with is_strided:
# General loop
with cgutils.for_range(builder, loopcount, intp=intp_t):
build_slow_loop_body(
context, func, builder, arrays, out, offsets,
store_offset, signature, pyapi,
env=envptr,
)
builder.ret_void()
del builder
# Link and finalize
wrapperlib.add_ir_module(wrapper_module)
wrapperlib.add_linking_library(library)
return _wrapper_info(library=wrapperlib, env=env, name=wrapper.name)
class UArrayArg(object):
def __init__(self, context, builder, args, steps, i, fe_type):
self.context = context
self.builder = builder
self.fe_type = fe_type
offset = self.context.get_constant(types.intp, i)
offseted_args = self.builder.load(builder.gep(args, [offset]))
data_type = context.get_data_type(fe_type)
self.dataptr = self.builder.bitcast(offseted_args,
data_type.as_pointer())
sizeof = self.context.get_abi_sizeof(data_type)
self.abisize = self.context.get_constant(types.intp, sizeof)
offseted_step = self.builder.gep(steps, [offset])
self.step = self.builder.load(offseted_step)
self.is_unit_strided = builder.icmp_unsigned('==',
self.abisize, self.step)
self.builder = builder
def load_direct(self, byteoffset):
"""
Generic load from the given *byteoffset*. load_aligned() is
preferred if possible.
"""
ptr = cgutils.pointer_add(self.builder, self.dataptr, byteoffset)
return self.context.unpack_value(self.builder, self.fe_type, ptr)
def load_aligned(self, ind):
# Using gep() instead of explicit pointer addition helps LLVM
# vectorize the loop.
ptr = self.builder.gep(self.dataptr, [ind])
return self.context.unpack_value(self.builder, self.fe_type, ptr)
def store_direct(self, value, byteoffset):
ptr = cgutils.pointer_add(self.builder, self.dataptr, byteoffset)
self.context.pack_value(self.builder, self.fe_type, value, ptr)
def store_aligned(self, value, ind):
ptr = self.builder.gep(self.dataptr, [ind])
self.context.pack_value(self.builder, self.fe_type, value, ptr)
GufWrapperCache = make_library_cache('guf')
class _GufuncWrapper(object):
def __init__(self, py_func, cres, sin, sout, cache, is_parfors):
"""
The *is_parfors* argument is a boolean that indicates if the GUfunc
being built is to be used as a ParFors kernel. If True, it disables
the caching on the wrapper as a separate unit because it will be linked
into the caller function and cached along with it.
"""
self.py_func = py_func
self.cres = cres
self.sin = sin
self.sout = sout
self.is_objectmode = self.signature.return_type == types.pyobject
self.cache = (GufWrapperCache(py_func=self.py_func)
if cache else NullCache())
self.is_parfors = bool(is_parfors)
@property
def library(self):
return self.cres.library
@property
def context(self):
return self.cres.target_context
@property
def call_conv(self):
return self.context.call_conv
@property
def signature(self):
return self.cres.signature
@property
def fndesc(self):
return self.cres.fndesc
@property
def env(self):
return self.cres.environment
def _wrapper_function_type(self):
byte_t = ir.IntType(8)
byte_ptr_t = ir.PointerType(byte_t)
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
intp_t = self.context.get_value_type(types.intp)
intp_ptr_t = ir.PointerType(intp_t)
fnty = ir.FunctionType(ir.VoidType(), [byte_ptr_ptr_t, intp_ptr_t,
intp_ptr_t, byte_ptr_t])
return fnty
def _build_wrapper(self, library, name):
"""
The LLVM IRBuilder code to create the gufunc wrapper.
The *library* arg is the CodeLibrary to which the wrapper should
be added. The *name* arg is the name of the wrapper function being
created.
"""
intp_t = self.context.get_value_type(types.intp)
fnty = self._wrapper_function_type()
wrapper_module = library.create_ir_module('_gufunc_wrapper')
func_type = self.call_conv.get_function_type(self.fndesc.restype,
self.fndesc.argtypes)
fname = self.fndesc.llvm_func_name
func = ir.Function(wrapper_module, func_type, name=fname)
func.attributes.add("alwaysinline")
wrapper = ir.Function(wrapper_module, fnty, name)
# The use of weak_odr linkage avoids the function being dropped due
# to the order in which the wrappers and the user function are linked.
wrapper.linkage = 'weak_odr'
arg_args, arg_dims, arg_steps, arg_data = wrapper.args
arg_args.name = "args"
arg_dims.name = "dims"
arg_steps.name = "steps"
arg_data.name = "data"
builder = IRBuilder(wrapper.append_basic_block("entry"))
loopcount = builder.load(arg_dims, name="loopcount")
pyapi = self.context.get_python_api(builder)
# Unpack shapes
unique_syms = set()
for grp in (self.sin, self.sout):
for syms in grp:
unique_syms |= set(syms)
sym_map = {}
for syms in self.sin:
for s in syms:
if s not in sym_map:
sym_map[s] = len(sym_map)
sym_dim = {}
for s, i in sym_map.items():
sym_dim[s] = builder.load(builder.gep(arg_dims,
[self.context.get_constant(
types.intp,
i + 1)]))
# Prepare inputs
arrays = []
step_offset = len(self.sin) + len(self.sout)
for i, (typ, sym) in enumerate(zip(self.signature.args,
self.sin + self.sout)):
ary = GUArrayArg(self.context, builder, arg_args,
arg_steps, i, step_offset, typ, sym, sym_dim)
step_offset += len(sym)
arrays.append(ary)
bbreturn = builder.append_basic_block('.return')
# Prologue
self.gen_prologue(builder, pyapi)
# Loop
with cgutils.for_range(builder, loopcount, intp=intp_t) as loop:
args = [a.get_array_at_offset(loop.index) for a in arrays]
innercall, error = self.gen_loop_body(builder, pyapi, func, args)
# If error, escape
cgutils.cbranch_or_continue(builder, error, bbreturn)
builder.branch(bbreturn)
builder.position_at_end(bbreturn)
# Epilogue
self.gen_epilogue(builder, pyapi)
builder.ret_void()
# Link
library.add_ir_module(wrapper_module)
library.add_linking_library(self.library)
def _compile_wrapper(self, wrapper_name):
# Gufunc created by Parfors?
if self.is_parfors:
# No wrapper caching for parfors
wrapperlib = self.context.codegen().create_library(str(self))
# Build wrapper
self._build_wrapper(wrapperlib, wrapper_name)
# Non-parfors?
else:
# Use cache and compiler in a critical section
wrapperlib = self.cache.load_overload(
self.cres.signature, self.cres.target_context,
)
if wrapperlib is None:
# Create library and enable caching
wrapperlib = self.context.codegen().create_library(str(self))
wrapperlib.enable_object_caching()
# Build wrapper
self._build_wrapper(wrapperlib, wrapper_name)
# Cache
self.cache.save_overload(self.cres.signature, wrapperlib)
return wrapperlib
@global_compiler_lock
def build(self):
wrapper_name = "__gufunc__." + self.fndesc.mangled_name
wrapperlib = self._compile_wrapper(wrapper_name)
return _wrapper_info(
library=wrapperlib, env=self.env, name=wrapper_name,
)
def gen_loop_body(self, builder, pyapi, func, args):
status, retval = self.call_conv.call_function(
builder, func, self.signature.return_type, self.signature.args,
args)
with builder.if_then(status.is_error, likely=False):
gil = pyapi.gil_ensure()
self.context.call_conv.raise_error(builder, pyapi, status)
pyapi.gil_release(gil)
return status.code, status.is_error
def gen_prologue(self, builder, pyapi):
pass # Do nothing
def gen_epilogue(self, builder, pyapi):
pass # Do nothing
class _GufuncObjectWrapper(_GufuncWrapper):
def gen_loop_body(self, builder, pyapi, func, args):
innercall, error = _prepare_call_to_object_mode(self.context,
builder, pyapi, func,
self.signature,
args)
return innercall, error
def gen_prologue(self, builder, pyapi):
# Acquire the GIL
self.gil = pyapi.gil_ensure()
def gen_epilogue(self, builder, pyapi):
# Release GIL
pyapi.gil_release(self.gil)
def build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors):
signature = cres.signature
wrapcls = (_GufuncObjectWrapper
if signature.return_type == types.pyobject
else _GufuncWrapper)
return wrapcls(
py_func, cres, sin, sout, cache, is_parfors=is_parfors,
).build()
def _prepare_call_to_object_mode(context, builder, pyapi, func,
signature, args):
mod = builder.module
bb_core_return = builder.append_basic_block('ufunc.core.return')
# Call to
# PyObject* ndarray_new(int nd,
# npy_intp *dims, /* shape */
# npy_intp *strides,
# void* data,
# int type_num,
# int itemsize)
ll_int = context.get_value_type(types.int32)
ll_intp = context.get_value_type(types.intp)
ll_intp_ptr = ir.PointerType(ll_intp)
ll_voidptr = context.get_value_type(types.voidptr)
ll_pyobj = context.get_value_type(types.pyobject)
fnty = ir.FunctionType(ll_pyobj, [ll_int, ll_intp_ptr,
ll_intp_ptr, ll_voidptr,
ll_int, ll_int])
fn_array_new = cgutils.get_or_insert_function(mod, fnty,
"numba_ndarray_new")
# Convert each llarray into pyobject
error_pointer = cgutils.alloca_once(builder, ir.IntType(1), name='error')
builder.store(cgutils.true_bit, error_pointer)
# The PyObject* arguments to the kernel function
object_args = []
object_pointers = []
for i, (arg, argty) in enumerate(zip(args, signature.args)):
# Allocate NULL-initialized slot for this argument
objptr = cgutils.alloca_once(builder, ll_pyobj, zfill=True)
object_pointers.append(objptr)
if isinstance(argty, types.Array):
# Special case arrays: we don't need full-blown NRT reflection
# since the argument will be gone at the end of the kernel
arycls = context.make_array(argty)
array = arycls(context, builder, value=arg)
zero = Constant(ll_int, 0)
# Extract members of the llarray
nd = Constant(ll_int, argty.ndim)
dims = builder.gep(array._get_ptr_by_name('shape'), [zero, zero])
strides = builder.gep(array._get_ptr_by_name('strides'),
[zero, zero])
data = builder.bitcast(array.data, ll_voidptr)
dtype = np.dtype(str(argty.dtype))
# Prepare other info for reconstruction of the PyArray
type_num = Constant(ll_int, dtype.num)
itemsize = Constant(ll_int, dtype.itemsize)
# Call helper to reconstruct PyArray objects
obj = builder.call(fn_array_new, [nd, dims, strides, data,
type_num, itemsize])
else:
# Other argument types => use generic boxing
obj = pyapi.from_native_value(argty, arg)
builder.store(obj, objptr)
object_args.append(obj)
obj_is_null = cgutils.is_null(builder, obj)
builder.store(obj_is_null, error_pointer)
cgutils.cbranch_or_continue(builder, obj_is_null, bb_core_return)
# Call ufunc core function
object_sig = [types.pyobject] * len(object_args)
status, retval = context.call_conv.call_function(
builder, func, types.pyobject, object_sig,
object_args)
builder.store(status.is_error, error_pointer)
# Release returned object
pyapi.decref(retval)
builder.branch(bb_core_return)
# At return block
builder.position_at_end(bb_core_return)
# Release argument objects
for objptr in object_pointers:
pyapi.decref(builder.load(objptr))
innercall = status.code
return innercall, builder.load(error_pointer)
class GUArrayArg(object):
def __init__(self, context, builder, args, steps, i, step_offset,
typ, syms, sym_dim):
self.context = context
self.builder = builder
offset = context.get_constant(types.intp, i)
data = builder.load(builder.gep(args, [offset], name="data.ptr"),
name="data")
self.data = data
core_step_ptr = builder.gep(steps, [offset], name="core.step.ptr")
core_step = builder.load(core_step_ptr)
if isinstance(typ, types.Array):
as_scalar = not syms
# number of symbol in the shape spec should match the dimension
# of the array type.
if len(syms) != typ.ndim:
if len(syms) == 0 and typ.ndim == 1:
# This is an exception for handling scalar argument.
# The type can be 1D array for scalar.
# In the future, we may deprecate this exception.
pass
else:
raise TypeError("type and shape signature mismatch for arg "
"#{0}".format(i + 1))
ndim = typ.ndim
shape = [sym_dim[s] for s in syms]
strides = []
for j in range(ndim):
stepptr = builder.gep(steps,
[context.get_constant(types.intp,
step_offset + j)],
name="step.ptr")
step = builder.load(stepptr)
strides.append(step)
ldcls = (_ArrayAsScalarArgLoader
if as_scalar
else _ArrayArgLoader)
self._loader = ldcls(dtype=typ.dtype,
ndim=ndim,
core_step=core_step,
as_scalar=as_scalar,
shape=shape,
strides=strides)
else:
# If typ is not an array
if syms:
raise TypeError("scalar type {0} given for non scalar "
"argument #{1}".format(typ, i + 1))
self._loader = _ScalarArgLoader(dtype=typ, stride=core_step)
def get_array_at_offset(self, ind):
return self._loader.load(context=self.context, builder=self.builder,
data=self.data, ind=ind)
class _ScalarArgLoader(object):
"""
Handle GFunc argument loading where a scalar type is used in the core
function.
Note: It still has a stride because the input to the gufunc can be an array
for this argument.
"""
def __init__(self, dtype, stride):
self.dtype = dtype
self.stride = stride
def load(self, context, builder, data, ind):
# Load at base + ind * stride
data = builder.gep(data, [builder.mul(ind, self.stride)])
dptr = builder.bitcast(data,
context.get_data_type(self.dtype).as_pointer())
return builder.load(dptr)
class _ArrayArgLoader(object):
"""
Handle GUFunc argument loading where an array is expected.
"""
def __init__(self, dtype, ndim, core_step, as_scalar, shape, strides):
self.dtype = dtype
self.ndim = ndim
self.core_step = core_step
self.as_scalar = as_scalar
self.shape = shape
self.strides = strides
def load(self, context, builder, data, ind):
arytyp = types.Array(dtype=self.dtype, ndim=self.ndim, layout="A")
arycls = context.make_array(arytyp)
array = arycls(context, builder)
offseted_data = cgutils.pointer_add(builder,
data,
builder.mul(self.core_step,
ind))
shape, strides = self._shape_and_strides(context, builder)
itemsize = context.get_abi_sizeof(context.get_data_type(self.dtype))
context.populate_array(array,
data=builder.bitcast(offseted_data,
array.data.type),
shape=shape,
strides=strides,
itemsize=context.get_constant(types.intp,
itemsize),
meminfo=None)
return array._getvalue()
def _shape_and_strides(self, context, builder):
shape = cgutils.pack_array(builder, self.shape)
strides = cgutils.pack_array(builder, self.strides)
return shape, strides
class _ArrayAsScalarArgLoader(_ArrayArgLoader):
"""
Handle GUFunc argument loading where the shape signature specifies
a scalar "()" but a 1D array is used for the type of the core function.
"""
def _shape_and_strides(self, context, builder):
# Set shape and strides for a 1D size 1 array
one = context.get_constant(types.intp, 1)
zero = context.get_constant(types.intp, 0)
shape = cgutils.pack_array(builder, [one])
strides = cgutils.pack_array(builder, [zero])
return shape, strides