Videre
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from numba.np.ufunc.decorators import Vectorize, GUVectorize, vectorize, guvectorize
|
||||
from numba.np.ufunc._internal import PyUFunc_None, PyUFunc_Zero, PyUFunc_One
|
||||
from numba.np.ufunc import _internal, array_exprs
|
||||
from numba.np.ufunc.parallel import (threading_layer, get_num_threads,
|
||||
set_num_threads, get_thread_id,
|
||||
set_parallel_chunksize,
|
||||
get_parallel_chunksize)
|
||||
|
||||
|
||||
if hasattr(_internal, 'PyUFunc_ReorderableNone'):
|
||||
PyUFunc_ReorderableNone = _internal.PyUFunc_ReorderableNone
|
||||
del _internal, array_exprs
|
||||
|
||||
|
||||
def _init():
|
||||
|
||||
def init_cuda_vectorize():
|
||||
from numba.cuda.vectorizers import CUDAVectorize
|
||||
return CUDAVectorize
|
||||
|
||||
def init_cuda_guvectorize():
|
||||
from numba.cuda.vectorizers import CUDAGUFuncVectorize
|
||||
return CUDAGUFuncVectorize
|
||||
|
||||
Vectorize.target_registry.ondemand['cuda'] = init_cuda_vectorize
|
||||
GUVectorize.target_registry.ondemand['cuda'] = init_cuda_guvectorize
|
||||
|
||||
|
||||
_init()
|
||||
del _init
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,428 @@
|
||||
import ast
|
||||
from collections import defaultdict, OrderedDict
|
||||
import contextlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
from numba.core import types, targetconfig, ir, rewrites, compiler
|
||||
from numba.core.typing import npydecl
|
||||
from numba.np.ufunc.dufunc import DUFunc
|
||||
|
||||
|
||||
def _is_ufunc(func):
|
||||
return isinstance(func, (np.ufunc, DUFunc))
|
||||
|
||||
|
||||
@rewrites.register_rewrite('after-inference')
|
||||
class RewriteArrayExprs(rewrites.Rewrite):
|
||||
'''The RewriteArrayExprs class is responsible for finding array
|
||||
expressions in Numba intermediate representation code, and
|
||||
rewriting those expressions to a single operation that will expand
|
||||
into something similar to a ufunc call.
|
||||
'''
|
||||
def __init__(self, state, *args, **kws):
|
||||
super(RewriteArrayExprs, self).__init__(state, *args, **kws)
|
||||
# Install a lowering hook if we are using this rewrite.
|
||||
special_ops = state.targetctx.special_ops
|
||||
if 'arrayexpr' not in special_ops:
|
||||
special_ops['arrayexpr'] = _lower_array_expr
|
||||
|
||||
def match(self, func_ir, block, typemap, calltypes):
|
||||
"""
|
||||
Using typing and a basic block, search the basic block for array
|
||||
expressions.
|
||||
Return True when one or more matches were found, False otherwise.
|
||||
"""
|
||||
# We can trivially reject everything if there are no
|
||||
# calls in the type results.
|
||||
if len(calltypes) == 0:
|
||||
return False
|
||||
|
||||
self.crnt_block = block
|
||||
self.typemap = typemap
|
||||
# { variable name: IR assignment (of a function call or operator) }
|
||||
self.array_assigns = OrderedDict()
|
||||
# { variable name: IR assignment (of a constant) }
|
||||
self.const_assigns = {}
|
||||
|
||||
assignments = block.find_insts(ir.Assign)
|
||||
for instr in assignments:
|
||||
target_name = instr.target.name
|
||||
expr = instr.value
|
||||
# Does it assign an expression to an array variable?
|
||||
if (isinstance(expr, ir.Expr) and
|
||||
isinstance(typemap.get(target_name, None), types.Array)):
|
||||
self._match_array_expr(instr, expr, target_name)
|
||||
elif isinstance(expr, ir.Const):
|
||||
# Track constants since we might need them for an
|
||||
# array expression.
|
||||
self.const_assigns[target_name] = expr
|
||||
|
||||
return len(self.array_assigns) > 0
|
||||
|
||||
def _match_array_expr(self, instr, expr, target_name):
|
||||
"""
|
||||
Find whether the given assignment (*instr*) of an expression (*expr*)
|
||||
to variable *target_name* is an array expression.
|
||||
"""
|
||||
# We've matched a subexpression assignment to an
|
||||
# array variable. Now see if the expression is an
|
||||
# array expression.
|
||||
expr_op = expr.op
|
||||
array_assigns = self.array_assigns
|
||||
|
||||
if ((expr_op in ('unary', 'binop')) and (
|
||||
expr.fn in npydecl.supported_array_operators)):
|
||||
# It is an array operator that maps to a ufunc.
|
||||
# check that all args have internal types
|
||||
if all(self.typemap[var.name].is_internal
|
||||
for var in expr.list_vars()):
|
||||
array_assigns[target_name] = instr
|
||||
|
||||
elif ((expr_op == 'call') and (expr.func.name in self.typemap)):
|
||||
# It could be a match for a known ufunc call.
|
||||
func_type = self.typemap[expr.func.name]
|
||||
if isinstance(func_type, types.Function):
|
||||
func_key = func_type.typing_key
|
||||
if _is_ufunc(func_key):
|
||||
# If so, check whether an explicit output is passed.
|
||||
if not self._has_explicit_output(expr, func_key):
|
||||
# If not, match it as a (sub)expression.
|
||||
array_assigns[target_name] = instr
|
||||
|
||||
def _has_explicit_output(self, expr, func):
|
||||
"""
|
||||
Return whether the *expr* call to *func* (a ufunc) features an
|
||||
explicit output argument.
|
||||
"""
|
||||
nargs = len(expr.args) + len(expr.kws)
|
||||
if expr.vararg is not None:
|
||||
# XXX *args unsupported here, assume there may be an explicit
|
||||
# output
|
||||
return True
|
||||
return nargs > func.nin
|
||||
|
||||
def _get_array_operator(self, ir_expr):
|
||||
ir_op = ir_expr.op
|
||||
if ir_op in ('unary', 'binop'):
|
||||
return ir_expr.fn
|
||||
elif ir_op == 'call':
|
||||
return self.typemap[ir_expr.func.name].typing_key
|
||||
raise NotImplementedError(
|
||||
"Don't know how to find the operator for '{0}' expressions.".format(
|
||||
ir_op))
|
||||
|
||||
def _get_operands(self, ir_expr):
|
||||
'''Given a Numba IR expression, return the operands to the expression
|
||||
in order they appear in the expression.
|
||||
'''
|
||||
ir_op = ir_expr.op
|
||||
if ir_op == 'binop':
|
||||
return ir_expr.lhs, ir_expr.rhs
|
||||
elif ir_op == 'unary':
|
||||
return ir_expr.list_vars()
|
||||
elif ir_op == 'call':
|
||||
return ir_expr.args
|
||||
raise NotImplementedError(
|
||||
"Don't know how to find the operands for '{0}' expressions.".format(
|
||||
ir_op))
|
||||
|
||||
def _translate_expr(self, ir_expr):
|
||||
'''Translate the given expression from Numba IR to an array expression
|
||||
tree.
|
||||
'''
|
||||
ir_op = ir_expr.op
|
||||
if ir_op == 'arrayexpr':
|
||||
return ir_expr.expr
|
||||
operands_or_args = [self.const_assigns.get(op_var.name, op_var)
|
||||
for op_var in self._get_operands(ir_expr)]
|
||||
return self._get_array_operator(ir_expr), operands_or_args
|
||||
|
||||
def _handle_matches(self):
|
||||
'''Iterate over the matches, trying to find which instructions should
|
||||
be rewritten, deleted, or moved.
|
||||
'''
|
||||
replace_map = {}
|
||||
dead_vars = set()
|
||||
used_vars = defaultdict(int)
|
||||
for instr in self.array_assigns.values():
|
||||
expr = instr.value
|
||||
arr_inps = []
|
||||
arr_expr = self._get_array_operator(expr), arr_inps
|
||||
new_expr = ir.Expr(op='arrayexpr',
|
||||
loc=expr.loc,
|
||||
expr=arr_expr,
|
||||
ty=self.typemap[instr.target.name])
|
||||
new_instr = ir.Assign(new_expr, instr.target, instr.loc)
|
||||
replace_map[instr] = new_instr
|
||||
self.array_assigns[instr.target.name] = new_instr
|
||||
for operand in self._get_operands(expr):
|
||||
operand_name = operand.name
|
||||
if operand.is_temp and operand_name in self.array_assigns:
|
||||
child_assign = self.array_assigns[operand_name]
|
||||
child_expr = child_assign.value
|
||||
child_operands = child_expr.list_vars()
|
||||
for operand in child_operands:
|
||||
used_vars[operand.name] += 1
|
||||
arr_inps.append(self._translate_expr(child_expr))
|
||||
if child_assign.target.is_temp:
|
||||
dead_vars.add(child_assign.target.name)
|
||||
replace_map[child_assign] = None
|
||||
elif operand_name in self.const_assigns:
|
||||
arr_inps.append(self.const_assigns[operand_name])
|
||||
else:
|
||||
used_vars[operand.name] += 1
|
||||
arr_inps.append(operand)
|
||||
return replace_map, dead_vars, used_vars
|
||||
|
||||
def _get_final_replacement(self, replacement_map, instr):
|
||||
'''Find the final replacement instruction for a given initial
|
||||
instruction by chasing instructions in a map from instructions
|
||||
to replacement instructions.
|
||||
'''
|
||||
replacement = replacement_map[instr]
|
||||
while replacement in replacement_map:
|
||||
replacement = replacement_map[replacement]
|
||||
return replacement
|
||||
|
||||
def apply(self):
|
||||
'''When we've found array expressions in a basic block, rewrite that
|
||||
block, returning a new, transformed block.
|
||||
'''
|
||||
# Part 1: Figure out what instructions should be rewritten
|
||||
# based on the matches found.
|
||||
replace_map, dead_vars, used_vars = self._handle_matches()
|
||||
# Part 2: Using the information above, rewrite the target
|
||||
# basic block.
|
||||
result = self.crnt_block.copy()
|
||||
result.clear()
|
||||
delete_map = {}
|
||||
for instr in self.crnt_block.body:
|
||||
if isinstance(instr, ir.Assign):
|
||||
if instr in replace_map:
|
||||
replacement = self._get_final_replacement(
|
||||
replace_map, instr)
|
||||
if replacement:
|
||||
result.append(replacement)
|
||||
for var in replacement.value.list_vars():
|
||||
var_name = var.name
|
||||
if var_name in delete_map:
|
||||
result.append(delete_map.pop(var_name))
|
||||
if used_vars[var_name] > 0:
|
||||
used_vars[var_name] -= 1
|
||||
|
||||
else:
|
||||
result.append(instr)
|
||||
elif isinstance(instr, ir.Del):
|
||||
instr_value = instr.value
|
||||
if used_vars[instr_value] > 0:
|
||||
used_vars[instr_value] -= 1
|
||||
delete_map[instr_value] = instr
|
||||
elif instr_value not in dead_vars:
|
||||
result.append(instr)
|
||||
else:
|
||||
result.append(instr)
|
||||
if delete_map:
|
||||
for instr in delete_map.values():
|
||||
result.insert_before_terminator(instr)
|
||||
return result
|
||||
|
||||
|
||||
_unaryops = {
|
||||
operator.pos: ast.UAdd,
|
||||
operator.neg: ast.USub,
|
||||
operator.invert: ast.Invert,
|
||||
}
|
||||
|
||||
_binops = {
|
||||
operator.add: ast.Add,
|
||||
operator.sub: ast.Sub,
|
||||
operator.mul: ast.Mult,
|
||||
operator.truediv: ast.Div,
|
||||
operator.mod: ast.Mod,
|
||||
operator.or_: ast.BitOr,
|
||||
operator.rshift: ast.RShift,
|
||||
operator.xor: ast.BitXor,
|
||||
operator.lshift: ast.LShift,
|
||||
operator.and_: ast.BitAnd,
|
||||
operator.pow: ast.Pow,
|
||||
operator.floordiv: ast.FloorDiv,
|
||||
}
|
||||
|
||||
|
||||
_cmpops = {
|
||||
operator.eq: ast.Eq,
|
||||
operator.ne: ast.NotEq,
|
||||
operator.lt: ast.Lt,
|
||||
operator.le: ast.LtE,
|
||||
operator.gt: ast.Gt,
|
||||
operator.ge: ast.GtE,
|
||||
}
|
||||
|
||||
|
||||
def _arr_expr_to_ast(expr):
|
||||
'''Build a Python expression AST from an array expression built by
|
||||
RewriteArrayExprs.
|
||||
'''
|
||||
if isinstance(expr, tuple):
|
||||
op, arr_expr_args = expr
|
||||
ast_args = []
|
||||
env = {}
|
||||
for arg in arr_expr_args:
|
||||
ast_arg, child_env = _arr_expr_to_ast(arg)
|
||||
ast_args.append(ast_arg)
|
||||
env.update(child_env)
|
||||
if op in npydecl.supported_array_operators:
|
||||
if len(ast_args) == 2:
|
||||
if op in _binops:
|
||||
return ast.BinOp(
|
||||
ast_args[0], _binops[op](), ast_args[1]), env
|
||||
if op in _cmpops:
|
||||
return ast.Compare(
|
||||
ast_args[0], [_cmpops[op]()], [ast_args[1]]), env
|
||||
else:
|
||||
assert op in _unaryops
|
||||
return ast.UnaryOp(_unaryops[op](), ast_args[0]), env
|
||||
elif _is_ufunc(op):
|
||||
fn_name = "__ufunc_or_dufunc_{0}".format(
|
||||
hex(hash(op)).replace("-", "_"))
|
||||
fn_ast_name = ast.Name(fn_name, ast.Load())
|
||||
env[fn_name] = op # Stash the ufunc or DUFunc in the environment
|
||||
ast_call = ast.Call(fn_ast_name, ast_args, [])
|
||||
return ast_call, env
|
||||
elif isinstance(expr, ir.Var):
|
||||
return ast.Name(expr.name, ast.Load(),
|
||||
lineno=expr.loc.line,
|
||||
col_offset=expr.loc.col if expr.loc.col else 0), {}
|
||||
elif isinstance(expr, ir.Const):
|
||||
return ast.Constant(expr.value), {}
|
||||
raise NotImplementedError(
|
||||
"Don't know how to translate array expression '%r'" % (expr,))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _legalize_parameter_names(var_list):
|
||||
"""
|
||||
Legalize names in the variable list for use as a Python function's
|
||||
parameter names.
|
||||
"""
|
||||
var_map = OrderedDict()
|
||||
for var in var_list:
|
||||
old_name = var.name
|
||||
new_name = var.scope.redefine(old_name, loc=var.loc).name
|
||||
new_name = new_name.replace("$", "_").replace(".", "_")
|
||||
# Caller should ensure the names are unique
|
||||
if new_name in var_map:
|
||||
raise AssertionError(f"{new_name!r} not unique")
|
||||
var_map[new_name] = var, old_name
|
||||
var.name = new_name
|
||||
param_names = list(var_map)
|
||||
try:
|
||||
yield param_names
|
||||
finally:
|
||||
# Make sure the old names are restored, to avoid confusing
|
||||
# other parts of Numba (see issue #1466)
|
||||
for var, old_name in var_map.values():
|
||||
var.name = old_name
|
||||
|
||||
|
||||
class _EraseInvalidLineRanges(ast.NodeTransformer):
|
||||
def generic_visit(self, node: ast.AST) -> ast.AST:
|
||||
node = super().generic_visit(node)
|
||||
if hasattr(node, "lineno"):
|
||||
if getattr(node, "end_lineno", None) is not None:
|
||||
if node.lineno > node.end_lineno:
|
||||
del node.lineno
|
||||
del node.end_lineno
|
||||
return node
|
||||
|
||||
|
||||
def _fix_invalid_lineno_ranges(astree: ast.AST):
|
||||
"""Inplace fixes invalid lineno ranges.
|
||||
"""
|
||||
# Make sure lineno and end_lineno are present
|
||||
ast.fix_missing_locations(astree)
|
||||
# Delete invalid lineno ranges
|
||||
_EraseInvalidLineRanges().visit(astree)
|
||||
# Make sure lineno and end_lineno are present
|
||||
ast.fix_missing_locations(astree)
|
||||
|
||||
|
||||
def _lower_array_expr(lowerer, expr):
|
||||
'''Lower an array expression built by RewriteArrayExprs.
|
||||
'''
|
||||
expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
|
||||
expr_filename = expr.loc.filename
|
||||
expr_var_list = expr.list_vars()
|
||||
# The expression may use a given variable several times, but we
|
||||
# should only create one parameter for it.
|
||||
expr_var_unique = sorted(set(expr_var_list), key=lambda var: var.name)
|
||||
|
||||
# Arguments are the names external to the new closure
|
||||
expr_args = [var.name for var in expr_var_unique]
|
||||
|
||||
# 1. Create an AST tree from the array expression.
|
||||
with _legalize_parameter_names(expr_var_unique) as expr_params:
|
||||
ast_args = [ast.arg(param_name, None)
|
||||
for param_name in expr_params]
|
||||
# Parse a stub function to ensure the AST is populated with
|
||||
# reasonable defaults for the Python version.
|
||||
ast_module = ast.parse('def {0}(): return'.format(expr_name),
|
||||
expr_filename, 'exec')
|
||||
assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
|
||||
ast_fn = ast_module.body[0]
|
||||
ast_fn.args.args = ast_args
|
||||
ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
|
||||
_fix_invalid_lineno_ranges(ast_module)
|
||||
|
||||
# 2. Compile the AST module and extract the Python function.
|
||||
code_obj = compile(ast_module, expr_filename, 'exec')
|
||||
exec(code_obj, namespace)
|
||||
impl = namespace[expr_name]
|
||||
|
||||
# 3. Now compile a ufunc using the Python function as kernel.
|
||||
|
||||
context = lowerer.context
|
||||
builder = lowerer.builder
|
||||
outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_args))
|
||||
inner_sig_args = []
|
||||
for argty in outer_sig.args:
|
||||
if isinstance(argty, types.Optional):
|
||||
argty = argty.type
|
||||
if isinstance(argty, types.Array):
|
||||
inner_sig_args.append(argty.dtype)
|
||||
else:
|
||||
inner_sig_args.append(argty)
|
||||
inner_sig = outer_sig.return_type.dtype(*inner_sig_args)
|
||||
|
||||
flags = targetconfig.ConfigStack().top_or_none()
|
||||
flags = compiler.Flags() if flags is None else flags.copy() # make sure it's a clone or a fresh instance
|
||||
# Follow the Numpy error model. Note this also allows e.g. vectorizing
|
||||
# division (issue #1223).
|
||||
flags.error_model = 'numpy'
|
||||
cres = context.compile_subroutine(builder, impl, inner_sig, flags=flags,
|
||||
caching=False)
|
||||
|
||||
# Create kernel subclass calling our native function
|
||||
from numba.np import npyimpl
|
||||
|
||||
class ExprKernel(npyimpl._Kernel):
|
||||
def generate(self, *args):
|
||||
arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
|
||||
cast_args = [self.cast(val, inty, outty)
|
||||
for val, inty, outty in arg_zip]
|
||||
result = self.context.call_internal(
|
||||
builder, cres.fndesc, inner_sig, cast_args)
|
||||
return self.cast(result, inner_sig.return_type,
|
||||
self.outer_sig.return_type)
|
||||
|
||||
# create a fake ufunc object which is enough to trick numpy_ufunc_kernel
|
||||
ufunc = SimpleNamespace(nin=len(expr_args), nout=1, __name__=expr_name)
|
||||
ufunc.nargs = ufunc.nin + ufunc.nout
|
||||
|
||||
args = [lowerer.loadvar(name) for name in expr_args]
|
||||
return npyimpl.numpy_ufunc_kernel(
|
||||
context, builder, outer_sig, args, ufunc, ExprKernel)
|
||||
@@ -0,0 +1,208 @@
|
||||
import inspect
|
||||
|
||||
from numba.np.ufunc import _internal
|
||||
from numba.np.ufunc.parallel import ParallelUFuncBuilder, ParallelGUFuncBuilder
|
||||
|
||||
from numba.core.registry import DelayedRegistry
|
||||
from numba.np.ufunc import dufunc
|
||||
from numba.np.ufunc import gufunc
|
||||
|
||||
|
||||
class _BaseVectorize(object):
|
||||
|
||||
@classmethod
|
||||
def get_identity(cls, kwargs):
|
||||
return kwargs.pop('identity', None)
|
||||
|
||||
@classmethod
|
||||
def get_cache(cls, kwargs):
|
||||
return kwargs.pop('cache', False)
|
||||
|
||||
@classmethod
|
||||
def get_writable_args(cls, kwargs):
|
||||
return kwargs.pop('writable_args', ())
|
||||
|
||||
@classmethod
|
||||
def get_target_implementation(cls, kwargs):
|
||||
target = kwargs.pop('target', 'cpu')
|
||||
try:
|
||||
return cls.target_registry[target]
|
||||
except KeyError:
|
||||
raise ValueError("Unsupported target: %s" % target)
|
||||
|
||||
|
||||
class Vectorize(_BaseVectorize):
|
||||
target_registry = DelayedRegistry({'cpu': dufunc.DUFunc,
|
||||
'parallel': ParallelUFuncBuilder,})
|
||||
|
||||
def __new__(cls, func, **kws):
|
||||
identity = cls.get_identity(kws)
|
||||
cache = cls.get_cache(kws)
|
||||
imp = cls.get_target_implementation(kws)
|
||||
return imp(func, identity=identity, cache=cache, targetoptions=kws)
|
||||
|
||||
|
||||
class GUVectorize(_BaseVectorize):
|
||||
target_registry = DelayedRegistry({'cpu': gufunc.GUFunc,
|
||||
'parallel': ParallelGUFuncBuilder,})
|
||||
|
||||
def __new__(cls, func, signature, **kws):
|
||||
identity = cls.get_identity(kws)
|
||||
cache = cls.get_cache(kws)
|
||||
imp = cls.get_target_implementation(kws)
|
||||
writable_args = cls.get_writable_args(kws)
|
||||
if imp is gufunc.GUFunc:
|
||||
is_dyn = kws.pop('is_dynamic', False)
|
||||
return imp(func, signature, identity=identity, cache=cache,
|
||||
is_dynamic=is_dyn, targetoptions=kws,
|
||||
writable_args=writable_args)
|
||||
else:
|
||||
return imp(func, signature, identity=identity, cache=cache,
|
||||
targetoptions=kws, writable_args=writable_args)
|
||||
|
||||
|
||||
def vectorize(ftylist_or_function=(), **kws):
|
||||
"""vectorize(ftylist_or_function=(), target='cpu', identity=None, **kws)
|
||||
|
||||
A decorator that creates a NumPy ufunc object using Numba compiled
|
||||
code. When no arguments or only keyword arguments are given,
|
||||
vectorize will return a Numba dynamic ufunc (DUFunc) object, where
|
||||
compilation/specialization may occur at call-time.
|
||||
|
||||
Args
|
||||
-----
|
||||
ftylist_or_function: function or iterable
|
||||
|
||||
When the first argument is a function, signatures are dealt
|
||||
with at call-time.
|
||||
|
||||
When the first argument is an iterable of type signatures,
|
||||
which are either function type object or a string describing
|
||||
the function type, signatures are finalized at decoration
|
||||
time.
|
||||
|
||||
Keyword Args
|
||||
------------
|
||||
|
||||
target: str
|
||||
A string for code generation target. Default to "cpu".
|
||||
|
||||
identity: int, str, or None
|
||||
The identity (or unit) value for the element-wise function
|
||||
being implemented. Allowed values are None (the default), 0, 1,
|
||||
and "reorderable".
|
||||
|
||||
cache: bool
|
||||
Turns on caching.
|
||||
|
||||
|
||||
Returns
|
||||
--------
|
||||
|
||||
A NumPy universal function
|
||||
|
||||
Examples
|
||||
-------
|
||||
@vectorize(['float32(float32, float32)',
|
||||
'float64(float64, float64)'], identity=0)
|
||||
def sum(a, b):
|
||||
return a + b
|
||||
|
||||
@vectorize
|
||||
def sum(a, b):
|
||||
return a + b
|
||||
|
||||
@vectorize(identity=1)
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
|
||||
"""
|
||||
if isinstance(ftylist_or_function, str):
|
||||
# Common user mistake
|
||||
ftylist = [ftylist_or_function]
|
||||
elif inspect.isfunction(ftylist_or_function):
|
||||
return dufunc.DUFunc(ftylist_or_function, **kws)
|
||||
elif ftylist_or_function is not None:
|
||||
ftylist = ftylist_or_function
|
||||
|
||||
def wrap(func):
|
||||
vec = Vectorize(func, **kws)
|
||||
for sig in ftylist:
|
||||
vec.add(sig)
|
||||
if len(ftylist) > 0:
|
||||
vec.disable_compile()
|
||||
return vec.build_ufunc()
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def guvectorize(*args, **kwargs):
|
||||
"""guvectorize(ftylist, signature, target='cpu', identity=None, **kws)
|
||||
|
||||
A decorator to create NumPy generalized-ufunc object from Numba compiled
|
||||
code.
|
||||
|
||||
Args
|
||||
-----
|
||||
ftylist: iterable
|
||||
An iterable of type signatures, which are either
|
||||
function type object or a string describing the
|
||||
function type.
|
||||
|
||||
signature: str
|
||||
A NumPy generalized-ufunc signature.
|
||||
e.g. "(m, n), (n, p)->(m, p)"
|
||||
|
||||
identity: int, str, or None
|
||||
The identity (or unit) value for the element-wise function
|
||||
being implemented. Allowed values are None (the default), 0, 1,
|
||||
and "reorderable".
|
||||
|
||||
cache: bool
|
||||
Turns on caching.
|
||||
|
||||
writable_args: tuple
|
||||
a tuple of indices of input variables that are writable.
|
||||
|
||||
target: str
|
||||
A string for code generation target. Defaults to "cpu".
|
||||
|
||||
Returns
|
||||
--------
|
||||
|
||||
A NumPy generalized universal-function
|
||||
|
||||
Example
|
||||
-------
|
||||
@guvectorize(['void(int32[:,:], int32[:,:], int32[:,:])',
|
||||
'void(float32[:,:], float32[:,:], float32[:,:])'],
|
||||
'(x, y),(x, y)->(x, y)')
|
||||
def add_2d_array(a, b, c):
|
||||
for i in range(c.shape[0]):
|
||||
for j in range(c.shape[1]):
|
||||
c[i, j] = a[i, j] + b[i, j]
|
||||
|
||||
"""
|
||||
if len(args) == 1:
|
||||
ftylist = []
|
||||
signature = args[0]
|
||||
kwargs.setdefault('is_dynamic', True)
|
||||
elif len(args) == 2:
|
||||
ftylist = args[0]
|
||||
signature = args[1]
|
||||
else:
|
||||
raise TypeError('guvectorize() takes one or two positional arguments')
|
||||
|
||||
if isinstance(ftylist, str):
|
||||
# Common user mistake
|
||||
ftylist = [ftylist]
|
||||
|
||||
def wrap(func):
|
||||
guvec = GUVectorize(func, signature, **kwargs)
|
||||
for fty in ftylist:
|
||||
guvec.add(fty)
|
||||
if len(ftylist) > 0:
|
||||
guvec.disable_compile()
|
||||
return guvec.build_ufunc()
|
||||
|
||||
return wrap
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,325 @@
|
||||
from numba import typeof
|
||||
from numba.core import types
|
||||
from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
|
||||
from numba.np.ufunc.sigparse import parse_signature
|
||||
from numba.np.ufunc.ufunc_base import UfuncBase, UfuncLowererBase
|
||||
from numba.np.numpy_support import ufunc_find_matching_loop
|
||||
from numba.core import serialize, errors
|
||||
from numba.core.typing import npydecl
|
||||
from numba.core.typing.templates import signature, AbstractTemplate
|
||||
import functools
|
||||
|
||||
|
||||
def make_gufunc_kernel(_dufunc):
|
||||
from numba.np import npyimpl
|
||||
|
||||
class GUFuncKernel(npyimpl._Kernel):
|
||||
"""
|
||||
npyimpl._Kernel subclass responsible for lowering a gufunc kernel
|
||||
(element-wise function) inside a broadcast loop (which is
|
||||
generated by npyimpl.numpy_gufunc_kernel()).
|
||||
"""
|
||||
dufunc = _dufunc
|
||||
|
||||
def __init__(self, context, builder, outer_sig):
|
||||
super().__init__(context, builder, outer_sig)
|
||||
ewise_types = self.dufunc._get_ewise_dtypes(outer_sig.args)
|
||||
self.inner_sig, self.cres = self.dufunc.find_ewise_function(
|
||||
ewise_types)
|
||||
|
||||
def cast(self, val, fromty, toty):
|
||||
# Handle the case where "fromty" is an array and "toty" a scalar
|
||||
if isinstance(fromty, types.Array) and not \
|
||||
isinstance(toty, types.Array):
|
||||
return super().cast(val, fromty.dtype, toty)
|
||||
return super().cast(val, fromty, toty)
|
||||
|
||||
def generate(self, *args):
|
||||
if self.cres.objectmode:
|
||||
msg = ('Calling a guvectorize function in object mode is not '
|
||||
'supported yet.')
|
||||
raise errors.NumbaRuntimeError(msg)
|
||||
self.context.add_linking_libs((self.cres.library,))
|
||||
return super().generate(*args)
|
||||
|
||||
GUFuncKernel.__name__ += _dufunc.__name__
|
||||
return GUFuncKernel
|
||||
|
||||
|
||||
class GUFuncLowerer(UfuncLowererBase):
|
||||
'''Callable class responsible for lowering calls to a specific gufunc.
|
||||
'''
|
||||
def __init__(self, gufunc):
|
||||
from numba.np import npyimpl
|
||||
super().__init__(gufunc,
|
||||
make_gufunc_kernel,
|
||||
npyimpl.numpy_gufunc_kernel)
|
||||
|
||||
|
||||
class GUFunc(serialize.ReduceMixin, UfuncBase):
|
||||
"""
|
||||
Dynamic generalized universal function (GUFunc)
|
||||
intended to act like a normal Numpy gufunc, but capable
|
||||
of call-time (just-in-time) compilation of fast loops
|
||||
specialized to inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, py_func, signature, identity=None, cache=None,
|
||||
is_dynamic=False, targetoptions=None, writable_args=()):
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
self.ufunc = None
|
||||
self._frozen = False
|
||||
self._is_dynamic = is_dynamic
|
||||
self._identity = identity
|
||||
|
||||
# GUFunc cannot inherit from GUFuncBuilder because "identity"
|
||||
# is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder
|
||||
# object here
|
||||
self.gufunc_builder = GUFuncBuilder(
|
||||
py_func, signature, identity, cache, targetoptions, writable_args)
|
||||
|
||||
self.__name__ = self.gufunc_builder.py_func.__name__
|
||||
self.__doc__ = self.gufunc_builder.py_func.__doc__
|
||||
self._dispatcher = self.gufunc_builder.nb_func
|
||||
self._initialize(self._dispatcher)
|
||||
functools.update_wrapper(self, py_func)
|
||||
|
||||
def _initialize(self, dispatcher):
|
||||
self.build_ufunc()
|
||||
self._install_type()
|
||||
self._lower_me = GUFuncLowerer(self)
|
||||
self._install_cg()
|
||||
|
||||
def _reduce_states(self):
|
||||
gb = self.gufunc_builder
|
||||
dct = dict(
|
||||
py_func=gb.py_func,
|
||||
signature=gb.signature,
|
||||
identity=self._identity,
|
||||
cache=gb.cache,
|
||||
is_dynamic=self._is_dynamic,
|
||||
targetoptions=gb.targetoptions,
|
||||
writable_args=gb.writable_args,
|
||||
typesigs=gb._sigs,
|
||||
frozen=self._frozen,
|
||||
)
|
||||
return dct
|
||||
|
||||
@classmethod
|
||||
def _rebuild(cls, py_func, signature, identity, cache, is_dynamic,
|
||||
targetoptions, writable_args, typesigs, frozen):
|
||||
self = cls(py_func=py_func, signature=signature, identity=identity,
|
||||
cache=cache, is_dynamic=is_dynamic,
|
||||
targetoptions=targetoptions, writable_args=writable_args)
|
||||
for sig in typesigs:
|
||||
self.add(sig)
|
||||
self.build_ufunc()
|
||||
self._frozen = frozen
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return f"<numba._GUFunc '{self.__name__}'>"
|
||||
|
||||
def _install_type(self, typingctx=None):
|
||||
"""Constructs and installs a typing class for a gufunc object in the
|
||||
input typing context. If no typing context is given, then
|
||||
_install_type() installs into the typing context of the
|
||||
dispatcher object (should be same default context used by
|
||||
jit() and njit()).
|
||||
"""
|
||||
if typingctx is None:
|
||||
typingctx = self._dispatcher.targetdescr.typing_context
|
||||
_ty_cls = type('GUFuncTyping_' + self.__name__,
|
||||
(AbstractTemplate,),
|
||||
dict(key=self, generic=self._type_me))
|
||||
typingctx.insert_user_function(self, _ty_cls)
|
||||
|
||||
def add(self, fty):
|
||||
self.gufunc_builder.add(fty)
|
||||
|
||||
def build_ufunc(self):
|
||||
self.ufunc = self.gufunc_builder.build_ufunc()
|
||||
return self
|
||||
|
||||
def expected_ndims(self):
|
||||
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
||||
return (tuple(map(len, parsed_sig[0])), tuple(map(len, parsed_sig[1])))
|
||||
|
||||
def _type_me(self, argtys, kws):
|
||||
"""
|
||||
Implement AbstractTemplate.generic() for the typing class
|
||||
built by gufunc._install_type().
|
||||
|
||||
Return the call-site signature after either validating the
|
||||
element-wise signature or compiling for it.
|
||||
"""
|
||||
assert not kws
|
||||
ufunc = self.ufunc
|
||||
sig = self.gufunc_builder.signature
|
||||
inp_ndims, out_ndims = self.expected_ndims()
|
||||
ndims = inp_ndims + out_ndims
|
||||
|
||||
assert len(argtys), len(ndims)
|
||||
for idx, arg in enumerate(argtys):
|
||||
if isinstance(arg, types.Array) and arg.ndim < ndims[idx]:
|
||||
kind = "Input" if idx < len(inp_ndims) else "Output"
|
||||
i = idx if idx < len(inp_ndims) else idx - len(inp_ndims)
|
||||
msg = (
|
||||
f"{self.__name__}: {kind} operand {i} does not have "
|
||||
f"enough dimensions (has {arg.ndim}, gufunc core with "
|
||||
f"signature {sig} requires {ndims[idx]})")
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
_handle_inputs_result = npydecl.Numpy_rules_ufunc._handle_inputs(
|
||||
ufunc, argtys, kws)
|
||||
ewise_types, _, _, _ = _handle_inputs_result
|
||||
sig, _ = self.find_ewise_function(ewise_types)
|
||||
|
||||
if sig is None:
|
||||
# Matching element-wise signature was not found; must
|
||||
# compile.
|
||||
if self._frozen:
|
||||
msg = f"cannot call {self} with types {argtys}"
|
||||
raise errors.TypingError(msg)
|
||||
self._compile_for_argtys(ewise_types)
|
||||
# double check to ensure there is a match
|
||||
sig, _ = self.find_ewise_function(ewise_types)
|
||||
if sig == (None, None):
|
||||
msg = f"Fail to compile {self.__name__} with types {argtys}"
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
assert sig is not None
|
||||
|
||||
return signature(types.none, *argtys)
|
||||
|
||||
def _compile_for_argtys(self, argtys, return_type=None):
|
||||
# Compile a new guvectorize function! Use the gufunc signature
|
||||
# i.e. (n,m),(m)->(n)
|
||||
# plus ewise_types to build a numba function type
|
||||
fnty = self._get_function_type(*argtys)
|
||||
self.gufunc_builder.add(fnty)
|
||||
|
||||
def match_signature(self, ewise_types, sig):
|
||||
dtypes = self._get_ewise_dtypes(sig.args)
|
||||
return tuple(dtypes) == tuple(ewise_types)
|
||||
|
||||
@property
|
||||
def is_dynamic(self):
|
||||
return self._is_dynamic
|
||||
|
||||
def _get_ewise_dtypes(self, args):
|
||||
argtys = map(lambda arg: arg if isinstance(arg, types.Type) else
|
||||
typeof(arg), args)
|
||||
tys = []
|
||||
for argty in argtys:
|
||||
if isinstance(argty, types.Array):
|
||||
tys.append(argty.dtype)
|
||||
else:
|
||||
tys.append(argty)
|
||||
return tys
|
||||
|
||||
def _num_args_match(self, *args):
|
||||
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
||||
return len(args) == len(parsed_sig[0]) + len(parsed_sig[1])
|
||||
|
||||
def _get_function_type(self, *args):
|
||||
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
||||
# ewise_types is a list of [int32, int32, int32, ...]
|
||||
ewise_types = self._get_ewise_dtypes(args)
|
||||
|
||||
# first time calling the gufunc
|
||||
# generate a signature based on input arguments
|
||||
l = []
|
||||
for idx, sig_dim in enumerate(parsed_sig[0]):
|
||||
ndim = len(sig_dim)
|
||||
if ndim == 0: # append scalar
|
||||
l.append(ewise_types[idx])
|
||||
else:
|
||||
l.append(types.Array(ewise_types[idx], ndim, 'A'))
|
||||
|
||||
offset = len(parsed_sig[0])
|
||||
# add return type to signature
|
||||
for idx, sig_dim in enumerate(parsed_sig[1]):
|
||||
retty = ewise_types[idx + offset]
|
||||
ret_ndim = len(sig_dim) or 1 # small hack to return scalars
|
||||
l.append(types.Array(retty, ret_ndim, 'A'))
|
||||
|
||||
return types.none(*l)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# If compilation is disabled OR it is NOT a dynamic gufunc
|
||||
# call the underlying gufunc
|
||||
if self._frozen or not self.is_dynamic:
|
||||
# Do not unwrap the ufunc if the argument is a wrapper that will
|
||||
# potentially pickle the ufunc after it receives it in
|
||||
# __array_ufunc__. The same logic in theory should be replicated
|
||||
# for reduce(), outer(), etc., but they're not implemented in dask.
|
||||
if args and _is_array_wrapper(args[0]):
|
||||
return args[0].__array_ufunc__(
|
||||
self, "__call__", *args, **kwargs
|
||||
)
|
||||
else:
|
||||
return self.ufunc(*args, **kwargs)
|
||||
elif "out" in kwargs:
|
||||
# If "out" argument is supplied
|
||||
args += (kwargs.pop("out"),)
|
||||
|
||||
if self._num_args_match(*args) is False:
|
||||
# It is not allowed to call a dynamic gufunc without
|
||||
# providing all the arguments
|
||||
# see: https://github.com/numba/numba/pull/5938#discussion_r506429392 # noqa: E501
|
||||
msg = (
|
||||
f"Too few arguments for function '{self.__name__}'. "
|
||||
"Note that the pattern `out = gufunc(Arg1, Arg2, ..., ArgN)` "
|
||||
"is not allowed. Use `gufunc(Arg1, Arg2, ..., ArgN, out) "
|
||||
"instead.")
|
||||
raise TypeError(msg)
|
||||
|
||||
# at this point we know the gufunc is a dynamic one
|
||||
ewise = self._get_ewise_dtypes(args)
|
||||
if not (self.ufunc and ufunc_find_matching_loop(self.ufunc, ewise)):
|
||||
# A previous call (@njit -> @guvectorize) may have compiled a
|
||||
# version for the element-wise dtypes. In this case, we don't need
|
||||
# to compile it again, just build the (g)ufunc
|
||||
if not self.find_ewise_function(ewise) != (None, None):
|
||||
sig = self._get_function_type(*args)
|
||||
self.add(sig)
|
||||
self.build_ufunc()
|
||||
|
||||
return self.ufunc(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_array_wrapper(obj):
|
||||
"""Return True if obj wraps around numpy or another numpy-like library
|
||||
and is likely going to apply the ufunc to the wrapped array; False
|
||||
otherwise.
|
||||
|
||||
At the moment, this returns True for
|
||||
|
||||
- dask.array.Array
|
||||
- dask.dataframe.DataFrame
|
||||
- dask.dataframe.Series
|
||||
- xarray.DataArray
|
||||
- xarray.Dataset
|
||||
- xarray.Variable
|
||||
- pint.Quantity
|
||||
- other potential wrappers around dask array or dask dataframe
|
||||
|
||||
We may need to add other libraries that pickle ufuncs from their
|
||||
__array_ufunc__ method in the future.
|
||||
|
||||
Note that the below test is a lot more naive than
|
||||
`dask.base.is_dask_collection`
|
||||
(https://github.com/dask/dask/blob/5949e54bc04158d215814586a44d51e0eb4a964d/dask/base.py#L209-L249), # noqa: E501
|
||||
because it doesn't need to find out if we're actually dealing with
|
||||
a dask collection, only that we're dealing with a wrapper.
|
||||
Namely, it will return True for a pint.Quantity wrapping around a plain float, a
|
||||
numpy.ndarray, or a dask.array.Array, and it's OK because in all cases
|
||||
Quantity.__array_ufunc__ is going to forward the ufunc call inwards.
|
||||
"""
|
||||
return (
|
||||
not isinstance(obj, type)
|
||||
and hasattr(obj, "__dask_graph__")
|
||||
and hasattr(obj, "__array_ufunc__")
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,761 @@
|
||||
"""
|
||||
This file implements the code-generator for parallel-vectorize.
|
||||
|
||||
ParallelUFunc is the platform independent base class for generating
|
||||
the thread dispatcher. This thread dispatcher launches threads
|
||||
that execute the generated function of UFuncCore.
|
||||
UFuncCore is subclassed to specialize for the input/output types.
|
||||
The actual workload is invoked inside the function generated by UFuncCore.
|
||||
UFuncCore also defines a work-stealing mechanism that allows idle threads
|
||||
to steal works from other threads.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from threading import RLock as threadRLock
|
||||
from ctypes import CFUNCTYPE, c_int, CDLL, POINTER, c_uint
|
||||
|
||||
import numpy as np
|
||||
|
||||
import llvmlite.binding as ll
|
||||
from llvmlite import ir
|
||||
|
||||
from numba.np.numpy_support import as_dtype
|
||||
from numba.core import types, cgutils, config, errors
|
||||
from numba.core.typing import signature
|
||||
from numba.np.ufunc.wrappers import _wrapper_info
|
||||
from numba.np.ufunc import ufuncbuilder
|
||||
from numba.extending import overload, intrinsic
|
||||
|
||||
_IS_OSX = sys.platform.startswith('darwin')
|
||||
_IS_LINUX = sys.platform.startswith('linux')
|
||||
_IS_WINDOWS = sys.platform.startswith('win32')
|
||||
|
||||
|
||||
def get_thread_count():
|
||||
"""
|
||||
Gets the available thread count.
|
||||
"""
|
||||
t = config.NUMBA_NUM_THREADS
|
||||
if t < 1:
|
||||
raise ValueError("Number of threads specified must be > 0.")
|
||||
return t
|
||||
|
||||
|
||||
NUM_THREADS = get_thread_count()
|
||||
|
||||
|
||||
def build_gufunc_kernel(library, ctx, info, sig, inner_ndim):
|
||||
"""Wrap the original CPU ufunc/gufunc with a parallel dispatcher.
|
||||
This function will wrap gufuncs and ufuncs something like.
|
||||
|
||||
Args
|
||||
----
|
||||
ctx
|
||||
numba's codegen context
|
||||
|
||||
info: (library, env, name)
|
||||
inner function info
|
||||
|
||||
sig
|
||||
type signature of the gufunc
|
||||
|
||||
inner_ndim
|
||||
inner dimension of the gufunc (this is len(sig.args) in the case of a
|
||||
ufunc)
|
||||
|
||||
Returns
|
||||
-------
|
||||
wrapper_info : (library, env, name)
|
||||
The info for the gufunc wrapper.
|
||||
|
||||
Details
|
||||
-------
|
||||
|
||||
The kernel signature looks like this:
|
||||
|
||||
void kernel(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
|
||||
|
||||
args - the input arrays + output arrays
|
||||
dimensions - the dimensions of the arrays
|
||||
steps - the step size for the array (this is like sizeof(type))
|
||||
data - any additional data
|
||||
|
||||
The parallel backend then stages multiple calls to this kernel concurrently
|
||||
across a number of threads. Practically, for each item of work, the backend
|
||||
duplicates `dimensions` and adjusts the first entry to reflect the size of
|
||||
the item of work, it also forms up an array of pointers into the args for
|
||||
offsets to read/write from/to with respect to its position in the items of
|
||||
work. This allows the same kernel to be used for each item of work, with
|
||||
simply adjusted reads/writes/domain sizes and is safe by virtue of the
|
||||
domain partitioning.
|
||||
|
||||
NOTE: The execution backend is passed the requested thread count, but it can
|
||||
choose to ignore it (TBB)!
|
||||
"""
|
||||
assert isinstance(info, tuple) # guard against old usage
|
||||
# Declare types and function
|
||||
byte_t = ir.IntType(8)
|
||||
byte_ptr_t = ir.PointerType(byte_t)
|
||||
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
|
||||
|
||||
intp_t = ctx.get_value_type(types.intp)
|
||||
intp_ptr_t = ir.PointerType(intp_t)
|
||||
|
||||
fnty = ir.FunctionType(ir.VoidType(), [ir.PointerType(byte_ptr_t),
|
||||
ir.PointerType(intp_t),
|
||||
ir.PointerType(intp_t),
|
||||
byte_ptr_t])
|
||||
wrapperlib = ctx.codegen().create_library('parallelgufuncwrapper')
|
||||
mod = wrapperlib.create_ir_module('parallel.gufunc.wrapper')
|
||||
kernel_name = ".kernel.{}_{}".format(id(info.env), info.name)
|
||||
lfunc = ir.Function(mod, fnty, name=kernel_name)
|
||||
|
||||
bb_entry = lfunc.append_basic_block('')
|
||||
|
||||
# Function body starts
|
||||
builder = ir.IRBuilder(bb_entry)
|
||||
|
||||
args, dimensions, steps, data = lfunc.args
|
||||
|
||||
# Release the GIL (and ensure we have the GIL)
|
||||
# Note: numpy ufunc may not always release the GIL; thus,
|
||||
# we need to ensure we have the GIL.
|
||||
pyapi = ctx.get_python_api(builder)
|
||||
gil_state = pyapi.gil_ensure()
|
||||
thread_state = pyapi.save_thread()
|
||||
|
||||
def as_void_ptr(arg):
|
||||
return builder.bitcast(arg, byte_ptr_t)
|
||||
|
||||
# Array count depends on whether an "output" array is needed. In the case
|
||||
# of a void return type cf. gufunc it is the number of args, in the case of
|
||||
# a non-void return type cf. ufunc it is the number of args + 1 so as to
|
||||
# account for the output array.
|
||||
array_count = len(sig.args)
|
||||
if not isinstance(sig.return_type, types.NoneType):
|
||||
array_count += 1
|
||||
|
||||
parallel_for_ty = ir.FunctionType(ir.VoidType(),
|
||||
[byte_ptr_t] * 5 + [intp_t, ] * 3)
|
||||
parallel_for = cgutils.get_or_insert_function(mod, parallel_for_ty,
|
||||
'numba_parallel_for')
|
||||
|
||||
# Reference inner-function and link
|
||||
innerfunc_fnty = ir.FunctionType(
|
||||
ir.VoidType(),
|
||||
[byte_ptr_ptr_t, intp_ptr_t, intp_ptr_t, byte_ptr_t],
|
||||
)
|
||||
tmp_voidptr = cgutils.get_or_insert_function(mod, innerfunc_fnty,
|
||||
info.name,)
|
||||
wrapperlib.add_linking_library(info.library)
|
||||
|
||||
get_num_threads = cgutils.get_or_insert_function(
|
||||
builder.module,
|
||||
ir.FunctionType(ir.IntType(types.intp.bitwidth), []),
|
||||
"get_num_threads")
|
||||
|
||||
num_threads = builder.call(get_num_threads, [])
|
||||
|
||||
# Prepare call
|
||||
fnptr = builder.bitcast(tmp_voidptr, byte_ptr_t)
|
||||
innerargs = [as_void_ptr(x) for x
|
||||
in [args, dimensions, steps, data]]
|
||||
builder.call(parallel_for, [fnptr] + innerargs +
|
||||
[intp_t(x) for x in (inner_ndim, array_count)] + [num_threads])
|
||||
|
||||
# Release the GIL
|
||||
pyapi.restore_thread(thread_state)
|
||||
pyapi.gil_release(gil_state)
|
||||
|
||||
builder.ret_void()
|
||||
|
||||
wrapperlib.add_ir_module(mod)
|
||||
wrapperlib.add_linking_library(library)
|
||||
return _wrapper_info(library=wrapperlib, name=lfunc.name, env=info.env)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
class ParallelUFuncBuilder(ufuncbuilder.UFuncBuilder):
|
||||
def build(self, cres, sig):
|
||||
_launch_threads()
|
||||
|
||||
# Builder wrapper for ufunc entry point
|
||||
ctx = cres.target_context
|
||||
signature = cres.signature
|
||||
library = cres.library
|
||||
fname = cres.fndesc.llvm_func_name
|
||||
|
||||
info = build_ufunc_wrapper(library, ctx, fname, signature, cres)
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
# Get dtypes
|
||||
dtypenums = [np.dtype(a.name).num for a in signature.args]
|
||||
dtypenums.append(np.dtype(signature.return_type.name).num)
|
||||
keepalive = ()
|
||||
return dtypenums, ptr, keepalive
|
||||
|
||||
|
||||
def build_ufunc_wrapper(library, ctx, fname, signature, cres):
|
||||
innerfunc = ufuncbuilder.build_ufunc_wrapper(library, ctx, fname,
|
||||
signature, objmode=False,
|
||||
cres=cres)
|
||||
info = build_gufunc_kernel(library, ctx, innerfunc, signature,
|
||||
len(signature.args))
|
||||
return info
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ParallelGUFuncBuilder(ufuncbuilder.GUFuncBuilder):
|
||||
def __init__(self, py_func, signature, identity=None, cache=False,
|
||||
targetoptions=None, writable_args=()):
|
||||
# Force nopython mode
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
targetoptions.update(dict(nopython=True))
|
||||
super(
|
||||
ParallelGUFuncBuilder,
|
||||
self).__init__(
|
||||
py_func=py_func,
|
||||
signature=signature,
|
||||
identity=identity,
|
||||
cache=cache,
|
||||
targetoptions=targetoptions,
|
||||
writable_args=writable_args)
|
||||
|
||||
def build(self, cres):
|
||||
"""
|
||||
Returns (dtype numbers, function ptr, EnvironmentObject)
|
||||
"""
|
||||
_launch_threads()
|
||||
|
||||
# Build wrapper for ufunc entry point
|
||||
info = build_gufunc_wrapper(
|
||||
self.py_func, cres, self.sin, self.sout, cache=self.cache,
|
||||
is_parfors=False,
|
||||
)
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
env = info.env
|
||||
|
||||
# Get dtypes
|
||||
dtypenums = []
|
||||
for a in cres.signature.args:
|
||||
if isinstance(a, types.Array):
|
||||
ty = a.dtype
|
||||
else:
|
||||
ty = a
|
||||
dtypenums.append(as_dtype(ty).num)
|
||||
|
||||
return dtypenums, ptr, env
|
||||
|
||||
|
||||
# This is not a member of the ParallelGUFuncBuilder function because it is
|
||||
# called without an enclosing instance from parfors
|
||||
|
||||
def build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors):
|
||||
"""Build gufunc wrapper for the given arguments.
|
||||
The *is_parfors* is a boolean indicating whether the gufunc is being
|
||||
built for use as a ParFors kernel. This changes codegen and caching
|
||||
behavior.
|
||||
"""
|
||||
library = cres.library
|
||||
ctx = cres.target_context
|
||||
signature = cres.signature
|
||||
innerinfo = ufuncbuilder.build_gufunc_wrapper(
|
||||
py_func, cres, sin, sout, cache=cache, is_parfors=is_parfors,
|
||||
)
|
||||
sym_in = set(sym for term in sin for sym in term)
|
||||
sym_out = set(sym for term in sout for sym in term)
|
||||
inner_ndim = len(sym_in | sym_out)
|
||||
|
||||
info = build_gufunc_kernel(
|
||||
library, ctx, innerinfo, signature, inner_ndim,
|
||||
)
|
||||
return info
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_backend_init_thread_lock = threadRLock()
|
||||
|
||||
_windows = sys.platform.startswith('win32')
|
||||
|
||||
|
||||
class _nop(object):
|
||||
"""A no-op contextmanager
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_backend_init_process_lock = None
|
||||
|
||||
|
||||
def _set_init_process_lock():
|
||||
global _backend_init_process_lock
|
||||
try:
|
||||
# Force the use of an RLock in the case a fork was used to start the
|
||||
# process and thereby the init sequence, some of the threading backend
|
||||
# init sequences are not fork safe. Also, windows global mp locks seem
|
||||
# to be fine.
|
||||
with _backend_init_thread_lock: # protect part-initialized module access
|
||||
import multiprocessing
|
||||
if "fork" in multiprocessing.get_start_method() or _windows:
|
||||
ctx = multiprocessing.get_context()
|
||||
_backend_init_process_lock = ctx.RLock()
|
||||
else:
|
||||
_backend_init_process_lock = _nop()
|
||||
|
||||
except OSError as e:
|
||||
|
||||
# probably lack of /dev/shm for semaphore writes, warn the user
|
||||
msg = (
|
||||
"Could not obtain multiprocessing lock due to OS level error: %s\n"
|
||||
"A likely cause of this problem is '/dev/shm' is missing or "
|
||||
"read-only such that necessary semaphores cannot be written.\n"
|
||||
"*** The responsibility of ensuring multiprocessing safe access to "
|
||||
"this initialization sequence/module import is deferred to the "
|
||||
"user! ***\n"
|
||||
)
|
||||
warnings.warn(msg % str(e), errors.NumbaSystemWarning)
|
||||
|
||||
_backend_init_process_lock = _nop()
|
||||
|
||||
|
||||
_is_initialized = False
|
||||
|
||||
# this is set by _launch_threads
|
||||
_threading_layer = None
|
||||
|
||||
|
||||
def threading_layer():
|
||||
"""
|
||||
Get the name of the threading layer in use for parallel CPU targets
|
||||
"""
|
||||
if _threading_layer is None:
|
||||
raise ValueError("Threading layer is not initialized.")
|
||||
else:
|
||||
return _threading_layer
|
||||
|
||||
|
||||
def _check_tbb_version_compatible():
|
||||
"""
|
||||
Checks that if TBB is present it is of a compatible version.
|
||||
"""
|
||||
try:
|
||||
# first check that the TBB version is new enough
|
||||
if _IS_WINDOWS:
|
||||
libtbb_name = 'tbb12.dll'
|
||||
elif _IS_OSX:
|
||||
libtbb_name = 'libtbb.12.dylib'
|
||||
elif _IS_LINUX:
|
||||
libtbb_name = 'libtbb.so.12'
|
||||
else:
|
||||
raise ValueError("Unknown operating system")
|
||||
libtbb = CDLL(libtbb_name)
|
||||
version_func = libtbb.TBB_runtime_interface_version
|
||||
version_func.argtypes = []
|
||||
version_func.restype = c_int
|
||||
tbb_iface_ver = version_func()
|
||||
if tbb_iface_ver < 12060: # magic number from TBB
|
||||
msg = ("The TBB threading layer requires TBB "
|
||||
"version 2021 update 6 or later i.e., "
|
||||
"TBB_INTERFACE_VERSION >= 12060. Found "
|
||||
"TBB_INTERFACE_VERSION = %s. The TBB "
|
||||
"threading layer is disabled.") % tbb_iface_ver
|
||||
problem = errors.NumbaWarning(msg)
|
||||
warnings.warn(problem)
|
||||
raise ImportError("Problem with TBB. Reason: %s" % msg)
|
||||
except (ValueError, OSError) as e:
|
||||
# Translate as an ImportError for consistent error class use, this error
|
||||
# will never materialise
|
||||
raise ImportError("Problem with TBB. Reason: %s" % e)
|
||||
|
||||
|
||||
def _launch_threads():
|
||||
if not _backend_init_process_lock:
|
||||
_set_init_process_lock()
|
||||
|
||||
with _backend_init_process_lock:
|
||||
with _backend_init_thread_lock:
|
||||
global _is_initialized
|
||||
if _is_initialized:
|
||||
return
|
||||
|
||||
def select_known_backend(backend):
|
||||
"""
|
||||
Loads a specific threading layer backend based on string
|
||||
"""
|
||||
lib = None
|
||||
if backend.startswith("tbb"):
|
||||
try:
|
||||
# check if TBB is present and compatible
|
||||
_check_tbb_version_compatible()
|
||||
# now try and load the backend
|
||||
from numba.np.ufunc import tbbpool as lib
|
||||
except ImportError:
|
||||
pass
|
||||
elif backend.startswith("omp"):
|
||||
# TODO: Check that if MKL is present that it is a version
|
||||
# that understands GNU OMP might be present
|
||||
try:
|
||||
from numba.np.ufunc import omppool as lib
|
||||
except ImportError:
|
||||
pass
|
||||
elif backend.startswith("workqueue"):
|
||||
from numba.np.ufunc import workqueue as lib
|
||||
else:
|
||||
msg = "Unknown value specified for threading layer: %s"
|
||||
raise ValueError(msg % backend)
|
||||
return lib
|
||||
|
||||
def select_from_backends(backends):
|
||||
"""
|
||||
Selects from presented backends and returns the first working
|
||||
"""
|
||||
lib = None
|
||||
for backend in backends:
|
||||
lib = select_known_backend(backend)
|
||||
if lib is not None:
|
||||
break
|
||||
else:
|
||||
backend = ''
|
||||
return lib, backend
|
||||
|
||||
t = str(config.THREADING_LAYER).lower()
|
||||
namedbackends = config.THREADING_LAYER_PRIORITY
|
||||
if not (len(namedbackends) == 3 and
|
||||
set(namedbackends) == {'tbb', 'omp', 'workqueue'}):
|
||||
raise ValueError(
|
||||
"THREADING_LAYER_PRIORITY invalid: %s. "
|
||||
"It must be a permutation of "
|
||||
"{'tbb', 'omp', 'workqueue'}"
|
||||
% namedbackends
|
||||
)
|
||||
|
||||
lib = None
|
||||
err_helpers = dict()
|
||||
err_helpers['TBB'] = ("Intel TBB is required, try:\n"
|
||||
"$ conda/pip install tbb")
|
||||
err_helpers['OSX_OMP'] = ("Intel OpenMP is required, try:\n"
|
||||
"$ conda/pip install intel-openmp")
|
||||
requirements = []
|
||||
|
||||
def raise_with_hint(required):
|
||||
errmsg = "No threading layer could be loaded.\n%s"
|
||||
hintmsg = "HINT:\n%s"
|
||||
if len(required) == 0:
|
||||
hint = ''
|
||||
if len(required) == 1:
|
||||
hint = hintmsg % err_helpers[required[0]]
|
||||
if len(required) > 1:
|
||||
options = '\nOR\n'.join([err_helpers[x] for x in required])
|
||||
hint = hintmsg % ("One of:\n%s" % options)
|
||||
raise ValueError(errmsg % hint)
|
||||
|
||||
if t in namedbackends:
|
||||
# Try and load the specific named backend
|
||||
lib = select_known_backend(t)
|
||||
if not lib:
|
||||
# something is missing preventing a valid backend from
|
||||
# loading, set requirements for hinting
|
||||
if t == 'tbb':
|
||||
requirements.append('TBB')
|
||||
elif t == 'omp' and _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
libname = t
|
||||
elif t in ['threadsafe', 'forksafe', 'safe']:
|
||||
# User wants a specific behaviour...
|
||||
available = ['tbb']
|
||||
requirements.append('TBB')
|
||||
if t == "safe":
|
||||
# "safe" is TBB, which is fork and threadsafe everywhere
|
||||
pass
|
||||
elif t == "threadsafe":
|
||||
if _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
# omp is threadsafe everywhere
|
||||
available.append('omp')
|
||||
elif t == "forksafe":
|
||||
# everywhere apart from linux (GNU OpenMP) has a guaranteed
|
||||
# forksafe OpenMP, as OpenMP has better performance, prefer
|
||||
# this to workqueue
|
||||
if not _IS_LINUX:
|
||||
available.append('omp')
|
||||
if _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
# workqueue is forksafe everywhere
|
||||
available.append('workqueue')
|
||||
else: # unreachable
|
||||
msg = "No threading layer available for purpose %s"
|
||||
raise ValueError(msg % t)
|
||||
# select amongst available
|
||||
lib, libname = select_from_backends(available)
|
||||
elif t == 'default':
|
||||
# If default is supplied, try them in order, tbb, omp,
|
||||
# workqueue
|
||||
lib, libname = select_from_backends(namedbackends)
|
||||
if not lib:
|
||||
# set requirements for hinting
|
||||
requirements.append('TBB')
|
||||
if _IS_OSX:
|
||||
requirements.append('OSX_OMP')
|
||||
else:
|
||||
msg = "The threading layer requested '%s' is unknown to Numba."
|
||||
raise ValueError(msg % t)
|
||||
|
||||
# No lib found, raise and hint
|
||||
if not lib:
|
||||
raise_with_hint(requirements)
|
||||
|
||||
ll.add_symbol('numba_parallel_for', lib.parallel_for)
|
||||
ll.add_symbol('do_scheduling_signed', lib.do_scheduling_signed)
|
||||
ll.add_symbol('do_scheduling_unsigned', lib.do_scheduling_unsigned)
|
||||
ll.add_symbol('allocate_sched', lib.allocate_sched)
|
||||
ll.add_symbol('deallocate_sched', lib.deallocate_sched)
|
||||
|
||||
launch_threads = CFUNCTYPE(None, c_int)(lib.launch_threads)
|
||||
launch_threads(NUM_THREADS)
|
||||
|
||||
_load_threading_functions(lib) # load late
|
||||
|
||||
# set library name so it can be queried
|
||||
global _threading_layer
|
||||
_threading_layer = libname
|
||||
_is_initialized = True
|
||||
|
||||
|
||||
def _load_threading_functions(lib):
|
||||
|
||||
ll.add_symbol('get_num_threads', lib.get_num_threads)
|
||||
ll.add_symbol('set_num_threads', lib.set_num_threads)
|
||||
ll.add_symbol('get_thread_id', lib.get_thread_id)
|
||||
|
||||
global _set_num_threads
|
||||
_set_num_threads = CFUNCTYPE(None, c_int)(lib.set_num_threads)
|
||||
_set_num_threads(NUM_THREADS)
|
||||
|
||||
global _get_num_threads
|
||||
_get_num_threads = CFUNCTYPE(c_int)(lib.get_num_threads)
|
||||
|
||||
global _get_thread_id
|
||||
_get_thread_id = CFUNCTYPE(c_int)(lib.get_thread_id)
|
||||
|
||||
ll.add_symbol('set_parallel_chunksize', lib.set_parallel_chunksize)
|
||||
ll.add_symbol('get_parallel_chunksize', lib.get_parallel_chunksize)
|
||||
ll.add_symbol('get_sched_size', lib.get_sched_size)
|
||||
global _set_parallel_chunksize
|
||||
_set_parallel_chunksize = CFUNCTYPE(c_uint,
|
||||
c_uint)(lib.set_parallel_chunksize)
|
||||
global _get_parallel_chunksize
|
||||
_get_parallel_chunksize = CFUNCTYPE(c_uint)(lib.get_parallel_chunksize)
|
||||
global _get_sched_size
|
||||
_get_sched_size = CFUNCTYPE(c_uint,
|
||||
c_uint,
|
||||
c_uint,
|
||||
POINTER(c_int),
|
||||
POINTER(c_int))(lib.get_sched_size)
|
||||
|
||||
|
||||
# Some helpers to make set_num_threads jittable
|
||||
|
||||
def gen_snt_check():
|
||||
from numba.core.config import NUMBA_NUM_THREADS
|
||||
msg = "The number of threads must be between 1 and %s" % NUMBA_NUM_THREADS
|
||||
|
||||
def snt_check(n):
|
||||
if n > NUMBA_NUM_THREADS or n < 1:
|
||||
raise ValueError(msg)
|
||||
return snt_check
|
||||
|
||||
|
||||
snt_check = gen_snt_check()
|
||||
|
||||
|
||||
@overload(snt_check)
|
||||
def ol_snt_check(n):
|
||||
return snt_check
|
||||
|
||||
|
||||
def set_num_threads(n):
|
||||
"""
|
||||
Set the number of threads to use for parallel execution.
|
||||
|
||||
By default, all :obj:`numba.config.NUMBA_NUM_THREADS` threads are used.
|
||||
|
||||
This functionality works by masking out threads that are not used.
|
||||
Therefore, the number of threads *n* must be less than or equal to
|
||||
:obj:`~.NUMBA_NUM_THREADS`, the total number of threads that are launched.
|
||||
See its documentation for more details.
|
||||
|
||||
This function can be used inside of a jitted function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n: The number of threads. Must be between 1 and NUMBA_NUM_THREADS.
|
||||
|
||||
See Also
|
||||
--------
|
||||
get_num_threads, numba.config.NUMBA_NUM_THREADS,
|
||||
numba.config.NUMBA_DEFAULT_NUM_THREADS, :envvar:`NUMBA_NUM_THREADS`
|
||||
|
||||
"""
|
||||
_launch_threads()
|
||||
if not isinstance(n, (int, np.integer)):
|
||||
raise TypeError("The number of threads specified must be an integer")
|
||||
snt_check(n)
|
||||
_set_num_threads(n)
|
||||
|
||||
|
||||
@overload(set_num_threads)
|
||||
def ol_set_num_threads(n):
|
||||
_launch_threads()
|
||||
if not isinstance(n, types.Integer):
|
||||
msg = "The number of threads specified must be an integer"
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(n):
|
||||
snt_check(n)
|
||||
_set_num_threads(n)
|
||||
return impl
|
||||
|
||||
|
||||
def get_num_threads():
|
||||
"""
|
||||
Get the number of threads used for parallel execution.
|
||||
|
||||
By default (if :func:`~.set_num_threads` is never called), all
|
||||
:obj:`numba.config.NUMBA_NUM_THREADS` threads are used.
|
||||
|
||||
This number is less than or equal to the total number of threads that are
|
||||
launched, :obj:`numba.config.NUMBA_NUM_THREADS`.
|
||||
|
||||
This function can be used inside of a jitted function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The number of threads.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_num_threads, numba.config.NUMBA_NUM_THREADS,
|
||||
numba.config.NUMBA_DEFAULT_NUM_THREADS, :envvar:`NUMBA_NUM_THREADS`
|
||||
|
||||
"""
|
||||
_launch_threads()
|
||||
num_threads = _get_num_threads()
|
||||
if num_threads <= 0:
|
||||
raise RuntimeError("Invalid number of threads. "
|
||||
"This likely indicates a bug in Numba. "
|
||||
"(thread_id=%s, num_threads=%s)" %
|
||||
(get_thread_id(), num_threads))
|
||||
return num_threads
|
||||
|
||||
|
||||
@overload(get_num_threads)
|
||||
def ol_get_num_threads():
|
||||
_launch_threads()
|
||||
|
||||
def impl():
|
||||
num_threads = _get_num_threads()
|
||||
if num_threads <= 0:
|
||||
print("Broken thread_id: ", get_thread_id())
|
||||
print("num_threads: ", num_threads)
|
||||
raise RuntimeError("Invalid number of threads. "
|
||||
"This likely indicates a bug in Numba.")
|
||||
return num_threads
|
||||
return impl
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _iget_num_threads(typingctx):
|
||||
_launch_threads()
|
||||
|
||||
def codegen(context, builder, signature, args):
|
||||
mod = builder.module
|
||||
fnty = ir.FunctionType(cgutils.intp_t, [])
|
||||
fn = cgutils.get_or_insert_function(mod, fnty, "get_num_threads")
|
||||
return builder.call(fn, [])
|
||||
return signature(types.intp), codegen
|
||||
|
||||
|
||||
def get_thread_id():
|
||||
"""
|
||||
Returns a unique ID for each thread in the range 0 (inclusive)
|
||||
to :func:`~.get_num_threads` (exclusive).
|
||||
"""
|
||||
# Called from the interpreter directly, this should return 0
|
||||
# Called from a sequential JIT region, this should return 0
|
||||
# Called from a parallel JIT region, this should return 0..N
|
||||
# Called from objmode in a parallel JIT region, this should return 0..N
|
||||
_launch_threads()
|
||||
return _get_thread_id()
|
||||
|
||||
|
||||
@overload(get_thread_id)
|
||||
def ol_get_thread_id():
|
||||
_launch_threads()
|
||||
|
||||
def impl():
|
||||
return _iget_thread_id()
|
||||
return impl
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _iget_thread_id(typingctx):
|
||||
def codegen(context, builder, signature, args):
|
||||
mod = builder.module
|
||||
fnty = ir.FunctionType(cgutils.intp_t, [])
|
||||
fn = cgutils.get_or_insert_function(mod, fnty, "get_thread_id")
|
||||
return builder.call(fn, [])
|
||||
return signature(types.intp), codegen
|
||||
|
||||
|
||||
_DYLD_WORKAROUND_SET = 'NUMBA_DYLD_WORKAROUND' in os.environ
|
||||
_DYLD_WORKAROUND_VAL = int(os.environ.get('NUMBA_DYLD_WORKAROUND', 0))
|
||||
|
||||
if _DYLD_WORKAROUND_SET and _DYLD_WORKAROUND_VAL:
|
||||
_launch_threads()
|
||||
|
||||
|
||||
def set_parallel_chunksize(n):
|
||||
_launch_threads()
|
||||
if not isinstance(n, (int, np.integer)):
|
||||
raise TypeError("The parallel chunksize must be an integer")
|
||||
if n < 0:
|
||||
raise ValueError("chunksize must be greater than or equal to zero")
|
||||
return _set_parallel_chunksize(n)
|
||||
|
||||
|
||||
def get_parallel_chunksize():
|
||||
_launch_threads()
|
||||
return _get_parallel_chunksize()
|
||||
|
||||
|
||||
@overload(set_parallel_chunksize)
|
||||
def ol_set_parallel_chunksize(n):
|
||||
_launch_threads()
|
||||
if not isinstance(n, types.Integer):
|
||||
msg = "The parallel chunksize must be an integer"
|
||||
raise errors.TypingError(msg)
|
||||
|
||||
def impl(n):
|
||||
if n < 0:
|
||||
raise ValueError("chunksize must be greater than or equal to zero")
|
||||
return _set_parallel_chunksize(n)
|
||||
return impl
|
||||
|
||||
|
||||
@overload(get_parallel_chunksize)
|
||||
def ol_get_parallel_chunksize():
|
||||
_launch_threads()
|
||||
|
||||
def impl():
|
||||
return _get_parallel_chunksize()
|
||||
return impl
|
||||
@@ -0,0 +1,63 @@
|
||||
import tokenize
|
||||
import string
|
||||
|
||||
|
||||
def parse_signature(sig):
|
||||
'''Parse generalized ufunc signature.
|
||||
|
||||
NOTE: ',' (COMMA) is a delimiter; not separator.
|
||||
This means trailing comma is legal.
|
||||
'''
|
||||
def stripws(s):
|
||||
return ''.join(c for c in s if c not in string.whitespace)
|
||||
|
||||
def tokenizer(src):
|
||||
def readline():
|
||||
yield src
|
||||
gen = readline()
|
||||
return tokenize.generate_tokens(lambda: next(gen))
|
||||
|
||||
def parse(src):
|
||||
tokgen = tokenizer(src)
|
||||
while True:
|
||||
tok = next(tokgen)
|
||||
if tok[1] == '(':
|
||||
symbols = []
|
||||
while True:
|
||||
tok = next(tokgen)
|
||||
if tok[1] == ')':
|
||||
break
|
||||
elif tok[0] == tokenize.NAME:
|
||||
symbols.append(tok[1])
|
||||
elif tok[1] == ',':
|
||||
continue
|
||||
else:
|
||||
raise ValueError('bad token in signature "%s"' % tok[1])
|
||||
yield tuple(symbols)
|
||||
tok = next(tokgen)
|
||||
if tok[1] == ',':
|
||||
continue
|
||||
elif tokenize.ISEOF(tok[0]):
|
||||
break
|
||||
elif tokenize.ISEOF(tok[0]):
|
||||
break
|
||||
else:
|
||||
raise ValueError('bad token in signature "%s"' % tok[1])
|
||||
|
||||
ins, _, outs = stripws(sig).partition('->')
|
||||
inputs = list(parse(ins))
|
||||
outputs = list(parse(outs))
|
||||
|
||||
# check that all output symbols are defined in the inputs
|
||||
isym = set()
|
||||
osym = set()
|
||||
for grp in inputs:
|
||||
isym |= set(grp)
|
||||
for grp in outputs:
|
||||
osym |= set(grp)
|
||||
|
||||
diff = osym.difference(isym)
|
||||
if diff:
|
||||
raise NameError('undefined output symbols: %s' % ','.join(sorted(diff)))
|
||||
|
||||
return inputs, outputs
|
||||
Binary file not shown.
@@ -0,0 +1,113 @@
|
||||
from numba.np import numpy_support
|
||||
from numba.core import types
|
||||
|
||||
|
||||
class UfuncLowererBase:
|
||||
'''Callable class responsible for lowering calls to a specific gufunc.
|
||||
'''
|
||||
def __init__(self, ufunc, make_kernel_fn, make_ufunc_kernel_fn):
|
||||
self.ufunc = ufunc
|
||||
self.make_ufunc_kernel_fn = make_ufunc_kernel_fn
|
||||
self.kernel = make_kernel_fn(ufunc)
|
||||
self.libs = []
|
||||
|
||||
def __call__(self, context, builder, sig, args):
|
||||
return self.make_ufunc_kernel_fn(context, builder, sig, args,
|
||||
self.ufunc, self.kernel)
|
||||
|
||||
|
||||
class UfuncBase:
|
||||
|
||||
@property
|
||||
def nin(self):
|
||||
return self.ufunc.nin
|
||||
|
||||
@property
|
||||
def nout(self):
|
||||
return self.ufunc.nout
|
||||
|
||||
@property
|
||||
def nargs(self):
|
||||
return self.ufunc.nargs
|
||||
|
||||
@property
|
||||
def ntypes(self):
|
||||
return self.ufunc.ntypes
|
||||
|
||||
@property
|
||||
def types(self):
|
||||
return self.ufunc.types
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self.ufunc.identity
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
return self.ufunc.signature
|
||||
|
||||
@property
|
||||
def accumulate(self):
|
||||
return self.ufunc.accumulate
|
||||
|
||||
@property
|
||||
def at(self):
|
||||
return self.ufunc.at
|
||||
|
||||
@property
|
||||
def outer(self):
|
||||
return self.ufunc.outer
|
||||
|
||||
@property
|
||||
def reduce(self):
|
||||
return self.ufunc.reduce
|
||||
|
||||
@property
|
||||
def reduceat(self):
|
||||
return self.ufunc.reduceat
|
||||
|
||||
def disable_compile(self):
|
||||
"""
|
||||
Disable the compilation of new signatures at call time.
|
||||
"""
|
||||
# If disabling compilation then there must be at least one signature
|
||||
assert len(self._dispatcher.overloads) > 0
|
||||
self._frozen = True
|
||||
|
||||
def _install_cg(self, targetctx=None):
|
||||
"""
|
||||
Install an implementation function for a GUFunc/DUFunc object in the
|
||||
given target context. If no target context is given, then
|
||||
_install_cg() installs into the target context of the
|
||||
dispatcher object (should be same default context used by
|
||||
jit() and njit()).
|
||||
"""
|
||||
if targetctx is None:
|
||||
targetctx = self._dispatcher.targetdescr.target_context
|
||||
_any = types.Any
|
||||
_arr = types.Array
|
||||
# Either all outputs are explicit or none of them are
|
||||
sig0 = (_any,) * self.ufunc.nin + (_arr,) * self.ufunc.nout
|
||||
sig1 = (_any,) * self.ufunc.nin
|
||||
targetctx.insert_func_defn(
|
||||
[(self._lower_me, self, sig) for sig in (sig0, sig1)])
|
||||
|
||||
def find_ewise_function(self, ewise_types):
|
||||
"""
|
||||
Given a tuple of element-wise argument types, find a matching
|
||||
signature in the dispatcher.
|
||||
|
||||
Return a 2-tuple containing the matching signature, and
|
||||
compilation result. Will return two None's if no matching
|
||||
signature was found.
|
||||
"""
|
||||
if self._frozen:
|
||||
# If we cannot compile, coerce to the best matching loop
|
||||
loop = numpy_support.ufunc_find_matching_loop(self, ewise_types)
|
||||
if loop is None:
|
||||
return None, None
|
||||
ewise_types = tuple(loop.inputs + loop.outputs)[:len(ewise_types)]
|
||||
for sig, cres in self._dispatcher.overloads.items():
|
||||
if self.match_signature(ewise_types, sig):
|
||||
return sig, cres
|
||||
return None, None
|
||||
@@ -0,0 +1,444 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from numba.core import config, targetconfig
|
||||
from numba.core.decorators import jit
|
||||
from numba.core.descriptors import TargetDescriptor
|
||||
from numba.core.extending import is_jitted
|
||||
from numba.core.errors import NumbaDeprecationWarning
|
||||
from numba.core.options import TargetOptions, include_default_options
|
||||
from numba.core.registry import cpu_target
|
||||
from numba.core.target_extension import dispatcher_registry, target_registry
|
||||
from numba.core import utils, types, serialize, compiler, sigutils
|
||||
from numba.np.numpy_support import as_dtype
|
||||
from numba.np.ufunc import _internal
|
||||
from numba.np.ufunc.sigparse import parse_signature
|
||||
from numba.np.ufunc.wrappers import build_ufunc_wrapper, build_gufunc_wrapper
|
||||
from numba.core.caching import FunctionCache, NullCache
|
||||
from numba.core.compiler_lock import global_compiler_lock
|
||||
|
||||
|
||||
_options_mixin = include_default_options(
|
||||
"nopython",
|
||||
"forceobj",
|
||||
"boundscheck",
|
||||
"fastmath",
|
||||
"writable_args"
|
||||
)
|
||||
|
||||
|
||||
class UFuncTargetOptions(_options_mixin, TargetOptions):
|
||||
|
||||
def finalize(self, flags, options):
|
||||
if not flags.is_set("enable_pyobject"):
|
||||
flags.enable_pyobject = True
|
||||
|
||||
if not flags.is_set("enable_looplift"):
|
||||
flags.enable_looplift = True
|
||||
|
||||
flags.inherit_if_not_set("nrt", default=True)
|
||||
|
||||
if not flags.is_set("debuginfo"):
|
||||
flags.debuginfo = config.DEBUGINFO_DEFAULT
|
||||
|
||||
if not flags.is_set("boundscheck"):
|
||||
flags.boundscheck = flags.debuginfo
|
||||
|
||||
flags.enable_pyobject_looplift = True
|
||||
|
||||
flags.inherit_if_not_set("fastmath")
|
||||
|
||||
|
||||
class UFuncTarget(TargetDescriptor):
|
||||
options = UFuncTargetOptions
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('ufunc')
|
||||
|
||||
@property
|
||||
def typing_context(self):
|
||||
return cpu_target.typing_context
|
||||
|
||||
@property
|
||||
def target_context(self):
|
||||
return cpu_target.target_context
|
||||
|
||||
|
||||
ufunc_target = UFuncTarget()
|
||||
|
||||
|
||||
class UFuncDispatcher(serialize.ReduceMixin):
|
||||
"""
|
||||
An object handling compilation of various signatures for a ufunc.
|
||||
"""
|
||||
targetdescr = ufunc_target
|
||||
|
||||
def __init__(self, py_func, locals=None, targetoptions=None):
|
||||
if locals is None:
|
||||
locals = {}
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
self.py_func = py_func
|
||||
self.overloads = utils.UniqueDict()
|
||||
self.targetoptions = targetoptions
|
||||
self.locals = locals
|
||||
self.cache = NullCache()
|
||||
|
||||
def _reduce_states(self):
|
||||
"""
|
||||
NOTE: part of ReduceMixin protocol
|
||||
"""
|
||||
return dict(
|
||||
pyfunc=self.py_func,
|
||||
locals=self.locals,
|
||||
targetoptions=self.targetoptions,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _rebuild(cls, pyfunc, locals, targetoptions):
|
||||
"""
|
||||
NOTE: part of ReduceMixin protocol
|
||||
"""
|
||||
return cls(py_func=pyfunc, locals=locals, targetoptions=targetoptions)
|
||||
|
||||
def enable_caching(self):
|
||||
self.cache = FunctionCache(self.py_func)
|
||||
|
||||
def compile(self, sig, locals=None, **targetoptions):
|
||||
if locals is None:
|
||||
locals = {}
|
||||
locs = self.locals.copy()
|
||||
locs.update(locals)
|
||||
|
||||
topt = self.targetoptions.copy()
|
||||
topt.update(targetoptions)
|
||||
|
||||
flags = compiler.Flags()
|
||||
self.targetdescr.options.parse_as_flags(flags, topt)
|
||||
|
||||
flags.no_cpython_wrapper = True
|
||||
flags.error_model = "numpy"
|
||||
# Disable loop lifting
|
||||
# The feature requires a real
|
||||
# python function
|
||||
flags.enable_looplift = False
|
||||
|
||||
return self._compile_core(sig, flags, locals)
|
||||
|
||||
def _compile_core(self, sig, flags, locals):
|
||||
"""
|
||||
Trigger the compiler on the core function or load a previously
|
||||
compiled version from the cache. Returns the CompileResult.
|
||||
"""
|
||||
typingctx = self.targetdescr.typing_context
|
||||
targetctx = self.targetdescr.target_context
|
||||
|
||||
@contextmanager
|
||||
def store_overloads_on_success():
|
||||
# use to ensure overloads are stored on success
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
exists = self.overloads.get(cres.signature)
|
||||
if exists is None:
|
||||
self.overloads[cres.signature] = cres
|
||||
|
||||
# Use cache and compiler in a critical section
|
||||
with global_compiler_lock:
|
||||
with targetconfig.ConfigStack().enter(flags.copy()):
|
||||
with store_overloads_on_success():
|
||||
# attempt look up of existing
|
||||
cres = self.cache.load_overload(sig, targetctx)
|
||||
if cres is not None:
|
||||
return cres
|
||||
|
||||
# Compile
|
||||
args, return_type = sigutils.normalize_signature(sig)
|
||||
cres = compiler.compile_extra(typingctx, targetctx,
|
||||
self.py_func, args=args,
|
||||
return_type=return_type,
|
||||
flags=flags, locals=locals)
|
||||
|
||||
# cache lookup failed before so safe to save
|
||||
self.cache.save_overload(sig, cres)
|
||||
|
||||
return cres
|
||||
|
||||
|
||||
dispatcher_registry[target_registry['npyufunc']] = UFuncDispatcher
|
||||
|
||||
|
||||
# Utility functions
|
||||
|
||||
def _compile_element_wise_function(nb_func, targetoptions, sig):
|
||||
# Do compilation
|
||||
# Return CompileResult to test
|
||||
cres = nb_func.compile(sig, **targetoptions)
|
||||
args, return_type = sigutils.normalize_signature(sig)
|
||||
return cres, args, return_type
|
||||
|
||||
|
||||
def _finalize_ufunc_signature(cres, args, return_type):
|
||||
'''Given a compilation result, argument types, and a return type,
|
||||
build a valid Numba signature after validating that it doesn't
|
||||
violate the constraints for the compilation mode.
|
||||
'''
|
||||
if return_type is None:
|
||||
if cres.objectmode:
|
||||
# Object mode is used and return type is not specified
|
||||
raise TypeError("return type must be specified for object mode")
|
||||
else:
|
||||
return_type = cres.signature.return_type
|
||||
|
||||
assert return_type != types.pyobject
|
||||
return return_type(*args)
|
||||
|
||||
|
||||
def _build_element_wise_ufunc_wrapper(cres, signature):
|
||||
'''Build a wrapper for the ufunc loop entry point given by the
|
||||
compilation result object, using the element-wise signature.
|
||||
'''
|
||||
ctx = cres.target_context
|
||||
library = cres.library
|
||||
fname = cres.fndesc.llvm_func_name
|
||||
|
||||
with global_compiler_lock:
|
||||
info = build_ufunc_wrapper(library, ctx, fname, signature,
|
||||
cres.objectmode, cres)
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
# Get dtypes
|
||||
dtypenums = [as_dtype(a).num for a in signature.args]
|
||||
dtypenums.append(as_dtype(signature.return_type).num)
|
||||
return dtypenums, ptr, cres.environment
|
||||
|
||||
|
||||
_identities = {
|
||||
0: _internal.PyUFunc_Zero,
|
||||
1: _internal.PyUFunc_One,
|
||||
None: _internal.PyUFunc_None,
|
||||
"reorderable": _internal.PyUFunc_ReorderableNone,
|
||||
}
|
||||
|
||||
|
||||
def parse_identity(identity):
|
||||
"""
|
||||
Parse an identity value and return the corresponding low-level value
|
||||
for Numpy.
|
||||
"""
|
||||
try:
|
||||
identity = _identities[identity]
|
||||
except KeyError:
|
||||
raise ValueError("Invalid identity value %r" % (identity,))
|
||||
return identity
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _suppress_deprecation_warning_nopython_not_supplied():
|
||||
"""This suppresses the NumbaDeprecationWarning that occurs through the use
|
||||
of `jit` without the `nopython` kwarg. This use of `jit` occurs in a few
|
||||
places in the `{g,}ufunc` mechanism in Numba, predominantly to wrap the
|
||||
"kernel" function."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore',
|
||||
category=NumbaDeprecationWarning,
|
||||
message=(".*The 'nopython' keyword argument "
|
||||
"was not supplied*"),)
|
||||
yield
|
||||
|
||||
|
||||
# Class definitions
|
||||
|
||||
class _BaseUFuncBuilder(object):
|
||||
|
||||
def add(self, sig=None):
|
||||
if hasattr(self, 'targetoptions'):
|
||||
targetoptions = self.targetoptions
|
||||
else:
|
||||
targetoptions = self.nb_func.targetoptions
|
||||
cres, args, return_type = _compile_element_wise_function(
|
||||
self.nb_func, targetoptions, sig)
|
||||
sig = self._finalize_signature(cres, args, return_type)
|
||||
self._sigs.append(sig)
|
||||
self._cres[sig] = cres
|
||||
return cres
|
||||
|
||||
def disable_compile(self):
|
||||
"""
|
||||
Disable the compilation of new signatures at call time.
|
||||
"""
|
||||
# Override this for implementations that support lazy compilation
|
||||
|
||||
|
||||
class UFuncBuilder(_BaseUFuncBuilder):
|
||||
|
||||
def __init__(self, py_func, identity=None, cache=False, targetoptions=None):
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
if is_jitted(py_func):
|
||||
py_func = py_func.py_func
|
||||
self.py_func = py_func
|
||||
self.identity = parse_identity(identity)
|
||||
with _suppress_deprecation_warning_nopython_not_supplied():
|
||||
self.nb_func = jit(_target='npyufunc',
|
||||
cache=cache,
|
||||
**targetoptions)(py_func)
|
||||
self._sigs = []
|
||||
self._cres = {}
|
||||
|
||||
def _finalize_signature(self, cres, args, return_type):
|
||||
'''Slated for deprecation, use ufuncbuilder._finalize_ufunc_signature()
|
||||
instead.
|
||||
'''
|
||||
return _finalize_ufunc_signature(cres, args, return_type)
|
||||
|
||||
def build_ufunc(self):
|
||||
with global_compiler_lock:
|
||||
dtypelist = []
|
||||
ptrlist = []
|
||||
if not self.nb_func:
|
||||
raise TypeError("No definition")
|
||||
|
||||
# Get signature in the order they are added
|
||||
keepalive = []
|
||||
cres = None
|
||||
for sig in self._sigs:
|
||||
cres = self._cres[sig]
|
||||
dtypenums, ptr, env = self.build(cres, sig)
|
||||
dtypelist.append(dtypenums)
|
||||
ptrlist.append(int(ptr))
|
||||
keepalive.append((cres.library, env))
|
||||
|
||||
datlist = [None] * len(ptrlist)
|
||||
|
||||
if cres is None:
|
||||
argspec = inspect.getfullargspec(self.py_func)
|
||||
inct = len(argspec.args)
|
||||
else:
|
||||
inct = len(cres.signature.args)
|
||||
outct = 1
|
||||
|
||||
# Becareful that fromfunc does not provide full error checking yet.
|
||||
# If typenum is out-of-bound, we have nasty memory corruptions.
|
||||
# For instance, -1 for typenum will cause segfault.
|
||||
# If elements of type-list (2nd arg) is tuple instead,
|
||||
# there will also memory corruption. (Seems like code rewrite.)
|
||||
ufunc = _internal.fromfunc(
|
||||
self.py_func.__name__, self.py_func.__doc__,
|
||||
ptrlist, dtypelist, inct, outct, datlist,
|
||||
keepalive, self.identity,
|
||||
)
|
||||
|
||||
return ufunc
|
||||
|
||||
def build(self, cres, signature):
|
||||
'''Slated for deprecation, use
|
||||
ufuncbuilder._build_element_wise_ufunc_wrapper().
|
||||
'''
|
||||
return _build_element_wise_ufunc_wrapper(cres, signature)
|
||||
|
||||
|
||||
class GUFuncBuilder(_BaseUFuncBuilder):
|
||||
|
||||
# TODO handle scalar
|
||||
def __init__(self, py_func, signature, identity=None, cache=False,
|
||||
targetoptions=None, writable_args=()):
|
||||
if targetoptions is None:
|
||||
targetoptions = {}
|
||||
self.py_func = py_func
|
||||
self.identity = parse_identity(identity)
|
||||
with _suppress_deprecation_warning_nopython_not_supplied():
|
||||
self.nb_func = jit(_target='npyufunc', cache=cache)(py_func)
|
||||
self.signature = signature
|
||||
self.sin, self.sout = parse_signature(signature)
|
||||
self.targetoptions = targetoptions
|
||||
self.cache = cache
|
||||
self._sigs = []
|
||||
self._cres = {}
|
||||
|
||||
transform_arg = _get_transform_arg(py_func)
|
||||
self.writable_args = tuple([transform_arg(a) for a in writable_args])
|
||||
|
||||
def _finalize_signature(self, cres, args, return_type):
|
||||
if not cres.objectmode and cres.signature.return_type != types.void:
|
||||
raise TypeError("gufunc kernel must have void return type")
|
||||
|
||||
if return_type is None:
|
||||
return_type = types.void
|
||||
|
||||
return return_type(*args)
|
||||
|
||||
@global_compiler_lock
|
||||
def build_ufunc(self):
|
||||
type_list = []
|
||||
func_list = []
|
||||
if not self.nb_func:
|
||||
raise TypeError("No definition")
|
||||
|
||||
# Get signature in the order they are added
|
||||
keepalive = []
|
||||
for sig in self._sigs:
|
||||
cres = self._cres[sig]
|
||||
dtypenums, ptr, env = self.build(cres)
|
||||
type_list.append(dtypenums)
|
||||
func_list.append(int(ptr))
|
||||
keepalive.append((cres.library, env))
|
||||
|
||||
datalist = [None] * len(func_list)
|
||||
|
||||
nin = len(self.sin)
|
||||
nout = len(self.sout)
|
||||
|
||||
# Pass envs to fromfuncsig to bind to the lifetime of the ufunc object
|
||||
ufunc = _internal.fromfunc(
|
||||
self.py_func.__name__, self.py_func.__doc__,
|
||||
func_list, type_list, nin, nout, datalist,
|
||||
keepalive, self.identity, self.signature, self.writable_args
|
||||
)
|
||||
return ufunc
|
||||
|
||||
def build(self, cres):
|
||||
"""
|
||||
Returns (dtype numbers, function ptr, EnvironmentObject)
|
||||
"""
|
||||
# Builder wrapper for ufunc entry point
|
||||
signature = cres.signature
|
||||
info = build_gufunc_wrapper(
|
||||
self.py_func, cres, self.sin, self.sout,
|
||||
cache=self.cache, is_parfors=False,
|
||||
)
|
||||
|
||||
env = info.env
|
||||
ptr = info.library.get_pointer_to_function(info.name)
|
||||
# Get dtypes
|
||||
dtypenums = []
|
||||
for a in signature.args:
|
||||
if isinstance(a, types.Array):
|
||||
ty = a.dtype
|
||||
else:
|
||||
ty = a
|
||||
dtypenums.append(as_dtype(ty).num)
|
||||
return dtypenums, ptr, env
|
||||
|
||||
|
||||
def _get_transform_arg(py_func):
|
||||
"""Return function that transform arg into index"""
|
||||
args = inspect.getfullargspec(py_func).args
|
||||
pos_by_arg = {arg: i for i, arg in enumerate(args)}
|
||||
|
||||
def transform_arg(arg):
|
||||
if isinstance(arg, int):
|
||||
return arg
|
||||
|
||||
try:
|
||||
return pos_by_arg[arg]
|
||||
except KeyError:
|
||||
msg = (f"Specified writable arg {arg} not found in arg list "
|
||||
f"{args} for function {py_func.__qualname__}")
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return transform_arg
|
||||
Binary file not shown.
@@ -0,0 +1,743 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from llvmlite.ir import Constant, IRBuilder
|
||||
from llvmlite import ir
|
||||
|
||||
from numba.core import types, cgutils
|
||||
from numba.core.compiler_lock import global_compiler_lock
|
||||
from numba.core.caching import make_library_cache, NullCache
|
||||
|
||||
|
||||
_wrapper_info = namedtuple('_wrapper_info', ['library', 'env', 'name'])
|
||||
|
||||
|
||||
def _build_ufunc_loop_body(load, store, context, func, builder, arrays, out,
|
||||
offsets, store_offset, signature, pyapi, env):
|
||||
elems = load()
|
||||
|
||||
# Compute
|
||||
status, retval = context.call_conv.call_function(builder, func,
|
||||
signature.return_type,
|
||||
signature.args, elems)
|
||||
|
||||
# Store
|
||||
with builder.if_else(status.is_ok, likely=True) as (if_ok, if_error):
|
||||
with if_ok:
|
||||
store(retval)
|
||||
with if_error:
|
||||
gil = pyapi.gil_ensure()
|
||||
context.call_conv.raise_error(builder, pyapi, status)
|
||||
pyapi.gil_release(gil)
|
||||
|
||||
# increment indices
|
||||
for off, ary in zip(offsets, arrays):
|
||||
builder.store(builder.add(builder.load(off), ary.step), off)
|
||||
|
||||
builder.store(builder.add(builder.load(store_offset), out.step),
|
||||
store_offset)
|
||||
|
||||
return status.code
|
||||
|
||||
|
||||
def _build_ufunc_loop_body_objmode(load, store, context, func, builder,
|
||||
arrays, out, offsets, store_offset,
|
||||
signature, env, pyapi):
|
||||
elems = load()
|
||||
|
||||
# Compute
|
||||
_objargs = [types.pyobject] * len(signature.args)
|
||||
# We need to push the error indicator to avoid it messing with
|
||||
# the ufunc's execution. We restore it unless the ufunc raised
|
||||
# a new error.
|
||||
with pyapi.err_push(keep_new=True):
|
||||
status, retval = context.call_conv.call_function(builder, func,
|
||||
types.pyobject,
|
||||
_objargs, elems)
|
||||
# Release owned reference to arguments
|
||||
for elem in elems:
|
||||
pyapi.decref(elem)
|
||||
# NOTE: if an error occurred, it will be caught by the Numpy machinery
|
||||
|
||||
# Store
|
||||
store(retval)
|
||||
|
||||
# increment indices
|
||||
for off, ary in zip(offsets, arrays):
|
||||
builder.store(builder.add(builder.load(off), ary.step), off)
|
||||
|
||||
builder.store(builder.add(builder.load(store_offset), out.step),
|
||||
store_offset)
|
||||
|
||||
return status.code
|
||||
|
||||
|
||||
def build_slow_loop_body(context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi, env):
|
||||
def load():
|
||||
elems = [ary.load_direct(builder.load(off))
|
||||
for off, ary in zip(offsets, arrays)]
|
||||
return elems
|
||||
|
||||
def store(retval):
|
||||
out.store_direct(retval, builder.load(store_offset))
|
||||
|
||||
return _build_ufunc_loop_body(load, store, context, func, builder, arrays,
|
||||
out, offsets, store_offset, signature, pyapi,
|
||||
env=env)
|
||||
|
||||
|
||||
def build_obj_loop_body(context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi, envptr, env):
|
||||
env_body = context.get_env_body(builder, envptr)
|
||||
env_manager = pyapi.get_env_manager(env, env_body, envptr)
|
||||
|
||||
def load():
|
||||
# Load
|
||||
elems = [ary.load_direct(builder.load(off))
|
||||
for off, ary in zip(offsets, arrays)]
|
||||
# Box
|
||||
elems = [pyapi.from_native_value(t, v, env_manager)
|
||||
for v, t in zip(elems, signature.args)]
|
||||
return elems
|
||||
|
||||
def store(retval):
|
||||
is_ok = cgutils.is_not_null(builder, retval)
|
||||
# If an error is raised by the object mode ufunc, it will
|
||||
# simply get caught by the Numpy ufunc machinery.
|
||||
with builder.if_then(is_ok, likely=True):
|
||||
# Unbox
|
||||
native = pyapi.to_native_value(signature.return_type, retval)
|
||||
assert native.cleanup is None
|
||||
# Store
|
||||
out.store_direct(native.value, builder.load(store_offset))
|
||||
# Release owned reference
|
||||
pyapi.decref(retval)
|
||||
|
||||
return _build_ufunc_loop_body_objmode(load, store, context, func, builder,
|
||||
arrays, out, offsets, store_offset,
|
||||
signature, envptr, pyapi)
|
||||
|
||||
|
||||
def build_fast_loop_body(context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, ind, pyapi, env):
|
||||
def load():
|
||||
elems = [ary.load_aligned(ind)
|
||||
for ary in arrays]
|
||||
return elems
|
||||
|
||||
def store(retval):
|
||||
out.store_aligned(retval, ind)
|
||||
|
||||
return _build_ufunc_loop_body(load, store, context, func, builder, arrays,
|
||||
out, offsets, store_offset, signature, pyapi,
|
||||
env=env)
|
||||
|
||||
|
||||
def build_ufunc_wrapper(library, context, fname, signature, objmode, cres):
|
||||
"""
|
||||
Wrap the scalar function with a loop that iterates over the arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
(library, env, name)
|
||||
"""
|
||||
assert isinstance(fname, str)
|
||||
byte_t = ir.IntType(8)
|
||||
byte_ptr_t = ir.PointerType(byte_t)
|
||||
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
|
||||
intp_t = context.get_value_type(types.intp)
|
||||
intp_ptr_t = ir.PointerType(intp_t)
|
||||
|
||||
fnty = ir.FunctionType(ir.VoidType(), [byte_ptr_ptr_t, intp_ptr_t,
|
||||
intp_ptr_t, byte_ptr_t])
|
||||
|
||||
wrapperlib = context.codegen().create_library('ufunc_wrapper')
|
||||
wrapper_module = wrapperlib.create_ir_module('')
|
||||
if objmode:
|
||||
func_type = context.call_conv.get_function_type(
|
||||
types.pyobject, [types.pyobject] * len(signature.args))
|
||||
else:
|
||||
func_type = context.call_conv.get_function_type(
|
||||
signature.return_type, signature.args)
|
||||
|
||||
func = ir.Function(wrapper_module, func_type, name=fname)
|
||||
func.attributes.add("alwaysinline")
|
||||
|
||||
wrapper = ir.Function(wrapper_module, fnty, "__ufunc__." + func.name)
|
||||
arg_args, arg_dims, arg_steps, arg_data = wrapper.args
|
||||
arg_args.name = "args"
|
||||
arg_dims.name = "dims"
|
||||
arg_steps.name = "steps"
|
||||
arg_data.name = "data"
|
||||
|
||||
builder = IRBuilder(wrapper.append_basic_block("entry"))
|
||||
|
||||
# Prepare Environment
|
||||
envname = context.get_env_name(cres.fndesc)
|
||||
env = cres.environment
|
||||
envptr = builder.load(context.declare_env_global(builder.module, envname))
|
||||
|
||||
# Emit loop
|
||||
loopcount = builder.load(arg_dims, name="loopcount")
|
||||
|
||||
# Prepare inputs
|
||||
arrays = []
|
||||
for i, typ in enumerate(signature.args):
|
||||
arrays.append(UArrayArg(context, builder, arg_args, arg_steps, i, typ))
|
||||
|
||||
# Prepare output
|
||||
out = UArrayArg(context, builder, arg_args, arg_steps, len(arrays),
|
||||
signature.return_type)
|
||||
|
||||
# Setup indices
|
||||
offsets = []
|
||||
zero = context.get_constant(types.intp, 0)
|
||||
for _ in arrays:
|
||||
p = cgutils.alloca_once(builder, intp_t)
|
||||
offsets.append(p)
|
||||
builder.store(zero, p)
|
||||
|
||||
store_offset = cgutils.alloca_once(builder, intp_t)
|
||||
builder.store(zero, store_offset)
|
||||
|
||||
unit_strided = cgutils.true_bit
|
||||
for ary in arrays:
|
||||
unit_strided = builder.and_(unit_strided, ary.is_unit_strided)
|
||||
|
||||
pyapi = context.get_python_api(builder)
|
||||
if objmode:
|
||||
# General loop
|
||||
gil = pyapi.gil_ensure()
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t):
|
||||
build_obj_loop_body(
|
||||
context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi, envptr, env,
|
||||
)
|
||||
pyapi.gil_release(gil)
|
||||
builder.ret_void()
|
||||
|
||||
else:
|
||||
with builder.if_else(unit_strided) as (is_unit_strided, is_strided):
|
||||
with is_unit_strided:
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t) as loop:
|
||||
build_fast_loop_body(
|
||||
context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, loop.index, pyapi,
|
||||
env=envptr,
|
||||
)
|
||||
|
||||
with is_strided:
|
||||
# General loop
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t):
|
||||
build_slow_loop_body(
|
||||
context, func, builder, arrays, out, offsets,
|
||||
store_offset, signature, pyapi,
|
||||
env=envptr,
|
||||
)
|
||||
|
||||
builder.ret_void()
|
||||
del builder
|
||||
|
||||
# Link and finalize
|
||||
wrapperlib.add_ir_module(wrapper_module)
|
||||
wrapperlib.add_linking_library(library)
|
||||
return _wrapper_info(library=wrapperlib, env=env, name=wrapper.name)
|
||||
|
||||
|
||||
class UArrayArg(object):
|
||||
def __init__(self, context, builder, args, steps, i, fe_type):
|
||||
self.context = context
|
||||
self.builder = builder
|
||||
self.fe_type = fe_type
|
||||
offset = self.context.get_constant(types.intp, i)
|
||||
offseted_args = self.builder.load(builder.gep(args, [offset]))
|
||||
data_type = context.get_data_type(fe_type)
|
||||
self.dataptr = self.builder.bitcast(offseted_args,
|
||||
data_type.as_pointer())
|
||||
sizeof = self.context.get_abi_sizeof(data_type)
|
||||
self.abisize = self.context.get_constant(types.intp, sizeof)
|
||||
offseted_step = self.builder.gep(steps, [offset])
|
||||
self.step = self.builder.load(offseted_step)
|
||||
self.is_unit_strided = builder.icmp_unsigned('==',
|
||||
self.abisize, self.step)
|
||||
self.builder = builder
|
||||
|
||||
def load_direct(self, byteoffset):
|
||||
"""
|
||||
Generic load from the given *byteoffset*. load_aligned() is
|
||||
preferred if possible.
|
||||
"""
|
||||
ptr = cgutils.pointer_add(self.builder, self.dataptr, byteoffset)
|
||||
return self.context.unpack_value(self.builder, self.fe_type, ptr)
|
||||
|
||||
def load_aligned(self, ind):
|
||||
# Using gep() instead of explicit pointer addition helps LLVM
|
||||
# vectorize the loop.
|
||||
ptr = self.builder.gep(self.dataptr, [ind])
|
||||
return self.context.unpack_value(self.builder, self.fe_type, ptr)
|
||||
|
||||
def store_direct(self, value, byteoffset):
|
||||
ptr = cgutils.pointer_add(self.builder, self.dataptr, byteoffset)
|
||||
self.context.pack_value(self.builder, self.fe_type, value, ptr)
|
||||
|
||||
def store_aligned(self, value, ind):
|
||||
ptr = self.builder.gep(self.dataptr, [ind])
|
||||
self.context.pack_value(self.builder, self.fe_type, value, ptr)
|
||||
|
||||
|
||||
GufWrapperCache = make_library_cache('guf')
|
||||
|
||||
|
||||
class _GufuncWrapper(object):
|
||||
def __init__(self, py_func, cres, sin, sout, cache, is_parfors):
|
||||
"""
|
||||
The *is_parfors* argument is a boolean that indicates if the GUfunc
|
||||
being built is to be used as a ParFors kernel. If True, it disables
|
||||
the caching on the wrapper as a separate unit because it will be linked
|
||||
into the caller function and cached along with it.
|
||||
"""
|
||||
self.py_func = py_func
|
||||
self.cres = cres
|
||||
self.sin = sin
|
||||
self.sout = sout
|
||||
self.is_objectmode = self.signature.return_type == types.pyobject
|
||||
self.cache = (GufWrapperCache(py_func=self.py_func)
|
||||
if cache else NullCache())
|
||||
self.is_parfors = bool(is_parfors)
|
||||
|
||||
@property
|
||||
def library(self):
|
||||
return self.cres.library
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
return self.cres.target_context
|
||||
|
||||
@property
|
||||
def call_conv(self):
|
||||
return self.context.call_conv
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
return self.cres.signature
|
||||
|
||||
@property
|
||||
def fndesc(self):
|
||||
return self.cres.fndesc
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self.cres.environment
|
||||
|
||||
def _wrapper_function_type(self):
|
||||
byte_t = ir.IntType(8)
|
||||
byte_ptr_t = ir.PointerType(byte_t)
|
||||
byte_ptr_ptr_t = ir.PointerType(byte_ptr_t)
|
||||
intp_t = self.context.get_value_type(types.intp)
|
||||
intp_ptr_t = ir.PointerType(intp_t)
|
||||
|
||||
fnty = ir.FunctionType(ir.VoidType(), [byte_ptr_ptr_t, intp_ptr_t,
|
||||
intp_ptr_t, byte_ptr_t])
|
||||
return fnty
|
||||
|
||||
def _build_wrapper(self, library, name):
|
||||
"""
|
||||
The LLVM IRBuilder code to create the gufunc wrapper.
|
||||
The *library* arg is the CodeLibrary to which the wrapper should
|
||||
be added. The *name* arg is the name of the wrapper function being
|
||||
created.
|
||||
"""
|
||||
intp_t = self.context.get_value_type(types.intp)
|
||||
fnty = self._wrapper_function_type()
|
||||
|
||||
wrapper_module = library.create_ir_module('_gufunc_wrapper')
|
||||
func_type = self.call_conv.get_function_type(self.fndesc.restype,
|
||||
self.fndesc.argtypes)
|
||||
fname = self.fndesc.llvm_func_name
|
||||
func = ir.Function(wrapper_module, func_type, name=fname)
|
||||
|
||||
func.attributes.add("alwaysinline")
|
||||
wrapper = ir.Function(wrapper_module, fnty, name)
|
||||
# The use of weak_odr linkage avoids the function being dropped due
|
||||
# to the order in which the wrappers and the user function are linked.
|
||||
wrapper.linkage = 'weak_odr'
|
||||
arg_args, arg_dims, arg_steps, arg_data = wrapper.args
|
||||
arg_args.name = "args"
|
||||
arg_dims.name = "dims"
|
||||
arg_steps.name = "steps"
|
||||
arg_data.name = "data"
|
||||
|
||||
builder = IRBuilder(wrapper.append_basic_block("entry"))
|
||||
loopcount = builder.load(arg_dims, name="loopcount")
|
||||
pyapi = self.context.get_python_api(builder)
|
||||
|
||||
# Unpack shapes
|
||||
unique_syms = set()
|
||||
for grp in (self.sin, self.sout):
|
||||
for syms in grp:
|
||||
unique_syms |= set(syms)
|
||||
|
||||
sym_map = {}
|
||||
for syms in self.sin:
|
||||
for s in syms:
|
||||
if s not in sym_map:
|
||||
sym_map[s] = len(sym_map)
|
||||
|
||||
sym_dim = {}
|
||||
for s, i in sym_map.items():
|
||||
sym_dim[s] = builder.load(builder.gep(arg_dims,
|
||||
[self.context.get_constant(
|
||||
types.intp,
|
||||
i + 1)]))
|
||||
|
||||
# Prepare inputs
|
||||
arrays = []
|
||||
step_offset = len(self.sin) + len(self.sout)
|
||||
for i, (typ, sym) in enumerate(zip(self.signature.args,
|
||||
self.sin + self.sout)):
|
||||
ary = GUArrayArg(self.context, builder, arg_args,
|
||||
arg_steps, i, step_offset, typ, sym, sym_dim)
|
||||
step_offset += len(sym)
|
||||
arrays.append(ary)
|
||||
|
||||
bbreturn = builder.append_basic_block('.return')
|
||||
|
||||
# Prologue
|
||||
self.gen_prologue(builder, pyapi)
|
||||
|
||||
# Loop
|
||||
with cgutils.for_range(builder, loopcount, intp=intp_t) as loop:
|
||||
args = [a.get_array_at_offset(loop.index) for a in arrays]
|
||||
innercall, error = self.gen_loop_body(builder, pyapi, func, args)
|
||||
# If error, escape
|
||||
cgutils.cbranch_or_continue(builder, error, bbreturn)
|
||||
|
||||
builder.branch(bbreturn)
|
||||
builder.position_at_end(bbreturn)
|
||||
|
||||
# Epilogue
|
||||
self.gen_epilogue(builder, pyapi)
|
||||
|
||||
builder.ret_void()
|
||||
|
||||
# Link
|
||||
library.add_ir_module(wrapper_module)
|
||||
library.add_linking_library(self.library)
|
||||
|
||||
def _compile_wrapper(self, wrapper_name):
|
||||
# Gufunc created by Parfors?
|
||||
if self.is_parfors:
|
||||
# No wrapper caching for parfors
|
||||
wrapperlib = self.context.codegen().create_library(str(self))
|
||||
# Build wrapper
|
||||
self._build_wrapper(wrapperlib, wrapper_name)
|
||||
# Non-parfors?
|
||||
else:
|
||||
# Use cache and compiler in a critical section
|
||||
wrapperlib = self.cache.load_overload(
|
||||
self.cres.signature, self.cres.target_context,
|
||||
)
|
||||
if wrapperlib is None:
|
||||
# Create library and enable caching
|
||||
wrapperlib = self.context.codegen().create_library(str(self))
|
||||
wrapperlib.enable_object_caching()
|
||||
# Build wrapper
|
||||
self._build_wrapper(wrapperlib, wrapper_name)
|
||||
# Cache
|
||||
self.cache.save_overload(self.cres.signature, wrapperlib)
|
||||
|
||||
return wrapperlib
|
||||
|
||||
@global_compiler_lock
|
||||
def build(self):
|
||||
wrapper_name = "__gufunc__." + self.fndesc.mangled_name
|
||||
wrapperlib = self._compile_wrapper(wrapper_name)
|
||||
return _wrapper_info(
|
||||
library=wrapperlib, env=self.env, name=wrapper_name,
|
||||
)
|
||||
|
||||
def gen_loop_body(self, builder, pyapi, func, args):
|
||||
status, retval = self.call_conv.call_function(
|
||||
builder, func, self.signature.return_type, self.signature.args,
|
||||
args)
|
||||
|
||||
with builder.if_then(status.is_error, likely=False):
|
||||
gil = pyapi.gil_ensure()
|
||||
self.context.call_conv.raise_error(builder, pyapi, status)
|
||||
pyapi.gil_release(gil)
|
||||
|
||||
return status.code, status.is_error
|
||||
|
||||
def gen_prologue(self, builder, pyapi):
|
||||
pass # Do nothing
|
||||
|
||||
def gen_epilogue(self, builder, pyapi):
|
||||
pass # Do nothing
|
||||
|
||||
|
||||
class _GufuncObjectWrapper(_GufuncWrapper):
|
||||
def gen_loop_body(self, builder, pyapi, func, args):
|
||||
innercall, error = _prepare_call_to_object_mode(self.context,
|
||||
builder, pyapi, func,
|
||||
self.signature,
|
||||
args)
|
||||
return innercall, error
|
||||
|
||||
def gen_prologue(self, builder, pyapi):
|
||||
# Acquire the GIL
|
||||
self.gil = pyapi.gil_ensure()
|
||||
|
||||
def gen_epilogue(self, builder, pyapi):
|
||||
# Release GIL
|
||||
pyapi.gil_release(self.gil)
|
||||
|
||||
|
||||
def build_gufunc_wrapper(py_func, cres, sin, sout, cache, is_parfors):
|
||||
signature = cres.signature
|
||||
wrapcls = (_GufuncObjectWrapper
|
||||
if signature.return_type == types.pyobject
|
||||
else _GufuncWrapper)
|
||||
return wrapcls(
|
||||
py_func, cres, sin, sout, cache, is_parfors=is_parfors,
|
||||
).build()
|
||||
|
||||
|
||||
def _prepare_call_to_object_mode(context, builder, pyapi, func,
|
||||
signature, args):
|
||||
mod = builder.module
|
||||
|
||||
bb_core_return = builder.append_basic_block('ufunc.core.return')
|
||||
|
||||
# Call to
|
||||
# PyObject* ndarray_new(int nd,
|
||||
# npy_intp *dims, /* shape */
|
||||
# npy_intp *strides,
|
||||
# void* data,
|
||||
# int type_num,
|
||||
# int itemsize)
|
||||
|
||||
ll_int = context.get_value_type(types.int32)
|
||||
ll_intp = context.get_value_type(types.intp)
|
||||
ll_intp_ptr = ir.PointerType(ll_intp)
|
||||
ll_voidptr = context.get_value_type(types.voidptr)
|
||||
ll_pyobj = context.get_value_type(types.pyobject)
|
||||
fnty = ir.FunctionType(ll_pyobj, [ll_int, ll_intp_ptr,
|
||||
ll_intp_ptr, ll_voidptr,
|
||||
ll_int, ll_int])
|
||||
|
||||
fn_array_new = cgutils.get_or_insert_function(mod, fnty,
|
||||
"numba_ndarray_new")
|
||||
|
||||
# Convert each llarray into pyobject
|
||||
error_pointer = cgutils.alloca_once(builder, ir.IntType(1), name='error')
|
||||
builder.store(cgutils.true_bit, error_pointer)
|
||||
|
||||
# The PyObject* arguments to the kernel function
|
||||
object_args = []
|
||||
object_pointers = []
|
||||
|
||||
for i, (arg, argty) in enumerate(zip(args, signature.args)):
|
||||
# Allocate NULL-initialized slot for this argument
|
||||
objptr = cgutils.alloca_once(builder, ll_pyobj, zfill=True)
|
||||
object_pointers.append(objptr)
|
||||
|
||||
if isinstance(argty, types.Array):
|
||||
# Special case arrays: we don't need full-blown NRT reflection
|
||||
# since the argument will be gone at the end of the kernel
|
||||
arycls = context.make_array(argty)
|
||||
array = arycls(context, builder, value=arg)
|
||||
|
||||
zero = Constant(ll_int, 0)
|
||||
|
||||
# Extract members of the llarray
|
||||
nd = Constant(ll_int, argty.ndim)
|
||||
dims = builder.gep(array._get_ptr_by_name('shape'), [zero, zero])
|
||||
strides = builder.gep(array._get_ptr_by_name('strides'),
|
||||
[zero, zero])
|
||||
data = builder.bitcast(array.data, ll_voidptr)
|
||||
dtype = np.dtype(str(argty.dtype))
|
||||
|
||||
# Prepare other info for reconstruction of the PyArray
|
||||
type_num = Constant(ll_int, dtype.num)
|
||||
itemsize = Constant(ll_int, dtype.itemsize)
|
||||
|
||||
# Call helper to reconstruct PyArray objects
|
||||
obj = builder.call(fn_array_new, [nd, dims, strides, data,
|
||||
type_num, itemsize])
|
||||
else:
|
||||
# Other argument types => use generic boxing
|
||||
obj = pyapi.from_native_value(argty, arg)
|
||||
|
||||
builder.store(obj, objptr)
|
||||
object_args.append(obj)
|
||||
|
||||
obj_is_null = cgutils.is_null(builder, obj)
|
||||
builder.store(obj_is_null, error_pointer)
|
||||
cgutils.cbranch_or_continue(builder, obj_is_null, bb_core_return)
|
||||
|
||||
# Call ufunc core function
|
||||
object_sig = [types.pyobject] * len(object_args)
|
||||
|
||||
status, retval = context.call_conv.call_function(
|
||||
builder, func, types.pyobject, object_sig,
|
||||
object_args)
|
||||
builder.store(status.is_error, error_pointer)
|
||||
|
||||
# Release returned object
|
||||
pyapi.decref(retval)
|
||||
|
||||
builder.branch(bb_core_return)
|
||||
# At return block
|
||||
builder.position_at_end(bb_core_return)
|
||||
|
||||
# Release argument objects
|
||||
for objptr in object_pointers:
|
||||
pyapi.decref(builder.load(objptr))
|
||||
|
||||
innercall = status.code
|
||||
return innercall, builder.load(error_pointer)
|
||||
|
||||
|
||||
class GUArrayArg(object):
|
||||
def __init__(self, context, builder, args, steps, i, step_offset,
|
||||
typ, syms, sym_dim):
|
||||
|
||||
self.context = context
|
||||
self.builder = builder
|
||||
|
||||
offset = context.get_constant(types.intp, i)
|
||||
|
||||
data = builder.load(builder.gep(args, [offset], name="data.ptr"),
|
||||
name="data")
|
||||
self.data = data
|
||||
|
||||
core_step_ptr = builder.gep(steps, [offset], name="core.step.ptr")
|
||||
core_step = builder.load(core_step_ptr)
|
||||
|
||||
if isinstance(typ, types.Array):
|
||||
as_scalar = not syms
|
||||
|
||||
# number of symbol in the shape spec should match the dimension
|
||||
# of the array type.
|
||||
if len(syms) != typ.ndim:
|
||||
if len(syms) == 0 and typ.ndim == 1:
|
||||
# This is an exception for handling scalar argument.
|
||||
# The type can be 1D array for scalar.
|
||||
# In the future, we may deprecate this exception.
|
||||
pass
|
||||
else:
|
||||
raise TypeError("type and shape signature mismatch for arg "
|
||||
"#{0}".format(i + 1))
|
||||
|
||||
ndim = typ.ndim
|
||||
shape = [sym_dim[s] for s in syms]
|
||||
strides = []
|
||||
|
||||
for j in range(ndim):
|
||||
stepptr = builder.gep(steps,
|
||||
[context.get_constant(types.intp,
|
||||
step_offset + j)],
|
||||
name="step.ptr")
|
||||
step = builder.load(stepptr)
|
||||
strides.append(step)
|
||||
|
||||
ldcls = (_ArrayAsScalarArgLoader
|
||||
if as_scalar
|
||||
else _ArrayArgLoader)
|
||||
|
||||
self._loader = ldcls(dtype=typ.dtype,
|
||||
ndim=ndim,
|
||||
core_step=core_step,
|
||||
as_scalar=as_scalar,
|
||||
shape=shape,
|
||||
strides=strides)
|
||||
else:
|
||||
# If typ is not an array
|
||||
if syms:
|
||||
raise TypeError("scalar type {0} given for non scalar "
|
||||
"argument #{1}".format(typ, i + 1))
|
||||
self._loader = _ScalarArgLoader(dtype=typ, stride=core_step)
|
||||
|
||||
def get_array_at_offset(self, ind):
|
||||
return self._loader.load(context=self.context, builder=self.builder,
|
||||
data=self.data, ind=ind)
|
||||
|
||||
|
||||
class _ScalarArgLoader(object):
|
||||
"""
|
||||
Handle GFunc argument loading where a scalar type is used in the core
|
||||
function.
|
||||
Note: It still has a stride because the input to the gufunc can be an array
|
||||
for this argument.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, stride):
|
||||
self.dtype = dtype
|
||||
self.stride = stride
|
||||
|
||||
def load(self, context, builder, data, ind):
|
||||
# Load at base + ind * stride
|
||||
data = builder.gep(data, [builder.mul(ind, self.stride)])
|
||||
dptr = builder.bitcast(data,
|
||||
context.get_data_type(self.dtype).as_pointer())
|
||||
return builder.load(dptr)
|
||||
|
||||
|
||||
class _ArrayArgLoader(object):
|
||||
"""
|
||||
Handle GUFunc argument loading where an array is expected.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, ndim, core_step, as_scalar, shape, strides):
|
||||
self.dtype = dtype
|
||||
self.ndim = ndim
|
||||
self.core_step = core_step
|
||||
self.as_scalar = as_scalar
|
||||
self.shape = shape
|
||||
self.strides = strides
|
||||
|
||||
def load(self, context, builder, data, ind):
|
||||
arytyp = types.Array(dtype=self.dtype, ndim=self.ndim, layout="A")
|
||||
arycls = context.make_array(arytyp)
|
||||
|
||||
array = arycls(context, builder)
|
||||
offseted_data = cgutils.pointer_add(builder,
|
||||
data,
|
||||
builder.mul(self.core_step,
|
||||
ind))
|
||||
|
||||
shape, strides = self._shape_and_strides(context, builder)
|
||||
|
||||
itemsize = context.get_abi_sizeof(context.get_data_type(self.dtype))
|
||||
context.populate_array(array,
|
||||
data=builder.bitcast(offseted_data,
|
||||
array.data.type),
|
||||
shape=shape,
|
||||
strides=strides,
|
||||
itemsize=context.get_constant(types.intp,
|
||||
itemsize),
|
||||
meminfo=None)
|
||||
|
||||
return array._getvalue()
|
||||
|
||||
def _shape_and_strides(self, context, builder):
|
||||
shape = cgutils.pack_array(builder, self.shape)
|
||||
strides = cgutils.pack_array(builder, self.strides)
|
||||
return shape, strides
|
||||
|
||||
|
||||
class _ArrayAsScalarArgLoader(_ArrayArgLoader):
|
||||
"""
|
||||
Handle GUFunc argument loading where the shape signature specifies
|
||||
a scalar "()" but a 1D array is used for the type of the core function.
|
||||
"""
|
||||
|
||||
def _shape_and_strides(self, context, builder):
|
||||
# Set shape and strides for a 1D size 1 array
|
||||
one = context.get_constant(types.intp, 1)
|
||||
zero = context.get_constant(types.intp, 0)
|
||||
shape = cgutils.pack_array(builder, [one])
|
||||
strides = cgutils.pack_array(builder, [zero])
|
||||
return shape, strides
|
||||
Reference in New Issue
Block a user