Videre
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Module containing private utility functions
|
||||
===========================================
|
||||
|
||||
The ``scipy._lib`` namespace is empty (for now). Tests for all
|
||||
utilities in submodules of ``_lib`` can be run with::
|
||||
|
||||
from scipy import _lib
|
||||
_lib.test()
|
||||
|
||||
"""
|
||||
from scipy._lib._testutils import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,9 @@
|
||||
# DO NOT RENAME THIS FILE
|
||||
# This is a hook for array_api_extra/src/array_api_extra/_lib/_compat.py
|
||||
# to override functions of array_api_compat.
|
||||
|
||||
from .array_api_compat import * # noqa: F403
|
||||
from ._array_api_override import array_namespace as scipy_array_namespace
|
||||
|
||||
# overrides array_api_compat.array_namespace inside array-api-extra
|
||||
array_namespace = scipy_array_namespace # type: ignore[assignment]
|
||||
@@ -0,0 +1,294 @@
|
||||
"""Generate flat tables showing Array API capabilities for use in docs.
|
||||
|
||||
These tables are intended for presenting Array API capabilities across
|
||||
a wide number of functions at once. Rows correspond to functions and
|
||||
columns correspond to library/device/option combinations.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from enum import auto, Enum
|
||||
from importlib import import_module
|
||||
from types import ModuleType
|
||||
|
||||
from scipy._lib._array_api import xp_capabilities_table
|
||||
from scipy._lib._array_api import _make_sphinx_capabilities
|
||||
|
||||
# For undocumented aliases of public functions which are kept around for
|
||||
# backwards compatibility reasons. These should be excluded from the
|
||||
# tables since they would be redundant. There are also no docs pages to
|
||||
# link entries to.
|
||||
ALIASES = {
|
||||
"scipy.linalg": {
|
||||
# Alias of scipy.linalg.solve_continuous_lyapunov
|
||||
"solve_lyapunov",
|
||||
},
|
||||
"scipy.ndimage": {
|
||||
# Alias of scipy.ndimage.sum_labels
|
||||
"sum",
|
||||
},
|
||||
"scipy.special": {
|
||||
# Alias of scipy.special.jv
|
||||
"jn",
|
||||
# Alias of scipy.special.roots_legendre
|
||||
"p_roots",
|
||||
# Alias of scipy.special.roots_chebyt
|
||||
"t_roots",
|
||||
# Alias of scipy.special.roots_chebyu
|
||||
"u_roots",
|
||||
# Alias of scipy.special.roots_chebyc
|
||||
"c_roots",
|
||||
# Alias of scipy.special.roots_chebys
|
||||
"s_roots",
|
||||
# Alias of scipy.special.roots_jacobi
|
||||
"j_roots",
|
||||
# Alias of scipy.special.roots_laguerre
|
||||
"l_roots",
|
||||
# Alias of scipy.special.roots_genlaguerre
|
||||
"la_roots",
|
||||
# Alias of scipy.special.roots_hermite
|
||||
"h_roots",
|
||||
# Alias of scipy.special.roots_hermitenorm
|
||||
"he_roots",
|
||||
# Alias of scipy.special.roots_gegenbauer
|
||||
"cg_roots",
|
||||
# Alias of scipy.special.roots_sh_legendre
|
||||
"ps_roots",
|
||||
# Alias of scipy.special.roots_sh_chebyt
|
||||
"ts_roots",
|
||||
# Alias of scipy.special.roots_chebyu
|
||||
"us_roots",
|
||||
# Alias of scipy.special.roots_sh_jacobi
|
||||
"js_roots",
|
||||
}
|
||||
}
|
||||
|
||||
# Shortened names for use in table.
|
||||
BACKEND_NAMES_MAP = {
|
||||
"jax.numpy": "jax",
|
||||
"dask.array": "dask",
|
||||
}
|
||||
|
||||
|
||||
class BackendSupportStatus(Enum):
|
||||
YES = auto()
|
||||
NO = auto()
|
||||
OUT_OF_SCOPE = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
|
||||
def _process_capabilities_table_entry(entry: dict | None) -> dict[str, dict[str, bool]]:
|
||||
"""Returns dict showing alternative backend support in easy to consume form.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
entry : Optional[dict]
|
||||
A dict with the structure of the values of the dict
|
||||
scipy._lib._array_api.xp_capabilities_table. If None, it is
|
||||
assumped that no alternative backends are supported.
|
||||
Default: None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, dict[str, bool]]
|
||||
The output dict currently has keys "cpu", "gpu", "jit" and "lazy".
|
||||
The value associated to each key is itself a dict. The keys of
|
||||
the inner dicts correspond to backends, with bool values stating
|
||||
whether or not the backend is supported with a given device or
|
||||
mode. Inapplicable backends do not appear in the inner dicts
|
||||
(e.g. since cupy is gpu-only, it does not appear in the inner
|
||||
dict keyed on "cpu"). Only alternative backends to NumPy are
|
||||
included since NumPY support should be guaranteed.
|
||||
|
||||
"""
|
||||
# This is a template for the output format. If more backends and
|
||||
# backend options are added, it will need to be updated manually.
|
||||
# Entries start as boolean, but upon returning, will take values
|
||||
# from the BackendSupportStatus Enum.
|
||||
output = {
|
||||
"cpu": {"torch": False, "jax": False, "dask": False},
|
||||
"gpu": {"cupy": False, "torch": False, "jax": False},
|
||||
"jit": {"jax": False},
|
||||
"lazy": {"dask": False},
|
||||
}
|
||||
S = BackendSupportStatus
|
||||
if entry is None:
|
||||
# If there is no entry, assume no alternative backends are supported.
|
||||
# If the list of supported backends will grows, this hard-coded dict
|
||||
# will need to be updated.
|
||||
return {
|
||||
outer_key: {inner_key: S.UNKNOWN for inner_key in outer_value}
|
||||
for outer_key, outer_value in output.items()
|
||||
}
|
||||
|
||||
if entry["out_of_scope"]:
|
||||
# None is used to signify out-of-scope functions.
|
||||
return {
|
||||
outer_key: {inner_key: S.OUT_OF_SCOPE for inner_key in outer_value}
|
||||
for outer_key, outer_value in output.items()
|
||||
}
|
||||
|
||||
# For now, use _make_sphinx_capabilities because that's where
|
||||
# the relevant logic for determining what is and isn't
|
||||
# supported based on xp_capabilities_table entries lives.
|
||||
# Perhaps this logic should be decoupled from sphinx.
|
||||
for backend, capabilities in _make_sphinx_capabilities(**entry).items():
|
||||
if backend in {"array_api_strict", "numpy"}:
|
||||
continue
|
||||
backend = BACKEND_NAMES_MAP.get(backend, backend)
|
||||
cpu, gpu = capabilities.cpu, capabilities.gpu
|
||||
if cpu is not None:
|
||||
if backend not in output["cpu"]:
|
||||
raise ValueError(
|
||||
"Input capabilities table entry contains unhandled"
|
||||
f" backend {backend} on cpu."
|
||||
)
|
||||
output["cpu"][backend] = cpu
|
||||
if gpu is not None:
|
||||
if backend not in output["gpu"]:
|
||||
raise ValueError(
|
||||
"Input capabilities table entry contains unhandled"
|
||||
f" backend {backend} on gpu."
|
||||
)
|
||||
output["gpu"][backend] = gpu
|
||||
if backend == "jax":
|
||||
output["jit"]["jax"] = entry["jax_jit"] and output["cpu"]["jax"]
|
||||
if backend == "dask.array":
|
||||
support_lazy = not entry["allow_dask_compute"] and output["dask"]
|
||||
output["lazy"]["dask"] = support_lazy
|
||||
return {
|
||||
outer_key: {
|
||||
inner_key: S.YES if inner_value else S.NO
|
||||
for inner_key, inner_value in outer_value.items()
|
||||
}
|
||||
for outer_key, outer_value in output.items()
|
||||
}
|
||||
|
||||
|
||||
def is_named_function_like_object(obj):
|
||||
return (
|
||||
not isinstance(obj, ModuleType | type)
|
||||
and callable(obj) and hasattr(obj, "__name__")
|
||||
)
|
||||
|
||||
|
||||
def make_flat_capabilities_table(
|
||||
modules: str | list[str],
|
||||
backend_type: str,
|
||||
/,
|
||||
*,
|
||||
capabilities_table: list[str] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Generate full table of array api capabilities across public functions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
modules : str | list[str]
|
||||
A string containing single SciPy module, (e.g `scipy.stats`, `scipy.fft`)
|
||||
or a list of such strings.
|
||||
|
||||
backend_type : {'cpu', 'gpu', 'jit', 'lazy'}
|
||||
|
||||
capabilities_table : Optional[list[str]]
|
||||
Table in the form of `scipy._lib._array_api.xp_capabilities_table`.
|
||||
If None, uses `scipy._lib._array_api.xp_capabilities_table`.
|
||||
Default: None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : list[dict[str, str]]
|
||||
`output` is a table in dict format
|
||||
(keys corresponding to column names). The first column is "module".
|
||||
The other columns correspond to supported backends for the given
|
||||
`backend_type`, e.g. jax.numpy, torch, and dask on cpu.
|
||||
numpy is excluded because it should always be supported.
|
||||
See the helper function
|
||||
`_process_capabilities_table_entry` above).
|
||||
|
||||
"""
|
||||
if backend_type not in {"cpu", "gpu", "jit", "lazy"}:
|
||||
raise ValueError(f"Received unhandled backend type {backend_type}")
|
||||
|
||||
if isinstance(modules, str):
|
||||
modules = [modules]
|
||||
|
||||
if capabilities_table is None:
|
||||
capabilities_table = xp_capabilities_table
|
||||
|
||||
output = []
|
||||
|
||||
for module_name in modules:
|
||||
module = import_module(module_name)
|
||||
public_things = module.__all__
|
||||
for name in public_things:
|
||||
if name in ALIASES.get(module_name, {}):
|
||||
# Skip undocumented aliases that are kept
|
||||
# for backwards compatibility reasons.
|
||||
continue
|
||||
thing = getattr(module, name)
|
||||
if not is_named_function_like_object(thing):
|
||||
continue
|
||||
entry = xp_capabilities_table.get(thing, None)
|
||||
capabilities = _process_capabilities_table_entry(entry)[backend_type]
|
||||
row = {"module": module_name}
|
||||
row.update({"function": name})
|
||||
row.update(capabilities)
|
||||
output.append(row)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_table_statistics(
|
||||
flat_table: list[dict[str, str]]
|
||||
) -> dict[str, tuple[dict[str, str], bool]]:
|
||||
"""Get counts of what is supported per module.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flat_table : list[dict[str, str]]
|
||||
A table as returned by `make_flat_capabilities_table`
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, tuple[dict[str, str], bool]]
|
||||
dict mapping module names to 2-tuples containing an inner dict and a
|
||||
bool. The inner dicts have a key "total" along with keys for each
|
||||
backend column of the supplied flat capabilities table. The value
|
||||
corresponding to total is the total count of functions in the given
|
||||
module, and the value associated to the other keys is the count of
|
||||
functions that support that particular backend. The bool is False if
|
||||
the calculation may be innacurate due to missing xp_capabilities
|
||||
decorators, and True if all functions for that particular module have
|
||||
been decorated with xp_capabilities.
|
||||
"""
|
||||
if not flat_table:
|
||||
return []
|
||||
|
||||
counter = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
S = BackendSupportStatus
|
||||
# Keep track of which modules have functions with missing xp_capabilities
|
||||
# decorators so this information can be passed back to the caller.
|
||||
missing_xp_capabilities = set()
|
||||
for entry in flat_table:
|
||||
entry = entry.copy()
|
||||
entry.pop("function")
|
||||
module = entry.pop("module")
|
||||
current_counter = counter[module]
|
||||
|
||||
# By design, all backends and options must be considered out-of-scope
|
||||
# if one is, so just pick an arbitrary entry here to test if function is
|
||||
# in-scope.
|
||||
if next(iter(entry.values())) != S.OUT_OF_SCOPE:
|
||||
current_counter["total"] += 1
|
||||
for key, value in entry.items():
|
||||
# Functions missing xp_capabilities will be tabulated as
|
||||
# unsupported, but may actually be supported. There is a
|
||||
# note about this in the documentation and this function is
|
||||
# set up to return information needed to put asterisks next
|
||||
# to percentages impacted by missing xp_capabilities decorators.
|
||||
current_counter[key] += 1 if value == S.YES else 0
|
||||
if value == S.UNKNOWN:
|
||||
missing_xp_capabilities.add(module)
|
||||
return {
|
||||
key: (dict(value), key not in missing_xp_capabilities)
|
||||
for key, value in counter.items()
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Extra testing functions that forbid 0d-input, see #21044
|
||||
|
||||
While the xp_assert_* functions generally aim to follow the conventions of the
|
||||
underlying `xp` library, NumPy in particular is inconsistent in its handling
|
||||
of scalars vs. 0d-arrays, see https://github.com/numpy/numpy/issues/24897.
|
||||
|
||||
For example, this means that the following operations (as of v2.0.1) currently
|
||||
return scalars, even though a 0d-array would often be more appropriate:
|
||||
|
||||
import numpy as np
|
||||
np.array(0) * 2 # scalar, not 0d array
|
||||
- np.array(0) # scalar, not 0d-array
|
||||
np.sin(np.array(0)) # scalar, not 0d array
|
||||
np.mean([1, 2, 3]) # scalar, not 0d array
|
||||
|
||||
Libraries like CuPy tend to return a 0d-array in scenarios like those above,
|
||||
and even `xp.asarray(0)[()]` remains a 0d-array there. To deal with the reality
|
||||
of the inconsistencies present in NumPy, as well as 20+ years of code on top,
|
||||
the `xp_assert_*` functions here enforce consistency in the only way that
|
||||
doesn't go against the tide, i.e. by forbidding 0d-arrays as the return type.
|
||||
|
||||
However, when scalars are not generally the expected NumPy return type,
|
||||
it remains preferable to use the assert functions from
|
||||
the `scipy._lib._array_api` module, which have less surprising behaviour.
|
||||
"""
|
||||
from scipy._lib._array_api import array_namespace, is_numpy
|
||||
from scipy._lib._array_api import (xp_assert_close as xp_assert_close_base,
|
||||
xp_assert_equal as xp_assert_equal_base,
|
||||
xp_assert_less as xp_assert_less_base)
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
def _check_scalar(actual, desired, *, xp=None, **kwargs):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
|
||||
if xp is None:
|
||||
xp = array_namespace(actual)
|
||||
|
||||
# necessary to handle non-numpy scalars, e.g. bare `0.0` has no shape
|
||||
desired = xp.asarray(desired)
|
||||
|
||||
# Only NumPy distinguishes between scalars and arrays;
|
||||
# shape check in xp_assert_* is sufficient except for shape == ()
|
||||
if not (is_numpy(xp) and desired.shape == ()):
|
||||
return
|
||||
|
||||
_msg = ("Result is a NumPy 0d-array. Many SciPy functions intend to follow "
|
||||
"the convention of many NumPy functions, returning a scalar when a "
|
||||
"0d-array would be correct. The specialized `xp_assert_*` functions "
|
||||
"in the `scipy._lib._array_api_no_0d` module err on the side of "
|
||||
"caution and do not accept 0d-arrays by default. If the correct "
|
||||
"result may legitimately be a 0d-array, pass `check_0d=True`, "
|
||||
"or use the `xp_assert_*` functions from `scipy._lib._array_api`.")
|
||||
assert xp.isscalar(actual), _msg
|
||||
|
||||
|
||||
def xp_assert_equal(actual, desired, *, check_0d=False, **kwargs):
|
||||
# in contrast to xp_assert_equal_base, this defaults to check_0d=False,
|
||||
# but will do an extra check in that case, which forbids 0d-arrays for `actual`
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
|
||||
# array-ness (check_0d == True) is taken care of by the *_base functions
|
||||
if not check_0d:
|
||||
_check_scalar(actual, desired, **kwargs)
|
||||
return xp_assert_equal_base(actual, desired, check_0d=check_0d, **kwargs)
|
||||
|
||||
|
||||
def xp_assert_close(actual, desired, *, check_0d=False, **kwargs):
|
||||
# as for xp_assert_equal
|
||||
__tracebackhide__ = True
|
||||
|
||||
if not check_0d:
|
||||
_check_scalar(actual, desired, **kwargs)
|
||||
return xp_assert_close_base(actual, desired, check_0d=check_0d, **kwargs)
|
||||
|
||||
|
||||
def xp_assert_less(actual, desired, *, check_0d=False, **kwargs):
|
||||
# as for xp_assert_equal
|
||||
__tracebackhide__ = True
|
||||
|
||||
if not check_0d:
|
||||
_check_scalar(actual, desired, **kwargs)
|
||||
return xp_assert_less_base(actual, desired, check_0d=check_0d, **kwargs)
|
||||
|
||||
|
||||
def assert_array_almost_equal(actual, desired, decimal=6, *args, **kwds):
|
||||
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
|
||||
"""
|
||||
rtol, atol = 0, 1.5*10**(-decimal)
|
||||
return xp_assert_close(actual, desired,
|
||||
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
|
||||
*args, **kwds)
|
||||
|
||||
|
||||
def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
|
||||
"""Backwards compatible replacement. In new code, use xp_assert_close instead.
|
||||
"""
|
||||
rtol, atol = 0, 1.5*10**(-decimal)
|
||||
return xp_assert_close(actual, desired,
|
||||
atol=atol, rtol=rtol, check_dtype=False, check_shape=False,
|
||||
*args, **kwds)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Override functions from array_api_compat, for use by array-api-extra
|
||||
and internally.
|
||||
|
||||
See also _array_api_compat_vendor.py
|
||||
"""
|
||||
import enum
|
||||
import os
|
||||
|
||||
from functools import lru_cache
|
||||
from types import ModuleType
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from scipy._lib import array_api_compat
|
||||
import scipy._lib.array_api_compat.numpy as np_compat
|
||||
from scipy._lib.array_api_compat import is_array_api_obj, is_jax_array
|
||||
from scipy._lib._sparse import SparseABC
|
||||
|
||||
|
||||
Array: TypeAlias = Any # To be changed to a Protocol later (see array-api#589)
|
||||
ArrayLike: TypeAlias = Array | npt.ArrayLike
|
||||
|
||||
# To enable array API and strict array-like input validation
|
||||
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
|
||||
# To control the default device - for use in the test suite only
|
||||
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
|
||||
|
||||
|
||||
class _ArrayClsInfo(enum.Enum):
|
||||
skip = 0
|
||||
numpy = 1
|
||||
array_like = 2
|
||||
unknown = 3
|
||||
|
||||
|
||||
@lru_cache(100)
|
||||
def _validate_array_cls(cls: type) -> _ArrayClsInfo:
|
||||
if issubclass(cls, (list, tuple)):
|
||||
return _ArrayClsInfo.array_like
|
||||
|
||||
# this comes from `_util._asarray_validated`
|
||||
if issubclass(cls, SparseABC):
|
||||
msg = ('Sparse arrays/matrices are not supported by this function. '
|
||||
'Perhaps one of the `scipy.sparse.linalg` functions '
|
||||
'would work instead.')
|
||||
raise ValueError(msg)
|
||||
|
||||
if issubclass(cls, np.ma.MaskedArray):
|
||||
raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
|
||||
|
||||
if issubclass(cls, np.matrix):
|
||||
raise TypeError("Inputs of type `numpy.matrix` are not supported.")
|
||||
|
||||
if issubclass(cls, (np.ndarray, np.generic)):
|
||||
return _ArrayClsInfo.numpy
|
||||
|
||||
# Note: this must happen after the test for np.generic, because
|
||||
# np.float64 and np.complex128 are subclasses of float and complex respectively.
|
||||
# This matches the behavior of array_api_compat.
|
||||
if issubclass(cls, (int, float, complex, bool, type(None))):
|
||||
return _ArrayClsInfo.skip
|
||||
|
||||
return _ArrayClsInfo.unknown
|
||||
|
||||
|
||||
def array_namespace(*arrays: Array) -> ModuleType:
|
||||
"""Get the array API compatible namespace for the arrays xs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*arrays : sequence of array_like
|
||||
Arrays used to infer the common namespace.
|
||||
|
||||
Returns
|
||||
-------
|
||||
namespace : module
|
||||
Common namespace.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Wrapper around `array_api_compat.array_namespace`.
|
||||
|
||||
1. Check for the global switch `SCIPY_ARRAY_API`. If disabled, just
|
||||
return array_api_compat.numpy namespace and skip all compliance checks.
|
||||
|
||||
2. Check for known-bad array classes.
|
||||
The following subclasses are not supported and raise and error:
|
||||
|
||||
- `numpy.ma.MaskedArray`
|
||||
- `numpy.matrix`
|
||||
- NumPy arrays which do not have a boolean or numerical dtype
|
||||
- `scipy.sparse` arrays
|
||||
|
||||
3. Coerce array-likes to NumPy arrays and check their dtype.
|
||||
Note that non-scalar array-likes can't be mixed with non-NumPy Array
|
||||
API objects; e.g.
|
||||
|
||||
- `array_namespace([1, 2])` returns NumPy namespace;
|
||||
- `array_namespace(np.asarray([1, 2], [3, 4])` returns NumPy namespace;
|
||||
- `array_namespace(cp.asarray([1, 2], [3, 4])` raises an error.
|
||||
"""
|
||||
if not SCIPY_ARRAY_API:
|
||||
# here we could wrap the namespace if needed
|
||||
return np_compat
|
||||
|
||||
numpy_arrays = []
|
||||
api_arrays = []
|
||||
|
||||
for array in arrays:
|
||||
arr_info = _validate_array_cls(type(array))
|
||||
if arr_info is _ArrayClsInfo.skip:
|
||||
pass
|
||||
|
||||
elif arr_info is _ArrayClsInfo.numpy:
|
||||
if array.dtype.kind in 'iufcb': # Numeric or bool
|
||||
numpy_arrays.append(array)
|
||||
elif array.dtype.kind == 'V' and is_jax_array(array):
|
||||
# Special case for JAX zero gradient arrays;
|
||||
# see array_api_compat._common._helpers._is_jax_zero_gradient_array
|
||||
api_arrays.append(array) # JAX zero gradient array
|
||||
else:
|
||||
raise TypeError(f"An argument has dtype `{array.dtype!r}`; "
|
||||
"only boolean and numerical dtypes are supported.")
|
||||
|
||||
elif arr_info is _ArrayClsInfo.unknown and is_array_api_obj(array):
|
||||
api_arrays.append(array)
|
||||
|
||||
else:
|
||||
# list, tuple, or arbitrary object
|
||||
try:
|
||||
array = np.asanyarray(array)
|
||||
except TypeError:
|
||||
raise TypeError("An argument is neither array API compatible nor "
|
||||
"coercible by NumPy.")
|
||||
if array.dtype.kind not in 'iufcb': # Numeric or bool
|
||||
raise TypeError(f"An argument has dtype `{array.dtype!r}`; "
|
||||
"only boolean and numerical dtypes are supported.")
|
||||
numpy_arrays.append(array)
|
||||
|
||||
# When there are exclusively NumPy and ArrayLikes, skip calling
|
||||
# array_api_compat.array_namespace for performance.
|
||||
if not api_arrays:
|
||||
return np_compat
|
||||
|
||||
# In case of mix of NumPy/ArrayLike and non-NumPy Array API arrays,
|
||||
# let array_api_compat.array_namespace raise an error.
|
||||
return array_api_compat.array_namespace(*numpy_arrays, *api_arrays)
|
||||
@@ -0,0 +1,229 @@
|
||||
import sys as _sys
|
||||
from keyword import iskeyword as _iskeyword
|
||||
|
||||
|
||||
def _validate_names(typename, field_names, extra_field_names):
|
||||
"""
|
||||
Ensure that all the given names are valid Python identifiers that
|
||||
do not start with '_'. Also check that there are no duplicates
|
||||
among field_names + extra_field_names.
|
||||
"""
|
||||
for name in [typename] + field_names + extra_field_names:
|
||||
if not isinstance(name, str):
|
||||
raise TypeError('typename and all field names must be strings')
|
||||
if not name.isidentifier():
|
||||
raise ValueError('typename and all field names must be valid '
|
||||
f'identifiers: {name!r}')
|
||||
if _iskeyword(name):
|
||||
raise ValueError('typename and all field names cannot be a '
|
||||
f'keyword: {name!r}')
|
||||
|
||||
seen = set()
|
||||
for name in field_names + extra_field_names:
|
||||
if name.startswith('_'):
|
||||
raise ValueError('Field names cannot start with an underscore: '
|
||||
f'{name!r}')
|
||||
if name in seen:
|
||||
raise ValueError(f'Duplicate field name: {name!r}')
|
||||
seen.add(name)
|
||||
|
||||
|
||||
# Note: This code is adapted from CPython:Lib/collections/__init__.py
|
||||
def _make_tuple_bunch(typename, field_names, extra_field_names=None,
|
||||
module=None):
|
||||
"""
|
||||
Create a namedtuple-like class with additional attributes.
|
||||
|
||||
This function creates a subclass of tuple that acts like a namedtuple
|
||||
and that has additional attributes.
|
||||
|
||||
The additional attributes are listed in `extra_field_names`. The
|
||||
values assigned to these attributes are not part of the tuple.
|
||||
|
||||
The reason this function exists is to allow functions in SciPy
|
||||
that currently return a tuple or a namedtuple to returned objects
|
||||
that have additional attributes, while maintaining backwards
|
||||
compatibility.
|
||||
|
||||
This should only be used to enhance *existing* functions in SciPy.
|
||||
New functions are free to create objects as return values without
|
||||
having to maintain backwards compatibility with an old tuple or
|
||||
namedtuple return value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
typename : str
|
||||
The name of the type.
|
||||
field_names : list of str
|
||||
List of names of the values to be stored in the tuple. These names
|
||||
will also be attributes of instances, so the values in the tuple
|
||||
can be accessed by indexing or as attributes. At least one name
|
||||
is required. See the Notes for additional restrictions.
|
||||
extra_field_names : list of str, optional
|
||||
List of names of values that will be stored as attributes of the
|
||||
object. See the notes for additional restrictions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cls : type
|
||||
The new class.
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are restrictions on the names that may be used in `field_names`
|
||||
and `extra_field_names`:
|
||||
|
||||
* The names must be unique--no duplicates allowed.
|
||||
* The names must be valid Python identifiers, and must not begin with
|
||||
an underscore.
|
||||
* The names must not be Python keywords (e.g. 'def', 'and', etc., are
|
||||
not allowed).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy._lib._bunch import _make_tuple_bunch
|
||||
|
||||
Create a class that acts like a namedtuple with length 2 (with field
|
||||
names `x` and `y`) that will also have the attributes `w` and `beta`:
|
||||
|
||||
>>> Result = _make_tuple_bunch('Result', ['x', 'y'], ['w', 'beta'])
|
||||
|
||||
`Result` is the new class. We call it with keyword arguments to create
|
||||
a new instance with given values.
|
||||
|
||||
>>> result1 = Result(x=1, y=2, w=99, beta=0.5)
|
||||
>>> result1
|
||||
Result(x=1, y=2, w=99, beta=0.5)
|
||||
|
||||
`result1` acts like a tuple of length 2:
|
||||
|
||||
>>> len(result1)
|
||||
2
|
||||
>>> result1[:]
|
||||
(1, 2)
|
||||
|
||||
The values assigned when the instance was created are available as
|
||||
attributes:
|
||||
|
||||
>>> result1.y
|
||||
2
|
||||
>>> result1.beta
|
||||
0.5
|
||||
"""
|
||||
if len(field_names) == 0:
|
||||
raise ValueError('field_names must contain at least one name')
|
||||
|
||||
if extra_field_names is None:
|
||||
extra_field_names = []
|
||||
_validate_names(typename, field_names, extra_field_names)
|
||||
|
||||
typename = _sys.intern(str(typename))
|
||||
field_names = tuple(map(_sys.intern, field_names))
|
||||
extra_field_names = tuple(map(_sys.intern, extra_field_names))
|
||||
|
||||
all_names = field_names + extra_field_names
|
||||
arg_list = ', '.join(field_names)
|
||||
full_list = ', '.join(all_names)
|
||||
repr_fmt = ''.join(('(',
|
||||
', '.join(f'{name}=%({name})r' for name in all_names),
|
||||
')'))
|
||||
tuple_new = tuple.__new__
|
||||
_dict, _tuple, _zip = dict, tuple, zip
|
||||
|
||||
# Create all the named tuple methods to be added to the class namespace
|
||||
|
||||
s = f"""\
|
||||
def __new__(_cls, {arg_list}, **extra_fields):
|
||||
return _tuple_new(_cls, ({arg_list},))
|
||||
|
||||
def __init__(self, {arg_list}, **extra_fields):
|
||||
for key in self._extra_fields:
|
||||
if key not in extra_fields:
|
||||
raise TypeError("missing keyword argument '%s'" % (key,))
|
||||
for key, val in extra_fields.items():
|
||||
if key not in self._extra_fields:
|
||||
raise TypeError("unexpected keyword argument '%s'" % (key,))
|
||||
self.__dict__[key] = val
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
if key in {repr(field_names)}:
|
||||
raise AttributeError("can't set attribute %r of class %r"
|
||||
% (key, self.__class__.__name__))
|
||||
else:
|
||||
self.__dict__[key] = val
|
||||
"""
|
||||
del arg_list
|
||||
namespace = {'_tuple_new': tuple_new,
|
||||
'__builtins__': dict(TypeError=TypeError,
|
||||
AttributeError=AttributeError),
|
||||
'__name__': f'namedtuple_{typename}'}
|
||||
exec(s, namespace)
|
||||
__new__ = namespace['__new__']
|
||||
__new__.__doc__ = f'Create new instance of {typename}({full_list})'
|
||||
__init__ = namespace['__init__']
|
||||
__init__.__doc__ = f'Instantiate instance of {typename}({full_list})'
|
||||
__setattr__ = namespace['__setattr__']
|
||||
|
||||
def __repr__(self):
|
||||
'Return a nicely formatted representation string'
|
||||
return self.__class__.__name__ + repr_fmt % self._asdict()
|
||||
|
||||
def _asdict(self):
|
||||
'Return a new dict which maps field names to their values.'
|
||||
out = _dict(_zip(self._fields, self))
|
||||
out.update(self.__dict__)
|
||||
return out
|
||||
|
||||
def __getnewargs_ex__(self):
|
||||
'Return self as a plain tuple. Used by copy and pickle.'
|
||||
return _tuple(self), self.__dict__
|
||||
|
||||
# Modify function metadata to help with introspection and debugging
|
||||
for method in (__new__, __repr__, _asdict, __getnewargs_ex__):
|
||||
method.__qualname__ = f'{typename}.{method.__name__}'
|
||||
|
||||
# Build-up the class namespace dictionary
|
||||
# and use type() to build the result class
|
||||
class_namespace = {
|
||||
'__doc__': f'{typename}({full_list})',
|
||||
'_fields': field_names,
|
||||
'__new__': __new__,
|
||||
'__init__': __init__,
|
||||
'__repr__': __repr__,
|
||||
'__setattr__': __setattr__,
|
||||
'_asdict': _asdict,
|
||||
'_extra_fields': extra_field_names,
|
||||
'__getnewargs_ex__': __getnewargs_ex__,
|
||||
# _field_defaults and _replace are added to get Polars to detect
|
||||
# a bunch object as a namedtuple. See gh-22450
|
||||
'_field_defaults': {},
|
||||
'_replace': None,
|
||||
}
|
||||
for index, name in enumerate(field_names):
|
||||
|
||||
def _get(self, index=index):
|
||||
return self[index]
|
||||
class_namespace[name] = property(_get)
|
||||
for name in extra_field_names:
|
||||
|
||||
def _get(self, name=name):
|
||||
return self.__dict__[name]
|
||||
class_namespace[name] = property(_get)
|
||||
|
||||
result = type(typename, (tuple,), class_namespace)
|
||||
|
||||
# For pickling to work, the __module__ variable needs to be set to the
|
||||
# frame where the named tuple is created. Bypass this step in environments
|
||||
# where sys._getframe is not defined (Jython for example) or sys._getframe
|
||||
# is not defined for arguments greater than 0 (IronPython), or where the
|
||||
# user has specified a particular module.
|
||||
if module is None:
|
||||
try:
|
||||
module = _sys._getframe(1).f_globals.get('__name__', '__main__')
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
if module is not None:
|
||||
result.__module__ = module
|
||||
__new__.__module__ = module
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,251 @@
|
||||
from . import _ccallback_c
|
||||
|
||||
import ctypes
|
||||
|
||||
PyCFuncPtr = ctypes.CFUNCTYPE(ctypes.c_void_p).__bases__[0]
|
||||
|
||||
ffi = None
|
||||
|
||||
class CData:
|
||||
pass
|
||||
|
||||
def _import_cffi():
|
||||
global ffi, CData
|
||||
|
||||
if ffi is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import cffi
|
||||
ffi = cffi.FFI()
|
||||
CData = ffi.CData
|
||||
except ImportError:
|
||||
ffi = False
|
||||
|
||||
|
||||
class LowLevelCallable(tuple):
|
||||
"""
|
||||
Low-level callback function.
|
||||
|
||||
Some functions in SciPy take as arguments callback functions, which
|
||||
can either be python callables or low-level compiled functions. Using
|
||||
compiled callback functions can improve performance somewhat by
|
||||
avoiding wrapping data in Python objects.
|
||||
|
||||
Such low-level functions in SciPy are wrapped in `LowLevelCallable`
|
||||
objects, which can be constructed from function pointers obtained from
|
||||
ctypes, cffi, Cython, or contained in Python `PyCapsule` objects.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Functions accepting low-level callables:
|
||||
|
||||
`scipy.integrate.quad`, `scipy.ndimage.generic_filter`,
|
||||
`scipy.ndimage.generic_filter1d`, `scipy.ndimage.geometric_transform`
|
||||
|
||||
Usage examples:
|
||||
|
||||
:ref:`ndimage-ccallbacks`, :ref:`quad-callbacks`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
function : {PyCapsule, ctypes function pointer, cffi function pointer}
|
||||
Low-level callback function.
|
||||
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}
|
||||
User data to pass on to the callback function.
|
||||
signature : str, optional
|
||||
Signature of the function. If omitted, determined from *function*,
|
||||
if possible.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
function
|
||||
Callback function given.
|
||||
user_data
|
||||
User data given.
|
||||
signature
|
||||
Signature of the function.
|
||||
|
||||
Methods
|
||||
-------
|
||||
from_cython
|
||||
Class method for constructing callables from Cython C-exported
|
||||
functions.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The argument ``function`` can be one of:
|
||||
|
||||
- PyCapsule, whose name contains the C function signature
|
||||
- ctypes function pointer
|
||||
- cffi function pointer
|
||||
|
||||
The signature of the low-level callback must match one of those expected
|
||||
by the routine it is passed to.
|
||||
|
||||
If constructing low-level functions from a PyCapsule, the name of the
|
||||
capsule must be the corresponding signature, in the format::
|
||||
|
||||
return_type (arg1_type, arg2_type, ...)
|
||||
|
||||
For example::
|
||||
|
||||
"void (double)"
|
||||
"double (double, int *, void *)"
|
||||
|
||||
The context of a PyCapsule passed in as ``function`` is used as ``user_data``,
|
||||
if an explicit value for ``user_data`` was not given.
|
||||
|
||||
"""
|
||||
|
||||
# Make the class immutable
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, function, user_data=None, signature=None):
|
||||
# We need to hold a reference to the function & user data,
|
||||
# to prevent them going out of scope
|
||||
item = cls._parse_callback(function, user_data, signature)
|
||||
return tuple.__new__(cls, (item, function, user_data))
|
||||
|
||||
def __repr__(self):
|
||||
return f"LowLevelCallable({self.function!r}, {self.user_data!r})"
|
||||
|
||||
@property
|
||||
def function(self):
|
||||
return tuple.__getitem__(self, 1)
|
||||
|
||||
@property
|
||||
def user_data(self):
|
||||
return tuple.__getitem__(self, 2)
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
return _ccallback_c.get_capsule_signature(tuple.__getitem__(self, 0))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
raise ValueError()
|
||||
|
||||
@classmethod
|
||||
def from_cython(cls, module, name, user_data=None, signature=None):
|
||||
"""
|
||||
Create a low-level callback function from an exported Cython function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : module
|
||||
Cython module where the exported function resides
|
||||
name : str
|
||||
Name of the exported function
|
||||
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}, optional
|
||||
User data to pass on to the callback function.
|
||||
signature : str, optional
|
||||
Signature of the function. If omitted, determined from *function*.
|
||||
|
||||
"""
|
||||
try:
|
||||
function = module.__pyx_capi__[name]
|
||||
except AttributeError as e:
|
||||
message = "Given module is not a Cython module with __pyx_capi__ attribute"
|
||||
raise ValueError(message) from e
|
||||
except KeyError as e:
|
||||
message = f"No function {name!r} found in __pyx_capi__ of the module"
|
||||
raise ValueError(message) from e
|
||||
return cls(function, user_data, signature)
|
||||
|
||||
@classmethod
|
||||
def _parse_callback(cls, obj, user_data=None, signature=None):
|
||||
_import_cffi()
|
||||
|
||||
if isinstance(obj, LowLevelCallable):
|
||||
func = tuple.__getitem__(obj, 0)
|
||||
elif isinstance(obj, PyCFuncPtr):
|
||||
func, signature = _get_ctypes_func(obj, signature)
|
||||
elif isinstance(obj, CData):
|
||||
func, signature = _get_cffi_func(obj, signature)
|
||||
elif _ccallback_c.check_capsule(obj):
|
||||
func = obj
|
||||
else:
|
||||
raise ValueError("Given input is not a callable or a "
|
||||
"low-level callable (pycapsule/ctypes/cffi)")
|
||||
|
||||
if isinstance(user_data, ctypes.c_void_p):
|
||||
context = _get_ctypes_data(user_data)
|
||||
elif isinstance(user_data, CData):
|
||||
context = _get_cffi_data(user_data)
|
||||
elif user_data is None:
|
||||
context = 0
|
||||
elif _ccallback_c.check_capsule(user_data):
|
||||
context = user_data
|
||||
else:
|
||||
raise ValueError("Given user data is not a valid "
|
||||
"low-level void* pointer (pycapsule/ctypes/cffi)")
|
||||
|
||||
return _ccallback_c.get_raw_capsule(func, signature, context)
|
||||
|
||||
|
||||
#
|
||||
# ctypes helpers
|
||||
#
|
||||
|
||||
def _get_ctypes_func(func, signature=None):
|
||||
# Get function pointer
|
||||
func_ptr = ctypes.cast(func, ctypes.c_void_p).value
|
||||
|
||||
# Construct function signature
|
||||
if signature is None:
|
||||
signature = _typename_from_ctypes(func.restype) + " ("
|
||||
for j, arg in enumerate(func.argtypes):
|
||||
if j == 0:
|
||||
signature += _typename_from_ctypes(arg)
|
||||
else:
|
||||
signature += ", " + _typename_from_ctypes(arg)
|
||||
signature += ")"
|
||||
|
||||
return func_ptr, signature
|
||||
|
||||
|
||||
def _typename_from_ctypes(item):
|
||||
if item is None:
|
||||
return "void"
|
||||
elif item is ctypes.c_void_p:
|
||||
return "void *"
|
||||
|
||||
name = item.__name__
|
||||
|
||||
pointer_level = 0
|
||||
while name.startswith("LP_"):
|
||||
pointer_level += 1
|
||||
name = name[3:]
|
||||
|
||||
if name.startswith('c_'):
|
||||
name = name[2:]
|
||||
|
||||
if pointer_level > 0:
|
||||
name += " " + "*"*pointer_level
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def _get_ctypes_data(data):
|
||||
# Get voidp pointer
|
||||
return ctypes.cast(data, ctypes.c_void_p).value
|
||||
|
||||
|
||||
#
|
||||
# CFFI helpers
|
||||
#
|
||||
|
||||
def _get_cffi_func(func, signature=None):
|
||||
# Get function pointer
|
||||
func_ptr = ffi.cast('uintptr_t', func)
|
||||
|
||||
# Get signature
|
||||
if signature is None:
|
||||
signature = ffi.getctype(ffi.typeof(func)).replace('(*)', ' ')
|
||||
|
||||
return func_ptr, signature
|
||||
|
||||
|
||||
def _get_cffi_data(data):
|
||||
# Get pointer
|
||||
return ffi.cast('uintptr_t', data)
|
||||
Binary file not shown.
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Disjoint set data structure
|
||||
"""
|
||||
|
||||
|
||||
class DisjointSet:
|
||||
""" Disjoint set data structure for incremental connectivity queries.
|
||||
|
||||
.. versionadded:: 1.6.0
|
||||
|
||||
Attributes
|
||||
----------
|
||||
n_subsets : int
|
||||
The number of subsets.
|
||||
|
||||
Methods
|
||||
-------
|
||||
add
|
||||
merge
|
||||
connected
|
||||
subset
|
||||
subset_size
|
||||
subsets
|
||||
__getitem__
|
||||
|
||||
Notes
|
||||
-----
|
||||
This class implements the disjoint set [1]_, also known as the *union-find*
|
||||
or *merge-find* data structure. The *find* operation (implemented in
|
||||
`__getitem__`) implements the *path halving* variant. The *merge* method
|
||||
implements the *merge by size* variant.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy.cluster.hierarchy import DisjointSet
|
||||
|
||||
Initialize a disjoint set:
|
||||
|
||||
>>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
|
||||
|
||||
Merge some subsets:
|
||||
|
||||
>>> disjoint_set.merge(1, 2)
|
||||
True
|
||||
>>> disjoint_set.merge(3, 'a')
|
||||
True
|
||||
>>> disjoint_set.merge('a', 'b')
|
||||
True
|
||||
>>> disjoint_set.merge('b', 'b')
|
||||
False
|
||||
|
||||
Find root elements:
|
||||
|
||||
>>> disjoint_set[2]
|
||||
1
|
||||
>>> disjoint_set['b']
|
||||
3
|
||||
|
||||
Test connectivity:
|
||||
|
||||
>>> disjoint_set.connected(1, 2)
|
||||
True
|
||||
>>> disjoint_set.connected(1, 'b')
|
||||
False
|
||||
|
||||
List elements in disjoint set:
|
||||
|
||||
>>> list(disjoint_set)
|
||||
[1, 2, 3, 'a', 'b']
|
||||
|
||||
Get the subset containing 'a':
|
||||
|
||||
>>> disjoint_set.subset('a')
|
||||
{'a', 3, 'b'}
|
||||
|
||||
Get the size of the subset containing 'a' (without actually instantiating
|
||||
the subset):
|
||||
|
||||
>>> disjoint_set.subset_size('a')
|
||||
3
|
||||
|
||||
Get all subsets in the disjoint set:
|
||||
|
||||
>>> disjoint_set.subsets()
|
||||
[{1, 2}, {'a', 3, 'b'}]
|
||||
"""
|
||||
def __init__(self, elements=None):
|
||||
self.n_subsets = 0
|
||||
self._sizes = {}
|
||||
self._parents = {}
|
||||
# _nbrs is a circular linked list which links connected elements.
|
||||
self._nbrs = {}
|
||||
# _indices tracks the element insertion order in `__iter__`.
|
||||
self._indices = {}
|
||||
if elements is not None:
|
||||
for x in elements:
|
||||
self.add(x)
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator of the elements in the disjoint set.
|
||||
|
||||
Elements are ordered by insertion order.
|
||||
"""
|
||||
return iter(self._indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._indices)
|
||||
|
||||
def __contains__(self, x):
|
||||
return x in self._indices
|
||||
|
||||
def __getitem__(self, x):
|
||||
"""Find the root element of `x`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : hashable object
|
||||
Input element.
|
||||
|
||||
Returns
|
||||
-------
|
||||
root : hashable object
|
||||
Root element of `x`.
|
||||
"""
|
||||
if x not in self._indices:
|
||||
raise KeyError(x)
|
||||
|
||||
# find by "path halving"
|
||||
parents = self._parents
|
||||
while self._indices[x] != self._indices[parents[x]]:
|
||||
parents[x] = parents[parents[x]]
|
||||
x = parents[x]
|
||||
return x
|
||||
|
||||
def add(self, x):
|
||||
"""Add element `x` to disjoint set
|
||||
"""
|
||||
if x in self._indices:
|
||||
return
|
||||
|
||||
self._sizes[x] = 1
|
||||
self._parents[x] = x
|
||||
self._nbrs[x] = x
|
||||
self._indices[x] = len(self._indices)
|
||||
self.n_subsets += 1
|
||||
|
||||
def merge(self, x, y):
|
||||
"""Merge the subsets of `x` and `y`.
|
||||
|
||||
The smaller subset (the child) is merged into the larger subset (the
|
||||
parent). If the subsets are of equal size, the root element which was
|
||||
first inserted into the disjoint set is selected as the parent.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x, y : hashable object
|
||||
Elements to merge.
|
||||
|
||||
Returns
|
||||
-------
|
||||
merged : bool
|
||||
True if `x` and `y` were in disjoint sets, False otherwise.
|
||||
"""
|
||||
xr = self[x]
|
||||
yr = self[y]
|
||||
if self._indices[xr] == self._indices[yr]:
|
||||
return False
|
||||
|
||||
sizes = self._sizes
|
||||
if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
|
||||
xr, yr = yr, xr
|
||||
self._parents[yr] = xr
|
||||
self._sizes[xr] += self._sizes[yr]
|
||||
self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
|
||||
self.n_subsets -= 1
|
||||
return True
|
||||
|
||||
def connected(self, x, y):
|
||||
"""Test whether `x` and `y` are in the same subset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x, y : hashable object
|
||||
Elements to test.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : bool
|
||||
True if `x` and `y` are in the same set, False otherwise.
|
||||
"""
|
||||
return self._indices[self[x]] == self._indices[self[y]]
|
||||
|
||||
def subset(self, x):
|
||||
"""Get the subset containing `x`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : hashable object
|
||||
Input element.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : set
|
||||
Subset containing `x`.
|
||||
"""
|
||||
if x not in self._indices:
|
||||
raise KeyError(x)
|
||||
|
||||
result = [x]
|
||||
nxt = self._nbrs[x]
|
||||
while self._indices[nxt] != self._indices[x]:
|
||||
result.append(nxt)
|
||||
nxt = self._nbrs[nxt]
|
||||
return set(result)
|
||||
|
||||
def subset_size(self, x):
|
||||
"""Get the size of the subset containing `x`.
|
||||
|
||||
Note that this method is faster than ``len(self.subset(x))`` because
|
||||
the size is directly read off an internal field, without the need to
|
||||
instantiate the full subset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : hashable object
|
||||
Input element.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : int
|
||||
Size of the subset containing `x`.
|
||||
"""
|
||||
return self._sizes[self[x]]
|
||||
|
||||
def subsets(self):
|
||||
"""Get all the subsets in the disjoint set.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : list
|
||||
Subsets in the disjoint set.
|
||||
"""
|
||||
result = []
|
||||
visited = set()
|
||||
for x in self:
|
||||
if x not in visited:
|
||||
xset = self.subset(x)
|
||||
visited.update(xset)
|
||||
result.append(xset)
|
||||
return result
|
||||
@@ -0,0 +1,761 @@
|
||||
# copied from numpydoc/docscrape.py, commit 97a6026508e0dd5382865672e9563a72cc113bd2
|
||||
"""Extract reference documentation from the NumPy source tree."""
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import pydoc
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import cached_property
|
||||
from warnings import warn
|
||||
|
||||
|
||||
def strip_blank_lines(l):
|
||||
"Remove leading and trailing blank lines from a list of lines"
|
||||
while l and not l[0].strip():
|
||||
del l[0]
|
||||
while l and not l[-1].strip():
|
||||
del l[-1]
|
||||
return l
|
||||
|
||||
|
||||
class Reader:
|
||||
"""A line-based string reader."""
|
||||
|
||||
def __init__(self, data):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data : str
|
||||
String with lines separated by '\\n'.
|
||||
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
self._str = data
|
||||
else:
|
||||
self._str = data.split("\n") # store string as list of lines
|
||||
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, n):
|
||||
return self._str[n]
|
||||
|
||||
def reset(self):
|
||||
self._l = 0 # current line nr
|
||||
|
||||
def read(self):
|
||||
if not self.eof():
|
||||
out = self[self._l]
|
||||
self._l += 1
|
||||
return out
|
||||
else:
|
||||
return ""
|
||||
|
||||
def seek_next_non_empty_line(self):
|
||||
for l in self[self._l :]:
|
||||
if l.strip():
|
||||
break
|
||||
else:
|
||||
self._l += 1
|
||||
|
||||
def eof(self):
|
||||
return self._l >= len(self._str)
|
||||
|
||||
def read_to_condition(self, condition_func):
|
||||
start = self._l
|
||||
for line in self[start:]:
|
||||
if condition_func(line):
|
||||
return self[start : self._l]
|
||||
self._l += 1
|
||||
if self.eof():
|
||||
return self[start : self._l + 1]
|
||||
return []
|
||||
|
||||
def read_to_next_empty_line(self):
|
||||
self.seek_next_non_empty_line()
|
||||
|
||||
def is_empty(line):
|
||||
return not line.strip()
|
||||
|
||||
return self.read_to_condition(is_empty)
|
||||
|
||||
def read_to_next_unindented_line(self):
|
||||
def is_unindented(line):
|
||||
return line.strip() and (len(line.lstrip()) == len(line))
|
||||
|
||||
return self.read_to_condition(is_unindented)
|
||||
|
||||
def peek(self, n=0):
|
||||
if self._l + n < len(self._str):
|
||||
return self[self._l + n]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def is_empty(self):
|
||||
return not "".join(self._str).strip()
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
def __str__(self):
|
||||
message = self.args[0]
|
||||
if hasattr(self, "docstring"):
|
||||
message = f"{message} in {self.docstring!r}"
|
||||
return message
|
||||
|
||||
|
||||
Parameter = namedtuple("Parameter", ["name", "type", "desc"])
|
||||
|
||||
|
||||
class NumpyDocString(Mapping):
|
||||
"""Parses a numpydoc string to an abstract representation
|
||||
|
||||
Instances define a mapping from section title to structured data.
|
||||
|
||||
"""
|
||||
|
||||
sections = {
|
||||
"Signature": "",
|
||||
"Summary": [""],
|
||||
"Extended Summary": [],
|
||||
"Parameters": [],
|
||||
"Attributes": [],
|
||||
"Methods": [],
|
||||
"Returns": [],
|
||||
"Yields": [],
|
||||
"Receives": [],
|
||||
"Other Parameters": [],
|
||||
"Raises": [],
|
||||
"Warns": [],
|
||||
"Warnings": [],
|
||||
"See Also": [],
|
||||
"Notes": [],
|
||||
"References": "",
|
||||
"Examples": "",
|
||||
"index": {},
|
||||
}
|
||||
|
||||
def __init__(self, docstring, config=None):
|
||||
orig_docstring = docstring
|
||||
docstring = textwrap.dedent(docstring).split("\n")
|
||||
|
||||
self._doc = Reader(docstring)
|
||||
self._parsed_data = copy.deepcopy(self.sections)
|
||||
|
||||
try:
|
||||
self._parse()
|
||||
except ParseError as e:
|
||||
e.docstring = orig_docstring
|
||||
raise
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._parsed_data[key]
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
if key not in self._parsed_data:
|
||||
self._error_location(f"Unknown section {key}", error=False)
|
||||
else:
|
||||
self._parsed_data[key] = val
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._parsed_data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._parsed_data)
|
||||
|
||||
def _is_at_section(self):
|
||||
self._doc.seek_next_non_empty_line()
|
||||
|
||||
if self._doc.eof():
|
||||
return False
|
||||
|
||||
l1 = self._doc.peek().strip() # e.g. Parameters
|
||||
|
||||
if l1.startswith(".. index::"):
|
||||
return True
|
||||
|
||||
l2 = self._doc.peek(1).strip() # ---------- or ==========
|
||||
if len(l2) >= 3 and (set(l2) in ({"-"}, {"="})) and len(l2) != len(l1):
|
||||
snip = "\n".join(self._doc._str[:2]) + "..."
|
||||
self._error_location(
|
||||
f"potentially wrong underline length... \n{l1} \n{l2} in \n{snip}",
|
||||
error=False,
|
||||
)
|
||||
return l2.startswith("-" * len(l1)) or l2.startswith("=" * len(l1))
|
||||
|
||||
def _strip(self, doc):
|
||||
i = 0
|
||||
j = 0
|
||||
for i, line in enumerate(doc):
|
||||
if line.strip():
|
||||
break
|
||||
|
||||
for j, line in enumerate(doc[::-1]):
|
||||
if line.strip():
|
||||
break
|
||||
|
||||
return doc[i : len(doc) - j]
|
||||
|
||||
def _read_to_next_section(self):
|
||||
section = self._doc.read_to_next_empty_line()
|
||||
|
||||
while not self._is_at_section() and not self._doc.eof():
|
||||
if not self._doc.peek(-1).strip(): # previous line was empty
|
||||
section += [""]
|
||||
|
||||
section += self._doc.read_to_next_empty_line()
|
||||
|
||||
return section
|
||||
|
||||
def _read_sections(self):
|
||||
while not self._doc.eof():
|
||||
data = self._read_to_next_section()
|
||||
name = data[0].strip()
|
||||
|
||||
if name.startswith(".."): # index section
|
||||
yield name, data[1:]
|
||||
elif len(data) < 2:
|
||||
yield StopIteration
|
||||
else:
|
||||
yield name, self._strip(data[2:])
|
||||
|
||||
def _parse_param_list(self, content, single_element_is_type=False):
|
||||
content = dedent_lines(content)
|
||||
r = Reader(content)
|
||||
params = []
|
||||
while not r.eof():
|
||||
header = r.read().strip()
|
||||
if " : " in header:
|
||||
arg_name, arg_type = header.split(" : ", maxsplit=1)
|
||||
else:
|
||||
# NOTE: param line with single element should never have a
|
||||
# a " :" before the description line, so this should probably
|
||||
# warn.
|
||||
if header.endswith(" :"):
|
||||
header = header[:-2]
|
||||
if single_element_is_type:
|
||||
arg_name, arg_type = "", header
|
||||
else:
|
||||
arg_name, arg_type = header, ""
|
||||
|
||||
desc = r.read_to_next_unindented_line()
|
||||
desc = dedent_lines(desc)
|
||||
desc = strip_blank_lines(desc)
|
||||
|
||||
params.append(Parameter(arg_name, arg_type, desc))
|
||||
|
||||
return params
|
||||
|
||||
# See also supports the following formats.
|
||||
#
|
||||
# <FUNCNAME>
|
||||
# <FUNCNAME> SPACE* COLON SPACE+ <DESC> SPACE*
|
||||
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)+ (COMMA | PERIOD)? SPACE*
|
||||
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)* SPACE* COLON SPACE+ <DESC> SPACE*
|
||||
|
||||
# <FUNCNAME> is one of
|
||||
# <PLAIN_FUNCNAME>
|
||||
# COLON <ROLE> COLON BACKTICK <PLAIN_FUNCNAME> BACKTICK
|
||||
# where
|
||||
# <PLAIN_FUNCNAME> is a legal function name, and
|
||||
# <ROLE> is any nonempty sequence of word characters.
|
||||
# Examples: func_f1 :meth:`func_h1` :obj:`~baz.obj_r` :class:`class_j`
|
||||
# <DESC> is a string describing the function.
|
||||
|
||||
_role = r":(?P<role>(py:)?\w+):"
|
||||
_funcbacktick = r"`(?P<name>(?:~\w+\.)?[a-zA-Z0-9_\.-]+)`"
|
||||
_funcplain = r"(?P<name2>[a-zA-Z0-9_\.-]+)"
|
||||
_funcname = r"(" + _role + _funcbacktick + r"|" + _funcplain + r")"
|
||||
_funcnamenext = _funcname.replace("role", "rolenext")
|
||||
_funcnamenext = _funcnamenext.replace("name", "namenext")
|
||||
_description = r"(?P<description>\s*:(\s+(?P<desc>\S+.*))?)?\s*$"
|
||||
_func_rgx = re.compile(r"^\s*" + _funcname + r"\s*")
|
||||
_line_rgx = re.compile(
|
||||
r"^\s*"
|
||||
+ r"(?P<allfuncs>"
|
||||
+ _funcname # group for all function names
|
||||
+ r"(?P<morefuncs>([,]\s+"
|
||||
+ _funcnamenext
|
||||
+ r")*)"
|
||||
+ r")"
|
||||
+ r"(?P<trailing>[,\.])?" # end of "allfuncs"
|
||||
+ _description # Some function lists have a trailing comma (or period) '\s*'
|
||||
)
|
||||
|
||||
# Empty <DESC> elements are replaced with '..'
|
||||
empty_description = ".."
|
||||
|
||||
def _parse_see_also(self, content):
|
||||
"""
|
||||
func_name : Descriptive text
|
||||
continued text
|
||||
another_func_name : Descriptive text
|
||||
func_name1, func_name2, :meth:`func_name`, func_name3
|
||||
|
||||
"""
|
||||
|
||||
content = dedent_lines(content)
|
||||
|
||||
items = []
|
||||
|
||||
def parse_item_name(text):
|
||||
"""Match ':role:`name`' or 'name'."""
|
||||
m = self._func_rgx.match(text)
|
||||
if not m:
|
||||
self._error_location(f"Error parsing See Also entry {line!r}")
|
||||
role = m.group("role")
|
||||
name = m.group("name") if role else m.group("name2")
|
||||
return name, role, m.end()
|
||||
|
||||
rest = []
|
||||
for line in content:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
line_match = self._line_rgx.match(line)
|
||||
description = None
|
||||
if line_match:
|
||||
description = line_match.group("desc")
|
||||
if line_match.group("trailing") and description:
|
||||
self._error_location(
|
||||
"Unexpected comma or period after function list at index %d of "
|
||||
'line "%s"' % (line_match.end("trailing"), line),
|
||||
error=False,
|
||||
)
|
||||
if not description and line.startswith(" "):
|
||||
rest.append(line.strip())
|
||||
elif line_match:
|
||||
funcs = []
|
||||
text = line_match.group("allfuncs")
|
||||
while True:
|
||||
if not text.strip():
|
||||
break
|
||||
name, role, match_end = parse_item_name(text)
|
||||
funcs.append((name, role))
|
||||
text = text[match_end:].strip()
|
||||
if text and text[0] == ",":
|
||||
text = text[1:].strip()
|
||||
rest = list(filter(None, [description]))
|
||||
items.append((funcs, rest))
|
||||
else:
|
||||
self._error_location(f"Error parsing See Also entry {line!r}")
|
||||
return items
|
||||
|
||||
def _parse_index(self, section, content):
|
||||
"""
|
||||
.. index:: default
|
||||
:refguide: something, else, and more
|
||||
|
||||
"""
|
||||
|
||||
def strip_each_in(lst):
|
||||
return [s.strip() for s in lst]
|
||||
|
||||
out = {}
|
||||
section = section.split("::")
|
||||
if len(section) > 1:
|
||||
out["default"] = strip_each_in(section[1].split(","))[0]
|
||||
for line in content:
|
||||
line = line.split(":")
|
||||
if len(line) > 2:
|
||||
out[line[1]] = strip_each_in(line[2].split(","))
|
||||
return out
|
||||
|
||||
def _parse_summary(self):
|
||||
"""Grab signature (if given) and summary"""
|
||||
if self._is_at_section():
|
||||
return
|
||||
|
||||
# If several signatures present, take the last one
|
||||
while True:
|
||||
summary = self._doc.read_to_next_empty_line()
|
||||
summary_str = " ".join([s.strip() for s in summary]).strip()
|
||||
compiled = re.compile(r"^([\w., ]+=)?\s*[\w\.]+\(.*\)$")
|
||||
if compiled.match(summary_str):
|
||||
self["Signature"] = summary_str
|
||||
if not self._is_at_section():
|
||||
continue
|
||||
break
|
||||
|
||||
if summary is not None:
|
||||
self["Summary"] = summary
|
||||
|
||||
if not self._is_at_section():
|
||||
self["Extended Summary"] = self._read_to_next_section()
|
||||
|
||||
def _parse(self):
|
||||
self._doc.reset()
|
||||
self._parse_summary()
|
||||
|
||||
sections = list(self._read_sections())
|
||||
section_names = {section for section, content in sections}
|
||||
|
||||
has_yields = "Yields" in section_names
|
||||
# We could do more tests, but we are not. Arbitrarily.
|
||||
if not has_yields and "Receives" in section_names:
|
||||
msg = "Docstring contains a Receives section but not Yields."
|
||||
raise ValueError(msg)
|
||||
|
||||
for section, content in sections:
|
||||
if not section.startswith(".."):
|
||||
section = (s.capitalize() for s in section.split(" "))
|
||||
section = " ".join(section)
|
||||
if self.get(section):
|
||||
self._error_location(
|
||||
"The section %s appears twice in %s"
|
||||
% (section, "\n".join(self._doc._str))
|
||||
)
|
||||
|
||||
if section in ("Parameters", "Other Parameters", "Attributes", "Methods"):
|
||||
self[section] = self._parse_param_list(content)
|
||||
elif section in ("Returns", "Yields", "Raises", "Warns", "Receives"):
|
||||
self[section] = self._parse_param_list(
|
||||
content, single_element_is_type=True
|
||||
)
|
||||
elif section.startswith(".. index::"):
|
||||
self["index"] = self._parse_index(section, content)
|
||||
elif section == "See Also":
|
||||
self["See Also"] = self._parse_see_also(content)
|
||||
else:
|
||||
self[section] = content
|
||||
|
||||
@property
|
||||
def _obj(self):
|
||||
if hasattr(self, "_cls"):
|
||||
return self._cls
|
||||
elif hasattr(self, "_f"):
|
||||
return self._f
|
||||
return None
|
||||
|
||||
def _error_location(self, msg, error=True):
|
||||
if self._obj is not None:
|
||||
# we know where the docs came from:
|
||||
try:
|
||||
filename = inspect.getsourcefile(self._obj)
|
||||
except TypeError:
|
||||
filename = None
|
||||
# Make UserWarning more descriptive via object introspection.
|
||||
# Skip if introspection fails
|
||||
name = getattr(self._obj, "__name__", None)
|
||||
if name is None:
|
||||
name = getattr(getattr(self._obj, "__class__", None), "__name__", None)
|
||||
if name is not None:
|
||||
msg += f" in the docstring of {name}"
|
||||
msg += f" in {filename}." if filename else ""
|
||||
if error:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
warn(msg, stacklevel=3)
|
||||
|
||||
# string conversion routines
|
||||
|
||||
def _str_header(self, name, symbol="-"):
|
||||
return [name, len(name) * symbol]
|
||||
|
||||
def _str_indent(self, doc, indent=4):
|
||||
return [" " * indent + line for line in doc]
|
||||
|
||||
def _str_signature(self):
|
||||
if self["Signature"]:
|
||||
return [self["Signature"].replace("*", r"\*")] + [""]
|
||||
return [""]
|
||||
|
||||
def _str_summary(self):
|
||||
if self["Summary"]:
|
||||
return self["Summary"] + [""]
|
||||
return []
|
||||
|
||||
def _str_extended_summary(self):
|
||||
if self["Extended Summary"]:
|
||||
return self["Extended Summary"] + [""]
|
||||
return []
|
||||
|
||||
def _str_param_list(self, name):
|
||||
out = []
|
||||
if self[name]:
|
||||
out += self._str_header(name)
|
||||
for param in self[name]:
|
||||
parts = []
|
||||
if param.name:
|
||||
parts.append(param.name)
|
||||
if param.type:
|
||||
parts.append(param.type)
|
||||
out += [" : ".join(parts)]
|
||||
if param.desc and "".join(param.desc).strip():
|
||||
out += self._str_indent(param.desc)
|
||||
out += [""]
|
||||
return out
|
||||
|
||||
def _str_section(self, name):
|
||||
out = []
|
||||
if self[name]:
|
||||
out += self._str_header(name)
|
||||
out += self[name]
|
||||
out += [""]
|
||||
return out
|
||||
|
||||
def _str_see_also(self, func_role):
|
||||
if not self["See Also"]:
|
||||
return []
|
||||
out = []
|
||||
out += self._str_header("See Also")
|
||||
out += [""]
|
||||
last_had_desc = True
|
||||
for funcs, desc in self["See Also"]:
|
||||
assert isinstance(funcs, list)
|
||||
links = []
|
||||
for func, role in funcs:
|
||||
if role:
|
||||
link = f":{role}:`{func}`"
|
||||
elif func_role:
|
||||
link = f":{func_role}:`{func}`"
|
||||
else:
|
||||
link = f"`{func}`_"
|
||||
links.append(link)
|
||||
link = ", ".join(links)
|
||||
out += [link]
|
||||
if desc:
|
||||
out += self._str_indent([" ".join(desc)])
|
||||
last_had_desc = True
|
||||
else:
|
||||
last_had_desc = False
|
||||
out += self._str_indent([self.empty_description])
|
||||
|
||||
if last_had_desc:
|
||||
out += [""]
|
||||
out += [""]
|
||||
return out
|
||||
|
||||
def _str_index(self):
|
||||
idx = self["index"]
|
||||
out = []
|
||||
output_index = False
|
||||
default_index = idx.get("default", "")
|
||||
if default_index:
|
||||
output_index = True
|
||||
out += [f".. index:: {default_index}"]
|
||||
for section, references in idx.items():
|
||||
if section == "default":
|
||||
continue
|
||||
output_index = True
|
||||
out += [f" :{section}: {', '.join(references)}"]
|
||||
if output_index:
|
||||
return out
|
||||
return ""
|
||||
|
||||
def __str__(self, func_role=""):
|
||||
out = []
|
||||
out += self._str_signature()
|
||||
out += self._str_summary()
|
||||
out += self._str_extended_summary()
|
||||
out += self._str_param_list("Parameters")
|
||||
for param_list in ("Attributes", "Methods"):
|
||||
out += self._str_param_list(param_list)
|
||||
for param_list in (
|
||||
"Returns",
|
||||
"Yields",
|
||||
"Receives",
|
||||
"Other Parameters",
|
||||
"Raises",
|
||||
"Warns",
|
||||
):
|
||||
out += self._str_param_list(param_list)
|
||||
out += self._str_section("Warnings")
|
||||
out += self._str_see_also(func_role)
|
||||
for s in ("Notes", "References", "Examples"):
|
||||
out += self._str_section(s)
|
||||
out += self._str_index()
|
||||
return "\n".join(out)
|
||||
|
||||
|
||||
def dedent_lines(lines):
|
||||
"""Deindent a list of lines maximally"""
|
||||
return textwrap.dedent("\n".join(lines)).split("\n")
|
||||
|
||||
|
||||
class FunctionDoc(NumpyDocString):
|
||||
def __init__(self, func, role="func", doc=None, config=None):
|
||||
self._f = func
|
||||
self._role = role # e.g. "func" or "meth"
|
||||
|
||||
if doc is None:
|
||||
if func is None:
|
||||
raise ValueError("No function or docstring given")
|
||||
doc = inspect.getdoc(func) or ""
|
||||
if config is None:
|
||||
config = {}
|
||||
NumpyDocString.__init__(self, doc, config)
|
||||
|
||||
def get_func(self):
|
||||
func_name = getattr(self._f, "__name__", self.__class__.__name__)
|
||||
if inspect.isclass(self._f):
|
||||
func = getattr(self._f, "__call__", self._f.__init__)
|
||||
else:
|
||||
func = self._f
|
||||
return func, func_name
|
||||
|
||||
def __str__(self):
|
||||
out = ""
|
||||
|
||||
func, func_name = self.get_func()
|
||||
|
||||
roles = {"func": "function", "meth": "method"}
|
||||
|
||||
if self._role:
|
||||
if self._role not in roles:
|
||||
print(f"Warning: invalid role {self._role}")
|
||||
out += f".. {roles.get(self._role, '')}:: {func_name}\n \n\n"
|
||||
|
||||
out += super().__str__(func_role=self._role)
|
||||
return out
|
||||
|
||||
|
||||
class ObjDoc(NumpyDocString):
|
||||
def __init__(self, obj, doc=None, config=None):
|
||||
self._f = obj
|
||||
if config is None:
|
||||
config = {}
|
||||
NumpyDocString.__init__(self, doc, config=config)
|
||||
|
||||
|
||||
class ClassDoc(NumpyDocString):
|
||||
extra_public_methods = ["__call__"]
|
||||
|
||||
def __init__(self, cls, doc=None, modulename="", func_doc=FunctionDoc, config=None):
|
||||
if not inspect.isclass(cls) and cls is not None:
|
||||
raise ValueError(f"Expected a class or None, but got {cls!r}")
|
||||
self._cls = cls
|
||||
|
||||
if "sphinx" in sys.modules:
|
||||
from sphinx.ext.autodoc import ALL
|
||||
else:
|
||||
ALL = object()
|
||||
|
||||
if config is None:
|
||||
config = {}
|
||||
self.show_inherited_members = config.get("show_inherited_class_members", True)
|
||||
|
||||
if modulename and not modulename.endswith("."):
|
||||
modulename += "."
|
||||
self._mod = modulename
|
||||
|
||||
if doc is None:
|
||||
if cls is None:
|
||||
raise ValueError("No class or documentation string given")
|
||||
doc = pydoc.getdoc(cls)
|
||||
|
||||
NumpyDocString.__init__(self, doc)
|
||||
|
||||
_members = config.get("members", [])
|
||||
if _members is ALL:
|
||||
_members = None
|
||||
_exclude = config.get("exclude-members", [])
|
||||
|
||||
if config.get("show_class_members", True) and _exclude is not ALL:
|
||||
|
||||
def splitlines_x(s):
|
||||
if not s:
|
||||
return []
|
||||
else:
|
||||
return s.splitlines()
|
||||
|
||||
for field, items in [
|
||||
("Methods", self.methods),
|
||||
("Attributes", self.properties),
|
||||
]:
|
||||
if not self[field]:
|
||||
doc_list = []
|
||||
for name in sorted(items):
|
||||
if name in _exclude or (_members and name not in _members):
|
||||
continue
|
||||
try:
|
||||
doc_item = pydoc.getdoc(getattr(self._cls, name))
|
||||
doc_list.append(Parameter(name, "", splitlines_x(doc_item)))
|
||||
except AttributeError:
|
||||
pass # method doesn't exist
|
||||
self[field] = doc_list
|
||||
|
||||
@property
|
||||
def methods(self):
|
||||
if self._cls is None:
|
||||
return []
|
||||
return [
|
||||
name
|
||||
for name, func in inspect.getmembers(self._cls)
|
||||
if (
|
||||
(not name.startswith("_") or name in self.extra_public_methods)
|
||||
and isinstance(func, Callable)
|
||||
and self._is_show_member(name)
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def properties(self):
|
||||
if self._cls is None:
|
||||
return []
|
||||
return [
|
||||
name
|
||||
for name, func in inspect.getmembers(self._cls)
|
||||
if (
|
||||
not name.startswith("_")
|
||||
and not self._should_skip_member(name, self._cls)
|
||||
and (
|
||||
func is None
|
||||
or isinstance(func, property | cached_property)
|
||||
or inspect.isdatadescriptor(func)
|
||||
)
|
||||
and self._is_show_member(name)
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _should_skip_member(name, klass):
|
||||
return (
|
||||
# Namedtuples should skip everything in their ._fields as the
|
||||
# docstrings for each of the members is: "Alias for field number X"
|
||||
issubclass(klass, tuple)
|
||||
and hasattr(klass, "_asdict")
|
||||
and hasattr(klass, "_fields")
|
||||
and name in klass._fields
|
||||
)
|
||||
|
||||
def _is_show_member(self, name):
|
||||
return (
|
||||
# show all class members
|
||||
self.show_inherited_members
|
||||
# or class member is not inherited
|
||||
or name in self._cls.__dict__
|
||||
)
|
||||
|
||||
|
||||
def get_doc_object(
|
||||
obj,
|
||||
what=None,
|
||||
doc=None,
|
||||
config=None,
|
||||
class_doc=ClassDoc,
|
||||
func_doc=FunctionDoc,
|
||||
obj_doc=ObjDoc,
|
||||
):
|
||||
if what is None:
|
||||
if inspect.isclass(obj):
|
||||
what = "class"
|
||||
elif inspect.ismodule(obj):
|
||||
what = "module"
|
||||
elif isinstance(obj, Callable):
|
||||
what = "function"
|
||||
else:
|
||||
what = "object"
|
||||
if config is None:
|
||||
config = {}
|
||||
|
||||
if what == "class":
|
||||
return class_doc(obj, func_doc=func_doc, doc=doc, config=config)
|
||||
elif what in ("function", "method"):
|
||||
return func_doc(obj, doc=doc, config=config)
|
||||
else:
|
||||
if doc is None:
|
||||
doc = pydoc.getdoc(obj)
|
||||
return obj_doc(obj, doc, config=config)
|
||||
@@ -0,0 +1,346 @@
|
||||
# `_elementwise_iterative_method.py` includes tools for writing functions that
|
||||
# - are vectorized to work elementwise on arrays,
|
||||
# - implement non-trivial, iterative algorithms with a callback interface, and
|
||||
# - return rich objects with iteration count, termination status, etc.
|
||||
#
|
||||
# Examples include:
|
||||
# `scipy.optimize._chandrupatla._chandrupatla for scalar rootfinding,
|
||||
# `scipy.optimize._chandrupatla._chandrupatla_minimize for scalar minimization,
|
||||
# `scipy.optimize._differentiate._differentiate for numerical differentiation,
|
||||
# `scipy.optimize._bracket._bracket_root for finding rootfinding brackets,
|
||||
# `scipy.optimize._bracket._bracket_minimize for finding minimization brackets,
|
||||
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature,
|
||||
# `scipy.differentiate.derivative` for finite difference based differentiation.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from ._util import _RichResult, _call_callback_maybe_halt
|
||||
from ._array_api import array_namespace, xp_size, xp_result_type
|
||||
import scipy._lib.array_api_extra as xpx
|
||||
|
||||
_ESIGNERR = -1
|
||||
_ECONVERR = -2
|
||||
_EVALUEERR = -3
|
||||
_ECALLBACK = -4
|
||||
_EINPUTERR = -5
|
||||
_ECONVERGED = 0
|
||||
_EINPROGRESS = 1
|
||||
|
||||
def _initialize(func, xs, args, complex_ok=False, preserve_shape=None, xp=None):
|
||||
"""Initialize abscissa, function, and args arrays for elementwise function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
An elementwise function with signature
|
||||
|
||||
func(x: ndarray, *args) -> ndarray
|
||||
|
||||
where each element of ``x`` is a finite real and ``args`` is a tuple,
|
||||
which may contain an arbitrary number of arrays that are broadcastable
|
||||
with ``x``.
|
||||
xs : tuple of arrays
|
||||
Finite real abscissa arrays. Must be broadcastable.
|
||||
args : tuple, optional
|
||||
Additional positional arguments to be passed to `func`.
|
||||
preserve_shape : bool, default:False
|
||||
When ``preserve_shape=False`` (default), `func` may be passed
|
||||
arguments of any shape; `_scalar_optimization_loop` is permitted
|
||||
to reshape and compress arguments at will. When
|
||||
``preserve_shape=False``, arguments passed to `func` must have shape
|
||||
`shape` or ``shape + (n,)``, where ``n`` is any integer.
|
||||
xp : namespace
|
||||
Namespace of array arguments in `xs`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xs, fs, args : tuple of arrays
|
||||
Broadcasted, writeable, 1D abscissa and function value arrays (or
|
||||
NumPy floats, if appropriate). The dtypes of the `xs` and `fs` are
|
||||
`xfat`; the dtype of the `args` are unchanged.
|
||||
shape : tuple of ints
|
||||
Original shape of broadcasted arrays.
|
||||
xfat : NumPy dtype
|
||||
Result dtype of abscissae, function values, and args determined using
|
||||
`np.result_type`, except integer types are promoted to `np.float64`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the result dtype is not that of a real scalar
|
||||
|
||||
Notes
|
||||
-----
|
||||
Useful for initializing the input of SciPy functions that accept
|
||||
an elementwise callable, abscissae, and arguments; e.g.
|
||||
`scipy.optimize._chandrupatla`.
|
||||
"""
|
||||
nx = len(xs)
|
||||
xp = array_namespace(*xs) if xp is None else xp
|
||||
|
||||
# Try to preserve `dtype`, but we need to ensure that the arguments are at
|
||||
# least floats before passing them into the function; integers can overflow
|
||||
# and cause failure.
|
||||
# There might be benefit to combining the `xs` into a single array and
|
||||
# calling `func` once on the combined array. For now, keep them separate.
|
||||
xat = xp_result_type(*xs, force_floating=True, xp=xp)
|
||||
xas = xp.broadcast_arrays(*xs, *args) # broadcast and rename
|
||||
xs, args = xas[:nx], xas[nx:]
|
||||
xs = [xp.asarray(x, dtype=xat) for x in xs] # use copy=False when implemented
|
||||
fs = [xp.asarray(func(x, *args)) for x in xs]
|
||||
shape = xs[0].shape
|
||||
fshape = fs[0].shape
|
||||
|
||||
if preserve_shape:
|
||||
# bind original shape/func now to avoid late-binding gotcha
|
||||
def func(x, *args, shape=shape, func=func, **kwargs):
|
||||
i = (0,)*(len(fshape) - len(shape))
|
||||
return func(x[i], *args, **kwargs)
|
||||
shape = np.broadcast_shapes(fshape, shape) # just shapes; use of NumPy OK
|
||||
xs = [xp.broadcast_to(x, shape) for x in xs]
|
||||
args = [xp.broadcast_to(arg, shape) for arg in args]
|
||||
|
||||
message = ("The shape of the array returned by `func` must be the same as "
|
||||
"the broadcasted shape of `x` and all other `args`.")
|
||||
if preserve_shape is not None: # only in tanhsinh for now
|
||||
message = f"When `preserve_shape=False`, {message.lower()}"
|
||||
shapes_equal = [f.shape == shape for f in fs]
|
||||
if not all(shapes_equal): # use Python all to reduce overhead
|
||||
raise ValueError(message)
|
||||
|
||||
# These algorithms tend to mix the dtypes of the abscissae and function
|
||||
# values, so figure out what the result will be and convert them all to
|
||||
# that type from the outset.
|
||||
xfat = xp.result_type(*([f.dtype for f in fs] + [xat]))
|
||||
if not complex_ok and not xp.isdtype(xfat, "real floating"):
|
||||
raise ValueError("Abscissae and function output must be real numbers.")
|
||||
xs = [xp.asarray(x, dtype=xfat, copy=True) for x in xs]
|
||||
fs = [xp.asarray(f, dtype=xfat, copy=True) for f in fs]
|
||||
|
||||
# To ensure that we can do indexing, we'll work with at least 1d arrays,
|
||||
# but remember the appropriate shape of the output.
|
||||
xs = [xp.reshape(x, (-1,)) for x in xs]
|
||||
fs = [xp.reshape(f, (-1,)) for f in fs]
|
||||
args = [xp.reshape(xp.asarray(arg, copy=True), (-1,)) for arg in args]
|
||||
return func, xs, fs, args, shape, xfat, xp
|
||||
|
||||
|
||||
def _loop(work, callback, shape, maxiter, func, args, dtype, pre_func_eval,
|
||||
post_func_eval, check_termination, post_termination_check,
|
||||
customize_result, res_work_pairs, xp, preserve_shape=False):
|
||||
"""Main loop of a vectorized scalar optimization algorithm
|
||||
|
||||
Parameters
|
||||
----------
|
||||
work : _RichResult
|
||||
All variables that need to be retained between iterations. Must
|
||||
contain attributes `nit`, `nfev`, and `success`. All arrays are
|
||||
subject to being "compressed" if `preserve_shape is False`; nest
|
||||
arrays that should not be compressed inside another object (e.g.
|
||||
`dict` or `_RichResult`).
|
||||
callback : callable
|
||||
User-specified callback function
|
||||
shape : tuple of ints
|
||||
The shape of all output arrays
|
||||
maxiter :
|
||||
Maximum number of iterations of the algorithm
|
||||
func : callable
|
||||
The user-specified callable that is being optimized or solved
|
||||
args : tuple
|
||||
Additional positional arguments to be passed to `func`.
|
||||
dtype : NumPy dtype
|
||||
The common dtype of all abscissae and function values
|
||||
pre_func_eval : callable
|
||||
A function that accepts `work` and returns `x`, the active elements
|
||||
of `x` at which `func` will be evaluated. May modify attributes
|
||||
of `work` with any algorithmic steps that need to happen
|
||||
at the beginning of an iteration, before `func` is evaluated,
|
||||
post_func_eval : callable
|
||||
A function that accepts `x`, `func(x)`, and `work`. May modify
|
||||
attributes of `work` with any algorithmic steps that need to happen
|
||||
in the middle of an iteration, after `func` is evaluated but before
|
||||
the termination check.
|
||||
check_termination : callable
|
||||
A function that accepts `work` and returns `stop`, a boolean array
|
||||
indicating which of the active elements have met a termination
|
||||
condition.
|
||||
post_termination_check : callable
|
||||
A function that accepts `work`. May modify `work` with any algorithmic
|
||||
steps that need to happen after the termination check and before the
|
||||
end of the iteration.
|
||||
customize_result : callable
|
||||
A function that accepts `res` and `shape` and returns `shape`. May
|
||||
modify `res` (in-place) according to preferences (e.g. rearrange
|
||||
elements between attributes) and modify `shape` if needed.
|
||||
res_work_pairs : list of (str, str)
|
||||
Identifies correspondence between attributes of `res` and attributes
|
||||
of `work`; i.e., attributes of active elements of `work` will be
|
||||
copied to the appropriate indices of `res` when appropriate. The order
|
||||
determines the order in which _RichResult attributes will be
|
||||
pretty-printed.
|
||||
preserve_shape : bool, default: False
|
||||
Whether to compress the attributes of `work` (to avoid unnecessary
|
||||
computation on elements that have already converged).
|
||||
|
||||
Returns
|
||||
-------
|
||||
res : _RichResult
|
||||
The final result object
|
||||
|
||||
Notes
|
||||
-----
|
||||
Besides providing structure, this framework provides several important
|
||||
services for a vectorized optimization algorithm.
|
||||
|
||||
- It handles common tasks involving iteration count, function evaluation
|
||||
count, a user-specified callback, and associated termination conditions.
|
||||
- It compresses the attributes of `work` to eliminate unnecessary
|
||||
computation on elements that have already converged.
|
||||
|
||||
"""
|
||||
if xp is None:
|
||||
raise NotImplementedError("Must provide xp.")
|
||||
|
||||
cb_terminate = False
|
||||
|
||||
# Initialize the result object and active element index array
|
||||
n_elements = math.prod(shape)
|
||||
active = xp.arange(n_elements) # in-progress element indices
|
||||
res_dict = {i: xp.zeros(n_elements, dtype=dtype) for i, j in res_work_pairs}
|
||||
res_dict['success'] = xp.zeros(n_elements, dtype=xp.bool)
|
||||
res_dict['status'] = xp.full(n_elements, xp.asarray(_EINPROGRESS), dtype=xp.int32)
|
||||
res_dict['nit'] = xp.zeros(n_elements, dtype=xp.int32)
|
||||
res_dict['nfev'] = xp.zeros(n_elements, dtype=xp.int32)
|
||||
res = _RichResult(res_dict)
|
||||
work.args = args
|
||||
|
||||
active = _check_termination(work, res, res_work_pairs, active,
|
||||
check_termination, preserve_shape, xp)
|
||||
|
||||
if callback is not None:
|
||||
temp = _prepare_result(work, res, res_work_pairs, active, shape,
|
||||
customize_result, preserve_shape, xp)
|
||||
if _call_callback_maybe_halt(callback, temp):
|
||||
cb_terminate = True
|
||||
|
||||
while work.nit < maxiter and xp_size(active) and not cb_terminate and n_elements:
|
||||
x = pre_func_eval(work)
|
||||
|
||||
if work.args and work.args[0].ndim != x.ndim:
|
||||
# `x` always starts as 1D. If the SciPy function that uses
|
||||
# _loop added dimensions to `x`, we need to
|
||||
# add them to the elements of `args`.
|
||||
args = []
|
||||
for arg in work.args:
|
||||
n_new_dims = x.ndim - arg.ndim
|
||||
new_shape = arg.shape + (1,)*n_new_dims
|
||||
args.append(xp.reshape(arg, new_shape))
|
||||
work.args = args
|
||||
|
||||
x_shape = x.shape
|
||||
if preserve_shape:
|
||||
x = xp.reshape(x, (shape + (-1,)))
|
||||
f = func(x, *work.args)
|
||||
f = xp.asarray(f, dtype=dtype)
|
||||
if preserve_shape:
|
||||
x = xp.reshape(x, x_shape)
|
||||
f = xp.reshape(f, x_shape)
|
||||
work.nfev += 1 if x.ndim == 1 else x.shape[-1]
|
||||
|
||||
post_func_eval(x, f, work)
|
||||
|
||||
work.nit += 1
|
||||
active = _check_termination(work, res, res_work_pairs, active,
|
||||
check_termination, preserve_shape, xp)
|
||||
|
||||
if callback is not None:
|
||||
temp = _prepare_result(work, res, res_work_pairs, active, shape,
|
||||
customize_result, preserve_shape, xp)
|
||||
if _call_callback_maybe_halt(callback, temp):
|
||||
cb_terminate = True
|
||||
break
|
||||
if xp_size(active) == 0:
|
||||
break
|
||||
|
||||
post_termination_check(work)
|
||||
|
||||
work.status = xpx.at(work.status)[:].set(_ECALLBACK if cb_terminate else _ECONVERR)
|
||||
return _prepare_result(work, res, res_work_pairs, active, shape,
|
||||
customize_result, preserve_shape, xp)
|
||||
|
||||
|
||||
def _check_termination(work, res, res_work_pairs, active, check_termination,
|
||||
preserve_shape, xp):
|
||||
# Checks termination conditions, updates elements of `res` with
|
||||
# corresponding elements of `work`, and compresses `work`.
|
||||
|
||||
stop = check_termination(work)
|
||||
|
||||
if xp.any(stop):
|
||||
# update the active elements of the result object with the active
|
||||
# elements for which a termination condition has been met
|
||||
_update_active(work, res, res_work_pairs, active, stop, preserve_shape, xp)
|
||||
|
||||
if preserve_shape:
|
||||
stop = stop[active]
|
||||
|
||||
proceed = ~stop
|
||||
active = active[proceed]
|
||||
|
||||
if not preserve_shape:
|
||||
# compress the arrays to avoid unnecessary computation
|
||||
for key, val in work.items():
|
||||
# `continued_fraction` hacks `n`; improve if this becomes a problem
|
||||
if key in {'args', 'n'}:
|
||||
continue
|
||||
work[key] = val[proceed] if getattr(val, 'ndim', 0) > 0 else val
|
||||
work.args = [arg[proceed] for arg in work.args]
|
||||
|
||||
return active
|
||||
|
||||
|
||||
def _update_active(work, res, res_work_pairs, active, mask, preserve_shape, xp):
|
||||
# Update `active` indices of the arrays in result object `res` with the
|
||||
# contents of the scalars and arrays in `update_dict`. When provided,
|
||||
# `mask` is a boolean array applied both to the arrays in `update_dict`
|
||||
# that are to be used and to the arrays in `res` that are to be updated.
|
||||
update_dict = {key1: work[key2] for key1, key2 in res_work_pairs}
|
||||
update_dict['success'] = work.status == 0
|
||||
|
||||
if mask is not None:
|
||||
if preserve_shape:
|
||||
active_mask = xp.zeros_like(mask)
|
||||
active_mask = xpx.at(active_mask)[active].set(True)
|
||||
active_mask = active_mask & mask
|
||||
for key, val in update_dict.items():
|
||||
val = val[active_mask] if getattr(val, 'ndim', 0) > 0 else val
|
||||
res[key] = xpx.at(res[key])[active_mask].set(val)
|
||||
else:
|
||||
active_mask = active[mask]
|
||||
for key, val in update_dict.items():
|
||||
val = val[mask] if getattr(val, 'ndim', 0) > 0 else val
|
||||
res[key] = xpx.at(res[key])[active_mask].set(val)
|
||||
else:
|
||||
for key, val in update_dict.items():
|
||||
if preserve_shape and getattr(val, 'ndim', 0) > 0:
|
||||
val = val[active]
|
||||
res[key] = xpx.at(res[key])[active].set(val)
|
||||
|
||||
|
||||
def _prepare_result(work, res, res_work_pairs, active, shape, customize_result,
|
||||
preserve_shape, xp):
|
||||
# Prepare the result object `res` by creating a copy, copying the latest
|
||||
# data from work, running the provided result customization function,
|
||||
# and reshaping the data to the original shapes.
|
||||
res = res.copy()
|
||||
_update_active(work, res, res_work_pairs, active, None, preserve_shape, xp)
|
||||
|
||||
shape = customize_result(res, shape)
|
||||
|
||||
for key, val in res.items():
|
||||
# this looks like it won't work for xp != np if val is not numeric
|
||||
temp = xp.reshape(val, shape)
|
||||
res[key] = temp[()] if temp.ndim == 0 else temp
|
||||
|
||||
res['_order_keys'] = ['success'] + [i for i, j in res_work_pairs]
|
||||
return _RichResult(**res)
|
||||
Binary file not shown.
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Module for testing automatic garbage collection of objects
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
set_gc_state - enable or disable garbage collection
|
||||
gc_state - context manager for given state of garbage collector
|
||||
assert_deallocated - context manager to check for circular references on object
|
||||
|
||||
"""
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
from contextlib import contextmanager
|
||||
from platform import python_implementation
|
||||
|
||||
__all__ = ['set_gc_state', 'gc_state', 'assert_deallocated']
|
||||
|
||||
|
||||
IS_PYPY = python_implementation() == 'PyPy'
|
||||
|
||||
|
||||
class ReferenceError(AssertionError):
|
||||
pass
|
||||
|
||||
|
||||
def set_gc_state(state):
|
||||
""" Set status of garbage collector """
|
||||
if gc.isenabled() == state:
|
||||
return
|
||||
if state:
|
||||
gc.enable()
|
||||
else:
|
||||
gc.disable()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def gc_state(state):
|
||||
""" Context manager to set state of garbage collector to `state`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state : bool
|
||||
True for gc enabled, False for disabled
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> with gc_state(False):
|
||||
... assert not gc.isenabled()
|
||||
>>> with gc_state(True):
|
||||
... assert gc.isenabled()
|
||||
"""
|
||||
orig_state = gc.isenabled()
|
||||
set_gc_state(state)
|
||||
yield
|
||||
set_gc_state(orig_state)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_deallocated(func, *args, **kwargs):
|
||||
"""Context manager to check that object is deallocated
|
||||
|
||||
This is useful for checking that an object can be freed directly by
|
||||
reference counting, without requiring gc to break reference cycles.
|
||||
GC is disabled inside the context manager.
|
||||
|
||||
This check is not available on PyPy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
Callable to create object to check
|
||||
\\*args : sequence
|
||||
positional arguments to `func` in order to create object to check
|
||||
\\*\\*kwargs : dict
|
||||
keyword arguments to `func` in order to create object to check
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> class C: pass
|
||||
>>> with assert_deallocated(C) as c:
|
||||
... # do something
|
||||
... del c
|
||||
|
||||
>>> class C:
|
||||
... def __init__(self):
|
||||
... self._circular = self # Make circular reference
|
||||
>>> with assert_deallocated(C) as c: #doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
... # do something
|
||||
... del c
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ReferenceError: Remaining reference(s) to object
|
||||
"""
|
||||
if IS_PYPY:
|
||||
raise RuntimeError("assert_deallocated is unavailable on PyPy")
|
||||
|
||||
with gc_state(False):
|
||||
obj = func(*args, **kwargs)
|
||||
ref = weakref.ref(obj)
|
||||
yield obj
|
||||
del obj
|
||||
if ref() is not None:
|
||||
raise ReferenceError("Remaining reference(s) to object")
|
||||
@@ -0,0 +1,487 @@
|
||||
"""Utility to compare pep440 compatible version strings.
|
||||
|
||||
The LooseVersion and StrictVersion classes that distutils provides don't
|
||||
work; they don't recognize anything like alpha/beta/rc/dev versions.
|
||||
"""
|
||||
|
||||
# Copyright (c) Donald Stufft and individual contributors.
|
||||
# All rights reserved.
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import re
|
||||
|
||||
|
||||
__all__ = [
|
||||
"parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN",
|
||||
]
|
||||
|
||||
|
||||
# BEGIN packaging/_structures.py
|
||||
|
||||
|
||||
class Infinity:
|
||||
def __repr__(self):
|
||||
return "Infinity"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(repr(self))
|
||||
|
||||
def __lt__(self, other):
|
||||
return False
|
||||
|
||||
def __le__(self, other):
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not isinstance(other, self.__class__)
|
||||
|
||||
def __gt__(self, other):
|
||||
return True
|
||||
|
||||
def __ge__(self, other):
|
||||
return True
|
||||
|
||||
def __neg__(self):
|
||||
return NegativeInfinity
|
||||
|
||||
|
||||
Infinity = Infinity()
|
||||
|
||||
|
||||
class NegativeInfinity:
|
||||
def __repr__(self):
|
||||
return "-Infinity"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(repr(self))
|
||||
|
||||
def __lt__(self, other):
|
||||
return True
|
||||
|
||||
def __le__(self, other):
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not isinstance(other, self.__class__)
|
||||
|
||||
def __gt__(self, other):
|
||||
return False
|
||||
|
||||
def __ge__(self, other):
|
||||
return False
|
||||
|
||||
def __neg__(self):
|
||||
return Infinity
|
||||
|
||||
|
||||
# BEGIN packaging/version.py
|
||||
|
||||
|
||||
NegativeInfinity = NegativeInfinity()
|
||||
|
||||
_Version = collections.namedtuple(
|
||||
"_Version",
|
||||
["epoch", "release", "dev", "pre", "post", "local"],
|
||||
)
|
||||
|
||||
|
||||
def parse(version):
|
||||
"""
|
||||
Parse the given version string and return either a :class:`Version` object
|
||||
or a :class:`LegacyVersion` object depending on if the given version is
|
||||
a valid PEP 440 version or a legacy version.
|
||||
"""
|
||||
try:
|
||||
return Version(version)
|
||||
except InvalidVersion:
|
||||
return LegacyVersion(version)
|
||||
|
||||
|
||||
class InvalidVersion(ValueError):
|
||||
"""
|
||||
An invalid version was found, users should refer to PEP 440.
|
||||
"""
|
||||
|
||||
|
||||
class _BaseVersion:
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._key)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self._compare(other, lambda s, o: s < o)
|
||||
|
||||
def __le__(self, other):
|
||||
return self._compare(other, lambda s, o: s <= o)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._compare(other, lambda s, o: s == o)
|
||||
|
||||
def __ge__(self, other):
|
||||
return self._compare(other, lambda s, o: s >= o)
|
||||
|
||||
def __gt__(self, other):
|
||||
return self._compare(other, lambda s, o: s > o)
|
||||
|
||||
def __ne__(self, other):
|
||||
return self._compare(other, lambda s, o: s != o)
|
||||
|
||||
def _compare(self, other, method):
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return method(self._key, other._key)
|
||||
|
||||
|
||||
class LegacyVersion(_BaseVersion):
|
||||
|
||||
def __init__(self, version):
|
||||
self._version = str(version)
|
||||
self._key = _legacy_cmpkey(self._version)
|
||||
|
||||
def __str__(self):
|
||||
return self._version
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LegacyVersion({repr(str(self))})>"
|
||||
|
||||
@property
|
||||
def public(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def base_version(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def local(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_prerelease(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_postrelease(self):
|
||||
return False
|
||||
|
||||
|
||||
_legacy_version_component_re = re.compile(
|
||||
r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE,
|
||||
)
|
||||
|
||||
_legacy_version_replacement_map = {
|
||||
"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@",
|
||||
}
|
||||
|
||||
|
||||
def _parse_version_parts(s):
|
||||
for part in _legacy_version_component_re.split(s):
|
||||
part = _legacy_version_replacement_map.get(part, part)
|
||||
|
||||
if not part or part == ".":
|
||||
continue
|
||||
|
||||
if part[:1] in "0123456789":
|
||||
# pad for numeric comparison
|
||||
yield part.zfill(8)
|
||||
else:
|
||||
yield "*" + part
|
||||
|
||||
# ensure that alpha/beta/candidate are before final
|
||||
yield "*final"
|
||||
|
||||
|
||||
def _legacy_cmpkey(version):
|
||||
# We hardcode an epoch of -1 here. A PEP 440 version can only have an epoch
|
||||
# greater than or equal to 0. This will effectively put the LegacyVersion,
|
||||
# which uses the defacto standard originally implemented by setuptools,
|
||||
# as before all PEP 440 versions.
|
||||
epoch = -1
|
||||
|
||||
# This scheme is taken from pkg_resources.parse_version setuptools prior to
|
||||
# its adoption of the packaging library.
|
||||
parts = []
|
||||
for part in _parse_version_parts(version.lower()):
|
||||
if part.startswith("*"):
|
||||
# remove "-" before a prerelease tag
|
||||
if part < "*final":
|
||||
while parts and parts[-1] == "*final-":
|
||||
parts.pop()
|
||||
|
||||
# remove trailing zeros from each series of numeric parts
|
||||
while parts and parts[-1] == "00000000":
|
||||
parts.pop()
|
||||
|
||||
parts.append(part)
|
||||
parts = tuple(parts)
|
||||
|
||||
return epoch, parts
|
||||
|
||||
|
||||
# Deliberately not anchored to the start and end of the string, to make it
|
||||
# easier for 3rd party code to reuse
|
||||
VERSION_PATTERN = r"""
|
||||
v?
|
||||
(?:
|
||||
(?:(?P<epoch>[0-9]+)!)? # epoch
|
||||
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
||||
(?P<pre> # pre-release
|
||||
[-_\.]?
|
||||
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
||||
[-_\.]?
|
||||
(?P<pre_n>[0-9]+)?
|
||||
)?
|
||||
(?P<post> # post release
|
||||
(?:-(?P<post_n1>[0-9]+))
|
||||
|
|
||||
(?:
|
||||
[-_\.]?
|
||||
(?P<post_l>post|rev|r)
|
||||
[-_\.]?
|
||||
(?P<post_n2>[0-9]+)?
|
||||
)
|
||||
)?
|
||||
(?P<dev> # dev release
|
||||
[-_\.]?
|
||||
(?P<dev_l>dev)
|
||||
[-_\.]?
|
||||
(?P<dev_n>[0-9]+)?
|
||||
)?
|
||||
)
|
||||
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
||||
"""
|
||||
|
||||
|
||||
class Version(_BaseVersion):
|
||||
|
||||
_regex = re.compile(
|
||||
r"^\s*" + VERSION_PATTERN + r"\s*$",
|
||||
re.VERBOSE | re.IGNORECASE,
|
||||
)
|
||||
|
||||
def __init__(self, version):
|
||||
# Validate the version and parse it into pieces
|
||||
match = self._regex.search(version)
|
||||
if not match:
|
||||
raise InvalidVersion(f"Invalid version: '{version}'")
|
||||
|
||||
# Store the parsed out pieces of the version
|
||||
self._version = _Version(
|
||||
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
|
||||
release=tuple(int(i) for i in match.group("release").split(".")),
|
||||
pre=_parse_letter_version(
|
||||
match.group("pre_l"),
|
||||
match.group("pre_n"),
|
||||
),
|
||||
post=_parse_letter_version(
|
||||
match.group("post_l"),
|
||||
match.group("post_n1") or match.group("post_n2"),
|
||||
),
|
||||
dev=_parse_letter_version(
|
||||
match.group("dev_l"),
|
||||
match.group("dev_n"),
|
||||
),
|
||||
local=_parse_local_version(match.group("local")),
|
||||
)
|
||||
|
||||
# Generate a key which will be used for sorting
|
||||
self._key = _cmpkey(
|
||||
self._version.epoch,
|
||||
self._version.release,
|
||||
self._version.pre,
|
||||
self._version.post,
|
||||
self._version.dev,
|
||||
self._version.local,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Version({repr(str(self))})>"
|
||||
|
||||
def __str__(self):
|
||||
parts = []
|
||||
|
||||
# Epoch
|
||||
if self._version.epoch != 0:
|
||||
parts.append(f"{self._version.epoch}!")
|
||||
|
||||
# Release segment
|
||||
parts.append(".".join(str(x) for x in self._version.release))
|
||||
|
||||
# Pre-release
|
||||
if self._version.pre is not None:
|
||||
parts.append("".join(str(x) for x in self._version.pre))
|
||||
|
||||
# Post-release
|
||||
if self._version.post is not None:
|
||||
parts.append(f".post{self._version.post[1]}")
|
||||
|
||||
# Development release
|
||||
if self._version.dev is not None:
|
||||
parts.append(f".dev{self._version.dev[1]}")
|
||||
|
||||
# Local version segment
|
||||
if self._version.local is not None:
|
||||
parts.append(
|
||||
"+{}".format(".".join(str(x) for x in self._version.local))
|
||||
)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@property
|
||||
def public(self):
|
||||
return str(self).split("+", 1)[0]
|
||||
|
||||
@property
|
||||
def base_version(self):
|
||||
parts = []
|
||||
|
||||
# Epoch
|
||||
if self._version.epoch != 0:
|
||||
parts.append(f"{self._version.epoch}!")
|
||||
|
||||
# Release segment
|
||||
parts.append(".".join(str(x) for x in self._version.release))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@property
|
||||
def local(self):
|
||||
version_string = str(self)
|
||||
if "+" in version_string:
|
||||
return version_string.split("+", 1)[1]
|
||||
|
||||
@property
|
||||
def is_prerelease(self):
|
||||
return bool(self._version.dev or self._version.pre)
|
||||
|
||||
@property
|
||||
def is_postrelease(self):
|
||||
return bool(self._version.post)
|
||||
|
||||
|
||||
def _parse_letter_version(letter, number):
|
||||
if letter:
|
||||
# We assume there is an implicit 0 in a pre-release if there is
|
||||
# no numeral associated with it.
|
||||
if number is None:
|
||||
number = 0
|
||||
|
||||
# We normalize any letters to their lower-case form
|
||||
letter = letter.lower()
|
||||
|
||||
# We consider some words to be alternate spellings of other words and
|
||||
# in those cases we want to normalize the spellings to our preferred
|
||||
# spelling.
|
||||
if letter == "alpha":
|
||||
letter = "a"
|
||||
elif letter == "beta":
|
||||
letter = "b"
|
||||
elif letter in ["c", "pre", "preview"]:
|
||||
letter = "rc"
|
||||
elif letter in ["rev", "r"]:
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
if not letter and number:
|
||||
# We assume that if we are given a number but not given a letter,
|
||||
# then this is using the implicit post release syntax (e.g., 1.0-1)
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
|
||||
|
||||
_local_version_seperators = re.compile(r"[\._-]")
|
||||
|
||||
|
||||
def _parse_local_version(local):
|
||||
"""
|
||||
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
||||
"""
|
||||
if local is not None:
|
||||
return tuple(
|
||||
part.lower() if not part.isdigit() else int(part)
|
||||
for part in _local_version_seperators.split(local)
|
||||
)
|
||||
|
||||
|
||||
def _cmpkey(epoch, release, pre, post, dev, local):
|
||||
# When we compare a release version, we want to compare it with all of the
|
||||
# trailing zeros removed. So we'll use a reverse the list, drop all the now
|
||||
# leading zeros until we come to something non-zero, then take the rest,
|
||||
# re-reverse it back into the correct order, and make it a tuple and use
|
||||
# that for our sorting key.
|
||||
release = tuple(
|
||||
reversed(list(
|
||||
itertools.dropwhile(
|
||||
lambda x: x == 0,
|
||||
reversed(release),
|
||||
)
|
||||
))
|
||||
)
|
||||
|
||||
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
|
||||
# We'll do this by abusing the pre-segment, but we _only_ want to do this
|
||||
# if there is no pre- or a post-segment. If we have one of those, then
|
||||
# the normal sorting rules will handle this case correctly.
|
||||
if pre is None and post is None and dev is not None:
|
||||
pre = -Infinity
|
||||
# Versions without a pre-release (except as noted above) should sort after
|
||||
# those with one.
|
||||
elif pre is None:
|
||||
pre = Infinity
|
||||
|
||||
# Versions without a post-segment should sort before those with one.
|
||||
if post is None:
|
||||
post = -Infinity
|
||||
|
||||
# Versions without a development segment should sort after those with one.
|
||||
if dev is None:
|
||||
dev = Infinity
|
||||
|
||||
if local is None:
|
||||
# Versions without a local segment should sort before those with one.
|
||||
local = -Infinity
|
||||
else:
|
||||
# Versions with a local segment need that segment parsed to implement
|
||||
# the sorting rules in PEP440.
|
||||
# - Alphanumeric segments sort before numeric segments
|
||||
# - Alphanumeric segments sort lexicographically
|
||||
# - Numeric segments sort numerically
|
||||
# - Shorter versions sort before longer versions when the prefixes
|
||||
# match exactly
|
||||
local = tuple(
|
||||
(i, "") if isinstance(i, int) else (-Infinity, i)
|
||||
for i in local
|
||||
)
|
||||
|
||||
return epoch, release, pre, post, dev, local
|
||||
@@ -0,0 +1,56 @@
|
||||
"""PUBLIC_MODULES was once included in scipy._lib.tests.test_public_api.
|
||||
|
||||
It has been separated into this file so that this list of public modules
|
||||
could be used when generating tables showing support for alternative
|
||||
array API backends across modules in
|
||||
scipy/doc/source/array_api_capabilities.py.
|
||||
"""
|
||||
|
||||
# Historically SciPy has not used leading underscores for private submodules
|
||||
# much. This has resulted in lots of things that look like public modules
|
||||
# (i.e. things that can be imported as `import scipy.somesubmodule.somefile`),
|
||||
# but were never intended to be public. The PUBLIC_MODULES list contains
|
||||
# modules that are either public because they were meant to be, or because they
|
||||
# contain public functions/objects that aren't present in any other namespace
|
||||
# for whatever reason and therefore should be treated as public.
|
||||
PUBLIC_MODULES = ["scipy." + s for s in [
|
||||
"cluster",
|
||||
"cluster.vq",
|
||||
"cluster.hierarchy",
|
||||
"constants",
|
||||
"datasets",
|
||||
"differentiate",
|
||||
"fft",
|
||||
"fftpack",
|
||||
"integrate",
|
||||
"interpolate",
|
||||
"io",
|
||||
"io.arff",
|
||||
"io.matlab",
|
||||
"io.wavfile",
|
||||
"linalg",
|
||||
"linalg.blas",
|
||||
"linalg.cython_blas",
|
||||
"linalg.lapack",
|
||||
"linalg.cython_lapack",
|
||||
"linalg.interpolative",
|
||||
"ndimage",
|
||||
"odr",
|
||||
"optimize",
|
||||
"optimize.elementwise",
|
||||
"signal",
|
||||
"signal.windows",
|
||||
"sparse",
|
||||
"sparse.linalg",
|
||||
"sparse.csgraph",
|
||||
"spatial",
|
||||
"spatial.distance",
|
||||
"spatial.transform",
|
||||
"special",
|
||||
"stats",
|
||||
"stats.contingency",
|
||||
"stats.distributions",
|
||||
"stats.mstats",
|
||||
"stats.qmc",
|
||||
"stats.sampling"
|
||||
]]
|
||||
@@ -0,0 +1,41 @@
|
||||
from abc import ABC
|
||||
|
||||
__all__ = ["SparseABC", "issparse"]
|
||||
|
||||
|
||||
class SparseABC(ABC):
|
||||
pass
|
||||
|
||||
|
||||
def issparse(x):
|
||||
"""Is `x` of a sparse array or sparse matrix type?
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x
|
||||
object to check for being a sparse array or sparse matrix
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if `x` is a sparse array or a sparse matrix, False otherwise
|
||||
|
||||
Notes
|
||||
-----
|
||||
Use `isinstance(x, sp.sparse.sparray)` to check between an array or matrix.
|
||||
Use `a.format` to check the sparse format, e.g. `a.format == 'csr'`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from scipy.sparse import csr_array, csr_matrix, issparse
|
||||
>>> issparse(csr_matrix([[5]]))
|
||||
True
|
||||
>>> issparse(csr_array([[5]]))
|
||||
True
|
||||
>>> issparse(np.array([[5]]))
|
||||
False
|
||||
>>> issparse(5)
|
||||
False
|
||||
"""
|
||||
return isinstance(x, SparseABC)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
Generic test utilities.
|
||||
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import sysconfig
|
||||
import threading
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
try:
|
||||
# Need type: ignore[import-untyped] for mypy >= 1.6
|
||||
import cython # type: ignore[import-untyped]
|
||||
from Cython.Compiler.Version import ( # type: ignore[import-untyped]
|
||||
version as cython_version,
|
||||
)
|
||||
except ImportError:
|
||||
cython = None
|
||||
else:
|
||||
from scipy._lib import _pep440
|
||||
required_version = '3.0.8'
|
||||
if _pep440.parse(cython_version) < _pep440.Version(required_version):
|
||||
# too old or wrong cython, skip Cython API tests
|
||||
cython = None
|
||||
|
||||
|
||||
__all__ = ['PytestTester', 'check_free_memory', '_TestPythranFunc', 'IS_MUSL']
|
||||
|
||||
|
||||
IS_MUSL = False
|
||||
# alternate way is
|
||||
# from packaging.tags import sys_tags
|
||||
# _tags = list(sys_tags())
|
||||
# if 'musllinux' in _tags[0].platform:
|
||||
_v = sysconfig.get_config_var('HOST_GNU_TYPE') or ''
|
||||
if 'musl' in _v:
|
||||
IS_MUSL = True
|
||||
|
||||
|
||||
IS_EDITABLE = 'editable' in scipy.__path__[0]
|
||||
|
||||
|
||||
class FPUModeChangeWarning(RuntimeWarning):
|
||||
"""Warning about FPU mode change"""
|
||||
pass
|
||||
|
||||
|
||||
class PytestTester:
|
||||
"""
|
||||
Run tests for this namespace
|
||||
|
||||
``scipy.test()`` runs tests for all of SciPy, with the default settings.
|
||||
When used from a submodule (e.g., ``scipy.cluster.test()``, only the tests
|
||||
for that namespace are run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : {'fast', 'full'}, optional
|
||||
Whether to run only the fast tests, or also those marked as slow.
|
||||
Default is 'fast'.
|
||||
verbose : int, optional
|
||||
Test output verbosity. Default is 1.
|
||||
extra_argv : list, optional
|
||||
Arguments to pass through to Pytest.
|
||||
doctests : bool, optional
|
||||
Whether to run doctests or not. Default is False.
|
||||
coverage : bool, optional
|
||||
Whether to run tests with code coverage measurements enabled.
|
||||
Default is False.
|
||||
tests : list of str, optional
|
||||
List of module names to run tests for. By default, uses the module
|
||||
from which the ``test`` function is called.
|
||||
parallel : int, optional
|
||||
Run tests in parallel with pytest-xdist, if number given is larger than
|
||||
1. Default is 1.
|
||||
|
||||
"""
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
|
||||
def __call__(self, label="fast", verbose=1, extra_argv=None, doctests=False,
|
||||
coverage=False, tests=None, parallel=None):
|
||||
import pytest
|
||||
|
||||
module = sys.modules[self.module_name]
|
||||
module_path = os.path.abspath(module.__path__[0])
|
||||
|
||||
pytest_args = ['--showlocals', '--tb=short']
|
||||
|
||||
if extra_argv is None:
|
||||
extra_argv = []
|
||||
pytest_args += extra_argv
|
||||
if any(arg == "-m" or arg == "--markers" for arg in extra_argv):
|
||||
# Likely conflict with default --mode=fast
|
||||
raise ValueError("Must specify -m before --")
|
||||
|
||||
if verbose and int(verbose) > 1:
|
||||
pytest_args += ["-" + "v"*(int(verbose)-1)]
|
||||
|
||||
if coverage:
|
||||
pytest_args += ["--cov=" + module_path]
|
||||
|
||||
if label == "fast":
|
||||
pytest_args += ["-m", "not slow"]
|
||||
elif label != "full":
|
||||
pytest_args += ["-m", label]
|
||||
|
||||
if tests is None:
|
||||
tests = [self.module_name]
|
||||
|
||||
if parallel is not None and parallel > 1:
|
||||
if _pytest_has_xdist():
|
||||
pytest_args += ['-n', str(parallel)]
|
||||
else:
|
||||
import warnings
|
||||
warnings.warn('Could not run tests in parallel because '
|
||||
'pytest-xdist plugin is not available.',
|
||||
stacklevel=2)
|
||||
|
||||
pytest_args += ['--pyargs'] + list(tests)
|
||||
|
||||
try:
|
||||
code = pytest.main(pytest_args)
|
||||
except SystemExit as exc:
|
||||
code = exc.code
|
||||
|
||||
return (code == 0)
|
||||
|
||||
|
||||
class _TestPythranFunc:
|
||||
'''
|
||||
These are situations that can be tested in our pythran tests:
|
||||
- A function with multiple array arguments and then
|
||||
other positional and keyword arguments.
|
||||
- A function with array-like keywords (e.g. `def somefunc(x0, x1=None)`.
|
||||
Note: list/tuple input is not yet tested!
|
||||
|
||||
`self.arguments`: A dictionary which key is the index of the argument,
|
||||
value is tuple(array value, all supported dtypes)
|
||||
`self.partialfunc`: A function used to freeze some non-array argument
|
||||
that of no interests in the original function
|
||||
'''
|
||||
ALL_INTEGER = [np.int8, np.int16, np.int32, np.int64, np.intc, np.intp]
|
||||
ALL_FLOAT = [np.float32, np.float64]
|
||||
ALL_COMPLEX = [np.complex64, np.complex128]
|
||||
|
||||
def setup_method(self):
|
||||
self.arguments = {}
|
||||
self.partialfunc = None
|
||||
self.expected = None
|
||||
|
||||
def get_optional_args(self, func):
|
||||
# get optional arguments with its default value,
|
||||
# used for testing keywords
|
||||
signature = inspect.signature(func)
|
||||
optional_args = {}
|
||||
for k, v in signature.parameters.items():
|
||||
if v.default is not inspect.Parameter.empty:
|
||||
optional_args[k] = v.default
|
||||
return optional_args
|
||||
|
||||
def get_max_dtype_list_length(self):
|
||||
# get the max supported dtypes list length in all arguments
|
||||
max_len = 0
|
||||
for arg_idx in self.arguments:
|
||||
cur_len = len(self.arguments[arg_idx][1])
|
||||
if cur_len > max_len:
|
||||
max_len = cur_len
|
||||
return max_len
|
||||
|
||||
def get_dtype(self, dtype_list, dtype_idx):
|
||||
# get the dtype from dtype_list via index
|
||||
# if the index is out of range, then return the last dtype
|
||||
if dtype_idx > len(dtype_list)-1:
|
||||
return dtype_list[-1]
|
||||
else:
|
||||
return dtype_list[dtype_idx]
|
||||
|
||||
def test_all_dtypes(self):
|
||||
for type_idx in range(self.get_max_dtype_list_length()):
|
||||
args_array = []
|
||||
for arg_idx in self.arguments:
|
||||
new_dtype = self.get_dtype(self.arguments[arg_idx][1],
|
||||
type_idx)
|
||||
args_array.append(self.arguments[arg_idx][0].astype(new_dtype))
|
||||
self.pythranfunc(*args_array)
|
||||
|
||||
def test_views(self):
|
||||
args_array = []
|
||||
for arg_idx in self.arguments:
|
||||
args_array.append(self.arguments[arg_idx][0][::-1][::-1])
|
||||
self.pythranfunc(*args_array)
|
||||
|
||||
def test_strided(self):
|
||||
args_array = []
|
||||
for arg_idx in self.arguments:
|
||||
args_array.append(np.repeat(self.arguments[arg_idx][0],
|
||||
2, axis=0)[::2])
|
||||
self.pythranfunc(*args_array)
|
||||
|
||||
|
||||
def _pytest_has_xdist():
|
||||
"""
|
||||
Check if the pytest-xdist plugin is installed, providing parallel tests
|
||||
"""
|
||||
# Check xdist exists without importing, otherwise pytests emits warnings
|
||||
from importlib.util import find_spec
|
||||
return find_spec('xdist') is not None
|
||||
|
||||
|
||||
def check_free_memory(free_mb):
|
||||
"""
|
||||
Check *free_mb* of memory is available, otherwise do pytest.skip
|
||||
"""
|
||||
import pytest
|
||||
|
||||
try:
|
||||
mem_free = _parse_size(os.environ['SCIPY_AVAILABLE_MEM'])
|
||||
msg = '{} MB memory required, but environment SCIPY_AVAILABLE_MEM={}'.format(
|
||||
free_mb, os.environ['SCIPY_AVAILABLE_MEM'])
|
||||
except KeyError:
|
||||
mem_free = _get_mem_available()
|
||||
if mem_free is None:
|
||||
pytest.skip("Could not determine available memory; set SCIPY_AVAILABLE_MEM "
|
||||
"variable to free memory in MB to run the test.")
|
||||
msg = f'{free_mb} MB memory required, but {mem_free/1e6} MB available'
|
||||
|
||||
if mem_free < free_mb * 1e6:
|
||||
pytest.skip(msg)
|
||||
|
||||
|
||||
def _parse_size(size_str):
|
||||
suffixes = {'': 1e6,
|
||||
'b': 1.0,
|
||||
'k': 1e3, 'M': 1e6, 'G': 1e9, 'T': 1e12,
|
||||
'kb': 1e3, 'Mb': 1e6, 'Gb': 1e9, 'Tb': 1e12,
|
||||
'kib': 1024.0, 'Mib': 1024.0**2, 'Gib': 1024.0**3, 'Tib': 1024.0**4}
|
||||
m = re.match(r'^\s*(\d+)\s*({})\s*$'.format('|'.join(suffixes.keys())),
|
||||
size_str,
|
||||
re.I)
|
||||
if not m or m.group(2) not in suffixes:
|
||||
raise ValueError("Invalid size string")
|
||||
|
||||
return float(m.group(1)) * suffixes[m.group(2)]
|
||||
|
||||
|
||||
def _get_mem_available():
|
||||
"""
|
||||
Get information about memory available, not counting swap.
|
||||
"""
|
||||
try:
|
||||
import psutil
|
||||
return psutil.virtual_memory().available
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
info = {}
|
||||
with open('/proc/meminfo') as f:
|
||||
for line in f:
|
||||
p = line.split()
|
||||
info[p[0].strip(':').lower()] = float(p[1]) * 1e3
|
||||
|
||||
if 'memavailable' in info:
|
||||
# Linux >= 3.14
|
||||
return info['memavailable']
|
||||
else:
|
||||
return info['memfree'] + info['cached']
|
||||
|
||||
return None
|
||||
|
||||
def _test_cython_extension(tmp_path, srcdir):
|
||||
"""
|
||||
Helper function to test building and importing Cython modules that
|
||||
make use of the Cython APIs for BLAS, LAPACK, optimize, and special.
|
||||
"""
|
||||
import pytest
|
||||
try:
|
||||
subprocess.check_call(["meson", "--version"])
|
||||
except FileNotFoundError:
|
||||
pytest.skip("No usable 'meson' found")
|
||||
|
||||
# Make safe for being called by multiple threads within one test
|
||||
tmp_path = tmp_path / str(threading.get_ident())
|
||||
|
||||
# build the examples in a temporary directory
|
||||
mod_name = os.path.split(srcdir)[1]
|
||||
shutil.copytree(srcdir, tmp_path / mod_name)
|
||||
build_dir = tmp_path / mod_name / 'tests' / '_cython_examples'
|
||||
target_dir = build_dir / 'build'
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Ensure we use the correct Python interpreter even when `meson` is
|
||||
# installed in a different Python environment (see numpy#24956)
|
||||
native_file = str(build_dir / 'interpreter-native-file.ini')
|
||||
with open(native_file, 'w') as f:
|
||||
f.write("[binaries]\n")
|
||||
f.write(f"python = '{sys.executable}'")
|
||||
|
||||
if sys.platform == "win32":
|
||||
subprocess.check_call(["meson", "setup",
|
||||
"--buildtype=release",
|
||||
"--native-file", native_file,
|
||||
"--vsenv", str(build_dir)],
|
||||
cwd=target_dir,
|
||||
)
|
||||
else:
|
||||
subprocess.check_call(["meson", "setup",
|
||||
"--native-file", native_file, str(build_dir)],
|
||||
cwd=target_dir
|
||||
)
|
||||
subprocess.check_call(["meson", "compile", "-vv"], cwd=target_dir)
|
||||
|
||||
# import without adding the directory to sys.path
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
|
||||
def load(modname):
|
||||
so = (target_dir / modname).with_suffix(suffix)
|
||||
spec = spec_from_file_location(modname, so)
|
||||
mod = module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
# test that the module can be imported
|
||||
return load("extending"), load("extending_cpp")
|
||||
|
||||
|
||||
def _run_concurrent_barrier(n_workers, fn, *args, **kwargs):
|
||||
"""
|
||||
Run a given function concurrently across a given number of threads.
|
||||
|
||||
This is equivalent to using a ThreadPoolExecutor, but using the threading
|
||||
primitives instead. This function ensures that the closure passed by
|
||||
parameter gets called concurrently by setting up a barrier before it gets
|
||||
called before any of the threads.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
n_workers: int
|
||||
Number of concurrent threads to spawn.
|
||||
fn: callable
|
||||
Function closure to execute concurrently. Its first argument will
|
||||
be the thread id.
|
||||
*args: tuple
|
||||
Variable number of positional arguments to pass to the function.
|
||||
**kwargs: dict
|
||||
Keyword arguments to pass to the function.
|
||||
"""
|
||||
barrier = threading.Barrier(n_workers)
|
||||
|
||||
def closure(i, *args, **kwargs):
|
||||
barrier.wait()
|
||||
fn(i, *args, **kwargs)
|
||||
|
||||
workers = []
|
||||
for i in range(0, n_workers):
|
||||
workers.append(threading.Thread(
|
||||
target=closure,
|
||||
args=(i,) + args, kwargs=kwargs))
|
||||
|
||||
for worker in workers:
|
||||
worker.start()
|
||||
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
@@ -0,0 +1,86 @@
|
||||
''' Contexts for *with* statement providing temporary directories
|
||||
'''
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import rmtree
|
||||
from tempfile import mkdtemp
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tempdir():
|
||||
"""Create and return a temporary directory. This has the same
|
||||
behavior as mkdtemp but can be used as a context manager.
|
||||
|
||||
Upon exiting the context, the directory and everything contained
|
||||
in it are removed.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import os
|
||||
>>> with tempdir() as tmpdir:
|
||||
... fname = os.path.join(tmpdir, 'example_file.txt')
|
||||
... with open(fname, 'wt') as fobj:
|
||||
... _ = fobj.write('a string\\n')
|
||||
>>> os.path.exists(tmpdir)
|
||||
False
|
||||
"""
|
||||
d = mkdtemp()
|
||||
yield d
|
||||
rmtree(d)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_tempdir():
|
||||
''' Create, return, and change directory to a temporary directory
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import os
|
||||
>>> my_cwd = os.getcwd()
|
||||
>>> with in_tempdir() as tmpdir:
|
||||
... _ = open('test.txt', 'wt').write('some text')
|
||||
... assert os.path.isfile('test.txt')
|
||||
... assert os.path.isfile(os.path.join(tmpdir, 'test.txt'))
|
||||
>>> os.path.exists(tmpdir)
|
||||
False
|
||||
>>> os.getcwd() == my_cwd
|
||||
True
|
||||
'''
|
||||
pwd = os.getcwd()
|
||||
d = mkdtemp()
|
||||
os.chdir(d)
|
||||
yield d
|
||||
os.chdir(pwd)
|
||||
rmtree(d)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_dir(dir=None):
|
||||
""" Change directory to given directory for duration of ``with`` block
|
||||
|
||||
Useful when you want to use `in_tempdir` for the final test, but
|
||||
you are still debugging. For example, you may want to do this in the end:
|
||||
|
||||
>>> with in_tempdir() as tmpdir:
|
||||
... # do something complicated which might break
|
||||
... pass
|
||||
|
||||
But, indeed, the complicated thing does break, and meanwhile, the
|
||||
``in_tempdir`` context manager wiped out the directory with the
|
||||
temporary files that you wanted for debugging. So, while debugging, you
|
||||
replace with something like:
|
||||
|
||||
>>> with in_dir() as tmpdir: # Use working directory by default
|
||||
... # do something complicated which might break
|
||||
... pass
|
||||
|
||||
You can then look at the temporary file outputs to debug what is happening,
|
||||
fix, and finally replace ``in_dir`` with ``in_tempdir`` again.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
if dir is None:
|
||||
yield cwd
|
||||
return
|
||||
os.chdir(dir)
|
||||
yield dir
|
||||
os.chdir(cwd)
|
||||
@@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2018, Quansight-Labs
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
.. note:
|
||||
If you are looking for overrides for NumPy-specific methods, see the
|
||||
documentation for :obj:`unumpy`. This page explains how to write
|
||||
back-ends and multimethods.
|
||||
|
||||
``uarray`` is built around a back-end protocol, and overridable multimethods.
|
||||
It is necessary to define multimethods for back-ends to be able to override them.
|
||||
See the documentation of :obj:`generate_multimethod` on how to write multimethods.
|
||||
|
||||
|
||||
|
||||
Let's start with the simplest:
|
||||
|
||||
``__ua_domain__`` defines the back-end *domain*. The domain consists of period-
|
||||
separated string consisting of the modules you extend plus the submodule. For
|
||||
example, if a submodule ``module2.submodule`` extends ``module1``
|
||||
(i.e., it exposes dispatchables marked as types available in ``module1``),
|
||||
then the domain string should be ``"module1.module2.submodule"``.
|
||||
|
||||
|
||||
For the purpose of this demonstration, we'll be creating an object and setting
|
||||
its attributes directly. However, note that you can use a module or your own type
|
||||
as a backend as well.
|
||||
|
||||
>>> class Backend: pass
|
||||
>>> be = Backend()
|
||||
>>> be.__ua_domain__ = "ua_examples"
|
||||
|
||||
It might be useful at this point to sidetrack to the documentation of
|
||||
:obj:`generate_multimethod` to find out how to generate a multimethod
|
||||
overridable by :obj:`uarray`. Needless to say, writing a backend and
|
||||
creating multimethods are mostly orthogonal activities, and knowing
|
||||
one doesn't necessarily require knowledge of the other, although it
|
||||
is certainly helpful. We expect core API designers/specifiers to write the
|
||||
multimethods, and implementors to override them. But, as is often the case,
|
||||
similar people write both.
|
||||
|
||||
Without further ado, here's an example multimethod:
|
||||
|
||||
>>> import uarray as ua
|
||||
>>> from uarray import Dispatchable
|
||||
>>> def override_me(a, b):
|
||||
... return Dispatchable(a, int),
|
||||
>>> def override_replacer(args, kwargs, dispatchables):
|
||||
... return (dispatchables[0], args[1]), {}
|
||||
>>> overridden_me = ua.generate_multimethod(
|
||||
... override_me, override_replacer, "ua_examples"
|
||||
... )
|
||||
|
||||
Next comes the part about overriding the multimethod. This requires
|
||||
the ``__ua_function__`` protocol, and the ``__ua_convert__``
|
||||
protocol. The ``__ua_function__`` protocol has the signature
|
||||
``(method, args, kwargs)`` where ``method`` is the passed
|
||||
multimethod, ``args``/``kwargs`` specify the arguments and ``dispatchables``
|
||||
is the list of converted dispatchables passed in.
|
||||
|
||||
>>> def __ua_function__(method, args, kwargs):
|
||||
... return method.__name__, args, kwargs
|
||||
>>> be.__ua_function__ = __ua_function__
|
||||
|
||||
The other protocol of interest is the ``__ua_convert__`` protocol. It has the
|
||||
signature ``(dispatchables, coerce)``. When ``coerce`` is ``False``, conversion
|
||||
between the formats should ideally be an ``O(1)`` operation, but it means that
|
||||
no memory copying should be involved, only views of the existing data.
|
||||
|
||||
>>> def __ua_convert__(dispatchables, coerce):
|
||||
... for d in dispatchables:
|
||||
... if d.type is int:
|
||||
... if coerce and d.coercible:
|
||||
... yield str(d.value)
|
||||
... else:
|
||||
... yield d.value
|
||||
>>> be.__ua_convert__ = __ua_convert__
|
||||
|
||||
Now that we have defined the backend, the next thing to do is to call the multimethod.
|
||||
|
||||
>>> with ua.set_backend(be):
|
||||
... overridden_me(1, "2")
|
||||
('override_me', (1, '2'), {})
|
||||
|
||||
Note that the marked type has no effect on the actual type of the passed object.
|
||||
We can also coerce the type of the input.
|
||||
|
||||
>>> with ua.set_backend(be, coerce=True):
|
||||
... overridden_me(1, "2")
|
||||
... overridden_me(1.0, "2")
|
||||
('override_me', ('1', '2'), {})
|
||||
('override_me', ('1.0', '2'), {})
|
||||
|
||||
Another feature is that if you remove ``__ua_convert__``, the arguments are not
|
||||
converted at all and it's up to the backend to handle that.
|
||||
|
||||
>>> del be.__ua_convert__
|
||||
>>> with ua.set_backend(be):
|
||||
... overridden_me(1, "2")
|
||||
('override_me', (1, '2'), {})
|
||||
|
||||
You also have the option to return ``NotImplemented``, in which case processing moves on
|
||||
to the next back-end, which in this case, doesn't exist. The same applies to
|
||||
``__ua_convert__``.
|
||||
|
||||
>>> be.__ua_function__ = lambda *a, **kw: NotImplemented
|
||||
>>> with ua.set_backend(be):
|
||||
... overridden_me(1, "2")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
The last possibility is if we don't have ``__ua_convert__``, in which case the job is
|
||||
left up to ``__ua_function__``, but putting things back into arrays after conversion
|
||||
will not be possible.
|
||||
"""
|
||||
|
||||
from ._backend import *
|
||||
__version__ = '0.8.8.dev0+aa94c5a4.scipy'
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,707 @@
|
||||
import typing
|
||||
import types
|
||||
import inspect
|
||||
import functools
|
||||
from . import _uarray
|
||||
import copyreg
|
||||
import pickle
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
from ._uarray import ( # type: ignore
|
||||
BackendNotImplementedError,
|
||||
_Function,
|
||||
_SkipBackendContext,
|
||||
_SetBackendContext,
|
||||
_BackendState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"set_backend",
|
||||
"set_global_backend",
|
||||
"skip_backend",
|
||||
"register_backend",
|
||||
"determine_backend",
|
||||
"determine_backend_multi",
|
||||
"clear_backends",
|
||||
"create_multimethod",
|
||||
"generate_multimethod",
|
||||
"_Function",
|
||||
"BackendNotImplementedError",
|
||||
"Dispatchable",
|
||||
"wrap_single_convertor",
|
||||
"wrap_single_convertor_instance",
|
||||
"all_of_type",
|
||||
"mark_as",
|
||||
"set_state",
|
||||
"get_state",
|
||||
"reset_state",
|
||||
"_BackendState",
|
||||
"_SkipBackendContext",
|
||||
"_SetBackendContext",
|
||||
]
|
||||
|
||||
ArgumentExtractorType = typing.Callable[..., tuple["Dispatchable", ...]]
|
||||
ArgumentReplacerType = typing.Callable[
|
||||
[tuple, dict, tuple], tuple[tuple, dict]
|
||||
]
|
||||
|
||||
def unpickle_function(mod_name, qname, self_):
|
||||
import importlib
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mod_name)
|
||||
qname = qname.split(".")
|
||||
func = module
|
||||
for q in qname:
|
||||
func = getattr(func, q)
|
||||
|
||||
if self_ is not None:
|
||||
func = types.MethodType(func, self_)
|
||||
|
||||
return func
|
||||
except (ImportError, AttributeError) as e:
|
||||
from pickle import UnpicklingError
|
||||
|
||||
raise UnpicklingError from e
|
||||
|
||||
|
||||
def pickle_function(func):
|
||||
mod_name = getattr(func, "__module__", None)
|
||||
qname = getattr(func, "__qualname__", None)
|
||||
self_ = getattr(func, "__self__", None)
|
||||
|
||||
try:
|
||||
test = unpickle_function(mod_name, qname, self_)
|
||||
except pickle.UnpicklingError:
|
||||
test = None
|
||||
|
||||
if test is not func:
|
||||
raise pickle.PicklingError(
|
||||
f"Can't pickle {func}: it's not the same object as {test}"
|
||||
)
|
||||
|
||||
return unpickle_function, (mod_name, qname, self_)
|
||||
|
||||
|
||||
def pickle_state(state):
|
||||
return _uarray._BackendState._unpickle, state._pickle()
|
||||
|
||||
|
||||
def pickle_set_backend_context(ctx):
|
||||
return _SetBackendContext, ctx._pickle()
|
||||
|
||||
|
||||
def pickle_skip_backend_context(ctx):
|
||||
return _SkipBackendContext, ctx._pickle()
|
||||
|
||||
|
||||
copyreg.pickle(_Function, pickle_function)
|
||||
copyreg.pickle(_uarray._BackendState, pickle_state)
|
||||
copyreg.pickle(_SetBackendContext, pickle_set_backend_context)
|
||||
copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context)
|
||||
|
||||
|
||||
def get_state():
|
||||
"""
|
||||
Returns an opaque object containing the current state of all the backends.
|
||||
|
||||
Can be used for synchronization between threads/processes.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_state
|
||||
Sets the state returned by this function.
|
||||
"""
|
||||
return _uarray.get_state()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def reset_state():
|
||||
"""
|
||||
Returns a context manager that resets all state once exited.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_state
|
||||
Context manager that sets the backend state.
|
||||
get_state
|
||||
Gets a state to be set by this context manager.
|
||||
"""
|
||||
with set_state(get_state()):
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_state(state):
|
||||
"""
|
||||
A context manager that sets the state of the backends to one returned by :obj:`get_state`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
get_state
|
||||
Gets a state to be set by this context manager.
|
||||
""" # noqa: E501
|
||||
old_state = get_state()
|
||||
_uarray.set_state(state)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_uarray.set_state(old_state, True)
|
||||
|
||||
|
||||
def create_multimethod(*args, **kwargs):
|
||||
"""
|
||||
Creates a decorator for generating multimethods.
|
||||
|
||||
This function creates a decorator that can be used with an argument
|
||||
extractor in order to generate a multimethod. Other than for the
|
||||
argument extractor, all arguments are passed on to
|
||||
:obj:`generate_multimethod`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
generate_multimethod
|
||||
Generates a multimethod.
|
||||
"""
|
||||
|
||||
def wrapper(a):
|
||||
return generate_multimethod(a, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def generate_multimethod(
|
||||
argument_extractor: ArgumentExtractorType,
|
||||
argument_replacer: ArgumentReplacerType,
|
||||
domain: str,
|
||||
default: typing.Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Generates a multimethod.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
argument_extractor : ArgumentExtractorType
|
||||
A callable which extracts the dispatchable arguments. Extracted arguments
|
||||
should be marked by the :obj:`Dispatchable` class. It has the same signature
|
||||
as the desired multimethod.
|
||||
argument_replacer : ArgumentReplacerType
|
||||
A callable with the signature (args, kwargs, dispatchables), which should also
|
||||
return an (args, kwargs) pair with the dispatchables replaced inside the
|
||||
args/kwargs.
|
||||
domain : str
|
||||
A string value indicating the domain of this multimethod.
|
||||
default: Optional[Callable], optional
|
||||
The default implementation of this multimethod, where ``None`` (the default)
|
||||
specifies there is no default implementation.
|
||||
|
||||
Examples
|
||||
--------
|
||||
In this example, ``a`` is to be dispatched over, so we return it, while marking it
|
||||
as an ``int``.
|
||||
The trailing comma is needed because the args have to be returned as an iterable.
|
||||
|
||||
>>> def override_me(a, b):
|
||||
... return Dispatchable(a, int),
|
||||
|
||||
Next, we define the argument replacer that replaces the dispatchables inside
|
||||
args/kwargs with the supplied ones.
|
||||
|
||||
>>> def override_replacer(args, kwargs, dispatchables):
|
||||
... return (dispatchables[0], args[1]), {}
|
||||
|
||||
Next, we define the multimethod.
|
||||
|
||||
>>> overridden_me = generate_multimethod(
|
||||
... override_me, override_replacer, "ua_examples"
|
||||
... )
|
||||
|
||||
Notice that there's no default implementation, unless you supply one.
|
||||
|
||||
>>> overridden_me(1, "a")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
>>> overridden_me2 = generate_multimethod(
|
||||
... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
|
||||
... )
|
||||
>>> overridden_me2(1, "a")
|
||||
(1, 'a')
|
||||
|
||||
See Also
|
||||
--------
|
||||
uarray
|
||||
See the module documentation for how to override the method by creating
|
||||
backends.
|
||||
"""
|
||||
kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
|
||||
ua_func = _Function(
|
||||
argument_extractor,
|
||||
argument_replacer,
|
||||
domain,
|
||||
arg_defaults,
|
||||
kw_defaults,
|
||||
default,
|
||||
)
|
||||
|
||||
return functools.update_wrapper(ua_func, argument_extractor)
|
||||
|
||||
|
||||
def set_backend(backend, coerce=False, only=False):
|
||||
"""
|
||||
A context manager that sets the preferred backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to set.
|
||||
coerce
|
||||
Whether or not to coerce to a specific backend's types. Implies ``only``.
|
||||
only
|
||||
Whether or not this should be the last backend to try.
|
||||
|
||||
See Also
|
||||
--------
|
||||
skip_backend: A context manager that allows skipping of backends.
|
||||
set_global_backend: Set a single, global backend for a domain.
|
||||
"""
|
||||
tid = threading.get_native_id()
|
||||
try:
|
||||
return backend.__ua_cache__[tid, "set", coerce, only]
|
||||
except AttributeError:
|
||||
backend.__ua_cache__ = {}
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
ctx = _SetBackendContext(backend, coerce, only)
|
||||
backend.__ua_cache__[tid, "set", coerce, only] = ctx
|
||||
return ctx
|
||||
|
||||
|
||||
def skip_backend(backend):
|
||||
"""
|
||||
A context manager that allows one to skip a given backend from processing
|
||||
entirely. This allows one to use another backend's code in a library that
|
||||
is also a consumer of the same backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to skip.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_backend: A context manager that allows setting of backends.
|
||||
set_global_backend: Set a single, global backend for a domain.
|
||||
"""
|
||||
tid = threading.get_native_id()
|
||||
try:
|
||||
return backend.__ua_cache__[tid, "skip"]
|
||||
except AttributeError:
|
||||
backend.__ua_cache__ = {}
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
ctx = _SkipBackendContext(backend)
|
||||
backend.__ua_cache__[tid, "skip"] = ctx
|
||||
return ctx
|
||||
|
||||
|
||||
def get_defaults(f):
|
||||
sig = inspect.signature(f)
|
||||
kw_defaults = {}
|
||||
arg_defaults = []
|
||||
opts = set()
|
||||
for k, v in sig.parameters.items():
|
||||
if v.default is not inspect.Parameter.empty:
|
||||
kw_defaults[k] = v.default
|
||||
if v.kind in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
):
|
||||
arg_defaults.append(v.default)
|
||||
opts.add(k)
|
||||
|
||||
return kw_defaults, tuple(arg_defaults), opts
|
||||
|
||||
|
||||
def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
|
||||
"""
|
||||
This utility method replaces the default backend for permanent use. It
|
||||
will be tried in the list of backends automatically, unless the
|
||||
``only`` flag is set on a backend. This will be the first tried
|
||||
backend outside the :obj:`set_backend` context manager.
|
||||
|
||||
Note that this method is not thread-safe.
|
||||
|
||||
.. warning::
|
||||
We caution library authors against using this function in
|
||||
their code. We do *not* support this use-case. This function
|
||||
is meant to be used only by users themselves, or by a reference
|
||||
implementation, if one exists.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to register.
|
||||
coerce : bool
|
||||
Whether to coerce input types when trying this backend.
|
||||
only : bool
|
||||
If ``True``, no more backends will be tried if this fails.
|
||||
Implied by ``coerce=True``.
|
||||
try_last : bool
|
||||
If ``True``, the global backend is tried after registered backends.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_backend: A context manager that allows setting of backends.
|
||||
skip_backend: A context manager that allows skipping of backends.
|
||||
"""
|
||||
_uarray.set_global_backend(backend, coerce, only, try_last)
|
||||
|
||||
|
||||
def register_backend(backend):
|
||||
"""
|
||||
This utility method sets registers backend for permanent use. It
|
||||
will be tried in the list of backends automatically, unless the
|
||||
``only`` flag is set on a backend.
|
||||
|
||||
Note that this method is not thread-safe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to register.
|
||||
"""
|
||||
_uarray.register_backend(backend)
|
||||
|
||||
|
||||
def clear_backends(domain, registered=True, globals=False):
|
||||
"""
|
||||
This utility method clears registered backends.
|
||||
|
||||
.. warning::
|
||||
We caution library authors against using this function in
|
||||
their code. We do *not* support this use-case. This function
|
||||
is meant to be used only by users themselves.
|
||||
|
||||
.. warning::
|
||||
Do NOT use this method inside a multimethod call, or the
|
||||
program is likely to crash.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
domain : Optional[str]
|
||||
The domain for which to de-register backends. ``None`` means
|
||||
de-register for all domains.
|
||||
registered : bool
|
||||
Whether or not to clear registered backends. See :obj:`register_backend`.
|
||||
globals : bool
|
||||
Whether or not to clear global backends. See :obj:`set_global_backend`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
register_backend : Register a backend globally.
|
||||
set_global_backend : Set a global backend.
|
||||
"""
|
||||
_uarray.clear_backends(domain, registered, globals)
|
||||
|
||||
|
||||
class Dispatchable:
|
||||
"""
|
||||
A utility class which marks an argument with a specific dispatch type.
|
||||
|
||||
|
||||
Attributes
|
||||
----------
|
||||
value
|
||||
The value of the Dispatchable.
|
||||
|
||||
type
|
||||
The type of the Dispatchable.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> x = Dispatchable(1, str)
|
||||
>>> x
|
||||
<Dispatchable: type=<class 'str'>, value=1>
|
||||
|
||||
See Also
|
||||
--------
|
||||
all_of_type
|
||||
Marks all unmarked parameters of a function.
|
||||
|
||||
mark_as
|
||||
Allows one to create a utility function to mark as a given type.
|
||||
"""
|
||||
|
||||
def __init__(self, value, dispatch_type, coercible=True):
|
||||
self.value = value
|
||||
self.type = dispatch_type
|
||||
self.coercible = coercible
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.type, self.value)[index]
|
||||
|
||||
def __str__(self):
|
||||
return f"<{type(self).__name__}: type={self.type!r}, value={self.value!r}>"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def mark_as(dispatch_type):
|
||||
"""
|
||||
Creates a utility function to mark something as a specific type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> mark_int = mark_as(int)
|
||||
>>> mark_int(1)
|
||||
<Dispatchable: type=<class 'int'>, value=1>
|
||||
"""
|
||||
return functools.partial(Dispatchable, dispatch_type=dispatch_type)
|
||||
|
||||
|
||||
def all_of_type(arg_type):
|
||||
"""
|
||||
Marks all unmarked arguments as a given type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @all_of_type(str)
|
||||
... def f(a, b):
|
||||
... return a, Dispatchable(b, int)
|
||||
>>> f('a', 1)
|
||||
(<Dispatchable: type=<class 'str'>, value='a'>,
|
||||
<Dispatchable: type=<class 'int'>, value=1>)
|
||||
"""
|
||||
|
||||
def outer(func):
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
extracted_args = func(*args, **kwargs)
|
||||
return tuple(
|
||||
Dispatchable(arg, arg_type)
|
||||
if not isinstance(arg, Dispatchable)
|
||||
else arg
|
||||
for arg in extracted_args
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
return outer
|
||||
|
||||
|
||||
def wrap_single_convertor(convert_single):
|
||||
"""
|
||||
Wraps a ``__ua_convert__`` defined for a single element to all elements.
|
||||
If any of them return ``NotImplemented``, the operation is assumed to be
|
||||
undefined.
|
||||
|
||||
Accepts a signature of (value, type, coerce).
|
||||
"""
|
||||
|
||||
@functools.wraps(convert_single)
|
||||
def __ua_convert__(dispatchables, coerce):
|
||||
converted = []
|
||||
for d in dispatchables:
|
||||
c = convert_single(d.value, d.type, coerce and d.coercible)
|
||||
|
||||
if c is NotImplemented:
|
||||
return NotImplemented
|
||||
|
||||
converted.append(c)
|
||||
|
||||
return converted
|
||||
|
||||
return __ua_convert__
|
||||
|
||||
|
||||
def wrap_single_convertor_instance(convert_single):
|
||||
"""
|
||||
Wraps a ``__ua_convert__`` defined for a single element to all elements.
|
||||
If any of them return ``NotImplemented``, the operation is assumed to be
|
||||
undefined.
|
||||
|
||||
Accepts a signature of (value, type, coerce).
|
||||
"""
|
||||
|
||||
@functools.wraps(convert_single)
|
||||
def __ua_convert__(self, dispatchables, coerce):
|
||||
converted = []
|
||||
for d in dispatchables:
|
||||
c = convert_single(self, d.value, d.type, coerce and d.coercible)
|
||||
|
||||
if c is NotImplemented:
|
||||
return NotImplemented
|
||||
|
||||
converted.append(c)
|
||||
|
||||
return converted
|
||||
|
||||
return __ua_convert__
|
||||
|
||||
|
||||
def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
|
||||
"""Set the backend to the first active backend that supports ``value``
|
||||
|
||||
This is useful for functions that call multimethods without any dispatchable
|
||||
arguments. You can use :func:`determine_backend` to ensure the same backend
|
||||
is used everywhere in a block of multimethod calls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value
|
||||
The value being tested
|
||||
dispatch_type
|
||||
The dispatch type associated with ``value``, aka
|
||||
":ref:`marking <MarkingGlossary>`".
|
||||
domain: string
|
||||
The domain to query for backends and set.
|
||||
coerce: bool
|
||||
Whether or not to allow coercion to the backend's types. Implies ``only``.
|
||||
only: bool
|
||||
Whether or not this should be the last backend to try.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_backend: For when you know which backend to set
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
Support is determined by the ``__ua_convert__`` protocol. Backends not
|
||||
supporting the type must return ``NotImplemented`` from their
|
||||
``__ua_convert__`` if they don't support input of that type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
|
||||
different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
|
||||
|
||||
>>> with ua.set_backend(ex.BackendA):
|
||||
... ex.call_multimethod(ex.TypeB(), ex.TypeB())
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
Now consider a multimethod that creates a new object of ``TypeA``, or
|
||||
``TypeB`` depending on the active backend.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, ex.TypeA())
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
|
||||
innermost with statement. So, ``call_multimethod`` fails since the types
|
||||
don't match.
|
||||
|
||||
Instead, we need to first find a backend suitable for all of our objects.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
|
||||
... x = ex.TypeA()
|
||||
... with ua.determine_backend(x, "mark", domain="ua_examples"):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, x)
|
||||
TypeA
|
||||
|
||||
"""
|
||||
dispatchables = (Dispatchable(value, dispatch_type, coerce),)
|
||||
backend = _uarray.determine_backend(domain, dispatchables, coerce)
|
||||
|
||||
return set_backend(backend, coerce=coerce, only=only)
|
||||
|
||||
|
||||
def determine_backend_multi(
|
||||
dispatchables, *, domain, only=True, coerce=False, **kwargs
|
||||
):
|
||||
"""Set a backend supporting all ``dispatchables``
|
||||
|
||||
This is useful for functions that call multimethods without any dispatchable
|
||||
arguments. You can use :func:`determine_backend_multi` to ensure the same
|
||||
backend is used everywhere in a block of multimethod calls involving
|
||||
multiple arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
|
||||
The dispatchables that must be supported
|
||||
domain: string
|
||||
The domain to query for backends and set.
|
||||
coerce: bool
|
||||
Whether or not to allow coercion to the backend's types. Implies ``only``.
|
||||
only: bool
|
||||
Whether or not this should be the last backend to try.
|
||||
dispatch_type: Optional[Any]
|
||||
The default dispatch type associated with ``dispatchables``, aka
|
||||
":ref:`marking <MarkingGlossary>`".
|
||||
|
||||
See Also
|
||||
--------
|
||||
determine_backend: For a single dispatch value
|
||||
set_backend: For when you know which backend to set
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
Support is determined by the ``__ua_convert__`` protocol. Backends not
|
||||
supporting the type must return ``NotImplemented`` from their
|
||||
``__ua_convert__`` if they don't support input of that type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
:func:`determine_backend` allows the backend to be set from a single
|
||||
object. :func:`determine_backend_multi` allows multiple objects to be
|
||||
checked simultaneously for support in the backend. Suppose we have a
|
||||
``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
|
||||
and a ``BackendBC`` that doesn't support ``TypeA``.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
|
||||
... a, b = ex.TypeA(), ex.TypeB()
|
||||
... with ua.determine_backend_multi(
|
||||
... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
|
||||
... domain="ua_examples"
|
||||
... ):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, a, b)
|
||||
TypeA
|
||||
|
||||
This won't call ``BackendBC`` because it doesn't support ``TypeA``.
|
||||
|
||||
We can also use leave out the ``ua.Dispatchable`` if we specify the
|
||||
default ``dispatch_type`` for the ``dispatchables`` argument.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
|
||||
... a, b = ex.TypeA(), ex.TypeB()
|
||||
... with ua.determine_backend_multi(
|
||||
... [a, b], dispatch_type="mark", domain="ua_examples"
|
||||
... ):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, a, b)
|
||||
TypeA
|
||||
|
||||
"""
|
||||
if "dispatch_type" in kwargs:
|
||||
disp_type = kwargs.pop("dispatch_type")
|
||||
dispatchables = tuple(
|
||||
d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
|
||||
for d in dispatchables
|
||||
)
|
||||
else:
|
||||
dispatchables = tuple(dispatchables)
|
||||
if not all(isinstance(d, Dispatchable) for d in dispatchables):
|
||||
raise TypeError("dispatchables must be instances of uarray.Dispatchable")
|
||||
|
||||
if len(kwargs) != 0:
|
||||
raise TypeError(f"Received unexpected keyword arguments: {kwargs}")
|
||||
|
||||
backend = _uarray.determine_backend(domain, dispatchables, coerce)
|
||||
|
||||
return set_backend(backend, coerce=coerce, only=only)
|
||||
Binary file not shown.
1251
linedance-app/venv/lib/python3.12/site-packages/scipy/_lib/_util.py
Normal file
1251
linedance-app/venv/lib/python3.12/site-packages/scipy/_lib/_util.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
NumPy Array API compatibility library
|
||||
|
||||
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are
|
||||
compatible with the Array API standard https://data-apis.org/array-api/latest/.
|
||||
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
|
||||
|
||||
Unlike array_api_strict, this is not a strict minimal implementation of the
|
||||
Array API, but rather just an extension of the main NumPy namespace with
|
||||
changes needed to be compliant with the Array API. See
|
||||
https://numpy.org/doc/stable/reference/array_api.html for a full list of
|
||||
changes. In particular, unlike array_api_strict, this package does not use a
|
||||
separate Array object, but rather just uses numpy.ndarray directly.
|
||||
|
||||
Library authors using the Array API may wish to test against array_api_strict
|
||||
to ensure they are not using functionality outside of the standard, but prefer
|
||||
this implementation for the default when working with NumPy arrays.
|
||||
|
||||
"""
|
||||
__version__ = '1.13.0'
|
||||
|
||||
from .common import * # noqa: F401, F403
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Internal helpers
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from types import ModuleType
|
||||
from typing import TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
|
||||
"""
|
||||
Decorator to automatically replace xp with the corresponding array module.
|
||||
|
||||
Use like
|
||||
|
||||
import numpy as np
|
||||
|
||||
@get_xp(np)
|
||||
def func(x, /, xp, kwarg=None):
|
||||
return xp.func(x, kwarg=kwarg)
|
||||
|
||||
Note that xp must be a keyword argument and come after all non-keyword
|
||||
arguments.
|
||||
|
||||
"""
|
||||
|
||||
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
|
||||
@wraps(f)
|
||||
def wrapped_f(*args: object, **kwargs: object) -> object:
|
||||
return f(*args, xp=xp, **kwargs)
|
||||
|
||||
sig = signature(f)
|
||||
new_sig = sig.replace(
|
||||
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
|
||||
)
|
||||
|
||||
if wrapped_f.__doc__ is None:
|
||||
wrapped_f.__doc__ = f"""\
|
||||
Array API compatibility wrapper for {f.__name__}.
|
||||
|
||||
See the corresponding documentation in NumPy/CuPy and/or the array API
|
||||
specification for more details.
|
||||
|
||||
"""
|
||||
wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
|
||||
return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]:
|
||||
"""Import everything from module, updating globals().
|
||||
Returns __all__.
|
||||
"""
|
||||
mod = importlib.import_module(mod_name)
|
||||
# Neither of these two methods is sufficient by itself,
|
||||
# depending on various idiosyncrasies of the libraries we're wrapping.
|
||||
objs = {}
|
||||
exec(f"from {mod.__name__} import *", objs)
|
||||
|
||||
for n in dir(mod):
|
||||
if not n.startswith("_") and hasattr(mod, n):
|
||||
objs[n] = getattr(mod, n)
|
||||
|
||||
globals_.update(objs)
|
||||
return list(objs)
|
||||
|
||||
|
||||
__all__ = ["get_xp", "clone_module"]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1 @@
|
||||
from ._helpers import * # noqa: F403
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,702 @@
|
||||
"""
|
||||
These are functions that are just aliases of existing functions in NumPy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
||||
|
||||
from ._helpers import _check_device, array_namespace
|
||||
from ._helpers import device as _get_device
|
||||
from ._helpers import is_cupy_namespace
|
||||
from ._typing import Array, Device, DType, Namespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# TODO: import from typing (requires Python >=3.13)
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
# These functions are modified from the NumPy versions.
|
||||
|
||||
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
|
||||
|
||||
|
||||
def arange(
|
||||
start: float,
|
||||
/,
|
||||
stop: float | None = None,
|
||||
step: float = 1,
|
||||
*,
|
||||
xp: Namespace,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def empty(
|
||||
shape: int | tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.empty(shape, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def empty_like(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.empty_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def eye(
|
||||
n_rows: int,
|
||||
n_cols: int | None = None,
|
||||
/,
|
||||
*,
|
||||
xp: Namespace,
|
||||
k: int = 0,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def full(
|
||||
shape: int | tuple[int, ...],
|
||||
fill_value: complex,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def full_like(
|
||||
x: Array,
|
||||
/,
|
||||
fill_value: complex,
|
||||
*,
|
||||
xp: Namespace,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def linspace(
|
||||
start: float,
|
||||
stop: float,
|
||||
/,
|
||||
num: int,
|
||||
*,
|
||||
xp: Namespace,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
endpoint: bool = True,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
|
||||
|
||||
|
||||
def ones(
|
||||
shape: int | tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.ones(shape, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def ones_like(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.ones_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def zeros(
|
||||
shape: int | tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.zeros(shape, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def zeros_like(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
return xp.zeros_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
# np.unique() is split into four functions in the array API:
|
||||
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
|
||||
# to remove polymorphic return types).
|
||||
|
||||
# The functions here return namedtuples (np.unique() returns a normal
|
||||
# tuple).
|
||||
|
||||
|
||||
# Note that these named tuples aren't actually part of the standard namespace,
|
||||
# but I don't see any issue with exporting the names here regardless.
|
||||
class UniqueAllResult(NamedTuple):
|
||||
values: Array
|
||||
indices: Array
|
||||
inverse_indices: Array
|
||||
counts: Array
|
||||
|
||||
|
||||
class UniqueCountsResult(NamedTuple):
|
||||
values: Array
|
||||
counts: Array
|
||||
|
||||
|
||||
class UniqueInverseResult(NamedTuple):
|
||||
values: Array
|
||||
inverse_indices: Array
|
||||
|
||||
|
||||
def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
|
||||
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
|
||||
# trying to parse version numbers, just check if equal_nan is in the
|
||||
# signature.
|
||||
s = inspect.signature(xp.unique)
|
||||
if "equal_nan" in s.parameters:
|
||||
return {"equal_nan": False}
|
||||
return {}
|
||||
|
||||
|
||||
def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
values, indices, inverse_indices, counts = xp.unique(
|
||||
x,
|
||||
return_counts=True,
|
||||
return_index=True,
|
||||
return_inverse=True,
|
||||
**kwargs,
|
||||
)
|
||||
# np.unique() flattens inverse indices, but they need to share x's shape
|
||||
# See https://github.com/numpy/numpy/issues/20638
|
||||
inverse_indices = inverse_indices.reshape(x.shape)
|
||||
return UniqueAllResult(
|
||||
values,
|
||||
indices,
|
||||
inverse_indices,
|
||||
counts,
|
||||
)
|
||||
|
||||
|
||||
def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
res = xp.unique(
|
||||
x, return_counts=True, return_index=False, return_inverse=False, **kwargs
|
||||
)
|
||||
|
||||
return UniqueCountsResult(*res)
|
||||
|
||||
|
||||
def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
values, inverse_indices = xp.unique(
|
||||
x,
|
||||
return_counts=False,
|
||||
return_index=False,
|
||||
return_inverse=True,
|
||||
**kwargs,
|
||||
)
|
||||
# xp.unique() flattens inverse indices, but they need to share x's shape
|
||||
# See https://github.com/numpy/numpy/issues/20638
|
||||
inverse_indices = inverse_indices.reshape(x.shape)
|
||||
return UniqueInverseResult(values, inverse_indices)
|
||||
|
||||
|
||||
def unique_values(x: Array, /, xp: Namespace) -> Array:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
return xp.unique(
|
||||
x,
|
||||
return_counts=False,
|
||||
return_index=False,
|
||||
return_inverse=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# These functions have different keyword argument names
|
||||
|
||||
|
||||
def std(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
correction: float = 0.0, # correction instead of ddof
|
||||
keepdims: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
|
||||
def var(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
correction: float = 0.0, # correction instead of ddof
|
||||
keepdims: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
|
||||
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
|
||||
# argument
|
||||
|
||||
|
||||
def cumulative_sum(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | None = None,
|
||||
dtype: DType | None = None,
|
||||
include_initial: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
wrapped_xp = array_namespace(x)
|
||||
|
||||
# TODO: The standard is not clear about what should happen when x.ndim == 0.
|
||||
if axis is None:
|
||||
if x.ndim > 1:
|
||||
raise ValueError(
|
||||
"axis must be specified in cumulative_sum for more than one dimension"
|
||||
)
|
||||
axis = 0
|
||||
|
||||
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
|
||||
|
||||
# np.cumsum does not support include_initial
|
||||
if include_initial:
|
||||
initial_shape = list(x.shape)
|
||||
initial_shape[axis] = 1
|
||||
res = xp.concatenate(
|
||||
[
|
||||
wrapped_xp.zeros(
|
||||
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
|
||||
),
|
||||
res,
|
||||
],
|
||||
axis=axis,
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def cumulative_prod(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | None = None,
|
||||
dtype: DType | None = None,
|
||||
include_initial: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
wrapped_xp = array_namespace(x)
|
||||
|
||||
if axis is None:
|
||||
if x.ndim > 1:
|
||||
raise ValueError(
|
||||
"axis must be specified in cumulative_prod for more than one dimension"
|
||||
)
|
||||
axis = 0
|
||||
|
||||
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
|
||||
|
||||
# np.cumprod does not support include_initial
|
||||
if include_initial:
|
||||
initial_shape = list(x.shape)
|
||||
initial_shape[axis] = 1
|
||||
res = xp.concatenate(
|
||||
[
|
||||
wrapped_xp.ones(
|
||||
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
|
||||
),
|
||||
res,
|
||||
],
|
||||
axis=axis,
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
# The min and max argument names in clip are different and not optional in numpy, and type
|
||||
# promotion behavior is different.
|
||||
def clip(
|
||||
x: Array,
|
||||
/,
|
||||
min: float | Array | None = None,
|
||||
max: float | Array | None = None,
|
||||
*,
|
||||
xp: Namespace,
|
||||
# TODO: np.clip has other ufunc kwargs
|
||||
out: Array | None = None,
|
||||
) -> Array:
|
||||
def _isscalar(a: object) -> TypeIs[float | None]:
|
||||
return isinstance(a, int | float) or a is None
|
||||
|
||||
min_shape = () if _isscalar(min) else min.shape
|
||||
max_shape = () if _isscalar(max) else max.shape
|
||||
|
||||
wrapped_xp = array_namespace(x)
|
||||
|
||||
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
|
||||
|
||||
# np.clip does type promotion but the array API clip requires that the
|
||||
# output have the same dtype as x. We do this instead of just downcasting
|
||||
# the result of xp.clip() to handle some corner cases better (e.g.,
|
||||
# avoiding uint64 -> float64 promotion).
|
||||
|
||||
# Note: cases where min or max overflow (integer) or round (float) in the
|
||||
# wrong direction when downcasting to x.dtype are unspecified. This code
|
||||
# just does whatever NumPy does when it downcasts in the assignment, but
|
||||
# other behavior could be preferred, especially for integers. For example,
|
||||
# this code produces:
|
||||
|
||||
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
|
||||
# -128
|
||||
|
||||
# but an answer of 0 might be preferred. See
|
||||
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
|
||||
|
||||
# At least handle the case of Python integers correctly (see
|
||||
# https://github.com/numpy/numpy/pull/26892).
|
||||
if wrapped_xp.isdtype(x.dtype, "integral"):
|
||||
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
|
||||
min = None
|
||||
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
|
||||
max = None
|
||||
|
||||
dev = _get_device(x)
|
||||
if out is None:
|
||||
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
|
||||
assert out is not None # workaround for a type-narrowing issue in pyright
|
||||
out[()] = x
|
||||
|
||||
if min is not None:
|
||||
a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev)
|
||||
a = xp.broadcast_to(a, result_shape)
|
||||
ia = (out < a) | xp.isnan(a)
|
||||
out[ia] = a[ia]
|
||||
|
||||
if max is not None:
|
||||
b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev)
|
||||
b = xp.broadcast_to(b, result_shape)
|
||||
ib = (out > b) | xp.isnan(b)
|
||||
out[ib] = b[ib]
|
||||
|
||||
# Return a scalar for 0-D
|
||||
return out[()]
|
||||
|
||||
|
||||
# Unlike transpose(), the axes argument to permute_dims() is required.
|
||||
def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array:
|
||||
return xp.transpose(x, axes)
|
||||
|
||||
|
||||
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
|
||||
def reshape(
|
||||
x: Array,
|
||||
/,
|
||||
shape: tuple[int, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
copy: bool | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
if copy is True:
|
||||
x = x.copy()
|
||||
elif copy is False:
|
||||
y = x.view()
|
||||
y.shape = shape
|
||||
return y
|
||||
return xp.reshape(x, shape, **kwargs)
|
||||
|
||||
|
||||
# The descending keyword is new in sort and argsort, and 'kind' replaced with
|
||||
# 'stable'
|
||||
def argsort(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: bool = False,
|
||||
stable: bool = True,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# Note: this keyword argument is different, and the default is different.
|
||||
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
|
||||
# as the default whereas cupy.sort uses kind=None.
|
||||
if stable:
|
||||
kwargs["kind"] = "stable"
|
||||
if not descending:
|
||||
res = xp.argsort(x, axis=axis, **kwargs)
|
||||
else:
|
||||
# As NumPy has no native descending sort, we imitate it here. Note that
|
||||
# simply flipping the results of xp.argsort(x, ...) would not
|
||||
# respect the relative order like it would in native descending sorts.
|
||||
res = xp.flip(
|
||||
xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
|
||||
axis=axis,
|
||||
)
|
||||
# Rely on flip()/argsort() to validate axis
|
||||
normalised_axis = axis if axis >= 0 else x.ndim + axis
|
||||
max_i = x.shape[normalised_axis] - 1
|
||||
res = max_i - res
|
||||
return res
|
||||
|
||||
|
||||
def sort(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: bool = False,
|
||||
stable: bool = True,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# Note: this keyword argument is different, and the default is different.
|
||||
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
|
||||
# as the default whereas cupy.sort uses kind=None.
|
||||
if stable:
|
||||
kwargs["kind"] = "stable"
|
||||
res = xp.sort(x, axis=axis, **kwargs)
|
||||
if descending:
|
||||
res = xp.flip(res, axis=axis)
|
||||
return res
|
||||
|
||||
|
||||
# nonzero should error for zero-dimensional arrays
|
||||
def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
|
||||
if x.ndim == 0:
|
||||
raise ValueError("nonzero() does not support zero-dimensional arrays")
|
||||
return xp.nonzero(x, **kwargs)
|
||||
|
||||
|
||||
# linear algebra functions
|
||||
|
||||
|
||||
def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
return xp.matmul(x1, x2, **kwargs)
|
||||
|
||||
|
||||
# Unlike transpose, matrix_transpose only transposes the last two axes.
|
||||
def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
|
||||
if x.ndim < 2:
|
||||
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
|
||||
return xp.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
def tensordot(
|
||||
x1: Array,
|
||||
x2: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.tensordot(x1, x2, axes=axes, **kwargs)
|
||||
|
||||
|
||||
def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
|
||||
if x1.shape[axis] != x2.shape[axis]:
|
||||
raise ValueError("x1 and x2 must have the same size along the given axis")
|
||||
|
||||
if hasattr(xp, "broadcast_tensors"):
|
||||
_broadcast = xp.broadcast_tensors
|
||||
else:
|
||||
_broadcast = xp.broadcast_arrays
|
||||
|
||||
x1_ = xp.moveaxis(x1, axis, -1)
|
||||
x2_ = xp.moveaxis(x2, axis, -1)
|
||||
x1_, x2_ = _broadcast(x1_, x2_)
|
||||
|
||||
res = xp.conj(x1_[..., None, :]) @ x2_[..., None]
|
||||
return res[..., 0, 0]
|
||||
|
||||
|
||||
# isdtype is a new function in the 2022.12 array API specification.
|
||||
|
||||
|
||||
def isdtype(
|
||||
dtype: DType,
|
||||
kind: DType | str | tuple[DType | str, ...],
|
||||
xp: Namespace,
|
||||
*,
|
||||
_tuple: bool = True, # Disallow nested tuples
|
||||
) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
|
||||
|
||||
Note that outside of this function, this compat library does not yet fully
|
||||
support complex numbers.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
|
||||
for more details
|
||||
"""
|
||||
if isinstance(kind, tuple) and _tuple:
|
||||
return any(
|
||||
isdtype(dtype, k, xp, _tuple=False)
|
||||
for k in cast("tuple[DType | str, ...]", kind)
|
||||
)
|
||||
elif isinstance(kind, str):
|
||||
if kind == "bool":
|
||||
return dtype == xp.bool_
|
||||
elif kind == "signed integer":
|
||||
return xp.issubdtype(dtype, xp.signedinteger)
|
||||
elif kind == "unsigned integer":
|
||||
return xp.issubdtype(dtype, xp.unsignedinteger)
|
||||
elif kind == "integral":
|
||||
return xp.issubdtype(dtype, xp.integer)
|
||||
elif kind == "real floating":
|
||||
return xp.issubdtype(dtype, xp.floating)
|
||||
elif kind == "complex floating":
|
||||
return xp.issubdtype(dtype, xp.complexfloating)
|
||||
elif kind == "numeric":
|
||||
return xp.issubdtype(dtype, xp.number)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized data type kind: {kind!r}")
|
||||
else:
|
||||
# This will allow things that aren't required by the spec, like
|
||||
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
|
||||
# more strict here to match the type annotation? Note that the
|
||||
# array_api_strict implementation will be very strict.
|
||||
return dtype == kind
|
||||
|
||||
|
||||
# unstack is a new function in the 2023.12 array API standard
|
||||
def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]:
|
||||
if x.ndim == 0:
|
||||
raise ValueError("Input array must be at least 1-d.")
|
||||
return tuple(xp.moveaxis(x, axis, 0))
|
||||
|
||||
|
||||
# numpy 1.26 does not use the standard definition for sign on complex numbers
|
||||
|
||||
|
||||
def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
if isdtype(x.dtype, "complex floating", xp=xp):
|
||||
out = (x / xp.abs(x, **kwargs))[...]
|
||||
# sign(0) = 0 but the above formula would give nan
|
||||
out[x == 0j] = 0j
|
||||
else:
|
||||
out = xp.sign(x, **kwargs)
|
||||
# CuPy sign() does not propagate nans. See
|
||||
# https://github.com/data-apis/array-api-compat/issues/136
|
||||
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
|
||||
out[xp.isnan(x)] = xp.nan
|
||||
return out[()]
|
||||
|
||||
|
||||
def finfo(type_: DType | Array, /, xp: Namespace) -> Any:
|
||||
# It is surprisingly difficult to recognize a dtype apart from an array.
|
||||
# np.int64 is not the same as np.asarray(1).dtype!
|
||||
try:
|
||||
return xp.finfo(type_)
|
||||
except (ValueError, TypeError):
|
||||
return xp.finfo(type_.dtype)
|
||||
|
||||
|
||||
def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
|
||||
try:
|
||||
return xp.iinfo(type_)
|
||||
except (ValueError, TypeError):
|
||||
return xp.iinfo(type_.dtype)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"arange",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"eye",
|
||||
"full",
|
||||
"full_like",
|
||||
"linspace",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
"UniqueAllResult",
|
||||
"UniqueCountsResult",
|
||||
"UniqueInverseResult",
|
||||
"unique_all",
|
||||
"unique_counts",
|
||||
"unique_inverse",
|
||||
"unique_values",
|
||||
"std",
|
||||
"var",
|
||||
"cumulative_sum",
|
||||
"cumulative_prod",
|
||||
"clip",
|
||||
"permute_dims",
|
||||
"reshape",
|
||||
"argsort",
|
||||
"sort",
|
||||
"nonzero",
|
||||
"matmul",
|
||||
"matrix_transpose",
|
||||
"tensordot",
|
||||
"vecdot",
|
||||
"isdtype",
|
||||
"unstack",
|
||||
"sign",
|
||||
"finfo",
|
||||
"iinfo",
|
||||
]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from ._typing import Array, Device, DType, Namespace
|
||||
|
||||
_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
|
||||
|
||||
# Note: NumPy fft functions improperly upcast float32 and complex64 to
|
||||
# complex128, which is why we require wrapping them all here.
|
||||
|
||||
def fft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def ifft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def fftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def ifftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def rfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype == xp.float32:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def irfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype == xp.complex64:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def rfftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype == xp.float32:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def irfftn(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
s: Sequence[int] | None = None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype == xp.complex64:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def hfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def ihfft(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
n: int | None = None,
|
||||
axis: int = -1,
|
||||
norm: _Norm = "backward",
|
||||
) -> Array:
|
||||
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def fftfreq(
|
||||
n: int,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
d: float = 1.0,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
res = xp.fft.fftfreq(n, d=d)
|
||||
if dtype is not None:
|
||||
return res.astype(dtype)
|
||||
return res
|
||||
|
||||
def rfftfreq(
|
||||
n: int,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
d: float = 1.0,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
res = xp.fft.rfftfreq(n, d=d)
|
||||
if dtype is not None:
|
||||
return res.astype(dtype)
|
||||
return res
|
||||
|
||||
def fftshift(
|
||||
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
|
||||
) -> Array:
|
||||
return xp.fft.fftshift(x, axes=axes)
|
||||
|
||||
def ifftshift(
|
||||
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
|
||||
) -> Array:
|
||||
return xp.fft.ifftshift(x, axes=axes)
|
||||
|
||||
__all__ = [
|
||||
"fft",
|
||||
"ifft",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfft",
|
||||
"irfft",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
"hfft",
|
||||
"ihfft",
|
||||
"fftfreq",
|
||||
"rfftfreq",
|
||||
"fftshift",
|
||||
"ifftshift",
|
||||
]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Literal, NamedTuple, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
if np.__version__[0] == "2":
|
||||
from numpy.lib.array_utils import normalize_axis_tuple
|
||||
else:
|
||||
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
|
||||
|
||||
from .._internal import get_xp
|
||||
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
|
||||
from ._typing import Array, DType, JustFloat, JustInt, Namespace
|
||||
|
||||
|
||||
# These are in the main NumPy namespace but not in numpy.linalg
|
||||
def cross(
|
||||
x1: Array,
|
||||
x2: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int = -1,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.cross(x1, x2, axis=axis, **kwargs)
|
||||
|
||||
def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
|
||||
return xp.outer(x1, x2, **kwargs)
|
||||
|
||||
class EighResult(NamedTuple):
|
||||
eigenvalues: Array
|
||||
eigenvectors: Array
|
||||
|
||||
class QRResult(NamedTuple):
|
||||
Q: Array
|
||||
R: Array
|
||||
|
||||
class SlogdetResult(NamedTuple):
|
||||
sign: Array
|
||||
logabsdet: Array
|
||||
|
||||
class SVDResult(NamedTuple):
|
||||
U: Array
|
||||
S: Array
|
||||
Vh: Array
|
||||
|
||||
# These functions are the same as their NumPy counterparts except they return
|
||||
# a namedtuple.
|
||||
def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
|
||||
return EighResult(*xp.linalg.eigh(x, **kwargs))
|
||||
|
||||
def qr(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
mode: Literal["reduced", "complete"] = "reduced",
|
||||
**kwargs: object,
|
||||
) -> QRResult:
|
||||
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
|
||||
|
||||
def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
|
||||
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
|
||||
|
||||
def svd(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
full_matrices: bool = True,
|
||||
**kwargs: object,
|
||||
) -> SVDResult:
|
||||
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
|
||||
|
||||
# These functions have additional keyword arguments
|
||||
|
||||
# The upper keyword argument is new from NumPy
|
||||
def cholesky(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
upper: bool = False,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
L = xp.linalg.cholesky(x, **kwargs)
|
||||
if upper:
|
||||
U = get_xp(xp)(matrix_transpose)(L)
|
||||
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
|
||||
U = xp.conj(U) # pyright: ignore[reportConstantRedefinition]
|
||||
return U
|
||||
return L
|
||||
|
||||
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
|
||||
# Note that it has a different semantic meaning from tol and rcond.
|
||||
def matrix_rank(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
rtol: float | Array | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# this is different from xp.linalg.matrix_rank, which supports 1
|
||||
# dimensional arrays.
|
||||
if x.ndim < 2:
|
||||
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
|
||||
S: Array = get_xp(xp)(svdvals)(x, **kwargs)
|
||||
if rtol is None:
|
||||
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
|
||||
else:
|
||||
# this is different from xp.linalg.matrix_rank, which does not
|
||||
# multiply the tolerance by the largest singular value.
|
||||
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
|
||||
return xp.count_nonzero(S > tol, axis=-1)
|
||||
|
||||
def pinv(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
rtol: float | Array | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
# this is different from xp.linalg.pinv, which does not multiply the
|
||||
# default tolerance by max(M, N).
|
||||
if rtol is None:
|
||||
rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
|
||||
return xp.linalg.pinv(x, rcond=rtol, **kwargs)
|
||||
|
||||
# These functions are new in the array API spec
|
||||
|
||||
def matrix_norm(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
keepdims: bool = False,
|
||||
ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro",
|
||||
) -> Array:
|
||||
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
|
||||
|
||||
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
|
||||
# xp.linalg.svd(compute_uv=False).
|
||||
def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
|
||||
return xp.linalg.svd(x, compute_uv=False)
|
||||
|
||||
def vector_norm(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
keepdims: bool = False,
|
||||
ord: JustInt | JustFloat = 2,
|
||||
) -> Array:
|
||||
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
|
||||
# when axis=None and the input is 2-D, so to force a vector norm, we make
|
||||
# it so the input is 1-D (for axis=None), or reshape so that norm is done
|
||||
# on a single dimension.
|
||||
if axis is None:
|
||||
# Note: xp.linalg.norm() doesn't handle 0-D arrays
|
||||
_x = x.ravel()
|
||||
_axis = 0
|
||||
elif isinstance(axis, tuple):
|
||||
# Note: The axis argument supports any number of axes, whereas
|
||||
# xp.linalg.norm() only supports a single axis for vector norm.
|
||||
normalized_axis = cast(
|
||||
"tuple[int, ...]",
|
||||
normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
|
||||
newshape = axis + rest
|
||||
_x = xp.transpose(x, newshape).reshape(
|
||||
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
|
||||
_axis = 0
|
||||
else:
|
||||
_x = x
|
||||
_axis = axis
|
||||
|
||||
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
|
||||
|
||||
if keepdims:
|
||||
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
|
||||
# above to avoid matrix norm logic.
|
||||
shape = list(x.shape)
|
||||
axes = cast(
|
||||
"tuple[int, ...]",
|
||||
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
|
||||
range(x.ndim) if axis is None else axis,
|
||||
x.ndim,
|
||||
),
|
||||
)
|
||||
for i in axes:
|
||||
shape[i] = 1
|
||||
res = xp.reshape(res, tuple(shape))
|
||||
|
||||
return res
|
||||
|
||||
# xp.diagonal and xp.trace operate on the first two axes whereas these
|
||||
# operates on the last two
|
||||
|
||||
def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
|
||||
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
|
||||
|
||||
def trace(
|
||||
x: Array,
|
||||
/,
|
||||
xp: Namespace,
|
||||
*,
|
||||
offset: int = 0,
|
||||
dtype: DType | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
return xp.asarray(
|
||||
xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)
|
||||
)
|
||||
|
||||
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
|
||||
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
|
||||
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
|
||||
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
|
||||
'trace']
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from types import ModuleType as Namespace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
final,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import Incomplete
|
||||
|
||||
SupportsBufferProtocol: TypeAlias = Incomplete
|
||||
Array: TypeAlias = Incomplete
|
||||
Device: TypeAlias = Incomplete
|
||||
DType: TypeAlias = Incomplete
|
||||
else:
|
||||
SupportsBufferProtocol = object
|
||||
Array = object
|
||||
Device = object
|
||||
DType = object
|
||||
|
||||
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
|
||||
|
||||
# These "Just" types are equivalent to the `Just` type from the `optype` library,
|
||||
# apart from them not being `@runtime_checkable`.
|
||||
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
|
||||
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
|
||||
@final
|
||||
class JustInt(Protocol): # type: ignore[misc]
|
||||
@property # type: ignore[override]
|
||||
def __class__(self, /) -> type[int]: ...
|
||||
@__class__.setter
|
||||
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
|
||||
|
||||
@final
|
||||
class JustFloat(Protocol): # type: ignore[misc]
|
||||
@property # type: ignore[override]
|
||||
def __class__(self, /) -> type[float]: ...
|
||||
@__class__.setter
|
||||
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
|
||||
|
||||
@final
|
||||
class JustComplex(Protocol): # type: ignore[misc]
|
||||
@property # type: ignore[override]
|
||||
def __class__(self, /) -> type[complex]: ...
|
||||
@__class__.setter
|
||||
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
|
||||
|
||||
class NestedSequence(Protocol[_T_co]):
|
||||
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
|
||||
def __len__(self, /) -> int: ...
|
||||
|
||||
|
||||
class SupportsArrayNamespace(Protocol[_T_co]):
|
||||
def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
|
||||
|
||||
|
||||
class HasShape(Protocol[_T_co]):
|
||||
@property
|
||||
def shape(self, /) -> _T_co: ...
|
||||
|
||||
|
||||
# Return type of `__array_namespace_info__.default_dtypes`
|
||||
Capabilities = TypedDict(
|
||||
"Capabilities",
|
||||
{
|
||||
"boolean indexing": bool,
|
||||
"data-dependent shapes": bool,
|
||||
"max dimensions": int,
|
||||
},
|
||||
)
|
||||
|
||||
# Return type of `__array_namespace_info__.default_dtypes`
|
||||
DefaultDTypes = TypedDict(
|
||||
"DefaultDTypes",
|
||||
{
|
||||
"real floating": DType,
|
||||
"complex floating": DType,
|
||||
"integral": DType,
|
||||
"indexing": DType,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_DTypeKind: TypeAlias = Literal[
|
||||
"bool",
|
||||
"signed integer",
|
||||
"unsigned integer",
|
||||
"integral",
|
||||
"real floating",
|
||||
"complex floating",
|
||||
"numeric",
|
||||
]
|
||||
# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
|
||||
DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="bool")`
|
||||
class DTypesBool(TypedDict):
|
||||
bool: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="signed integer")`
|
||||
class DTypesSigned(TypedDict):
|
||||
int8: DType
|
||||
int16: DType
|
||||
int32: DType
|
||||
int64: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="unsigned integer")`
|
||||
class DTypesUnsigned(TypedDict):
|
||||
uint8: DType
|
||||
uint16: DType
|
||||
uint32: DType
|
||||
uint64: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="integral")`
|
||||
class DTypesIntegral(DTypesSigned, DTypesUnsigned):
|
||||
pass
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="real floating")`
|
||||
class DTypesReal(TypedDict):
|
||||
float32: DType
|
||||
float64: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="complex floating")`
|
||||
class DTypesComplex(TypedDict):
|
||||
complex64: DType
|
||||
complex128: DType
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind="numeric")`
|
||||
class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
|
||||
pass
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind=None)` (default)
|
||||
class DTypesAll(DTypesBool, DTypesNumeric):
|
||||
pass
|
||||
|
||||
|
||||
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
|
||||
DTypesAny: TypeAlias = Mapping[str, DType]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Array",
|
||||
"Capabilities",
|
||||
"DType",
|
||||
"DTypeKind",
|
||||
"DTypesAny",
|
||||
"DTypesAll",
|
||||
"DTypesBool",
|
||||
"DTypesNumeric",
|
||||
"DTypesIntegral",
|
||||
"DTypesSigned",
|
||||
"DTypesUnsigned",
|
||||
"DTypesReal",
|
||||
"DTypesComplex",
|
||||
"DefaultDTypes",
|
||||
"Device",
|
||||
"HasShape",
|
||||
"Namespace",
|
||||
"JustInt",
|
||||
"JustFloat",
|
||||
"JustComplex",
|
||||
"NestedSequence",
|
||||
"SupportsArrayNamespace",
|
||||
"SupportsBufferProtocol",
|
||||
]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import Final
|
||||
from cupy import * # noqa: F403
|
||||
|
||||
# from cupy import * doesn't overwrite these builtin names
|
||||
from cupy import abs, max, min, round # noqa: F401
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
from ._info import __array_namespace_info__ # noqa: F401
|
||||
|
||||
# See the comment in the numpy __init__.py
|
||||
__import__(__package__ + '.linalg')
|
||||
__import__(__package__ + '.fft')
|
||||
|
||||
__array_api_version__: Final = '2024.12'
|
||||
|
||||
__all__ = sorted(
|
||||
{name for name in globals() if not name.startswith("__")}
|
||||
- {"Final", "_aliases", "_info", "_typing"}
|
||||
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
|
||||
)
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from builtins import bool as py_bool
|
||||
|
||||
import cupy as cp
|
||||
|
||||
from ..common import _aliases, _helpers
|
||||
from ..common._typing import NestedSequence, SupportsBufferProtocol
|
||||
from .._internal import get_xp
|
||||
from ._typing import Array, Device, DType
|
||||
|
||||
bool = cp.bool_
|
||||
|
||||
# Basic renames
|
||||
acos = cp.arccos
|
||||
acosh = cp.arccosh
|
||||
asin = cp.arcsin
|
||||
asinh = cp.arcsinh
|
||||
atan = cp.arctan
|
||||
atan2 = cp.arctan2
|
||||
atanh = cp.arctanh
|
||||
bitwise_left_shift = cp.left_shift
|
||||
bitwise_invert = cp.invert
|
||||
bitwise_right_shift = cp.right_shift
|
||||
concat = cp.concatenate
|
||||
pow = cp.power
|
||||
|
||||
arange = get_xp(cp)(_aliases.arange)
|
||||
empty = get_xp(cp)(_aliases.empty)
|
||||
empty_like = get_xp(cp)(_aliases.empty_like)
|
||||
eye = get_xp(cp)(_aliases.eye)
|
||||
full = get_xp(cp)(_aliases.full)
|
||||
full_like = get_xp(cp)(_aliases.full_like)
|
||||
linspace = get_xp(cp)(_aliases.linspace)
|
||||
ones = get_xp(cp)(_aliases.ones)
|
||||
ones_like = get_xp(cp)(_aliases.ones_like)
|
||||
zeros = get_xp(cp)(_aliases.zeros)
|
||||
zeros_like = get_xp(cp)(_aliases.zeros_like)
|
||||
UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(cp)(_aliases.unique_all)
|
||||
unique_counts = get_xp(cp)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(cp)(_aliases.unique_values)
|
||||
std = get_xp(cp)(_aliases.std)
|
||||
var = get_xp(cp)(_aliases.var)
|
||||
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
|
||||
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
|
||||
clip = get_xp(cp)(_aliases.clip)
|
||||
permute_dims = get_xp(cp)(_aliases.permute_dims)
|
||||
reshape = get_xp(cp)(_aliases.reshape)
|
||||
argsort = get_xp(cp)(_aliases.argsort)
|
||||
sort = get_xp(cp)(_aliases.sort)
|
||||
nonzero = get_xp(cp)(_aliases.nonzero)
|
||||
matmul = get_xp(cp)(_aliases.matmul)
|
||||
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
|
||||
tensordot = get_xp(cp)(_aliases.tensordot)
|
||||
sign = get_xp(cp)(_aliases.sign)
|
||||
finfo = get_xp(cp)(_aliases.finfo)
|
||||
iinfo = get_xp(cp)(_aliases.iinfo)
|
||||
|
||||
|
||||
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
||||
def asarray(
|
||||
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
|
||||
/,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
copy: py_bool | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for asarray().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
with cp.cuda.Device(device):
|
||||
if copy is None:
|
||||
return cp.asarray(obj, dtype=dtype, **kwargs)
|
||||
else:
|
||||
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
|
||||
if not copy and res is not obj:
|
||||
raise ValueError("Unable to avoid copy while creating an array as requested")
|
||||
return res
|
||||
|
||||
|
||||
def astype(
|
||||
x: Array,
|
||||
dtype: DType,
|
||||
/,
|
||||
*,
|
||||
copy: py_bool = True,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
if device is None:
|
||||
return x.astype(dtype=dtype, copy=copy)
|
||||
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
|
||||
return out.copy() if copy and out is x else out
|
||||
|
||||
|
||||
# cupy.count_nonzero does not have keepdims
|
||||
def count_nonzero(
|
||||
x: Array,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
keepdims: py_bool = False,
|
||||
) -> Array:
|
||||
result = cp.count_nonzero(x, axis)
|
||||
if keepdims:
|
||||
if axis is None:
|
||||
return cp.reshape(result, [1]*x.ndim)
|
||||
return cp.expand_dims(result, axis)
|
||||
return result
|
||||
|
||||
# ceil, floor, and trunc return integers for integer inputs
|
||||
|
||||
def ceil(x: Array, /) -> Array:
|
||||
if cp.issubdtype(x.dtype, cp.integer):
|
||||
return x.copy()
|
||||
return cp.ceil(x)
|
||||
|
||||
|
||||
def floor(x: Array, /) -> Array:
|
||||
if cp.issubdtype(x.dtype, cp.integer):
|
||||
return x.copy()
|
||||
return cp.floor(x)
|
||||
|
||||
|
||||
def trunc(x: Array, /) -> Array:
|
||||
if cp.issubdtype(x.dtype, cp.integer):
|
||||
return x.copy()
|
||||
return cp.trunc(x)
|
||||
|
||||
|
||||
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
|
||||
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
|
||||
return cp.take_along_axis(x, indices, axis=axis)
|
||||
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(cp, 'vecdot'):
|
||||
vecdot = cp.vecdot
|
||||
else:
|
||||
vecdot = get_xp(cp)(_aliases.vecdot)
|
||||
|
||||
if hasattr(cp, 'isdtype'):
|
||||
isdtype = cp.isdtype
|
||||
else:
|
||||
isdtype = get_xp(cp)(_aliases.isdtype)
|
||||
|
||||
if hasattr(cp, 'unstack'):
|
||||
unstack = cp.unstack
|
||||
else:
|
||||
unstack = get_xp(cp)(_aliases.unstack)
|
||||
|
||||
__all__ = _aliases.__all__ + ['asarray', 'astype',
|
||||
'acos', 'acosh', 'asin', 'asinh', 'atan',
|
||||
'atan2', 'atanh', 'bitwise_left_shift',
|
||||
'bitwise_invert', 'bitwise_right_shift',
|
||||
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
|
||||
'ceil', 'floor', 'trunc', 'take_along_axis']
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Array API Inspection namespace
|
||||
|
||||
This is the namespace for inspection functions as defined by the array API
|
||||
standard. See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html for
|
||||
more details.
|
||||
|
||||
"""
|
||||
from cupy import (
|
||||
dtype,
|
||||
cuda,
|
||||
bool_ as bool,
|
||||
intp,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
float32,
|
||||
float64,
|
||||
complex64,
|
||||
complex128,
|
||||
)
|
||||
|
||||
|
||||
class __array_namespace_info__:
|
||||
"""
|
||||
Get the array API inspection namespace for CuPy.
|
||||
|
||||
The array API inspection namespace defines the following functions:
|
||||
|
||||
- capabilities()
|
||||
- default_device()
|
||||
- default_dtypes()
|
||||
- dtypes()
|
||||
- devices()
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html
|
||||
for more details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : ModuleType
|
||||
The array API inspection namespace for CuPy.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': cupy.float64,
|
||||
'complex floating': cupy.complex128,
|
||||
'integral': cupy.int64,
|
||||
'indexing': cupy.int64}
|
||||
|
||||
"""
|
||||
|
||||
__module__ = 'cupy'
|
||||
|
||||
def capabilities(self):
|
||||
"""
|
||||
Return a dictionary of array API library capabilities.
|
||||
|
||||
The resulting dictionary has the following keys:
|
||||
|
||||
- **"boolean indexing"**: boolean indicating whether an array library
|
||||
supports boolean indexing. Always ``True`` for CuPy.
|
||||
|
||||
- **"data-dependent shapes"**: boolean indicating whether an array
|
||||
library supports data-dependent output shapes. Always ``True`` for
|
||||
CuPy.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
|
||||
for more details.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
capabilities : dict
|
||||
A dictionary of array API library capabilities.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.capabilities()
|
||||
{'boolean indexing': True,
|
||||
'data-dependent shapes': True,
|
||||
'max dimensions': 64}
|
||||
|
||||
"""
|
||||
return {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": True,
|
||||
"max dimensions": 64,
|
||||
}
|
||||
|
||||
def default_device(self):
|
||||
"""
|
||||
The default device used for new CuPy arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
device : Device
|
||||
The default device used for new CuPy arrays.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_device()
|
||||
Device(0)
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method returns the static default device when CuPy is initialized.
|
||||
However, the *current* device used by creation functions (``empty`` etc.)
|
||||
can be changed globally or with a context manager.
|
||||
|
||||
See Also
|
||||
--------
|
||||
https://github.com/data-apis/array-api/issues/835
|
||||
"""
|
||||
return cuda.Device(0)
|
||||
|
||||
def default_dtypes(self, *, device=None):
|
||||
"""
|
||||
The default data types used for new CuPy arrays.
|
||||
|
||||
For CuPy, this always returns the following dictionary:
|
||||
|
||||
- **"real floating"**: ``cupy.float64``
|
||||
- **"complex floating"**: ``cupy.complex128``
|
||||
- **"integral"**: ``cupy.intp``
|
||||
- **"indexing"**: ``cupy.intp``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the default data types for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary describing the default data types used for new CuPy
|
||||
arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': cupy.float64,
|
||||
'complex floating': cupy.complex128,
|
||||
'integral': cupy.int64,
|
||||
'indexing': cupy.int64}
|
||||
|
||||
"""
|
||||
# TODO: Does this depend on device?
|
||||
return {
|
||||
"real floating": dtype(float64),
|
||||
"complex floating": dtype(complex128),
|
||||
"integral": dtype(intp),
|
||||
"indexing": dtype(intp),
|
||||
}
|
||||
|
||||
def dtypes(self, *, device=None, kind=None):
|
||||
"""
|
||||
The array API data types supported by CuPy.
|
||||
|
||||
Note that this function only returns data types that are defined by
|
||||
the array API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the data types for.
|
||||
kind : str or tuple of str, optional
|
||||
The kind of data types to return. If ``None``, all data types are
|
||||
returned. If a string, only data types of that kind are returned.
|
||||
If a tuple, a dictionary containing the union of the given kinds
|
||||
is returned. The following kinds are supported:
|
||||
|
||||
- ``'bool'``: boolean data types (i.e., ``bool``).
|
||||
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
|
||||
``int16``, ``int32``, ``int64``).
|
||||
- ``'unsigned integer'``: unsigned integer data types (i.e.,
|
||||
``uint8``, ``uint16``, ``uint32``, ``uint64``).
|
||||
- ``'integral'``: integer data types. Shorthand for ``('signed
|
||||
integer', 'unsigned integer')``.
|
||||
- ``'real floating'``: real-valued floating-point data types
|
||||
(i.e., ``float32``, ``float64``).
|
||||
- ``'complex floating'``: complex floating-point data types (i.e.,
|
||||
``complex64``, ``complex128``).
|
||||
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
|
||||
'real floating', 'complex floating')``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary mapping the names of data types to the corresponding
|
||||
CuPy data types.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.dtypes(kind='signed integer')
|
||||
{'int8': cupy.int8,
|
||||
'int16': cupy.int16,
|
||||
'int32': cupy.int32,
|
||||
'int64': cupy.int64}
|
||||
|
||||
"""
|
||||
# TODO: Does this depend on device?
|
||||
if kind is None:
|
||||
return {
|
||||
"bool": dtype(bool),
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "bool":
|
||||
return {"bool": bool}
|
||||
if kind == "signed integer":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
}
|
||||
if kind == "unsigned integer":
|
||||
return {
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "integral":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "real floating":
|
||||
return {
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
}
|
||||
if kind == "complex floating":
|
||||
return {
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "numeric":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if isinstance(kind, tuple):
|
||||
res = {}
|
||||
for k in kind:
|
||||
res.update(self.dtypes(kind=k))
|
||||
return res
|
||||
raise ValueError(f"unsupported kind: {kind!r}")
|
||||
|
||||
def devices(self):
|
||||
"""
|
||||
The devices supported by CuPy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
devices : list[Device]
|
||||
The devices supported by CuPy.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes
|
||||
|
||||
"""
|
||||
return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["Array", "DType", "Device"]
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cupy as cp
|
||||
from cupy import ndarray as Array
|
||||
from cupy.cuda.device import Device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
|
||||
DType = cp.dtype[
|
||||
cp.intp
|
||||
| cp.int8
|
||||
| cp.int16
|
||||
| cp.int32
|
||||
| cp.int64
|
||||
| cp.uint8
|
||||
| cp.uint16
|
||||
| cp.uint32
|
||||
| cp.uint64
|
||||
| cp.float32
|
||||
| cp.float64
|
||||
| cp.complex64
|
||||
| cp.complex128
|
||||
| cp.bool_
|
||||
]
|
||||
else:
|
||||
DType = cp.dtype
|
||||
@@ -0,0 +1,36 @@
|
||||
from cupy.fft import * # noqa: F403
|
||||
|
||||
# cupy.fft doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from cupy.fft import __all__ as linalg_all
|
||||
_n: dict[str, object] = {}
|
||||
exec("from cupy.fft import *", _n)
|
||||
del _n["__builtins__"]
|
||||
fft_all = list(_n)
|
||||
del _n
|
||||
|
||||
from ..common import _fft
|
||||
from .._internal import get_xp
|
||||
|
||||
import cupy as cp
|
||||
|
||||
fft = get_xp(cp)(_fft.fft)
|
||||
ifft = get_xp(cp)(_fft.ifft)
|
||||
fftn = get_xp(cp)(_fft.fftn)
|
||||
ifftn = get_xp(cp)(_fft.ifftn)
|
||||
rfft = get_xp(cp)(_fft.rfft)
|
||||
irfft = get_xp(cp)(_fft.irfft)
|
||||
rfftn = get_xp(cp)(_fft.rfftn)
|
||||
irfftn = get_xp(cp)(_fft.irfftn)
|
||||
hfft = get_xp(cp)(_fft.hfft)
|
||||
ihfft = get_xp(cp)(_fft.ihfft)
|
||||
fftfreq = get_xp(cp)(_fft.fftfreq)
|
||||
rfftfreq = get_xp(cp)(_fft.rfftfreq)
|
||||
fftshift = get_xp(cp)(_fft.fftshift)
|
||||
ifftshift = get_xp(cp)(_fft.ifftshift)
|
||||
|
||||
__all__ = fft_all + _fft.__all__
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from cupy.linalg import * # noqa: F403
|
||||
# cupy.linalg doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from cupy.linalg import __all__ as linalg_all
|
||||
_n: dict[str, object] = {}
|
||||
exec('from cupy.linalg import *', _n)
|
||||
del _n['__builtins__']
|
||||
linalg_all = list(_n)
|
||||
del _n
|
||||
|
||||
from ..common import _linalg
|
||||
from .._internal import get_xp
|
||||
|
||||
import cupy as cp
|
||||
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
|
||||
|
||||
cross = get_xp(cp)(_linalg.cross)
|
||||
outer = get_xp(cp)(_linalg.outer)
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
eigh = get_xp(cp)(_linalg.eigh)
|
||||
qr = get_xp(cp)(_linalg.qr)
|
||||
slogdet = get_xp(cp)(_linalg.slogdet)
|
||||
svd = get_xp(cp)(_linalg.svd)
|
||||
cholesky = get_xp(cp)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
|
||||
pinv = get_xp(cp)(_linalg.pinv)
|
||||
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
|
||||
svdvals = get_xp(cp)(_linalg.svdvals)
|
||||
diagonal = get_xp(cp)(_linalg.diagonal)
|
||||
trace = get_xp(cp)(_linalg.trace)
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(cp.linalg, 'vector_norm'):
|
||||
vector_norm = cp.linalg.vector_norm
|
||||
else:
|
||||
vector_norm = get_xp(cp)(_linalg.vector_norm)
|
||||
|
||||
__all__ = linalg_all + _linalg.__all__
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
Binary file not shown.
@@ -0,0 +1,26 @@
|
||||
from typing import Final
|
||||
|
||||
from ..._internal import clone_module
|
||||
|
||||
__all__ = clone_module("dask.array", globals())
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from . import _aliases
|
||||
from ._aliases import * # type: ignore[assignment] # noqa: F403
|
||||
from ._info import __array_namespace_info__ # noqa: F401
|
||||
|
||||
__array_api_version__: Final = "2024.12"
|
||||
del Final
|
||||
|
||||
# See the comment in the numpy __init__.py
|
||||
__import__(__package__ + '.linalg')
|
||||
__import__(__package__ + '.fft')
|
||||
|
||||
__all__ = sorted(
|
||||
set(__all__)
|
||||
| set(_aliases.__all__)
|
||||
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
|
||||
)
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,369 @@
|
||||
# pyright: reportPrivateUsage=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportUnknownVariableType=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from builtins import bool as py_bool
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import dask.array as da
|
||||
import numpy as np
|
||||
from numpy import bool_ as bool
|
||||
from numpy import (
|
||||
can_cast,
|
||||
complex64,
|
||||
complex128,
|
||||
float32,
|
||||
float64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
result_type,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
)
|
||||
|
||||
from ..._internal import get_xp
|
||||
from ...common import _aliases, _helpers, array_namespace
|
||||
from ...common._typing import (
|
||||
Array,
|
||||
Device,
|
||||
DType,
|
||||
NestedSequence,
|
||||
SupportsBufferProtocol,
|
||||
)
|
||||
|
||||
isdtype = get_xp(np)(_aliases.isdtype)
|
||||
unstack = get_xp(da)(_aliases.unstack)
|
||||
|
||||
|
||||
# da.astype doesn't respect copy=True
|
||||
def astype(
|
||||
x: Array,
|
||||
dtype: DType,
|
||||
/,
|
||||
*,
|
||||
copy: py_bool = True,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for astype().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
# TODO: respect device keyword?
|
||||
_helpers._check_device(da, device)
|
||||
|
||||
if not copy and dtype == x.dtype:
|
||||
return x
|
||||
x = x.astype(dtype)
|
||||
return x.copy() if copy else x
|
||||
|
||||
|
||||
# Common aliases
|
||||
|
||||
|
||||
# This arange func is modified from the common one to
|
||||
# not pass stop/step as keyword arguments, which will cause
|
||||
# an error with dask
|
||||
def arange(
|
||||
start: float,
|
||||
/,
|
||||
stop: float | None = None,
|
||||
step: float = 1,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for arange().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
# TODO: respect device keyword?
|
||||
_helpers._check_device(da, device)
|
||||
|
||||
args: list[Any] = [start]
|
||||
if stop is not None:
|
||||
args.append(stop)
|
||||
else:
|
||||
# stop is None, so start is actually stop
|
||||
# prepend the default value for start which is 0
|
||||
args.insert(0, 0)
|
||||
args.append(step)
|
||||
|
||||
return da.arange(*args, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
eye = get_xp(da)(_aliases.eye)
|
||||
linspace = get_xp(da)(_aliases.linspace)
|
||||
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(da)(_aliases.unique_all)
|
||||
unique_counts = get_xp(da)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(da)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(da)(_aliases.unique_values)
|
||||
permute_dims = get_xp(da)(_aliases.permute_dims)
|
||||
std = get_xp(da)(_aliases.std)
|
||||
var = get_xp(da)(_aliases.var)
|
||||
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
|
||||
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
|
||||
empty = get_xp(da)(_aliases.empty)
|
||||
empty_like = get_xp(da)(_aliases.empty_like)
|
||||
full = get_xp(da)(_aliases.full)
|
||||
full_like = get_xp(da)(_aliases.full_like)
|
||||
ones = get_xp(da)(_aliases.ones)
|
||||
ones_like = get_xp(da)(_aliases.ones_like)
|
||||
zeros = get_xp(da)(_aliases.zeros)
|
||||
zeros_like = get_xp(da)(_aliases.zeros_like)
|
||||
reshape = get_xp(da)(_aliases.reshape)
|
||||
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
|
||||
vecdot = get_xp(da)(_aliases.vecdot)
|
||||
nonzero = get_xp(da)(_aliases.nonzero)
|
||||
matmul = get_xp(np)(_aliases.matmul)
|
||||
tensordot = get_xp(np)(_aliases.tensordot)
|
||||
sign = get_xp(np)(_aliases.sign)
|
||||
finfo = get_xp(np)(_aliases.finfo)
|
||||
iinfo = get_xp(np)(_aliases.iinfo)
|
||||
|
||||
|
||||
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
||||
def asarray(
|
||||
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
|
||||
/,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
copy: py_bool | None = None,
|
||||
**kwargs: object,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for asarray().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
# TODO: respect device keyword?
|
||||
_helpers._check_device(da, device)
|
||||
|
||||
if isinstance(obj, da.Array):
|
||||
if dtype is not None and dtype != obj.dtype:
|
||||
if copy is False:
|
||||
raise ValueError("Unable to avoid copy when changing dtype")
|
||||
obj = obj.astype(dtype)
|
||||
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
if copy is False:
|
||||
raise ValueError(
|
||||
"Unable to avoid copy when converting a non-dask object to dask"
|
||||
)
|
||||
|
||||
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
|
||||
# see https://github.com/dask/dask/pull/11524/
|
||||
obj = np.array(obj, dtype=dtype, copy=True)
|
||||
return da.from_array(obj)
|
||||
|
||||
|
||||
# Element wise aliases
|
||||
from dask.array import arccos as acos
|
||||
from dask.array import arccosh as acosh
|
||||
from dask.array import arcsin as asin
|
||||
from dask.array import arcsinh as asinh
|
||||
from dask.array import arctan as atan
|
||||
from dask.array import arctan2 as atan2
|
||||
from dask.array import arctanh as atanh
|
||||
|
||||
# Other
|
||||
from dask.array import concatenate as concat
|
||||
from dask.array import invert as bitwise_invert
|
||||
from dask.array import left_shift as bitwise_left_shift
|
||||
from dask.array import power as pow
|
||||
from dask.array import right_shift as bitwise_right_shift
|
||||
|
||||
|
||||
# dask.array.clip does not work unless all three arguments are provided.
|
||||
# Furthermore, the masking workaround in common._aliases.clip cannot work with
|
||||
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
|
||||
# now).
|
||||
def clip(
|
||||
x: Array,
|
||||
/,
|
||||
min: float | Array | None = None,
|
||||
max: float | Array | None = None,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for clip().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
|
||||
def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
|
||||
return a is None or isinstance(a, (int, float))
|
||||
|
||||
min_shape = () if _isscalar(min) else min.shape
|
||||
max_shape = () if _isscalar(max) else max.shape
|
||||
|
||||
# TODO: This won't handle dask unknown shapes
|
||||
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
|
||||
|
||||
if min is not None:
|
||||
min = da.broadcast_to(da.asarray(min), result_shape)
|
||||
if max is not None:
|
||||
max = da.broadcast_to(da.asarray(max), result_shape)
|
||||
|
||||
if min is None and max is None:
|
||||
return da.positive(x)
|
||||
|
||||
if min is None:
|
||||
return astype(da.minimum(x, max), x.dtype)
|
||||
if max is None:
|
||||
return astype(da.maximum(x, min), x.dtype)
|
||||
|
||||
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
|
||||
|
||||
|
||||
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
|
||||
"""
|
||||
Make sure that Array is not broken into multiple chunks along axis.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : Array
|
||||
The input Array with a single chunk along axis.
|
||||
restore : Callable[Array, Array]
|
||||
function to apply to the output to rechunk it back into reasonable chunks
|
||||
"""
|
||||
if axis < 0:
|
||||
axis += x.ndim
|
||||
if x.numblocks[axis] < 2:
|
||||
return x, lambda x: x
|
||||
|
||||
# Break chunks on other axes in an attempt to keep chunk size low
|
||||
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
|
||||
|
||||
# Rather than reconstructing the original chunks, which can be a
|
||||
# very expensive affair, just break down oversized chunks without
|
||||
# incurring in any transfers over the network.
|
||||
# This has the downside of a risk of overchunking if the array is
|
||||
# then used in operations against other arrays that match the
|
||||
# original chunking pattern.
|
||||
return x, lambda x: x.rechunk()
|
||||
|
||||
|
||||
def sort(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: py_bool = False,
|
||||
stable: py_bool = True,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility layer around the lack of sort() in Dask.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
This function temporarily rechunks the array along `axis` to a single chunk.
|
||||
This can be extremely inefficient and can lead to out-of-memory errors.
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
x, restore = _ensure_single_chunk(x, axis)
|
||||
|
||||
meta_xp = array_namespace(x._meta)
|
||||
x = da.map_blocks(
|
||||
meta_xp.sort,
|
||||
x,
|
||||
axis=axis,
|
||||
meta=x._meta,
|
||||
dtype=x.dtype,
|
||||
descending=descending,
|
||||
stable=stable,
|
||||
)
|
||||
|
||||
return restore(x)
|
||||
|
||||
|
||||
def argsort(
|
||||
x: Array,
|
||||
/,
|
||||
*,
|
||||
axis: int = -1,
|
||||
descending: py_bool = False,
|
||||
stable: py_bool = True,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility layer around the lack of argsort() in Dask.
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
This function temporarily rechunks the array along `axis` into a single chunk.
|
||||
This can be extremely inefficient and can lead to out-of-memory errors.
|
||||
"""
|
||||
x, restore = _ensure_single_chunk(x, axis)
|
||||
|
||||
meta_xp = array_namespace(x._meta)
|
||||
dtype = meta_xp.argsort(x._meta).dtype
|
||||
meta = meta_xp.astype(x._meta, dtype)
|
||||
x = da.map_blocks(
|
||||
meta_xp.argsort,
|
||||
x,
|
||||
axis=axis,
|
||||
meta=meta,
|
||||
dtype=dtype,
|
||||
descending=descending,
|
||||
stable=stable,
|
||||
)
|
||||
|
||||
return restore(x)
|
||||
|
||||
|
||||
# dask.array.count_nonzero does not have keepdims
|
||||
def count_nonzero(
|
||||
x: Array,
|
||||
axis: int | None = None,
|
||||
keepdims: py_bool = False,
|
||||
) -> Array:
|
||||
result = da.count_nonzero(x, axis)
|
||||
if keepdims:
|
||||
if axis is None:
|
||||
return da.reshape(result, [1] * x.ndim)
|
||||
return da.expand_dims(result, axis)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"count_nonzero",
|
||||
"bool",
|
||||
"int8", "int16", "int32", "int64",
|
||||
"uint8", "uint16", "uint32", "uint64",
|
||||
"float32", "float64",
|
||||
"complex64", "complex128",
|
||||
"asarray", "astype", "can_cast", "result_type",
|
||||
"pow",
|
||||
"concat",
|
||||
"acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
|
||||
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
|
||||
] # fmt: skip
|
||||
__all__ += _aliases.__all__
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Array API Inspection namespace
|
||||
|
||||
This is the namespace for inspection functions as defined by the array API
|
||||
standard. See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html for
|
||||
more details.
|
||||
|
||||
"""
|
||||
|
||||
# pyright: reportPrivateUsage=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, TypeAlias, overload
|
||||
|
||||
import dask.array as da
|
||||
from numpy import bool_ as bool
|
||||
from numpy import (
|
||||
complex64,
|
||||
complex128,
|
||||
dtype,
|
||||
float32,
|
||||
float64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
intp,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
)
|
||||
|
||||
from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device
|
||||
from ...common._typing import (
|
||||
Capabilities,
|
||||
DefaultDTypes,
|
||||
DType,
|
||||
DTypeKind,
|
||||
DTypesAll,
|
||||
DTypesAny,
|
||||
DTypesBool,
|
||||
DTypesComplex,
|
||||
DTypesIntegral,
|
||||
DTypesNumeric,
|
||||
DTypesReal,
|
||||
DTypesSigned,
|
||||
DTypesUnsigned,
|
||||
)
|
||||
Device: TypeAlias = Literal["cpu"] | _dask_device
|
||||
|
||||
|
||||
class __array_namespace_info__:
|
||||
"""
|
||||
Get the array API inspection namespace for Dask.
|
||||
|
||||
The array API inspection namespace defines the following functions:
|
||||
|
||||
- capabilities()
|
||||
- default_device()
|
||||
- default_dtypes()
|
||||
- dtypes()
|
||||
- devices()
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html
|
||||
for more details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : ModuleType
|
||||
The array API inspection namespace for Dask.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': dask.float64,
|
||||
'complex floating': dask.complex128,
|
||||
'integral': dask.int64,
|
||||
'indexing': dask.int64}
|
||||
|
||||
"""
|
||||
|
||||
__module__ = "dask.array"
|
||||
|
||||
def capabilities(self) -> Capabilities:
|
||||
"""
|
||||
Return a dictionary of array API library capabilities.
|
||||
|
||||
The resulting dictionary has the following keys:
|
||||
|
||||
- **"boolean indexing"**: boolean indicating whether an array library
|
||||
supports boolean indexing.
|
||||
|
||||
Dask support boolean indexing as long as both the index
|
||||
and the indexed arrays have known shapes.
|
||||
Note however that the output .shape and .size properties
|
||||
will contain a non-compliant math.nan instead of None.
|
||||
|
||||
- **"data-dependent shapes"**: boolean indicating whether an array
|
||||
library supports data-dependent output shapes.
|
||||
|
||||
Dask implements unique_values et.al.
|
||||
Note however that the output .shape and .size properties
|
||||
will contain a non-compliant math.nan instead of None.
|
||||
|
||||
- **"max dimensions"**: integer indicating the maximum number of
|
||||
dimensions supported by the array library.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
|
||||
for more details.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
capabilities : dict
|
||||
A dictionary of array API library capabilities.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.capabilities()
|
||||
{'boolean indexing': True,
|
||||
'data-dependent shapes': True,
|
||||
'max dimensions': 64}
|
||||
|
||||
"""
|
||||
return {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": True,
|
||||
"max dimensions": 64,
|
||||
}
|
||||
|
||||
def default_device(self) -> Device:
|
||||
"""
|
||||
The default device used for new Dask arrays.
|
||||
|
||||
For Dask, this always returns ``'cpu'``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
device : Device
|
||||
The default device used for new Dask arrays.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_device()
|
||||
'cpu'
|
||||
|
||||
"""
|
||||
return "cpu"
|
||||
|
||||
def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes:
|
||||
"""
|
||||
The default data types used for new Dask arrays.
|
||||
|
||||
For Dask, this always returns the following dictionary:
|
||||
|
||||
- **"real floating"**: ``numpy.float64``
|
||||
- **"complex floating"**: ``numpy.complex128``
|
||||
- **"integral"**: ``numpy.intp``
|
||||
- **"indexing"**: ``numpy.intp``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the default data types for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary describing the default data types used for new Dask
|
||||
arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': dask.float64,
|
||||
'complex floating': dask.complex128,
|
||||
'integral': dask.int64,
|
||||
'indexing': dask.int64}
|
||||
|
||||
"""
|
||||
_check_device(da, device)
|
||||
return {
|
||||
"real floating": dtype(float64),
|
||||
"complex floating": dtype(complex128),
|
||||
"integral": dtype(intp),
|
||||
"indexing": dtype(intp),
|
||||
}
|
||||
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: None = None
|
||||
) -> DTypesAll: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: Literal["bool"]
|
||||
) -> DTypesBool: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: Literal["signed integer"]
|
||||
) -> DTypesSigned: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: Literal["unsigned integer"]
|
||||
) -> DTypesUnsigned: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: Literal["integral"]
|
||||
) -> DTypesIntegral: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: Literal["real floating"]
|
||||
) -> DTypesReal: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: Literal["complex floating"]
|
||||
) -> DTypesComplex: ...
|
||||
@overload
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: Literal["numeric"]
|
||||
) -> DTypesNumeric: ...
|
||||
def dtypes(
|
||||
self, /, *, device: Device | None = None, kind: DTypeKind | None = None
|
||||
) -> DTypesAny:
|
||||
"""
|
||||
The array API data types supported by Dask.
|
||||
|
||||
Note that this function only returns data types that are defined by
|
||||
the array API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the data types for.
|
||||
kind : str or tuple of str, optional
|
||||
The kind of data types to return. If ``None``, all data types are
|
||||
returned. If a string, only data types of that kind are returned.
|
||||
If a tuple, a dictionary containing the union of the given kinds
|
||||
is returned. The following kinds are supported:
|
||||
|
||||
- ``'bool'``: boolean data types (i.e., ``bool``).
|
||||
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
|
||||
``int16``, ``int32``, ``int64``).
|
||||
- ``'unsigned integer'``: unsigned integer data types (i.e.,
|
||||
``uint8``, ``uint16``, ``uint32``, ``uint64``).
|
||||
- ``'integral'``: integer data types. Shorthand for ``('signed
|
||||
integer', 'unsigned integer')``.
|
||||
- ``'real floating'``: real-valued floating-point data types
|
||||
(i.e., ``float32``, ``float64``).
|
||||
- ``'complex floating'``: complex floating-point data types (i.e.,
|
||||
``complex64``, ``complex128``).
|
||||
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
|
||||
'real floating', 'complex floating')``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary mapping the names of data types to the corresponding
|
||||
Dask data types.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.dtypes(kind='signed integer')
|
||||
{'int8': dask.int8,
|
||||
'int16': dask.int16,
|
||||
'int32': dask.int32,
|
||||
'int64': dask.int64}
|
||||
|
||||
"""
|
||||
_check_device(da, device)
|
||||
if kind is None:
|
||||
return {
|
||||
"bool": dtype(bool),
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "bool":
|
||||
return {"bool": bool}
|
||||
if kind == "signed integer":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
}
|
||||
if kind == "unsigned integer":
|
||||
return {
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "integral":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "real floating":
|
||||
return {
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
}
|
||||
if kind == "complex floating":
|
||||
return {
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "numeric":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if isinstance(kind, tuple):
|
||||
res: dict[str, DType] = {}
|
||||
for k in kind:
|
||||
res.update(self.dtypes(kind=k))
|
||||
return res
|
||||
raise ValueError(f"unsupported kind: {kind!r}")
|
||||
|
||||
def devices(self) -> list[Device]:
|
||||
"""
|
||||
The devices supported by Dask.
|
||||
|
||||
For Dask, this always returns ``['cpu', DASK_DEVICE]``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
devices : list[Device]
|
||||
The devices supported by Dask.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = xp.__array_namespace_info__()
|
||||
>>> info.devices()
|
||||
['cpu', DASK_DEVICE]
|
||||
|
||||
"""
|
||||
return ["cpu", _DASK_DEVICE]
|
||||
@@ -0,0 +1,16 @@
|
||||
from ..._internal import clone_module
|
||||
|
||||
__all__ = clone_module("dask.array.fft", globals())
|
||||
|
||||
from ...common import _fft
|
||||
from ..._internal import get_xp
|
||||
|
||||
import dask.array as da
|
||||
|
||||
fftfreq = get_xp(da)(_fft.fftfreq)
|
||||
rfftfreq = get_xp(da)(_fft.rfftfreq)
|
||||
|
||||
__all__ += ["fftfreq", "rfftfreq"]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import dask.array as da
|
||||
|
||||
# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
|
||||
from dask.array import matmul, outer, tensordot
|
||||
|
||||
# Exports
|
||||
from ..._internal import clone_module, get_xp
|
||||
from ...common import _linalg
|
||||
from ...common._typing import Array
|
||||
|
||||
__all__ = clone_module("dask.array.linalg", globals())
|
||||
|
||||
from ._aliases import matrix_transpose, vecdot
|
||||
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
# TODO: use the QR wrapper once dask
|
||||
# supports the mode keyword on QR
|
||||
# https://github.com/dask/dask/issues/10388
|
||||
#qr = get_xp(da)(_linalg.qr)
|
||||
def qr( # type: ignore[no-redef]
|
||||
x: Array,
|
||||
mode: Literal["reduced", "complete"] = "reduced",
|
||||
**kwargs: object,
|
||||
) -> QRResult:
|
||||
if mode != "reduced":
|
||||
raise ValueError("dask arrays only support using mode='reduced'")
|
||||
return QRResult(*da.linalg.qr(x, **kwargs))
|
||||
trace = get_xp(da)(_linalg.trace)
|
||||
cholesky = get_xp(da)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(da)(_linalg.matrix_rank)
|
||||
matrix_norm = get_xp(da)(_linalg.matrix_norm)
|
||||
|
||||
|
||||
# Wrap the svd functions to not pass full_matrices to dask
|
||||
# when full_matrices=False (as that is the default behavior for dask),
|
||||
# and dask doesn't have the full_matrices keyword
|
||||
def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef]
|
||||
if full_matrices:
|
||||
raise ValueError("full_matrics=True is not supported by dask.")
|
||||
return da.linalg.svd(x, coerce_signs=False, **kwargs)
|
||||
|
||||
def svdvals(x: Array) -> Array:
|
||||
# TODO: can't avoid computing U or V for dask
|
||||
_, s, _ = svd(x)
|
||||
return s
|
||||
|
||||
vector_norm = get_xp(da)(_linalg.vector_norm)
|
||||
diagonal = get_xp(da)(_linalg.diagonal)
|
||||
|
||||
__all__ += ["trace", "outer", "matmul", "tensordot",
|
||||
"matrix_transpose", "vecdot", "EighResult",
|
||||
"QRResult", "SlogdetResult", "SVDResult", "qr",
|
||||
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
|
||||
"vector_norm", "diagonal"]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,38 @@
|
||||
# ruff: noqa: PLC0414
|
||||
from typing import Final
|
||||
|
||||
from .._internal import clone_module
|
||||
|
||||
# This needs to be loaded explicitly before cloning
|
||||
import numpy.typing # noqa: F401
|
||||
|
||||
__all__ = clone_module("numpy", globals())
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from . import _aliases
|
||||
from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403
|
||||
from ._info import __array_namespace_info__ # noqa: F401
|
||||
|
||||
# Don't know why, but we have to do an absolute import to import linalg. If we
|
||||
# instead do
|
||||
#
|
||||
# from . import linalg
|
||||
#
|
||||
# It doesn't overwrite np.linalg from above. The import is generated
|
||||
# dynamically so that the library can be vendored.
|
||||
__import__(__package__ + ".linalg")
|
||||
|
||||
__import__(__package__ + ".fft")
|
||||
|
||||
from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401
|
||||
|
||||
__array_api_version__: Final = "2024.12"
|
||||
|
||||
__all__ = sorted(
|
||||
set(__all__)
|
||||
| set(_aliases.__all__)
|
||||
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
|
||||
)
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,191 @@
|
||||
# pyright: reportPrivateUsage=false
|
||||
from __future__ import annotations
|
||||
|
||||
from builtins import bool as py_bool
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._internal import get_xp
|
||||
from ..common import _aliases, _helpers
|
||||
from ..common._typing import NestedSequence, SupportsBufferProtocol
|
||||
from ._typing import Array, Device, DType
|
||||
|
||||
bool = np.bool_
|
||||
|
||||
# Basic renames
|
||||
acos = np.arccos
|
||||
acosh = np.arccosh
|
||||
asin = np.arcsin
|
||||
asinh = np.arcsinh
|
||||
atan = np.arctan
|
||||
atan2 = np.arctan2
|
||||
atanh = np.arctanh
|
||||
bitwise_left_shift = np.left_shift
|
||||
bitwise_invert = np.invert
|
||||
bitwise_right_shift = np.right_shift
|
||||
concat = np.concatenate
|
||||
pow = np.power
|
||||
|
||||
arange = get_xp(np)(_aliases.arange)
|
||||
empty = get_xp(np)(_aliases.empty)
|
||||
empty_like = get_xp(np)(_aliases.empty_like)
|
||||
eye = get_xp(np)(_aliases.eye)
|
||||
full = get_xp(np)(_aliases.full)
|
||||
full_like = get_xp(np)(_aliases.full_like)
|
||||
linspace = get_xp(np)(_aliases.linspace)
|
||||
ones = get_xp(np)(_aliases.ones)
|
||||
ones_like = get_xp(np)(_aliases.ones_like)
|
||||
zeros = get_xp(np)(_aliases.zeros)
|
||||
zeros_like = get_xp(np)(_aliases.zeros_like)
|
||||
UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(np)(_aliases.unique_all)
|
||||
unique_counts = get_xp(np)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(np)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(np)(_aliases.unique_values)
|
||||
std = get_xp(np)(_aliases.std)
|
||||
var = get_xp(np)(_aliases.var)
|
||||
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
|
||||
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
|
||||
clip = get_xp(np)(_aliases.clip)
|
||||
permute_dims = get_xp(np)(_aliases.permute_dims)
|
||||
reshape = get_xp(np)(_aliases.reshape)
|
||||
argsort = get_xp(np)(_aliases.argsort)
|
||||
sort = get_xp(np)(_aliases.sort)
|
||||
nonzero = get_xp(np)(_aliases.nonzero)
|
||||
matmul = get_xp(np)(_aliases.matmul)
|
||||
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
|
||||
tensordot = get_xp(np)(_aliases.tensordot)
|
||||
sign = get_xp(np)(_aliases.sign)
|
||||
finfo = get_xp(np)(_aliases.finfo)
|
||||
iinfo = get_xp(np)(_aliases.iinfo)
|
||||
|
||||
|
||||
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
||||
# asarray() is different enough between numpy, cupy, and dask, the logic
|
||||
# complicated enough that it's easier to define it separately for each module
|
||||
# rather than trying to combine everything into one function in common/
|
||||
def asarray(
|
||||
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
|
||||
/,
|
||||
*,
|
||||
dtype: DType | None = None,
|
||||
device: Device | None = None,
|
||||
copy: py_bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Array:
|
||||
"""
|
||||
Array API compatibility wrapper for asarray().
|
||||
|
||||
See the corresponding documentation in the array library and/or the array API
|
||||
specification for more details.
|
||||
"""
|
||||
_helpers._check_device(np, device)
|
||||
|
||||
# None is unsupported in NumPy 1.0, but we can use an internal enum
|
||||
# False in NumPy 1.0 means None in NumPy 2.0 and in the Array API
|
||||
if copy is None:
|
||||
copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined]
|
||||
elif copy is False:
|
||||
copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined]
|
||||
|
||||
return np.array(obj, copy=copy, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
def astype(
|
||||
x: Array,
|
||||
dtype: DType,
|
||||
/,
|
||||
*,
|
||||
copy: py_bool = True,
|
||||
device: Device | None = None,
|
||||
) -> Array:
|
||||
_helpers._check_device(np, device)
|
||||
return x.astype(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
# count_nonzero returns a python int for axis=None and keepdims=False
|
||||
# https://github.com/numpy/numpy/issues/17562
|
||||
def count_nonzero(
|
||||
x: Array,
|
||||
axis: int | tuple[int, ...] | None = None,
|
||||
keepdims: py_bool = False,
|
||||
) -> Array:
|
||||
# NOTE: this is currently incorrectly typed in numpy, but will be fixed in
|
||||
# numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
|
||||
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
if axis is None and not keepdims:
|
||||
return np.asarray(result)
|
||||
return result
|
||||
|
||||
|
||||
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
|
||||
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
|
||||
return np.take_along_axis(x, indices, axis=axis)
|
||||
|
||||
|
||||
# ceil, floor, and trunc return integers for integer inputs in NumPy < 2
|
||||
|
||||
def ceil(x: Array, /) -> Array:
|
||||
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
|
||||
return x.copy()
|
||||
return np.ceil(x)
|
||||
|
||||
|
||||
def floor(x: Array, /) -> Array:
|
||||
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
|
||||
return x.copy()
|
||||
return np.floor(x)
|
||||
|
||||
|
||||
def trunc(x: Array, /) -> Array:
|
||||
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
|
||||
return x.copy()
|
||||
return np.trunc(x)
|
||||
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(np, "vecdot"):
|
||||
vecdot = np.vecdot
|
||||
else:
|
||||
vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment]
|
||||
|
||||
if hasattr(np, "isdtype"):
|
||||
isdtype = np.isdtype
|
||||
else:
|
||||
isdtype = get_xp(np)(_aliases.isdtype)
|
||||
|
||||
if hasattr(np, "unstack"):
|
||||
unstack = np.unstack
|
||||
else:
|
||||
unstack = get_xp(np)(_aliases.unstack)
|
||||
|
||||
__all__ = _aliases.__all__ + [
|
||||
"asarray",
|
||||
"astype",
|
||||
"acos",
|
||||
"acosh",
|
||||
"asin",
|
||||
"asinh",
|
||||
"atan",
|
||||
"atan2",
|
||||
"atanh",
|
||||
"ceil",
|
||||
"floor",
|
||||
"trunc",
|
||||
"bitwise_left_shift",
|
||||
"bitwise_invert",
|
||||
"bitwise_right_shift",
|
||||
"bool",
|
||||
"concat",
|
||||
"count_nonzero",
|
||||
"pow",
|
||||
"take_along_axis"
|
||||
]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Array API Inspection namespace
|
||||
|
||||
This is the namespace for inspection functions as defined by the array API
|
||||
standard. See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html for
|
||||
more details.
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from numpy import bool_ as bool
|
||||
from numpy import (
|
||||
complex64,
|
||||
complex128,
|
||||
dtype,
|
||||
float32,
|
||||
float64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
intp,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
)
|
||||
|
||||
from ..common._typing import DefaultDTypes
|
||||
from ._typing import Device, DType
|
||||
|
||||
|
||||
class __array_namespace_info__:
|
||||
"""
|
||||
Get the array API inspection namespace for NumPy.
|
||||
|
||||
The array API inspection namespace defines the following functions:
|
||||
|
||||
- capabilities()
|
||||
- default_device()
|
||||
- default_dtypes()
|
||||
- dtypes()
|
||||
- devices()
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/inspection.html
|
||||
for more details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : ModuleType
|
||||
The array API inspection namespace for NumPy.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': numpy.float64,
|
||||
'complex floating': numpy.complex128,
|
||||
'integral': numpy.int64,
|
||||
'indexing': numpy.int64}
|
||||
|
||||
"""
|
||||
|
||||
__module__ = 'numpy'
|
||||
|
||||
def capabilities(self):
|
||||
"""
|
||||
Return a dictionary of array API library capabilities.
|
||||
|
||||
The resulting dictionary has the following keys:
|
||||
|
||||
- **"boolean indexing"**: boolean indicating whether an array library
|
||||
supports boolean indexing. Always ``True`` for NumPy.
|
||||
|
||||
- **"data-dependent shapes"**: boolean indicating whether an array
|
||||
library supports data-dependent output shapes. Always ``True`` for
|
||||
NumPy.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
|
||||
for more details.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
capabilities : dict
|
||||
A dictionary of array API library capabilities.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.capabilities()
|
||||
{'boolean indexing': True,
|
||||
'data-dependent shapes': True,
|
||||
'max dimensions': 64}
|
||||
|
||||
"""
|
||||
return {
|
||||
"boolean indexing": True,
|
||||
"data-dependent shapes": True,
|
||||
"max dimensions": 64,
|
||||
}
|
||||
|
||||
def default_device(self):
|
||||
"""
|
||||
The default device used for new NumPy arrays.
|
||||
|
||||
For NumPy, this always returns ``'cpu'``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Returns
|
||||
-------
|
||||
device : Device
|
||||
The default device used for new NumPy arrays.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.default_device()
|
||||
'cpu'
|
||||
|
||||
"""
|
||||
return "cpu"
|
||||
|
||||
def default_dtypes(
|
||||
self,
|
||||
*,
|
||||
device: Device | None = None,
|
||||
) -> DefaultDTypes:
|
||||
"""
|
||||
The default data types used for new NumPy arrays.
|
||||
|
||||
For NumPy, this always returns the following dictionary:
|
||||
|
||||
- **"real floating"**: ``numpy.float64``
|
||||
- **"complex floating"**: ``numpy.complex128``
|
||||
- **"integral"**: ``numpy.intp``
|
||||
- **"indexing"**: ``numpy.intp``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the default data types for. For NumPy, only
|
||||
``'cpu'`` is allowed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary describing the default data types used for new NumPy
|
||||
arrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.default_dtypes()
|
||||
{'real floating': numpy.float64,
|
||||
'complex floating': numpy.complex128,
|
||||
'integral': numpy.int64,
|
||||
'indexing': numpy.int64}
|
||||
|
||||
"""
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(
|
||||
'Device not understood. Only "cpu" is allowed, but received:'
|
||||
f' {device}'
|
||||
)
|
||||
return {
|
||||
"real floating": dtype(float64),
|
||||
"complex floating": dtype(complex128),
|
||||
"integral": dtype(intp),
|
||||
"indexing": dtype(intp),
|
||||
}
|
||||
|
||||
def dtypes(
|
||||
self,
|
||||
*,
|
||||
device: Device | None = None,
|
||||
kind: str | tuple[str, ...] | None = None,
|
||||
) -> dict[str, DType]:
|
||||
"""
|
||||
The array API data types supported by NumPy.
|
||||
|
||||
Note that this function only returns data types that are defined by
|
||||
the array API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
device : str, optional
|
||||
The device to get the data types for. For NumPy, only ``'cpu'`` is
|
||||
allowed.
|
||||
kind : str or tuple of str, optional
|
||||
The kind of data types to return. If ``None``, all data types are
|
||||
returned. If a string, only data types of that kind are returned.
|
||||
If a tuple, a dictionary containing the union of the given kinds
|
||||
is returned. The following kinds are supported:
|
||||
|
||||
- ``'bool'``: boolean data types (i.e., ``bool``).
|
||||
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
|
||||
``int16``, ``int32``, ``int64``).
|
||||
- ``'unsigned integer'``: unsigned integer data types (i.e.,
|
||||
``uint8``, ``uint16``, ``uint32``, ``uint64``).
|
||||
- ``'integral'``: integer data types. Shorthand for ``('signed
|
||||
integer', 'unsigned integer')``.
|
||||
- ``'real floating'``: real-valued floating-point data types
|
||||
(i.e., ``float32``, ``float64``).
|
||||
- ``'complex floating'``: complex floating-point data types (i.e.,
|
||||
``complex64``, ``complex128``).
|
||||
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
|
||||
'real floating', 'complex floating')``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dtypes : dict
|
||||
A dictionary mapping the names of data types to the corresponding
|
||||
NumPy data types.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.devices
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.dtypes(kind='signed integer')
|
||||
{'int8': numpy.int8,
|
||||
'int16': numpy.int16,
|
||||
'int32': numpy.int32,
|
||||
'int64': numpy.int64}
|
||||
|
||||
"""
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(
|
||||
'Device not understood. Only "cpu" is allowed, but received:'
|
||||
f' {device}'
|
||||
)
|
||||
if kind is None:
|
||||
return {
|
||||
"bool": dtype(bool),
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "bool":
|
||||
return {"bool": dtype(bool)}
|
||||
if kind == "signed integer":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
}
|
||||
if kind == "unsigned integer":
|
||||
return {
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "integral":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
}
|
||||
if kind == "real floating":
|
||||
return {
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
}
|
||||
if kind == "complex floating":
|
||||
return {
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if kind == "numeric":
|
||||
return {
|
||||
"int8": dtype(int8),
|
||||
"int16": dtype(int16),
|
||||
"int32": dtype(int32),
|
||||
"int64": dtype(int64),
|
||||
"uint8": dtype(uint8),
|
||||
"uint16": dtype(uint16),
|
||||
"uint32": dtype(uint32),
|
||||
"uint64": dtype(uint64),
|
||||
"float32": dtype(float32),
|
||||
"float64": dtype(float64),
|
||||
"complex64": dtype(complex64),
|
||||
"complex128": dtype(complex128),
|
||||
}
|
||||
if isinstance(kind, tuple):
|
||||
res: dict[str, DType] = {}
|
||||
for k in kind:
|
||||
res.update(self.dtypes(kind=k))
|
||||
return res
|
||||
raise ValueError(f"unsupported kind: {kind!r}")
|
||||
|
||||
def devices(self) -> list[Device]:
|
||||
"""
|
||||
The devices supported by NumPy.
|
||||
|
||||
For NumPy, this always returns ``['cpu']``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
devices : list[Device]
|
||||
The devices supported by NumPy.
|
||||
|
||||
See Also
|
||||
--------
|
||||
__array_namespace_info__.capabilities,
|
||||
__array_namespace_info__.default_device,
|
||||
__array_namespace_info__.default_dtypes,
|
||||
__array_namespace_info__.dtypes
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> info = np.__array_namespace_info__()
|
||||
>>> info.devices()
|
||||
['cpu']
|
||||
|
||||
"""
|
||||
return ["cpu"]
|
||||
|
||||
|
||||
__all__ = ["__array_namespace_info__"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
|
||||
Device: TypeAlias = Literal["cpu"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
|
||||
DType: TypeAlias = np.dtype[
|
||||
np.bool_
|
||||
| np.integer[Any]
|
||||
| np.float32
|
||||
| np.float64
|
||||
| np.complex64
|
||||
| np.complex128
|
||||
]
|
||||
Array: TypeAlias = np.ndarray[Any, DType]
|
||||
else:
|
||||
DType: TypeAlias = np.dtype
|
||||
Array: TypeAlias = np.ndarray
|
||||
|
||||
__all__ = ["Array", "DType", "Device"]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return __all__
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user