Videre
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
from llvmlite import ir
|
||||
|
||||
from numba import cuda, types
|
||||
from numba.core import cgutils
|
||||
from numba.core.errors import RequireLiteralValue, NumbaValueError
|
||||
from numba.core.typing import signature
|
||||
from numba.core.extending import overload_attribute
|
||||
from numba.cuda import nvvmutils
|
||||
from numba.cuda.extending import intrinsic
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Grid functions
|
||||
|
||||
def _type_grid_function(ndim):
|
||||
val = ndim.literal_value
|
||||
if val == 1:
|
||||
restype = types.int64
|
||||
elif val in (2, 3):
|
||||
restype = types.UniTuple(types.int64, val)
|
||||
else:
|
||||
raise NumbaValueError('argument can only be 1, 2, 3')
|
||||
|
||||
return signature(restype, types.int32)
|
||||
|
||||
|
||||
@intrinsic
|
||||
def grid(typingctx, ndim):
|
||||
'''grid(ndim)
|
||||
|
||||
Return the absolute position of the current thread in the entire grid of
|
||||
blocks. *ndim* should correspond to the number of dimensions declared when
|
||||
instantiating the kernel. If *ndim* is 1, a single integer is returned.
|
||||
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
|
||||
|
||||
Computation of the first integer is as follows::
|
||||
|
||||
cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
|
||||
|
||||
and is similar for the other two indices, but using the ``y`` and ``z``
|
||||
attributes.
|
||||
'''
|
||||
|
||||
if not isinstance(ndim, types.IntegerLiteral):
|
||||
raise RequireLiteralValue(ndim)
|
||||
|
||||
sig = _type_grid_function(ndim)
|
||||
|
||||
def codegen(context, builder, sig, args):
|
||||
restype = sig.return_type
|
||||
if restype == types.int64:
|
||||
return nvvmutils.get_global_id(builder, dim=1)
|
||||
elif isinstance(restype, types.UniTuple):
|
||||
ids = nvvmutils.get_global_id(builder, dim=restype.count)
|
||||
return cgutils.pack_array(builder, ids)
|
||||
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def gridsize(typingctx, ndim):
|
||||
'''gridsize(ndim)
|
||||
|
||||
Return the absolute size (or shape) in threads of the entire grid of
|
||||
blocks. *ndim* should correspond to the number of dimensions declared when
|
||||
instantiating the kernel. If *ndim* is 1, a single integer is returned.
|
||||
If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
|
||||
|
||||
Computation of the first integer is as follows::
|
||||
|
||||
cuda.blockDim.x * cuda.gridDim.x
|
||||
|
||||
and is similar for the other two indices, but using the ``y`` and ``z``
|
||||
attributes.
|
||||
'''
|
||||
|
||||
if not isinstance(ndim, types.IntegerLiteral):
|
||||
raise RequireLiteralValue(ndim)
|
||||
|
||||
sig = _type_grid_function(ndim)
|
||||
|
||||
def _nthreads_for_dim(builder, dim):
|
||||
i64 = ir.IntType(64)
|
||||
ntid = nvvmutils.call_sreg(builder, f"ntid.{dim}")
|
||||
nctaid = nvvmutils.call_sreg(builder, f"nctaid.{dim}")
|
||||
return builder.mul(builder.sext(ntid, i64), builder.sext(nctaid, i64))
|
||||
|
||||
def codegen(context, builder, sig, args):
|
||||
restype = sig.return_type
|
||||
nx = _nthreads_for_dim(builder, 'x')
|
||||
|
||||
if restype == types.int64:
|
||||
return nx
|
||||
elif isinstance(restype, types.UniTuple):
|
||||
ny = _nthreads_for_dim(builder, 'y')
|
||||
|
||||
if restype.count == 2:
|
||||
return cgutils.pack_array(builder, (nx, ny))
|
||||
elif restype.count == 3:
|
||||
nz = _nthreads_for_dim(builder, 'z')
|
||||
return cgutils.pack_array(builder, (nx, ny, nz))
|
||||
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def _warpsize(typingctx):
|
||||
sig = signature(types.int32)
|
||||
|
||||
def codegen(context, builder, sig, args):
|
||||
return nvvmutils.call_sreg(builder, 'warpsize')
|
||||
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@overload_attribute(types.Module(cuda), 'warpsize', target='cuda')
|
||||
def cuda_warpsize(mod):
|
||||
'''
|
||||
The size of a warp. All architectures implemented to date have a warp size
|
||||
of 32.
|
||||
'''
|
||||
def get(mod):
|
||||
return _warpsize()
|
||||
return get
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# syncthreads
|
||||
|
||||
@intrinsic
|
||||
def syncthreads(typingctx):
|
||||
'''
|
||||
Synchronize all threads in the same thread block. This function implements
|
||||
the same pattern as barriers in traditional multi-threaded programming: this
|
||||
function waits until all threads in the block call it, at which point it
|
||||
returns control to all its callers.
|
||||
'''
|
||||
sig = signature(types.none)
|
||||
|
||||
def codegen(context, builder, sig, args):
|
||||
fname = 'llvm.nvvm.barrier0'
|
||||
lmod = builder.module
|
||||
fnty = ir.FunctionType(ir.VoidType(), ())
|
||||
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
||||
builder.call(sync, ())
|
||||
return context.get_dummy_value()
|
||||
|
||||
return sig, codegen
|
||||
|
||||
|
||||
def _syncthreads_predicate(typingctx, predicate, fname):
|
||||
if not isinstance(predicate, types.Integer):
|
||||
return None
|
||||
|
||||
sig = signature(types.i4, types.i4)
|
||||
|
||||
def codegen(context, builder, sig, args):
|
||||
fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32),))
|
||||
sync = cgutils.get_or_insert_function(builder.module, fnty, fname)
|
||||
return builder.call(sync, args)
|
||||
|
||||
return sig, codegen
|
||||
|
||||
|
||||
@intrinsic
|
||||
def syncthreads_count(typingctx, predicate):
|
||||
'''
|
||||
syncthreads_count(predicate)
|
||||
|
||||
An extension to numba.cuda.syncthreads where the return value is a count
|
||||
of the threads where predicate is true.
|
||||
'''
|
||||
fname = 'llvm.nvvm.barrier0.popc'
|
||||
return _syncthreads_predicate(typingctx, predicate, fname)
|
||||
|
||||
|
||||
@intrinsic
|
||||
def syncthreads_and(typingctx, predicate):
|
||||
'''
|
||||
syncthreads_and(predicate)
|
||||
|
||||
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
|
||||
true for all threads or 0 otherwise.
|
||||
'''
|
||||
fname = 'llvm.nvvm.barrier0.and'
|
||||
return _syncthreads_predicate(typingctx, predicate, fname)
|
||||
|
||||
|
||||
@intrinsic
|
||||
def syncthreads_or(typingctx, predicate):
|
||||
'''
|
||||
syncthreads_or(predicate)
|
||||
|
||||
An extension to numba.cuda.syncthreads where 1 is returned if predicate is
|
||||
true for any thread or 0 otherwise.
|
||||
'''
|
||||
fname = 'llvm.nvvm.barrier0.or'
|
||||
return _syncthreads_predicate(typingctx, predicate, fname)
|
||||
Reference in New Issue
Block a user