Videre
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def prepare_shape_strides_dtype(shape, strides, dtype, order):
|
||||
dtype = np.dtype(dtype)
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
if isinstance(strides, int):
|
||||
strides = (strides,)
|
||||
else:
|
||||
strides = strides or _fill_stride_by_order(shape, dtype, order)
|
||||
return shape, strides, dtype
|
||||
|
||||
|
||||
def _fill_stride_by_order(shape, dtype, order):
|
||||
nd = len(shape)
|
||||
if nd == 0:
|
||||
return ()
|
||||
strides = [0] * nd
|
||||
if order == 'C':
|
||||
strides[-1] = dtype.itemsize
|
||||
for d in reversed(range(nd - 1)):
|
||||
strides[d] = strides[d + 1] * shape[d + 1]
|
||||
elif order == 'F':
|
||||
strides[0] = dtype.itemsize
|
||||
for d in range(1, nd):
|
||||
strides[d] = strides[d - 1] * shape[d - 1]
|
||||
else:
|
||||
raise ValueError('must be either C/F order')
|
||||
return tuple(strides)
|
||||
Reference in New Issue
Block a user