This commit is contained in:
2026-04-10 15:06:59 +02:00
parent 3031b7153b
commit e5a4711004
7806 changed files with 1918528 additions and 335 deletions

View File

@@ -0,0 +1,7 @@
This directory contains bundled external dependencies that are updated
every once in a while.
Note for distribution packagers: if you want to remove the duplicated
code and depend on a packaged version, we suggest that you simply do a
symbolic link in this directory.

View File

@@ -0,0 +1,5 @@
"""
External, bundled dependencies.
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
# DO NOT RENAME THIS FILE
# This is a hook for array_api_extra/_lib/_compat.py
# to co-vendor array_api_compat and potentially override its functions.
from .array_api_compat import * # noqa: F403

View File

@@ -0,0 +1,759 @@
"""Extract reference documentation from the NumPy source tree."""
import copy
import inspect
import pydoc
import re
import sys
import textwrap
from collections import namedtuple
from collections.abc import Callable, Mapping
from functools import cached_property
from warnings import warn
def strip_blank_lines(l):
"Remove leading and trailing blank lines from a list of lines"
while l and not l[0].strip():
del l[0]
while l and not l[-1].strip():
del l[-1]
return l
class Reader:
"""A line-based string reader."""
def __init__(self, data):
"""
Parameters
----------
data : str
String with lines separated by '\\n'.
"""
if isinstance(data, list):
self._str = data
else:
self._str = data.split("\n") # store string as list of lines
self.reset()
def __getitem__(self, n):
return self._str[n]
def reset(self):
self._l = 0 # current line nr
def read(self):
if not self.eof():
out = self[self._l]
self._l += 1
return out
else:
return ""
def seek_next_non_empty_line(self):
for l in self[self._l :]:
if l.strip():
break
else:
self._l += 1
def eof(self):
return self._l >= len(self._str)
def read_to_condition(self, condition_func):
start = self._l
for line in self[start:]:
if condition_func(line):
return self[start : self._l]
self._l += 1
if self.eof():
return self[start : self._l + 1]
return []
def read_to_next_empty_line(self):
self.seek_next_non_empty_line()
def is_empty(line):
return not line.strip()
return self.read_to_condition(is_empty)
def read_to_next_unindented_line(self):
def is_unindented(line):
return line.strip() and (len(line.lstrip()) == len(line))
return self.read_to_condition(is_unindented)
def peek(self, n=0):
if self._l + n < len(self._str):
return self[self._l + n]
else:
return ""
def is_empty(self):
return not "".join(self._str).strip()
class ParseError(Exception):
def __str__(self):
message = self.args[0]
if hasattr(self, "docstring"):
message = f"{message} in {self.docstring!r}"
return message
Parameter = namedtuple("Parameter", ["name", "type", "desc"])
class NumpyDocString(Mapping):
"""Parses a numpydoc string to an abstract representation
Instances define a mapping from section title to structured data.
"""
sections = {
"Signature": "",
"Summary": [""],
"Extended Summary": [],
"Parameters": [],
"Attributes": [],
"Methods": [],
"Returns": [],
"Yields": [],
"Receives": [],
"Other Parameters": [],
"Raises": [],
"Warns": [],
"Warnings": [],
"See Also": [],
"Notes": [],
"References": "",
"Examples": "",
"index": {},
}
def __init__(self, docstring, config=None):
orig_docstring = docstring
docstring = textwrap.dedent(docstring).split("\n")
self._doc = Reader(docstring)
self._parsed_data = copy.deepcopy(self.sections)
try:
self._parse()
except ParseError as e:
e.docstring = orig_docstring
raise
def __getitem__(self, key):
return self._parsed_data[key]
def __setitem__(self, key, val):
if key not in self._parsed_data:
self._error_location(f"Unknown section {key}", error=False)
else:
self._parsed_data[key] = val
def __iter__(self):
return iter(self._parsed_data)
def __len__(self):
return len(self._parsed_data)
def _is_at_section(self):
self._doc.seek_next_non_empty_line()
if self._doc.eof():
return False
l1 = self._doc.peek().strip() # e.g. Parameters
if l1.startswith(".. index::"):
return True
l2 = self._doc.peek(1).strip() # ---------- or ==========
if len(l2) >= 3 and (set(l2) in ({"-"}, {"="})) and len(l2) != len(l1):
snip = "\n".join(self._doc._str[:2]) + "..."
self._error_location(
f"potentially wrong underline length... \n{l1} \n{l2} in \n{snip}",
error=False,
)
return l2.startswith("-" * len(l1)) or l2.startswith("=" * len(l1))
def _strip(self, doc):
i = 0
j = 0
for i, line in enumerate(doc):
if line.strip():
break
for j, line in enumerate(doc[::-1]):
if line.strip():
break
return doc[i : len(doc) - j]
def _read_to_next_section(self):
section = self._doc.read_to_next_empty_line()
while not self._is_at_section() and not self._doc.eof():
if not self._doc.peek(-1).strip(): # previous line was empty
section += [""]
section += self._doc.read_to_next_empty_line()
return section
def _read_sections(self):
while not self._doc.eof():
data = self._read_to_next_section()
name = data[0].strip()
if name.startswith(".."): # index section
yield name, data[1:]
elif len(data) < 2:
yield StopIteration
else:
yield name, self._strip(data[2:])
def _parse_param_list(self, content, single_element_is_type=False):
content = dedent_lines(content)
r = Reader(content)
params = []
while not r.eof():
header = r.read().strip()
if " : " in header:
arg_name, arg_type = header.split(" : ", maxsplit=1)
else:
# NOTE: param line with single element should never have a
# a " :" before the description line, so this should probably
# warn.
header = header.removesuffix(" :")
if single_element_is_type:
arg_name, arg_type = "", header
else:
arg_name, arg_type = header, ""
desc = r.read_to_next_unindented_line()
desc = dedent_lines(desc)
desc = strip_blank_lines(desc)
params.append(Parameter(arg_name, arg_type, desc))
return params
# See also supports the following formats.
#
# <FUNCNAME>
# <FUNCNAME> SPACE* COLON SPACE+ <DESC> SPACE*
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)+ (COMMA | PERIOD)? SPACE*
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)* SPACE* COLON SPACE+ <DESC> SPACE*
# <FUNCNAME> is one of
# <PLAIN_FUNCNAME>
# COLON <ROLE> COLON BACKTICK <PLAIN_FUNCNAME> BACKTICK
# where
# <PLAIN_FUNCNAME> is a legal function name, and
# <ROLE> is any nonempty sequence of word characters.
# Examples: func_f1 :meth:`func_h1` :obj:`~baz.obj_r` :class:`class_j`
# <DESC> is a string describing the function.
_role = r":(?P<role>(py:)?\w+):"
_funcbacktick = r"`(?P<name>(?:~\w+\.)?[a-zA-Z0-9_\.-]+)`"
_funcplain = r"(?P<name2>[a-zA-Z0-9_\.-]+)"
_funcname = r"(" + _role + _funcbacktick + r"|" + _funcplain + r")"
_funcnamenext = _funcname.replace("role", "rolenext")
_funcnamenext = _funcnamenext.replace("name", "namenext")
_description = r"(?P<description>\s*:(\s+(?P<desc>\S+.*))?)?\s*$"
_func_rgx = re.compile(r"^\s*" + _funcname + r"\s*")
_line_rgx = re.compile(
r"^\s*"
+ r"(?P<allfuncs>"
+ _funcname # group for all function names
+ r"(?P<morefuncs>([,]\s+"
+ _funcnamenext
+ r")*)"
+ r")"
+ r"(?P<trailing>[,\.])?" # end of "allfuncs"
+ _description # Some function lists have a trailing comma (or period) '\s*'
)
# Empty <DESC> elements are replaced with '..'
empty_description = ".."
def _parse_see_also(self, content):
"""
func_name : Descriptive text
continued text
another_func_name : Descriptive text
func_name1, func_name2, :meth:`func_name`, func_name3
"""
content = dedent_lines(content)
items = []
def parse_item_name(text):
"""Match ':role:`name`' or 'name'."""
m = self._func_rgx.match(text)
if not m:
self._error_location(f"Error parsing See Also entry {line!r}")
role = m.group("role")
name = m.group("name") if role else m.group("name2")
return name, role, m.end()
rest = []
for line in content:
if not line.strip():
continue
line_match = self._line_rgx.match(line)
description = None
if line_match:
description = line_match.group("desc")
if line_match.group("trailing") and description:
self._error_location(
"Unexpected comma or period after function list at index %d of "
'line "%s"' % (line_match.end("trailing"), line),
error=False,
)
if not description and line.startswith(" "):
rest.append(line.strip())
elif line_match:
funcs = []
text = line_match.group("allfuncs")
while True:
if not text.strip():
break
name, role, match_end = parse_item_name(text)
funcs.append((name, role))
text = text[match_end:].strip()
if text and text[0] == ",":
text = text[1:].strip()
rest = list(filter(None, [description]))
items.append((funcs, rest))
else:
self._error_location(f"Error parsing See Also entry {line!r}")
return items
def _parse_index(self, section, content):
"""
.. index:: default
:refguide: something, else, and more
"""
def strip_each_in(lst):
return [s.strip() for s in lst]
out = {}
section = section.split("::")
if len(section) > 1:
out["default"] = strip_each_in(section[1].split(","))[0]
for line in content:
line = line.split(":")
if len(line) > 2:
out[line[1]] = strip_each_in(line[2].split(","))
return out
def _parse_summary(self):
"""Grab signature (if given) and summary"""
if self._is_at_section():
return
# If several signatures present, take the last one
while True:
summary = self._doc.read_to_next_empty_line()
summary_str = " ".join([s.strip() for s in summary]).strip()
compiled = re.compile(r"^([\w., ]+=)?\s*[\w\.]+\(.*\)$")
if compiled.match(summary_str):
self["Signature"] = summary_str
if not self._is_at_section():
continue
break
if summary is not None:
self["Summary"] = summary
if not self._is_at_section():
self["Extended Summary"] = self._read_to_next_section()
def _parse(self):
self._doc.reset()
self._parse_summary()
sections = list(self._read_sections())
section_names = {section for section, content in sections}
has_yields = "Yields" in section_names
# We could do more tests, but we are not. Arbitrarily.
if not has_yields and "Receives" in section_names:
msg = "Docstring contains a Receives section but not Yields."
raise ValueError(msg)
for section, content in sections:
if not section.startswith(".."):
section = (s.capitalize() for s in section.split(" "))
section = " ".join(section)
if self.get(section):
self._error_location(
"The section %s appears twice in %s"
% (section, "\n".join(self._doc._str))
)
if section in ("Parameters", "Other Parameters", "Attributes", "Methods"):
self[section] = self._parse_param_list(content)
elif section in ("Returns", "Yields", "Raises", "Warns", "Receives"):
self[section] = self._parse_param_list(
content, single_element_is_type=True
)
elif section.startswith(".. index::"):
self["index"] = self._parse_index(section, content)
elif section == "See Also":
self["See Also"] = self._parse_see_also(content)
else:
self[section] = content
@property
def _obj(self):
if hasattr(self, "_cls"):
return self._cls
elif hasattr(self, "_f"):
return self._f
return None
def _error_location(self, msg, error=True):
if self._obj is not None:
# we know where the docs came from:
try:
filename = inspect.getsourcefile(self._obj)
except TypeError:
filename = None
# Make UserWarning more descriptive via object introspection.
# Skip if introspection fails
name = getattr(self._obj, "__name__", None)
if name is None:
name = getattr(getattr(self._obj, "__class__", None), "__name__", None)
if name is not None:
msg += f" in the docstring of {name}"
msg += f" in {filename}." if filename else ""
if error:
raise ValueError(msg)
else:
warn(msg, stacklevel=3)
# string conversion routines
def _str_header(self, name, symbol="-"):
return [name, len(name) * symbol]
def _str_indent(self, doc, indent=4):
return [" " * indent + line for line in doc]
def _str_signature(self):
if self["Signature"]:
return [self["Signature"].replace("*", r"\*")] + [""]
return [""]
def _str_summary(self):
if self["Summary"]:
return self["Summary"] + [""]
return []
def _str_extended_summary(self):
if self["Extended Summary"]:
return self["Extended Summary"] + [""]
return []
def _str_param_list(self, name):
out = []
if self[name]:
out += self._str_header(name)
for param in self[name]:
parts = []
if param.name:
parts.append(param.name)
if param.type:
parts.append(param.type)
out += [" : ".join(parts)]
if param.desc and "".join(param.desc).strip():
out += self._str_indent(param.desc)
out += [""]
return out
def _str_section(self, name):
out = []
if self[name]:
out += self._str_header(name)
out += self[name]
out += [""]
return out
def _str_see_also(self, func_role):
if not self["See Also"]:
return []
out = []
out += self._str_header("See Also")
out += [""]
last_had_desc = True
for funcs, desc in self["See Also"]:
assert isinstance(funcs, list)
links = []
for func, role in funcs:
if role:
link = f":{role}:`{func}`"
elif func_role:
link = f":{func_role}:`{func}`"
else:
link = f"`{func}`_"
links.append(link)
link = ", ".join(links)
out += [link]
if desc:
out += self._str_indent([" ".join(desc)])
last_had_desc = True
else:
last_had_desc = False
out += self._str_indent([self.empty_description])
if last_had_desc:
out += [""]
out += [""]
return out
def _str_index(self):
idx = self["index"]
out = []
output_index = False
default_index = idx.get("default", "")
if default_index:
output_index = True
out += [f".. index:: {default_index}"]
for section, references in idx.items():
if section == "default":
continue
output_index = True
out += [f" :{section}: {', '.join(references)}"]
if output_index:
return out
return ""
def __str__(self, func_role=""):
out = []
out += self._str_signature()
out += self._str_summary()
out += self._str_extended_summary()
out += self._str_param_list("Parameters")
for param_list in ("Attributes", "Methods"):
out += self._str_param_list(param_list)
for param_list in (
"Returns",
"Yields",
"Receives",
"Other Parameters",
"Raises",
"Warns",
):
out += self._str_param_list(param_list)
out += self._str_section("Warnings")
out += self._str_see_also(func_role)
for s in ("Notes", "References", "Examples"):
out += self._str_section(s)
out += self._str_index()
return "\n".join(out)
def dedent_lines(lines):
"""Deindent a list of lines maximally"""
return textwrap.dedent("\n".join(lines)).split("\n")
class FunctionDoc(NumpyDocString):
def __init__(self, func, role="func", doc=None, config=None):
self._f = func
self._role = role # e.g. "func" or "meth"
if doc is None:
if func is None:
raise ValueError("No function or docstring given")
doc = inspect.getdoc(func) or ""
if config is None:
config = {}
NumpyDocString.__init__(self, doc, config)
def get_func(self):
func_name = getattr(self._f, "__name__", self.__class__.__name__)
if inspect.isclass(self._f):
func = getattr(self._f, "__call__", self._f.__init__)
else:
func = self._f
return func, func_name
def __str__(self):
out = ""
func, func_name = self.get_func()
roles = {"func": "function", "meth": "method"}
if self._role:
if self._role not in roles:
print(f"Warning: invalid role {self._role}")
out += f".. {roles.get(self._role, '')}:: {func_name}\n \n\n"
out += super().__str__(func_role=self._role)
return out
class ObjDoc(NumpyDocString):
def __init__(self, obj, doc=None, config=None):
self._f = obj
if config is None:
config = {}
NumpyDocString.__init__(self, doc, config=config)
class ClassDoc(NumpyDocString):
extra_public_methods = ["__call__"]
def __init__(self, cls, doc=None, modulename="", func_doc=FunctionDoc, config=None):
if not inspect.isclass(cls) and cls is not None:
raise ValueError(f"Expected a class or None, but got {cls!r}")
self._cls = cls
if "sphinx" in sys.modules:
from sphinx.ext.autodoc import ALL
else:
ALL = object()
if config is None:
config = {}
self.show_inherited_members = config.get("show_inherited_class_members", True)
if modulename and not modulename.endswith("."):
modulename += "."
self._mod = modulename
if doc is None:
if cls is None:
raise ValueError("No class or documentation string given")
doc = pydoc.getdoc(cls)
NumpyDocString.__init__(self, doc)
_members = config.get("members", [])
if _members is ALL:
_members = None
_exclude = config.get("exclude-members", [])
if config.get("show_class_members", True) and _exclude is not ALL:
def splitlines_x(s):
if not s:
return []
else:
return s.splitlines()
for field, items in [
("Methods", self.methods),
("Attributes", self.properties),
]:
if not self[field]:
doc_list = []
for name in sorted(items):
if name in _exclude or (_members and name not in _members):
continue
try:
doc_item = pydoc.getdoc(getattr(self._cls, name))
doc_list.append(Parameter(name, "", splitlines_x(doc_item)))
except AttributeError:
pass # method doesn't exist
self[field] = doc_list
@property
def methods(self):
if self._cls is None:
return []
return [
name
for name, func in inspect.getmembers(self._cls)
if (
(not name.startswith("_") or name in self.extra_public_methods)
and isinstance(func, Callable)
and self._is_show_member(name)
)
]
@property
def properties(self):
if self._cls is None:
return []
return [
name
for name, func in inspect.getmembers(self._cls)
if (
not name.startswith("_")
and not self._should_skip_member(name, self._cls)
and (
func is None
or isinstance(func, (property, cached_property))
or inspect.isdatadescriptor(func)
)
and self._is_show_member(name)
)
]
@staticmethod
def _should_skip_member(name, klass):
return (
# Namedtuples should skip everything in their ._fields as the
# docstrings for each of the members is: "Alias for field number X"
issubclass(klass, tuple)
and hasattr(klass, "_asdict")
and hasattr(klass, "_fields")
and name in klass._fields
)
def _is_show_member(self, name):
return (
# show all class members
self.show_inherited_members
# or class member is not inherited
or name in self._cls.__dict__
)
def get_doc_object(
obj,
what=None,
doc=None,
config=None,
class_doc=ClassDoc,
func_doc=FunctionDoc,
obj_doc=ObjDoc,
):
if what is None:
if inspect.isclass(obj):
what = "class"
elif inspect.ismodule(obj):
what = "module"
elif isinstance(obj, Callable):
what = "function"
else:
what = "object"
if config is None:
config = {}
if what == "class":
return class_doc(obj, func_doc=func_doc, doc=doc, config=config)
elif what in ("function", "method"):
return func_doc(obj, doc=doc, config=config)
else:
if doc is None:
doc = pydoc.getdoc(obj)
return obj_doc(obj, doc, config=config)

View File

@@ -0,0 +1,90 @@
"""Vendoered from
https://github.com/pypa/packaging/blob/main/packaging/_structures.py
"""
# Copyright (c) Donald Stufft and individual contributors.
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
class InfinityType:
def __repr__(self) -> str:
return "Infinity"
def __hash__(self) -> int:
return hash(repr(self))
def __lt__(self, other: object) -> bool:
return False
def __le__(self, other: object) -> bool:
return False
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)
def __ne__(self, other: object) -> bool:
return not isinstance(other, self.__class__)
def __gt__(self, other: object) -> bool:
return True
def __ge__(self, other: object) -> bool:
return True
def __neg__(self: object) -> "NegativeInfinityType":
return NegativeInfinity
Infinity = InfinityType()
class NegativeInfinityType:
def __repr__(self) -> str:
return "-Infinity"
def __hash__(self) -> int:
return hash(repr(self))
def __lt__(self, other: object) -> bool:
return True
def __le__(self, other: object) -> bool:
return True
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)
def __ne__(self, other: object) -> bool:
return not isinstance(other, self.__class__)
def __gt__(self, other: object) -> bool:
return False
def __ge__(self, other: object) -> bool:
return False
def __neg__(self: object) -> InfinityType:
return Infinity
NegativeInfinity = NegativeInfinityType()

View File

@@ -0,0 +1,535 @@
"""Vendored from
https://github.com/pypa/packaging/blob/main/packaging/version.py
"""
# Copyright (c) Donald Stufft and individual contributors.
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import collections
import itertools
import re
import warnings
from typing import Callable, Iterator, List, Optional, SupportsInt, Tuple, Union
from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType
__all__ = ["parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN"]
InfiniteTypes = Union[InfinityType, NegativeInfinityType]
PrePostDevType = Union[InfiniteTypes, Tuple[str, int]]
SubLocalType = Union[InfiniteTypes, int, str]
LocalType = Union[
NegativeInfinityType,
Tuple[
Union[
SubLocalType,
Tuple[SubLocalType, str],
Tuple[NegativeInfinityType, SubLocalType],
],
...,
],
]
CmpKey = Tuple[
int, Tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType
]
LegacyCmpKey = Tuple[int, Tuple[str, ...]]
VersionComparisonMethod = Callable[
[Union[CmpKey, LegacyCmpKey], Union[CmpKey, LegacyCmpKey]], bool
]
_Version = collections.namedtuple(
"_Version", ["epoch", "release", "dev", "pre", "post", "local"]
)
def parse(version: str) -> Union["LegacyVersion", "Version"]:
"""Parse the given version from a string to an appropriate class.
Parameters
----------
version : str
Version in a string format, eg. "0.9.1" or "1.2.dev0".
Returns
-------
version : :class:`Version` object or a :class:`LegacyVersion` object
Returned class depends on the given version: if is a valid
PEP 440 version or a legacy version.
"""
try:
return Version(version)
except InvalidVersion:
return LegacyVersion(version)
class InvalidVersion(ValueError):
"""
An invalid version was found, users should refer to PEP 440.
"""
class _BaseVersion:
_key: Union[CmpKey, LegacyCmpKey]
def __hash__(self) -> int:
return hash(self._key)
# Please keep the duplicated `isinstance` check
# in the six comparisons hereunder
# unless you find a way to avoid adding overhead function calls.
def __lt__(self, other: "_BaseVersion") -> bool:
if not isinstance(other, _BaseVersion):
return NotImplemented
return self._key < other._key
def __le__(self, other: "_BaseVersion") -> bool:
if not isinstance(other, _BaseVersion):
return NotImplemented
return self._key <= other._key
def __eq__(self, other: object) -> bool:
if not isinstance(other, _BaseVersion):
return NotImplemented
return self._key == other._key
def __ge__(self, other: "_BaseVersion") -> bool:
if not isinstance(other, _BaseVersion):
return NotImplemented
return self._key >= other._key
def __gt__(self, other: "_BaseVersion") -> bool:
if not isinstance(other, _BaseVersion):
return NotImplemented
return self._key > other._key
def __ne__(self, other: object) -> bool:
if not isinstance(other, _BaseVersion):
return NotImplemented
return self._key != other._key
class LegacyVersion(_BaseVersion):
def __init__(self, version: str) -> None:
self._version = str(version)
self._key = _legacy_cmpkey(self._version)
warnings.warn(
"Creating a LegacyVersion has been deprecated and will be "
"removed in the next major release",
DeprecationWarning,
)
def __str__(self) -> str:
return self._version
def __repr__(self) -> str:
return f"<LegacyVersion('{self}')>"
@property
def public(self) -> str:
return self._version
@property
def base_version(self) -> str:
return self._version
@property
def epoch(self) -> int:
return -1
@property
def release(self) -> None:
return None
@property
def pre(self) -> None:
return None
@property
def post(self) -> None:
return None
@property
def dev(self) -> None:
return None
@property
def local(self) -> None:
return None
@property
def is_prerelease(self) -> bool:
return False
@property
def is_postrelease(self) -> bool:
return False
@property
def is_devrelease(self) -> bool:
return False
_legacy_version_component_re = re.compile(r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE)
_legacy_version_replacement_map = {
"pre": "c",
"preview": "c",
"-": "final-",
"rc": "c",
"dev": "@",
}
def _parse_version_parts(s: str) -> Iterator[str]:
for part in _legacy_version_component_re.split(s):
part = _legacy_version_replacement_map.get(part, part)
if not part or part == ".":
continue
if part[:1] in "0123456789":
# pad for numeric comparison
yield part.zfill(8)
else:
yield "*" + part
# ensure that alpha/beta/candidate are before final
yield "*final"
def _legacy_cmpkey(version: str) -> LegacyCmpKey:
# We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch
# greater than or equal to 0. This will effectively put the LegacyVersion,
# which uses the defacto standard originally implemented by setuptools,
# as before all PEP 440 versions.
epoch = -1
# This scheme is taken from pkg_resources.parse_version setuptools prior to
# it's adoption of the packaging library.
parts: List[str] = []
for part in _parse_version_parts(version.lower()):
if part.startswith("*"):
# remove "-" before a prerelease tag
if part < "*final":
while parts and parts[-1] == "*final-":
parts.pop()
# remove trailing zeros from each series of numeric parts
while parts and parts[-1] == "00000000":
parts.pop()
parts.append(part)
return epoch, tuple(parts)
# Deliberately not anchored to the start and end of the string, to make it
# easier for 3rd party code to reuse
VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
class Version(_BaseVersion):
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
def __init__(self, version: str) -> None:
# Validate the version and parse it into pieces
match = self._regex.search(version)
if not match:
raise InvalidVersion(f"Invalid version: '{version}'")
# Store the parsed out pieces of the version
self._version = _Version(
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
release=tuple(int(i) for i in match.group("release").split(".")),
pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
post=_parse_letter_version(
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
),
dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
local=_parse_local_version(match.group("local")),
)
# Generate a key which will be used for sorting
self._key = _cmpkey(
self._version.epoch,
self._version.release,
self._version.pre,
self._version.post,
self._version.dev,
self._version.local,
)
def __repr__(self) -> str:
return f"<Version('{self}')>"
def __str__(self) -> str:
parts = []
# Epoch
if self.epoch != 0:
parts.append(f"{self.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self.release))
# Pre-release
if self.pre is not None:
parts.append("".join(str(x) for x in self.pre))
# Post-release
if self.post is not None:
parts.append(f".post{self.post}")
# Development release
if self.dev is not None:
parts.append(f".dev{self.dev}")
# Local version segment
if self.local is not None:
parts.append(f"+{self.local}")
return "".join(parts)
@property
def epoch(self) -> int:
_epoch: int = self._version.epoch
return _epoch
@property
def release(self) -> Tuple[int, ...]:
_release: Tuple[int, ...] = self._version.release
return _release
@property
def pre(self) -> Optional[Tuple[str, int]]:
_pre: Optional[Tuple[str, int]] = self._version.pre
return _pre
@property
def post(self) -> Optional[int]:
return self._version.post[1] if self._version.post else None
@property
def dev(self) -> Optional[int]:
return self._version.dev[1] if self._version.dev else None
@property
def local(self) -> Optional[str]:
if self._version.local:
return ".".join(str(x) for x in self._version.local)
else:
return None
@property
def public(self) -> str:
return str(self).split("+", 1)[0]
@property
def base_version(self) -> str:
parts = []
# Epoch
if self.epoch != 0:
parts.append(f"{self.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self.release))
return "".join(parts)
@property
def is_prerelease(self) -> bool:
return self.dev is not None or self.pre is not None
@property
def is_postrelease(self) -> bool:
return self.post is not None
@property
def is_devrelease(self) -> bool:
return self.dev is not None
@property
def major(self) -> int:
return self.release[0] if len(self.release) >= 1 else 0
@property
def minor(self) -> int:
return self.release[1] if len(self.release) >= 2 else 0
@property
def micro(self) -> int:
return self.release[2] if len(self.release) >= 3 else 0
def _parse_letter_version(
letter: str, number: Union[str, bytes, SupportsInt]
) -> Optional[Tuple[str, int]]:
if letter:
# We consider there to be an implicit 0 in a pre-release if there is
# not a numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
elif letter in ["rev", "r"]:
letter = "post"
return letter, int(number)
if not letter and number:
# We assume if we are given a number, but we are not given a letter
# then this is using the implicit post release syntax (e.g. 1.0-1)
letter = "post"
return letter, int(number)
return None
_local_version_separators = re.compile(r"[\._-]")
def _parse_local_version(local: str) -> Optional[LocalType]:
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in _local_version_separators.split(local)
)
return None
def _cmpkey(
epoch: int,
release: Tuple[int, ...],
pre: Optional[Tuple[str, int]],
post: Optional[Tuple[str, int]],
dev: Optional[Tuple[str, int]],
local: Optional[Tuple[SubLocalType]],
) -> CmpKey:
# When we compare a release version, we want to compare it with all of the
# trailing zeros removed. So we'll use a reverse the list, drop all the now
# leading zeros until we come to something non zero, then take the rest
# re-reverse it back into the correct order and make it a tuple and use
# that for our sorting key.
_release = tuple(
reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
)
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
# We'll do this by abusing the pre segment, but we _only_ want to do this
# if there is not a pre or a post segment. If we have one of those then
# the normal sorting rules will handle this case correctly.
if pre is None and post is None and dev is not None:
_pre: PrePostDevType = NegativeInfinity
# Versions without a pre-release (except as noted above) should sort after
# those with one.
elif pre is None:
_pre = Infinity
else:
_pre = pre
# Versions without a post segment should sort before those with one.
if post is None:
_post: PrePostDevType = NegativeInfinity
else:
_post = post
# Versions without a development segment should sort after those with one.
if dev is None:
_dev: PrePostDevType = Infinity
else:
_dev = dev
if local is None:
# Versions without a local segment should sort before those with one.
_local: LocalType = NegativeInfinity
else:
# Versions with a local segment need that segment parsed to implement
# the sorting rules in PEP440.
# - Alpha numeric segments sort before numeric segments
# - Alpha numeric segments sort lexicographically
# - Numeric segments sort numerically
# - Shorter versions sort before longer versions when the prefixes
# match exactly
_local = tuple(
(i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
)
return epoch, _release, _pre, _post, _dev, _local

View File

@@ -0,0 +1 @@
from ._laplacian import laplacian

View File

@@ -0,0 +1,557 @@
"""
This file is a copy of the scipy.sparse.csgraph._laplacian module from SciPy 1.12
scipy.sparse.csgraph.laplacian supports sparse arrays only starting from Scipy 1.12,
see https://github.com/scipy/scipy/pull/19156. This vendored file can be removed as
soon as Scipy 1.12 becomes the minimum supported version.
Laplacian of a compressed-sparse graph
"""
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
from scipy.sparse import issparse
from scipy.sparse.linalg import LinearOperator
###############################################################################
# Graph laplacian
def laplacian(
csgraph,
normed=False,
return_diag=False,
use_out_degree=False,
*,
copy=True,
form="array",
dtype=None,
symmetrized=False,
):
"""
Return the Laplacian of a directed graph.
Parameters
----------
csgraph : array_like or sparse matrix, 2 dimensions
Compressed-sparse graph, with shape (N, N).
normed : bool, optional
If True, then compute symmetrically normalized Laplacian.
Default: False.
return_diag : bool, optional
If True, then also return an array related to vertex degrees.
Default: False.
use_out_degree : bool, optional
If True, then use out-degree instead of in-degree.
This distinction matters only if the graph is asymmetric.
Default: False.
copy : bool, optional
If False, then change `csgraph` in place if possible,
avoiding doubling the memory use.
Default: True, for backward compatibility.
form : 'array', or 'function', or 'lo'
Determines the format of the output Laplacian:
* 'array' is a numpy array;
* 'function' is a pointer to evaluating the Laplacian-vector
or Laplacian-matrix product;
* 'lo' results in the format of the `LinearOperator`.
Choosing 'function' or 'lo' always avoids doubling
the memory use, ignoring `copy` value.
Default: 'array', for backward compatibility.
dtype : None or one of numeric numpy dtypes, optional
The dtype of the output. If ``dtype=None``, the dtype of the
output matches the dtype of the input csgraph, except for
the case ``normed=True`` and integer-like csgraph, where
the output dtype is 'float' allowing accurate normalization,
but dramatically increasing the memory use.
Default: None, for backward compatibility.
symmetrized : bool, optional
If True, then the output Laplacian is symmetric/Hermitian.
The symmetrization is done by ``csgraph + csgraph.T.conj``
without dividing by 2 to preserve integer dtypes if possible
prior to the construction of the Laplacian.
The symmetrization will increase the memory footprint of
sparse matrices unless the sparsity pattern is symmetric or
`form` is 'function' or 'lo'.
Default: False, for backward compatibility.
Returns
-------
lap : ndarray, or sparse matrix, or `LinearOperator`
The N x N Laplacian of csgraph. It will be a NumPy array (dense)
if the input was dense, or a sparse matrix otherwise, or
the format of a function or `LinearOperator` if
`form` equals 'function' or 'lo', respectively.
diag : ndarray, optional
The length-N main diagonal of the Laplacian matrix.
For the normalized Laplacian, this is the array of square roots
of vertex degrees or 1 if the degree is zero.
Notes
-----
The Laplacian matrix of a graph is sometimes referred to as the
"Kirchhoff matrix" or just the "Laplacian", and is useful in many
parts of spectral graph theory.
In particular, the eigen-decomposition of the Laplacian can give
insight into many properties of the graph, e.g.,
is commonly used for spectral data embedding and clustering.
The constructed Laplacian doubles the memory use if ``copy=True`` and
``form="array"`` which is the default.
Choosing ``copy=False`` has no effect unless ``form="array"``
or the matrix is sparse in the ``coo`` format, or dense array, except
for the integer input with ``normed=True`` that forces the float output.
Sparse input is reformatted into ``coo`` if ``form="array"``,
which is the default.
If the input adjacency matrix is not symmetric, the Laplacian is
also non-symmetric unless ``symmetrized=True`` is used.
Diagonal entries of the input adjacency matrix are ignored and
replaced with zeros for the purpose of normalization where ``normed=True``.
The normalization uses the inverse square roots of row-sums of the input
adjacency matrix, and thus may fail if the row-sums contain
negative or complex with a non-zero imaginary part values.
The normalization is symmetric, making the normalized Laplacian also
symmetric if the input csgraph was symmetric.
References
----------
.. [1] Laplacian matrix. https://en.wikipedia.org/wiki/Laplacian_matrix
Examples
--------
>>> import numpy as np
>>> from scipy.sparse import csgraph
Our first illustration is the symmetric graph
>>> G = np.arange(4) * np.arange(4)[:, np.newaxis]
>>> G
array([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])
and its symmetric Laplacian matrix
>>> csgraph.laplacian(G)
array([[ 0, 0, 0, 0],
[ 0, 5, -2, -3],
[ 0, -2, 8, -6],
[ 0, -3, -6, 9]])
The non-symmetric graph
>>> G = np.arange(9).reshape(3, 3)
>>> G
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
has different row- and column sums, resulting in two varieties
of the Laplacian matrix, using an in-degree, which is the default
>>> L_in_degree = csgraph.laplacian(G)
>>> L_in_degree
array([[ 9, -1, -2],
[-3, 8, -5],
[-6, -7, 7]])
or alternatively an out-degree
>>> L_out_degree = csgraph.laplacian(G, use_out_degree=True)
>>> L_out_degree
array([[ 3, -1, -2],
[-3, 8, -5],
[-6, -7, 13]])
Constructing a symmetric Laplacian matrix, one can add the two as
>>> L_in_degree + L_out_degree.T
array([[ 12, -4, -8],
[ -4, 16, -12],
[ -8, -12, 20]])
or use the ``symmetrized=True`` option
>>> csgraph.laplacian(G, symmetrized=True)
array([[ 12, -4, -8],
[ -4, 16, -12],
[ -8, -12, 20]])
that is equivalent to symmetrizing the original graph
>>> csgraph.laplacian(G + G.T)
array([[ 12, -4, -8],
[ -4, 16, -12],
[ -8, -12, 20]])
The goal of normalization is to make the non-zero diagonal entries
of the Laplacian matrix to be all unit, also scaling off-diagonal
entries correspondingly. The normalization can be done manually, e.g.,
>>> G = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
>>> L, d = csgraph.laplacian(G, return_diag=True)
>>> L
array([[ 2, -1, -1],
[-1, 2, -1],
[-1, -1, 2]])
>>> d
array([2, 2, 2])
>>> scaling = np.sqrt(d)
>>> scaling
array([1.41421356, 1.41421356, 1.41421356])
>>> (1/scaling)*L*(1/scaling)
array([[ 1. , -0.5, -0.5],
[-0.5, 1. , -0.5],
[-0.5, -0.5, 1. ]])
Or using ``normed=True`` option
>>> L, d = csgraph.laplacian(G, return_diag=True, normed=True)
>>> L
array([[ 1. , -0.5, -0.5],
[-0.5, 1. , -0.5],
[-0.5, -0.5, 1. ]])
which now instead of the diagonal returns the scaling coefficients
>>> d
array([1.41421356, 1.41421356, 1.41421356])
Zero scaling coefficients are substituted with 1s, where scaling
has thus no effect, e.g.,
>>> G = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0]])
>>> G
array([[0, 0, 0],
[0, 0, 1],
[0, 1, 0]])
>>> L, d = csgraph.laplacian(G, return_diag=True, normed=True)
>>> L
array([[ 0., -0., -0.],
[-0., 1., -1.],
[-0., -1., 1.]])
>>> d
array([1., 1., 1.])
Only the symmetric normalization is implemented, resulting
in a symmetric Laplacian matrix if and only if its graph is symmetric
and has all non-negative degrees, like in the examples above.
The output Laplacian matrix is by default a dense array or a sparse matrix
inferring its shape, format, and dtype from the input graph matrix:
>>> G = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]).astype(np.float32)
>>> G
array([[0., 1., 1.],
[1., 0., 1.],
[1., 1., 0.]], dtype=float32)
>>> csgraph.laplacian(G)
array([[ 2., -1., -1.],
[-1., 2., -1.],
[-1., -1., 2.]], dtype=float32)
but can alternatively be generated matrix-free as a LinearOperator:
>>> L = csgraph.laplacian(G, form="lo")
>>> L
<3x3 _CustomLinearOperator with dtype=float32>
>>> L(np.eye(3))
array([[ 2., -1., -1.],
[-1., 2., -1.],
[-1., -1., 2.]])
or as a lambda-function:
>>> L = csgraph.laplacian(G, form="function")
>>> L
<function _laplace.<locals>.<lambda> at 0x0000012AE6F5A598>
>>> L(np.eye(3))
array([[ 2., -1., -1.],
[-1., 2., -1.],
[-1., -1., 2.]])
The Laplacian matrix is used for
spectral data clustering and embedding
as well as for spectral graph partitioning.
Our final example illustrates the latter
for a noisy directed linear graph.
>>> from scipy.sparse import diags, random
>>> from scipy.sparse.linalg import lobpcg
Create a directed linear graph with ``N=35`` vertices
using a sparse adjacency matrix ``G``:
>>> N = 35
>>> G = diags(np.ones(N-1), 1, format="csr")
Fix a random seed ``rng`` and add a random sparse noise to the graph ``G``:
>>> rng = np.random.default_rng()
>>> G += 1e-2 * random(N, N, density=0.1, random_state=rng)
Set initial approximations for eigenvectors:
>>> X = rng.random((N, 2))
The constant vector of ones is always a trivial eigenvector
of the non-normalized Laplacian to be filtered out:
>>> Y = np.ones((N, 1))
Alternating (1) the sign of the graph weights allows determining
labels for spectral max- and min- cuts in a single loop.
Since the graph is undirected, the option ``symmetrized=True``
must be used in the construction of the Laplacian.
The option ``normed=True`` cannot be used in (2) for the negative weights
here as the symmetric normalization evaluates square roots.
The option ``form="lo"`` in (2) is matrix-free, i.e., guarantees
a fixed memory footprint and read-only access to the graph.
Calling the eigenvalue solver ``lobpcg`` (3) computes the Fiedler vector
that determines the labels as the signs of its components in (5).
Since the sign in an eigenvector is not deterministic and can flip,
we fix the sign of the first component to be always +1 in (4).
>>> for cut in ["max", "min"]:
... G = -G # 1.
... L = csgraph.laplacian(G, symmetrized=True, form="lo") # 2.
... _, eves = lobpcg(L, X, Y=Y, largest=False, tol=1e-3) # 3.
... eves *= np.sign(eves[0, 0]) # 4.
... print(cut + "-cut labels:\\n", 1 * (eves[:, 0]>0)) # 5.
max-cut labels:
[1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1]
min-cut labels:
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
As anticipated for a (slightly noisy) linear graph,
the max-cut strips all the edges of the graph coloring all
odd vertices into one color and all even vertices into another one,
while the balanced min-cut partitions the graph
in the middle by deleting a single edge.
Both determined partitions are optimal.
"""
if csgraph.ndim != 2 or csgraph.shape[0] != csgraph.shape[1]:
raise ValueError("csgraph must be a square matrix or array")
if normed and (
np.issubdtype(csgraph.dtype, np.signedinteger)
or np.issubdtype(csgraph.dtype, np.uint)
):
csgraph = csgraph.astype(np.float64)
if form == "array":
create_lap = _laplacian_sparse if issparse(csgraph) else _laplacian_dense
else:
create_lap = (
_laplacian_sparse_flo if issparse(csgraph) else _laplacian_dense_flo
)
degree_axis = 1 if use_out_degree else 0
lap, d = create_lap(
csgraph,
normed=normed,
axis=degree_axis,
copy=copy,
form=form,
dtype=dtype,
symmetrized=symmetrized,
)
if return_diag:
return lap, d
return lap
def _setdiag_dense(m, d):
step = len(d) + 1
m.flat[::step] = d
def _laplace(m, d):
return lambda v: v * d[:, np.newaxis] - m @ v
def _laplace_normed(m, d, nd):
laplace = _laplace(m, d)
return lambda v: nd[:, np.newaxis] * laplace(v * nd[:, np.newaxis])
def _laplace_sym(m, d):
return (
lambda v: v * d[:, np.newaxis]
- m @ v
- np.transpose(np.conjugate(np.transpose(np.conjugate(v)) @ m))
)
def _laplace_normed_sym(m, d, nd):
laplace_sym = _laplace_sym(m, d)
return lambda v: nd[:, np.newaxis] * laplace_sym(v * nd[:, np.newaxis])
def _linearoperator(mv, shape, dtype):
return LinearOperator(matvec=mv, matmat=mv, shape=shape, dtype=dtype)
def _laplacian_sparse_flo(graph, normed, axis, copy, form, dtype, symmetrized):
# The keyword argument `copy` is unused and has no effect here.
del copy
if dtype is None:
dtype = graph.dtype
graph_sum = np.asarray(graph.sum(axis=axis)).ravel()
graph_diagonal = graph.diagonal()
diag = graph_sum - graph_diagonal
if symmetrized:
graph_sum += np.asarray(graph.sum(axis=1 - axis)).ravel()
diag = graph_sum - graph_diagonal - graph_diagonal
if normed:
isolated_node_mask = diag == 0
w = np.where(isolated_node_mask, 1, np.sqrt(diag))
if symmetrized:
md = _laplace_normed_sym(graph, graph_sum, 1.0 / w)
else:
md = _laplace_normed(graph, graph_sum, 1.0 / w)
if form == "function":
return md, w.astype(dtype, copy=False)
elif form == "lo":
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
return m, w.astype(dtype, copy=False)
else:
raise ValueError(f"Invalid form: {form!r}")
else:
if symmetrized:
md = _laplace_sym(graph, graph_sum)
else:
md = _laplace(graph, graph_sum)
if form == "function":
return md, diag.astype(dtype, copy=False)
elif form == "lo":
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
return m, diag.astype(dtype, copy=False)
else:
raise ValueError(f"Invalid form: {form!r}")
def _laplacian_sparse(graph, normed, axis, copy, form, dtype, symmetrized):
# The keyword argument `form` is unused and has no effect here.
del form
if dtype is None:
dtype = graph.dtype
needs_copy = False
if graph.format in ("lil", "dok"):
m = graph.tocoo()
else:
m = graph
if copy:
needs_copy = True
if symmetrized:
m += m.T.conj()
w = np.asarray(m.sum(axis=axis)).ravel() - m.diagonal()
if normed:
m = m.tocoo(copy=needs_copy)
isolated_node_mask = w == 0
w = np.where(isolated_node_mask, 1, np.sqrt(w))
m.data /= w[m.row]
m.data /= w[m.col]
m.data *= -1
m.setdiag(1 - isolated_node_mask)
else:
if m.format == "dia":
m = m.copy()
else:
m = m.tocoo(copy=needs_copy)
m.data *= -1
m.setdiag(w)
return m.astype(dtype, copy=False), w.astype(dtype)
def _laplacian_dense_flo(graph, normed, axis, copy, form, dtype, symmetrized):
if copy:
m = np.array(graph)
else:
m = np.asarray(graph)
if dtype is None:
dtype = m.dtype
graph_sum = m.sum(axis=axis)
graph_diagonal = m.diagonal()
diag = graph_sum - graph_diagonal
if symmetrized:
graph_sum += m.sum(axis=1 - axis)
diag = graph_sum - graph_diagonal - graph_diagonal
if normed:
isolated_node_mask = diag == 0
w = np.where(isolated_node_mask, 1, np.sqrt(diag))
if symmetrized:
md = _laplace_normed_sym(m, graph_sum, 1.0 / w)
else:
md = _laplace_normed(m, graph_sum, 1.0 / w)
if form == "function":
return md, w.astype(dtype, copy=False)
elif form == "lo":
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
return m, w.astype(dtype, copy=False)
else:
raise ValueError(f"Invalid form: {form!r}")
else:
if symmetrized:
md = _laplace_sym(m, graph_sum)
else:
md = _laplace(m, graph_sum)
if form == "function":
return md, diag.astype(dtype, copy=False)
elif form == "lo":
m = _linearoperator(md, shape=graph.shape, dtype=dtype)
return m, diag.astype(dtype, copy=False)
else:
raise ValueError(f"Invalid form: {form!r}")
def _laplacian_dense(graph, normed, axis, copy, form, dtype, symmetrized):
if form != "array":
raise ValueError(f'{form!r} must be "array"')
if dtype is None:
dtype = graph.dtype
if copy:
m = np.array(graph)
else:
m = np.asarray(graph)
if dtype is None:
dtype = m.dtype
if symmetrized:
m += m.T.conj()
np.fill_diagonal(m, 0)
w = m.sum(axis=axis)
if normed:
isolated_node_mask = w == 0
w = np.where(isolated_node_mask, 1, np.sqrt(w))
m /= w
m /= w[:, np.newaxis]
m *= -1
_setdiag_dense(m, 1 - isolated_node_mask)
else:
m *= -1
_setdiag_dense(m, w)
return m.astype(dtype, copy=False), w.astype(dtype, copy=False)

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 Consortium for Python Data API Standards
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1 @@
Update this directory using maint_tools/vendor_array_api_compat.sh

View File

@@ -0,0 +1,22 @@
"""
NumPy Array API compatibility library
This is a small wrapper around NumPy, CuPy, JAX, sparse and others that are
compatible with the Array API standard https://data-apis.org/array-api/latest/.
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
changes needed to be compliant with the Array API. See
https://numpy.org/doc/stable/reference/array_api.html for a full list of
changes. In particular, unlike array_api_strict, this package does not use a
separate Array object, but rather just uses numpy.ndarray directly.
Library authors using the Array API may wish to test against array_api_strict
to ensure they are not using functionality outside of the standard, but prefer
this implementation for the default when working with NumPy arrays.
"""
__version__ = '1.12.0'
from .common import * # noqa: F401, F403

View File

@@ -0,0 +1,59 @@
"""
Internal helpers
"""
from collections.abc import Callable
from functools import wraps
from inspect import signature
from types import ModuleType
from typing import TypeVar
_T = TypeVar("_T")
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""
Decorator to automatically replace xp with the corresponding array module.
Use like
import numpy as np
@get_xp(np)
def func(x, /, xp, kwarg=None):
return xp.func(x, kwarg=kwarg)
Note that xp must be a keyword argument and come after all non-keyword
arguments.
"""
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
@wraps(f)
def wrapped_f(*args: object, **kwargs: object) -> object:
return f(*args, xp=xp, **kwargs)
sig = signature(f)
new_sig = sig.replace(
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
)
if wrapped_f.__doc__ is None:
wrapped_f.__doc__ = f"""\
Array API compatibility wrapper for {f.__name__}.
See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.
"""
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]
return inner
__all__ = ["get_xp"]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1 @@
from ._helpers import * # noqa: F403

View File

@@ -0,0 +1,727 @@
"""
These are functions that are just aliases of existing functions in NumPy.
"""
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
from ._helpers import _check_device, array_namespace
from ._helpers import device as _get_device
from ._helpers import is_cupy_namespace as _is_cupy_namespace
from ._typing import Array, Device, DType, Namespace
if TYPE_CHECKING:
# TODO: import from typing (requires Python >=3.13)
from typing_extensions import TypeIs
# These functions are modified from the NumPy versions.
# Creation functions add the device keyword (which does nothing for NumPy and Dask)
def arange(
start: float,
/,
stop: float | None = None,
step: float = 1,
*,
xp: Namespace,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
def empty(
shape: int | tuple[int, ...],
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.empty(shape, dtype=dtype, **kwargs)
def empty_like(
x: Array,
/,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.empty_like(x, dtype=dtype, **kwargs)
def eye(
n_rows: int,
n_cols: int | None = None,
/,
*,
xp: Namespace,
k: int = 0,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
def full(
shape: int | tuple[int, ...],
fill_value: complex,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
def full_like(
x: Array,
/,
fill_value: complex,
*,
xp: Namespace,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
def linspace(
start: float,
stop: float,
/,
num: int,
*,
xp: Namespace,
dtype: DType | None = None,
device: Device | None = None,
endpoint: bool = True,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
def ones(
shape: int | tuple[int, ...],
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.ones(shape, dtype=dtype, **kwargs)
def ones_like(
x: Array,
/,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.ones_like(x, dtype=dtype, **kwargs)
def zeros(
shape: int | tuple[int, ...],
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.zeros(shape, dtype=dtype, **kwargs)
def zeros_like(
x: Array,
/,
xp: Namespace,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.zeros_like(x, dtype=dtype, **kwargs)
# np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
# to remove polymorphic return types).
# The functions here return namedtuples (np.unique() returns a normal
# tuple).
# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
class UniqueAllResult(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array
class UniqueCountsResult(NamedTuple):
values: Array
counts: Array
class UniqueInverseResult(NamedTuple):
values: Array
inverse_indices: Array
def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
# trying to parse version numbers, just check if equal_nan is in the
# signature.
s = inspect.signature(xp.unique)
if "equal_nan" in s.parameters:
return {"equal_nan": False}
return {}
def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
kwargs = _unique_kwargs(xp)
values, indices, inverse_indices, counts = xp.unique(
x,
return_counts=True,
return_index=True,
return_inverse=True,
**kwargs,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueAllResult(
values,
indices,
inverse_indices,
counts,
)
def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
kwargs = _unique_kwargs(xp)
res = xp.unique(
x, return_counts=True, return_index=False, return_inverse=False, **kwargs
)
return UniqueCountsResult(*res)
def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult:
kwargs = _unique_kwargs(xp)
values, inverse_indices = xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=True,
**kwargs,
)
# xp.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueInverseResult(values, inverse_indices)
def unique_values(x: Array, /, xp: Namespace) -> Array:
kwargs = _unique_kwargs(xp)
return xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=False,
**kwargs,
)
# These functions have different keyword argument names
def std(
x: Array,
/,
xp: Namespace,
*,
axis: int | tuple[int, ...] | None = None,
correction: float = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs: object,
) -> Array:
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
def var(
x: Array,
/,
xp: Namespace,
*,
axis: int | tuple[int, ...] | None = None,
correction: float = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs: object,
) -> Array:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
# argument
def cumulative_sum(
x: Array,
/,
xp: Namespace,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
**kwargs: object,
) -> Array:
wrapped_xp = array_namespace(x)
# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
raise ValueError(
"axis must be specified in cumulative_sum for more than one dimension"
)
axis = 0
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
# np.cumsum does not support include_initial
if include_initial:
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[
wrapped_xp.zeros(
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
),
res,
],
axis=axis,
)
return res
def cumulative_prod(
x: Array,
/,
xp: Namespace,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
**kwargs: object,
) -> Array:
wrapped_xp = array_namespace(x)
if axis is None:
if x.ndim > 1:
raise ValueError(
"axis must be specified in cumulative_prod for more than one dimension"
)
axis = 0
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
# np.cumprod does not support include_initial
if include_initial:
initial_shape = list(x.shape)
initial_shape[axis] = 1
res = xp.concatenate(
[
wrapped_xp.ones(
shape=initial_shape, dtype=res.dtype, device=_get_device(res)
),
res,
],
axis=axis,
)
return res
# The min and max argument names in clip are different and not optional in numpy, and type
# promotion behavior is different.
def clip(
x: Array,
/,
min: float | Array | None = None,
max: float | Array | None = None,
*,
xp: Namespace,
# TODO: np.clip has other ufunc kwargs
out: Array | None = None,
) -> Array:
def _isscalar(a: object) -> TypeIs[int | float | None]:
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
wrapped_xp = array_namespace(x)
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
# np.clip does type promotion but the array API clip requires that the
# output have the same dtype as x. We do this instead of just downcasting
# the result of xp.clip() to handle some corner cases better (e.g.,
# avoiding uint64 -> float64 promotion).
# Note: cases where min or max overflow (integer) or round (float) in the
# wrong direction when downcasting to x.dtype are unspecified. This code
# just does whatever NumPy does when it downcasts in the assignment, but
# other behavior could be preferred, especially for integers. For example,
# this code produces:
# >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
# -128
# but an answer of 0 might be preferred. See
# https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
# At least handle the case of Python integers correctly (see
# https://github.com/numpy/numpy/pull/26892).
if wrapped_xp.isdtype(x.dtype, "integral"):
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
min = None
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
max = None
dev = _get_device(x)
if out is None:
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
assert out is not None # workaround for a type-narrowing issue in pyright
out[()] = x
if min is not None:
a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev)
a = xp.broadcast_to(a, result_shape)
ia = (out < a) | xp.isnan(a)
out[ia] = a[ia]
if max is not None:
b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev)
b = xp.broadcast_to(b, result_shape)
ib = (out > b) | xp.isnan(b)
out[ib] = b[ib]
# Return a scalar for 0-D
return out[()]
# Unlike transpose(), the axes argument to permute_dims() is required.
def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array:
return xp.transpose(x, axes)
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
def reshape(
x: Array,
/,
shape: tuple[int, ...],
xp: Namespace,
*,
copy: Optional[bool] = None,
**kwargs: object,
) -> Array:
if copy is True:
x = x.copy()
elif copy is False:
y = x.view()
y.shape = shape
return y
return xp.reshape(x, shape, **kwargs)
# The descending keyword is new in sort and argsort, and 'kind' replaced with
# 'stable'
def argsort(
x: Array,
/,
xp: Namespace,
*,
axis: int = -1,
descending: bool = False,
stable: bool = True,
**kwargs: object,
) -> Array:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
kwargs["kind"] = "stable"
if not descending:
res = xp.argsort(x, axis=axis, **kwargs)
else:
# As NumPy has no native descending sort, we imitate it here. Note that
# simply flipping the results of xp.argsort(x, ...) would not
# respect the relative order like it would in native descending sorts.
res = xp.flip(
xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
axis=axis,
)
# Rely on flip()/argsort() to validate axis
normalised_axis = axis if axis >= 0 else x.ndim + axis
max_i = x.shape[normalised_axis] - 1
res = max_i - res
return res
def sort(
x: Array,
/,
xp: Namespace,
*,
axis: int = -1,
descending: bool = False,
stable: bool = True,
**kwargs: object,
) -> Array:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
kwargs["kind"] = "stable"
res = xp.sort(x, axis=axis, **kwargs)
if descending:
res = xp.flip(res, axis=axis)
return res
# nonzero should error for zero-dimensional arrays
def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)
# ceil, floor, and trunc return integers for integer inputs
def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.ceil(x, **kwargs)
def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.floor(x, **kwargs)
def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.trunc(x, **kwargs)
# linear algebra functions
def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
return xp.matmul(x1, x2, **kwargs)
# Unlike transpose, matrix_transpose only transposes the last two axes.
def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
if x.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return xp.swapaxes(x, -1, -2)
def tensordot(
x1: Array,
x2: Array,
/,
xp: Namespace,
*,
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
**kwargs: object,
) -> Array:
return xp.tensordot(x1, x2, axes=axes, **kwargs)
def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
if hasattr(xp, "broadcast_tensors"):
_broadcast = xp.broadcast_tensors
else:
_broadcast = xp.broadcast_arrays
x1_ = xp.moveaxis(x1, axis, -1)
x2_ = xp.moveaxis(x2, axis, -1)
x1_, x2_ = _broadcast(x1_, x2_)
res = xp.conj(x1_[..., None, :]) @ x2_[..., None]
return res[..., 0, 0]
# isdtype is a new function in the 2022.12 array API specification.
def isdtype(
dtype: DType,
kind: DType | str | tuple[DType | str, ...],
xp: Namespace,
*,
_tuple: bool = True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
Note that outside of this function, this compat library does not yet fully
support complex numbers.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple) and _tuple:
return any(
isdtype(dtype, k, xp, _tuple=False)
for k in cast("tuple[DType | str, ...]", kind)
)
elif isinstance(kind, str):
if kind == "bool":
return dtype == xp.bool_
elif kind == "signed integer":
return xp.issubdtype(dtype, xp.signedinteger)
elif kind == "unsigned integer":
return xp.issubdtype(dtype, xp.unsignedinteger)
elif kind == "integral":
return xp.issubdtype(dtype, xp.integer)
elif kind == "real floating":
return xp.issubdtype(dtype, xp.floating)
elif kind == "complex floating":
return xp.issubdtype(dtype, xp.complexfloating)
elif kind == "numeric":
return xp.issubdtype(dtype, xp.number)
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
# This will allow things that aren't required by the spec, like
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
# more strict here to match the type annotation? Note that the
# array_api_strict implementation will be very strict.
return dtype == kind
# unstack is a new function in the 2023.12 array API standard
def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("Input array must be at least 1-d.")
return tuple(xp.moveaxis(x, axis, 0))
# numpy 1.26 does not use the standard definition for sign on complex numbers
def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
if isdtype(x.dtype, "complex floating", xp=xp):
out = (x / xp.abs(x, **kwargs))[...]
# sign(0) = 0 but the above formula would give nan
out[x == 0j] = 0j
else:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]
def finfo(type_: DType | Array, /, xp: Namespace) -> Any:
# It is surprisingly difficult to recognize a dtype apart from an array.
# np.int64 is not the same as np.asarray(1).dtype!
try:
return xp.finfo(type_)
except (ValueError, TypeError):
return xp.finfo(type_.dtype)
def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
try:
return xp.iinfo(type_)
except (ValueError, TypeError):
return xp.iinfo(type_.dtype)
__all__ = [
"arange",
"empty",
"empty_like",
"eye",
"full",
"full_like",
"linspace",
"ones",
"ones_like",
"zeros",
"zeros_like",
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"std",
"var",
"cumulative_sum",
"cumulative_prod",
"clip",
"permute_dims",
"reshape",
"argsort",
"sort",
"nonzero",
"ceil",
"floor",
"trunc",
"matmul",
"matrix_transpose",
"tensordot",
"vecdot",
"isdtype",
"unstack",
"sign",
"finfo",
"iinfo",
]
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,213 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Literal, TypeAlias
from ._typing import Array, Device, DType, Namespace
_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.
def fft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def fftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def rfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def rfftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfftn(
x: Array,
/,
xp: Namespace,
*,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def hfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.float32)
return res
def ihfft(
x: Array,
/,
xp: Namespace,
*,
n: int | None = None,
axis: int = -1,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def fftfreq(
n: int,
/,
xp: Namespace,
*,
d: float = 1.0,
dtype: DType | None = None,
device: Device | None = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.fftfreq(n, d=d)
if dtype is not None:
return res.astype(dtype)
return res
def rfftfreq(
n: int,
/,
xp: Namespace,
*,
d: float = 1.0,
dtype: DType | None = None,
device: Device | None = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.rfftfreq(n, d=d)
if dtype is not None:
return res.astype(dtype)
return res
def fftshift(
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
) -> Array:
return xp.fft.fftshift(x, axes=axes)
def ifftshift(
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
) -> Array:
return xp.fft.ifftshift(x, axes=axes)
__all__ = [
"fft",
"ifft",
"fftn",
"ifftn",
"rfft",
"irfft",
"rfftn",
"irfftn",
"hfft",
"ihfft",
"fftfreq",
"rfftfreq",
"fftshift",
"ifftshift",
]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,232 @@
from __future__ import annotations
import math
from typing import Literal, NamedTuple, cast
import numpy as np
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
from numpy.core.numeric import normalize_axis_tuple
from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
from ._typing import Array, DType, JustFloat, JustInt, Namespace
# These are in the main NumPy namespace but not in numpy.linalg
def cross(
x1: Array,
x2: Array,
/,
xp: Namespace,
*,
axis: int = -1,
**kwargs: object,
) -> Array:
return xp.cross(x1, x2, axis=axis, **kwargs)
def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
return xp.outer(x1, x2, **kwargs)
class EighResult(NamedTuple):
eigenvalues: Array
eigenvectors: Array
class QRResult(NamedTuple):
Q: Array
R: Array
class SlogdetResult(NamedTuple):
sign: Array
logabsdet: Array
class SVDResult(NamedTuple):
U: Array
S: Array
Vh: Array
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
return EighResult(*xp.linalg.eigh(x, **kwargs))
def qr(
x: Array,
/,
xp: Namespace,
*,
mode: Literal["reduced", "complete"] = "reduced",
**kwargs: object,
) -> QRResult:
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
def svd(
x: Array,
/,
xp: Namespace,
*,
full_matrices: bool = True,
**kwargs: object,
) -> SVDResult:
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
# These functions have additional keyword arguments
# The upper keyword argument is new from NumPy
def cholesky(
x: Array,
/,
xp: Namespace,
*,
upper: bool = False,
**kwargs: object,
) -> Array:
L = xp.linalg.cholesky(x, **kwargs)
if upper:
U = get_xp(xp)(matrix_transpose)(L)
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
U = xp.conj(U) # pyright: ignore[reportConstantRedefinition]
return U
return L
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
# Note that it has a different semantic meaning from tol and rcond.
def matrix_rank(
x: Array,
/,
xp: Namespace,
*,
rtol: float | Array | None = None,
**kwargs: object,
) -> Array:
# this is different from xp.linalg.matrix_rank, which supports 1
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S: Array = get_xp(xp)(svdvals)(x, **kwargs)
if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
else:
# this is different from xp.linalg.matrix_rank, which does not
# multiply the tolerance by the largest singular value.
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
return xp.count_nonzero(S > tol, axis=-1)
def pinv(
x: Array,
/,
xp: Namespace,
*,
rtol: float | Array | None = None,
**kwargs: object,
) -> Array:
# this is different from xp.linalg.pinv, which does not multiply the
# default tolerance by max(M, N).
if rtol is None:
rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
return xp.linalg.pinv(x, rcond=rtol, **kwargs)
# These functions are new in the array API spec
def matrix_norm(
x: Array,
/,
xp: Namespace,
*,
keepdims: bool = False,
ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro",
) -> Array:
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# xp.linalg.svd(compute_uv=False).
def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
return xp.linalg.svd(x, compute_uv=False)
def vector_norm(
x: Array,
/,
xp: Namespace,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
ord: JustInt | JustFloat = 2,
) -> Array:
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
# it so the input is 1-D (for axis=None), or reshape so that norm is done
# on a single dimension.
if axis is None:
# Note: xp.linalg.norm() doesn't handle 0-D arrays
_x = x.ravel()
_axis = 0
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas
# xp.linalg.norm() only supports a single axis for vector norm.
normalized_axis = cast(
"tuple[int, ...]",
normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue]
)
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
newshape = axis + rest
_x = xp.transpose(x, newshape).reshape(
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
_axis = 0
else:
_x = x
_axis = axis
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
if keepdims:
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
_axis = cast(
"tuple[int, ...]",
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
range(x.ndim) if axis is None else axis,
x.ndim,
),
)
for i in _axis:
shape[i] = 1
res = xp.reshape(res, tuple(shape))
return res
# xp.diagonal and xp.trace operate on the first two axes whereas these
# operates on the last two
def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
def trace(
x: Array,
/,
xp: Namespace,
*,
offset: int = 0,
dtype: DType | None = None,
**kwargs: object,
) -> Array:
return xp.asarray(
xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)
)
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
from collections.abc import Mapping
from types import ModuleType as Namespace
from typing import (
TYPE_CHECKING,
Literal,
Protocol,
TypeAlias,
TypedDict,
TypeVar,
final,
)
if TYPE_CHECKING:
from _typeshed import Incomplete
SupportsBufferProtocol: TypeAlias = Incomplete
Array: TypeAlias = Incomplete
Device: TypeAlias = Incomplete
DType: TypeAlias = Incomplete
else:
SupportsBufferProtocol = object
Array = object
Device = object
DType = object
_T_co = TypeVar("_T_co", covariant=True)
# These "Just" types are equivalent to the `Just` type from the `optype` library,
# apart from them not being `@runtime_checkable`.
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
@final
class JustInt(Protocol):
@property
def __class__(self, /) -> type[int]: ...
@__class__.setter
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
@final
class JustFloat(Protocol):
@property
def __class__(self, /) -> type[float]: ...
@__class__.setter
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
@final
class JustComplex(Protocol):
@property
def __class__(self, /) -> type[complex]: ...
@__class__.setter
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
#
class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
class SupportsArrayNamespace(Protocol[_T_co]):
def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
class HasShape(Protocol[_T_co]):
@property
def shape(self, /) -> _T_co: ...
# Return type of `__array_namespace_info__.default_dtypes`
Capabilities = TypedDict(
"Capabilities",
{
"boolean indexing": bool,
"data-dependent shapes": bool,
"max dimensions": int,
},
)
# Return type of `__array_namespace_info__.default_dtypes`
DefaultDTypes = TypedDict(
"DefaultDTypes",
{
"real floating": DType,
"complex floating": DType,
"integral": DType,
"indexing": DType,
},
)
_DTypeKind: TypeAlias = Literal[
"bool",
"signed integer",
"unsigned integer",
"integral",
"real floating",
"complex floating",
"numeric",
]
# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
# `__array_namespace_info__.dtypes(kind="bool")`
class DTypesBool(TypedDict):
bool: DType
# `__array_namespace_info__.dtypes(kind="signed integer")`
class DTypesSigned(TypedDict):
int8: DType
int16: DType
int32: DType
int64: DType
# `__array_namespace_info__.dtypes(kind="unsigned integer")`
class DTypesUnsigned(TypedDict):
uint8: DType
uint16: DType
uint32: DType
uint64: DType
# `__array_namespace_info__.dtypes(kind="integral")`
class DTypesIntegral(DTypesSigned, DTypesUnsigned):
pass
# `__array_namespace_info__.dtypes(kind="real floating")`
class DTypesReal(TypedDict):
float32: DType
float64: DType
# `__array_namespace_info__.dtypes(kind="complex floating")`
class DTypesComplex(TypedDict):
complex64: DType
complex128: DType
# `__array_namespace_info__.dtypes(kind="numeric")`
class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
pass
# `__array_namespace_info__.dtypes(kind=None)` (default)
class DTypesAll(DTypesBool, DTypesNumeric):
pass
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
DTypesAny: TypeAlias = Mapping[str, DType]
__all__ = [
"Array",
"Capabilities",
"DType",
"DTypeKind",
"DTypesAny",
"DTypesAll",
"DTypesBool",
"DTypesNumeric",
"DTypesIntegral",
"DTypesSigned",
"DTypesUnsigned",
"DTypesReal",
"DTypesComplex",
"DefaultDTypes",
"Device",
"HasShape",
"Namespace",
"JustInt",
"JustFloat",
"JustComplex",
"NestedSequence",
"SupportsArrayNamespace",
"SupportsBufferProtocol",
]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,13 @@
from cupy import * # noqa: F403
# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round # noqa: F401
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
__array_api_version__ = '2024.12'

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
from typing import Optional
import cupy as cp
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from .._internal import get_xp
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType
bool = cp.bool_
# Basic renames
acos = cp.arccos
acosh = cp.arccosh
asin = cp.arcsin
asinh = cp.arcsinh
atan = cp.arctan
atan2 = cp.arctan2
atanh = cp.arctanh
bitwise_left_shift = cp.left_shift
bitwise_invert = cp.invert
bitwise_right_shift = cp.right_shift
concat = cp.concatenate
pow = cp.power
arange = get_xp(cp)(_aliases.arange)
empty = get_xp(cp)(_aliases.empty)
empty_like = get_xp(cp)(_aliases.empty_like)
eye = get_xp(cp)(_aliases.eye)
full = get_xp(cp)(_aliases.full)
full_like = get_xp(cp)(_aliases.full_like)
linspace = get_xp(cp)(_aliases.linspace)
ones = get_xp(cp)(_aliases.ones)
ones_like = get_xp(cp)(_aliases.ones_like)
zeros = get_xp(cp)(_aliases.zeros)
zeros_like = get_xp(cp)(_aliases.zeros_like)
UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult)
unique_all = get_xp(cp)(_aliases.unique_all)
unique_counts = get_xp(cp)(_aliases.unique_counts)
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
unique_values = get_xp(cp)(_aliases.unique_values)
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
clip = get_xp(cp)(_aliases.clip)
permute_dims = get_xp(cp)(_aliases.permute_dims)
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
ceil = get_xp(cp)(_aliases.ceil)
floor = get_xp(cp)(_aliases.floor)
trunc = get_xp(cp)(_aliases.trunc)
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
sign = get_xp(cp)(_aliases.sign)
finfo = get_xp(cp)(_aliases.finfo)
iinfo = get_xp(cp)(_aliases.iinfo)
# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: (
Array
| bool | int | float | complex
| NestedSequence[bool | int | float | complex]
| SupportsBufferProtocol
),
/,
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
copy: Optional[bool] = None,
**kwargs,
) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
with cp.cuda.Device(device):
if copy is None:
return cp.asarray(obj, dtype=dtype, **kwargs)
else:
res = cp.array(obj, dtype=dtype, copy=copy, **kwargs)
if not copy and res is not obj:
raise ValueError("Unable to avoid copy while creating an array as requested")
return res
def astype(
x: Array,
dtype: DType,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> Array:
if device is None:
return x.astype(dtype=dtype, copy=copy)
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
return out.copy() if copy and out is x else out
# cupy.count_nonzero does not have keepdims
def count_nonzero(
x: Array,
axis=None,
keepdims=False
) -> Array:
result = cp.count_nonzero(x, axis)
if keepdims:
if axis is None:
return cp.reshape(result, [1]*x.ndim)
return cp.expand_dims(result, axis)
return result
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
return cp.take_along_axis(x, indices, axis=axis)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
vecdot = cp.vecdot
else:
vecdot = get_xp(cp)(_aliases.vecdot)
if hasattr(cp, 'isdtype'):
isdtype = cp.isdtype
else:
isdtype = get_xp(cp)(_aliases.isdtype)
if hasattr(cp, 'unstack'):
unstack = cp.unstack
else:
unstack = get_xp(cp)(_aliases.unstack)
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
'take_along_axis']
_all_ignore = ['cp', 'get_xp']

View File

@@ -0,0 +1,336 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""
from cupy import (
dtype,
cuda,
bool_ as bool,
intp,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
complex64,
complex128,
)
class __array_namespace_info__:
"""
Get the array API inspection namespace for CuPy.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for CuPy.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': cupy.float64,
'complex floating': cupy.complex128,
'integral': cupy.int64,
'indexing': cupy.int64}
"""
__module__ = 'cupy'
def capabilities(self):
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing. Always ``True`` for CuPy.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
CuPy.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True,
'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
"max dimensions": 64,
}
def default_device(self):
"""
The default device used for new CuPy arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : Device
The default device used for new CuPy arrays.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_device()
Device(0)
Notes
-----
This method returns the static default device when CuPy is initialized.
However, the *current* device used by creation functions (``empty`` etc.)
can be changed globally or with a context manager.
See Also
--------
https://github.com/data-apis/array-api/issues/835
"""
return cuda.Device(0)
def default_dtypes(self, *, device=None):
"""
The default data types used for new CuPy arrays.
For CuPy, this always returns the following dictionary:
- **"real floating"**: ``cupy.float64``
- **"complex floating"**: ``cupy.complex128``
- **"integral"**: ``cupy.intp``
- **"indexing"**: ``cupy.intp``
Parameters
----------
device : str, optional
The device to get the default data types for.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new CuPy
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': cupy.float64,
'complex floating': cupy.complex128,
'integral': cupy.int64,
'indexing': cupy.int64}
"""
# TODO: Does this depend on device?
return {
"real floating": dtype(float64),
"complex floating": dtype(complex128),
"integral": dtype(intp),
"indexing": dtype(intp),
}
def dtypes(self, *, device=None, kind=None):
"""
The array API data types supported by CuPy.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : str, optional
The device to get the data types for.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
CuPy data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': cupy.int8,
'int16': cupy.int16,
'int32': cupy.int32,
'int64': cupy.int64}
"""
# TODO: Does this depend on device?
if kind is None:
return {
"bool": dtype(bool),
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
}
if kind == "unsigned integer":
return {
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "integral":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "real floating":
return {
"float32": dtype(float32),
"float64": dtype(float64),
}
if kind == "complex floating":
return {
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "numeric":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if isinstance(kind, tuple):
res = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
def devices(self):
"""
The devices supported by CuPy.
Returns
-------
devices : list[Device]
The devices supported by CuPy.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
"""
return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
__all__ = ["Array", "DType", "Device"]
_all_ignore = ["cp"]
from typing import TYPE_CHECKING
import cupy as cp
from cupy import ndarray as Array
from cupy.cuda.device import Device
if TYPE_CHECKING:
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
DType = cp.dtype[
cp.intp
| cp.int8
| cp.int16
| cp.int32
| cp.int64
| cp.uint8
| cp.uint16
| cp.uint32
| cp.uint64
| cp.float32
| cp.float64
| cp.complex64
| cp.complex128
| cp.bool_
]
else:
DType = cp.dtype

View File

@@ -0,0 +1,36 @@
from cupy.fft import * # noqa: F403
# cupy.fft doesn't have __all__. If it is added, replace this with
#
# from cupy.fft import __all__ as linalg_all
_n = {}
exec('from cupy.fft import *', _n)
del _n['__builtins__']
fft_all = list(_n)
del _n
from ..common import _fft
from .._internal import get_xp
import cupy as cp
fft = get_xp(cp)(_fft.fft)
ifft = get_xp(cp)(_fft.ifft)
fftn = get_xp(cp)(_fft.fftn)
ifftn = get_xp(cp)(_fft.ifftn)
rfft = get_xp(cp)(_fft.rfft)
irfft = get_xp(cp)(_fft.irfft)
rfftn = get_xp(cp)(_fft.rfftn)
irfftn = get_xp(cp)(_fft.irfftn)
hfft = get_xp(cp)(_fft.hfft)
ihfft = get_xp(cp)(_fft.ihfft)
fftfreq = get_xp(cp)(_fft.fftfreq)
rfftfreq = get_xp(cp)(_fft.rfftfreq)
fftshift = get_xp(cp)(_fft.fftshift)
ifftshift = get_xp(cp)(_fft.ifftshift)
__all__ = fft_all + _fft.__all__
del get_xp
del cp
del fft_all
del _fft

View File

@@ -0,0 +1,49 @@
from cupy.linalg import * # noqa: F403
# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n
from ..common import _linalg
from .._internal import get_xp
import cupy as cp
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
__all__ = linalg_all + _linalg.__all__
del get_xp
del cp
del linalg_all
del _linalg

View File

@@ -0,0 +1,12 @@
from typing import Final
from dask.array import * # noqa: F403
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
__array_api_version__: Final = "2024.12"
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')

View File

@@ -0,0 +1,376 @@
# pyright: reportPrivateUsage=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
from __future__ import annotations
from builtins import bool as py_bool
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from typing_extensions import TypeIs
import dask.array as da
import numpy as np
from numpy import bool_ as bool
from numpy import (
can_cast,
complex64,
complex128,
float32,
float64,
int8,
int16,
int32,
int64,
result_type,
uint8,
uint16,
uint32,
uint64,
)
from ..._internal import get_xp
from ...common import _aliases, _helpers, array_namespace
from ...common._typing import (
Array,
Device,
DType,
NestedSequence,
SupportsBufferProtocol,
)
from ._info import __array_namespace_info__
isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)
# da.astype doesn't respect copy=True
def astype(
x: Array,
dtype: DType,
/,
*,
copy: py_bool = True,
device: Device | None = None,
) -> Array:
"""
Array API compatibility wrapper for astype().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)
if not copy and dtype == x.dtype:
return x
x = x.astype(dtype)
return x.copy() if copy else x
# Common aliases
# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask
def arange(
start: float,
/,
stop: float | None = None,
step: float = 1,
*,
dtype: DType | None = None,
device: Device | None = None,
**kwargs: object,
) -> Array:
"""
Array API compatibility wrapper for arange().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)
args: list[Any] = [start]
if stop is not None:
args.append(stop)
else:
# stop is None, so start is actually stop
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return da.arange(*args, dtype=dtype, **kwargs)
eye = get_xp(da)(_aliases.eye)
linspace = get_xp(da)(_aliases.linspace)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)
nonzero = get_xp(da)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)
finfo = get_xp(np)(_aliases.finfo)
iinfo = get_xp(np)(_aliases.iinfo)
# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
/,
*,
dtype: DType | None = None,
device: Device | None = None,
copy: py_bool | None = None,
**kwargs: object,
) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)
if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
if copy is False:
raise ValueError("Unable to avoid copy when changing dtype")
obj = obj.astype(dtype)
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
if copy is False:
raise ValueError(
"Unable to avoid copy when converting a non-dask object to dask"
)
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)
# Element wise aliases
from dask.array import arccos as acos
from dask.array import arccosh as acosh
from dask.array import arcsin as asin
from dask.array import arcsinh as asinh
from dask.array import arctan as atan
from dask.array import arctan2 as atan2
from dask.array import arctanh as atanh
# Other
from dask.array import concatenate as concat
from dask.array import invert as bitwise_invert
from dask.array import left_shift as bitwise_left_shift
from dask.array import power as pow
from dask.array import right_shift as bitwise_right_shift
# dask.array.clip does not work unless all three arguments are provided.
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
# now).
def clip(
x: Array,
/,
min: float | Array | None = None,
max: float | Array | None = None,
) -> Array:
"""
Array API compatibility wrapper for clip().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
return a is None or isinstance(a, (int, float))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
# TODO: This won't handle dask unknown shapes
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
if min is not None:
min = da.broadcast_to(da.asarray(min), result_shape)
if max is not None:
max = da.broadcast_to(da.asarray(max), result_shape)
if min is None and max is None:
return da.positive(x)
if min is None:
return astype(da.minimum(x, max), x.dtype)
if max is None:
return astype(da.maximum(x, min), x.dtype)
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
"""
Make sure that Array is not broken into multiple chunks along axis.
Returns
-------
x : Array
The input Array with a single chunk along axis.
restore : Callable[Array, Array]
function to apply to the output to rechunk it back into reasonable chunks
"""
if axis < 0:
axis += x.ndim
if x.numblocks[axis] < 2:
return x, lambda x: x
# Break chunks on other axes in an attempt to keep chunk size low
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
# Rather than reconstructing the original chunks, which can be a
# very expensive affair, just break down oversized chunks without
# incurring in any transfers over the network.
# This has the downside of a risk of overchunking if the array is
# then used in operations against other arrays that match the
# original chunking pattern.
return x, lambda x: x.rechunk()
def sort(
x: Array,
/,
*,
axis: int = -1,
descending: py_bool = False,
stable: py_bool = True,
) -> Array:
"""
Array API compatibility layer around the lack of sort() in Dask.
Warnings
--------
This function temporarily rechunks the array along `axis` to a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
x, restore = _ensure_single_chunk(x, axis)
meta_xp = array_namespace(x._meta)
x = da.map_blocks(
meta_xp.sort,
x,
axis=axis,
meta=x._meta,
dtype=x.dtype,
descending=descending,
stable=stable,
)
return restore(x)
def argsort(
x: Array,
/,
*,
axis: int = -1,
descending: py_bool = False,
stable: py_bool = True,
) -> Array:
"""
Array API compatibility layer around the lack of argsort() in Dask.
See the corresponding documentation in the array library and/or the array API
specification for more details.
Warnings
--------
This function temporarily rechunks the array along `axis` into a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
"""
x, restore = _ensure_single_chunk(x, axis)
meta_xp = array_namespace(x._meta)
dtype = meta_xp.argsort(x._meta).dtype
meta = meta_xp.astype(x._meta, dtype)
x = da.map_blocks(
meta_xp.argsort,
x,
axis=axis,
meta=meta,
dtype=dtype,
descending=descending,
stable=stable,
)
return restore(x)
# dask.array.count_nonzero does not have keepdims
def count_nonzero(
x: Array,
axis: int | None = None,
keepdims: py_bool = False,
) -> Array:
result = da.count_nonzero(x, axis)
if keepdims:
if axis is None:
return da.reshape(result, [1] * x.ndim)
return da.expand_dims(result, axis)
return result
__all__ = [
"__array_namespace_info__",
"count_nonzero",
"bool",
"int8", "int16", "int32", "int64",
"uint8", "uint16", "uint32", "uint64",
"float32", "float64",
"complex64", "complex128",
"asarray", "astype", "can_cast", "result_type",
"pow",
"concat",
"acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
] # fmt: skip
__all__ += _aliases.__all__
_all_ignore = ["array_namespace", "get_xp", "da", "np"]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,416 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""
# pyright: reportPrivateUsage=false
from __future__ import annotations
from typing import Literal as L
from typing import TypeAlias, overload
from numpy import bool_ as bool
from numpy import (
complex64,
complex128,
dtype,
float32,
float64,
int8,
int16,
int32,
int64,
intp,
uint8,
uint16,
uint32,
uint64,
)
from ...common._helpers import _DASK_DEVICE, _dask_device
from ...common._typing import (
Capabilities,
DefaultDTypes,
DType,
DTypeKind,
DTypesAll,
DTypesAny,
DTypesBool,
DTypesComplex,
DTypesIntegral,
DTypesNumeric,
DTypesReal,
DTypesSigned,
DTypesUnsigned,
)
_Device: TypeAlias = L["cpu"] | _dask_device
class __array_namespace_info__:
"""
Get the array API inspection namespace for Dask.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for Dask.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': dask.float64,
'complex floating': dask.complex128,
'integral': dask.int64,
'indexing': dask.int64}
"""
__module__ = "dask.array"
def capabilities(self) -> Capabilities:
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing.
Dask support boolean indexing as long as both the index
and the indexed arrays have known shapes.
Note however that the output .shape and .size properties
will contain a non-compliant math.nan instead of None.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes.
Dask implements unique_values et.al.
Note however that the output .shape and .size properties
will contain a non-compliant math.nan instead of None.
- **"max dimensions"**: integer indicating the maximum number of
dimensions supported by the array library.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True,
'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
"max dimensions": 64,
}
def default_device(self) -> L["cpu"]:
"""
The default device used for new Dask arrays.
For Dask, this always returns ``'cpu'``.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : Device
The default device used for new Dask arrays.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_device()
'cpu'
"""
return "cpu"
def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
"""
The default data types used for new Dask arrays.
For Dask, this always returns the following dictionary:
- **"real floating"**: ``numpy.float64``
- **"complex floating"**: ``numpy.complex128``
- **"integral"**: ``numpy.intp``
- **"indexing"**: ``numpy.intp``
Parameters
----------
device : str, optional
The device to get the default data types for.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new Dask
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': dask.float64,
'complex floating': dask.complex128,
'integral': dask.int64,
'indexing': dask.int64}
"""
if device not in ["cpu", _DASK_DEVICE, None]:
raise ValueError(
f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, '
f"but received: {device!r}"
)
return {
"real floating": dtype(float64),
"complex floating": dtype(complex128),
"integral": dtype(intp),
"indexing": dtype(intp),
}
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: None = None
) -> DTypesAll: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["bool"]
) -> DTypesBool: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["signed integer"]
) -> DTypesSigned: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["unsigned integer"]
) -> DTypesUnsigned: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["integral"]
) -> DTypesIntegral: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["real floating"]
) -> DTypesReal: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["complex floating"]
) -> DTypesComplex: ...
@overload
def dtypes(
self, /, *, device: _Device | None = None, kind: L["numeric"]
) -> DTypesNumeric: ...
def dtypes(
self, /, *, device: _Device | None = None, kind: DTypeKind | None = None
) -> DTypesAny:
"""
The array API data types supported by Dask.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : str, optional
The device to get the data types for.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
Dask data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': dask.int8,
'int16': dask.int16,
'int32': dask.int32,
'int64': dask.int64}
"""
if device not in ["cpu", _DASK_DEVICE, None]:
raise ValueError(
'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
f" {device}"
)
if kind is None:
return {
"bool": dtype(bool),
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
}
if kind == "unsigned integer":
return {
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "integral":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "real floating":
return {
"float32": dtype(float32),
"float64": dtype(float64),
}
if kind == "complex floating":
return {
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "numeric":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall]
res: dict[str, DType] = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
def devices(self) -> list[_Device]:
"""
The devices supported by Dask.
For Dask, this always returns ``['cpu', DASK_DEVICE]``.
Returns
-------
devices : list[Device]
The devices supported by Dask.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.devices()
['cpu', DASK_DEVICE]
"""
return ["cpu", _DASK_DEVICE]

View File

@@ -0,0 +1,21 @@
from dask.array.fft import * # noqa: F403
# dask.array.fft doesn't have __all__. If it is added, replace this with
#
# from dask.array.fft import __all__ as linalg_all
_n = {}
exec('from dask.array.fft import *', _n)
for k in ("__builtins__", "Sequence", "annotations", "warnings"):
_n.pop(k, None)
fft_all = list(_n)
del _n, k
from ...common import _fft
from ..._internal import get_xp
import dask.array as da
fftfreq = get_xp(da)(_fft.fftfreq)
rfftfreq = get_xp(da)(_fft.rfftfreq)
__all__ = fft_all + ["fftfreq", "rfftfreq"]
_all_ignore = ["da", "fft_all", "get_xp", "warnings"]

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import Literal
import dask.array as da
# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
from dask.array import matmul, outer, tensordot
# Exports
from dask.array.linalg import * # noqa: F403
from ..._internal import get_xp
from ...common import _linalg
from ...common._typing import Array as _Array
from ._aliases import matrix_transpose, vecdot
# dask.array.linalg doesn't have __all__. If it is added, replace this with
#
# from dask.array.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
_n.pop(k, None)
linalg_all = list(_n)
del _n, k
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
# TODO: use the QR wrapper once dask
# supports the mode keyword on QR
# https://github.com/dask/dask/issues/10388
#qr = get_xp(da)(_linalg.qr)
def qr(
x: _Array,
mode: Literal["reduced", "complete"] = "reduced",
**kwargs: object,
) -> QRResult:
if mode != "reduced":
raise ValueError("dask arrays only support using mode='reduced'")
return QRResult(*da.linalg.qr(x, **kwargs))
trace = get_xp(da)(_linalg.trace)
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)
# Wrap the svd functions to not pass full_matrices to dask
# when full_matrices=False (as that is the default behavior for dask),
# and dask doesn't have the full_matrices keyword
def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult:
if full_matrices:
raise ValueError("full_matrics=True is not supported by dask.")
return da.linalg.svd(x, coerce_signs=False, **kwargs)
def svdvals(x: _Array) -> _Array:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s
vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
"matrix_transpose", "vecdot", "EighResult",
"QRResult", "SlogdetResult", "SVDResult", "qr",
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
"vector_norm", "diagonal"]
_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings']

View File

@@ -0,0 +1,28 @@
# ruff: noqa: PLC0414
from typing import Final
from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary]
# from numpy import * doesn't overwrite these builtin names
from numpy import abs as abs
from numpy import max as max
from numpy import min as min
from numpy import round as round
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
#
# from . import linalg
#
# It doesn't overwrite np.linalg from above. The import is generated
# dynamically so that the library can be vendored.
__import__(__package__ + ".linalg")
__import__(__package__ + ".fft")
from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401
__array_api_version__: Final = "2024.12"

View File

@@ -0,0 +1,190 @@
# pyright: reportPrivateUsage=false
from __future__ import annotations
from builtins import bool as py_bool
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
import numpy as np
from .._internal import get_xp
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType
if TYPE_CHECKING:
from typing_extensions import Buffer, TypeIs
# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`:
# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
bool = np.bool_
# Basic renames
acos = np.arccos
acosh = np.arccosh
asin = np.arcsin
asinh = np.arcsinh
atan = np.arctan
atan2 = np.arctan2
atanh = np.arctanh
bitwise_left_shift = np.left_shift
bitwise_invert = np.invert
bitwise_right_shift = np.right_shift
concat = np.concatenate
pow = np.power
arange = get_xp(np)(_aliases.arange)
empty = get_xp(np)(_aliases.empty)
empty_like = get_xp(np)(_aliases.empty_like)
eye = get_xp(np)(_aliases.eye)
full = get_xp(np)(_aliases.full)
full_like = get_xp(np)(_aliases.full_like)
linspace = get_xp(np)(_aliases.linspace)
ones = get_xp(np)(_aliases.ones)
ones_like = get_xp(np)(_aliases.ones_like)
zeros = get_xp(np)(_aliases.zeros)
zeros_like = get_xp(np)(_aliases.zeros_like)
UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult)
unique_all = get_xp(np)(_aliases.unique_all)
unique_counts = get_xp(np)(_aliases.unique_counts)
unique_inverse = get_xp(np)(_aliases.unique_inverse)
unique_values = get_xp(np)(_aliases.unique_values)
std = get_xp(np)(_aliases.std)
var = get_xp(np)(_aliases.var)
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
clip = get_xp(np)(_aliases.clip)
permute_dims = get_xp(np)(_aliases.permute_dims)
reshape = get_xp(np)(_aliases.reshape)
argsort = get_xp(np)(_aliases.argsort)
sort = get_xp(np)(_aliases.sort)
nonzero = get_xp(np)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)
finfo = get_xp(np)(_aliases.finfo)
iinfo = get_xp(np)(_aliases.iinfo)
def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction]
try:
memoryview(obj) # pyright: ignore[reportArgumentType]
except TypeError:
return False
return True
# asarray also adds the copy keyword, which is not present in numpy 1.0.
# asarray() is different enough between numpy, cupy, and dask, the logic
# complicated enough that it's easier to define it separately for each module
# rather than trying to combine everything into one function in common/
def asarray(
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
*,
dtype: DType | None = None,
device: Device | None = None,
copy: _Copy | None = None,
**kwargs: Any,
) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
_helpers._check_device(np, device)
if copy is None:
copy = np._CopyMode.IF_NEEDED
elif copy is False:
copy = np._CopyMode.NEVER
elif copy is True:
copy = np._CopyMode.ALWAYS
return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore
def astype(
x: Array,
dtype: DType,
/,
*,
copy: py_bool = True,
device: Device | None = None,
) -> Array:
_helpers._check_device(np, device)
return x.astype(dtype=dtype, copy=copy)
# count_nonzero returns a python int for axis=None and keepdims=False
# https://github.com/numpy/numpy/issues/17562
def count_nonzero(
x: Array,
axis: int | tuple[int, ...] | None = None,
keepdims: py_bool = False,
) -> Array:
# NOTE: this is currently incorrectly typed in numpy, but will be fixed in
# numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue]
if axis is None and not keepdims:
return np.asarray(result)
return result
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
return np.take_along_axis(x, indices, axis=axis)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, "vecdot"):
vecdot = np.vecdot
else:
vecdot = get_xp(np)(_aliases.vecdot)
if hasattr(np, "isdtype"):
isdtype = np.isdtype
else:
isdtype = get_xp(np)(_aliases.isdtype)
if hasattr(np, "unstack"):
unstack = np.unstack
else:
unstack = get_xp(np)(_aliases.unstack)
__all__ = [
"__array_namespace_info__",
"asarray",
"astype",
"acos",
"acosh",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"bitwise_left_shift",
"bitwise_invert",
"bitwise_right_shift",
"bool",
"concat",
"count_nonzero",
"pow",
"take_along_axis"
]
__all__ += _aliases.__all__
_all_ignore = ["np", "get_xp"]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,366 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""
from __future__ import annotations
from numpy import bool_ as bool
from numpy import (
complex64,
complex128,
dtype,
float32,
float64,
int8,
int16,
int32,
int64,
intp,
uint8,
uint16,
uint32,
uint64,
)
from ._typing import Device, DType
class __array_namespace_info__:
"""
Get the array API inspection namespace for NumPy.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for NumPy.
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': numpy.float64,
'complex floating': numpy.complex128,
'integral': numpy.int64,
'indexing': numpy.int64}
"""
__module__ = 'numpy'
def capabilities(self):
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing. Always ``True`` for NumPy.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
NumPy.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True,
'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
"max dimensions": 64,
}
def default_device(self):
"""
The default device used for new NumPy arrays.
For NumPy, this always returns ``'cpu'``.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : Device
The default device used for new NumPy arrays.
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.default_device()
'cpu'
"""
return "cpu"
def default_dtypes(
self,
*,
device: Device | None = None,
) -> dict[str, dtype[intp | float64 | complex128]]:
"""
The default data types used for new NumPy arrays.
For NumPy, this always returns the following dictionary:
- **"real floating"**: ``numpy.float64``
- **"complex floating"**: ``numpy.complex128``
- **"integral"**: ``numpy.intp``
- **"indexing"**: ``numpy.intp``
Parameters
----------
device : str, optional
The device to get the default data types for. For NumPy, only
``'cpu'`` is allowed.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new NumPy
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': numpy.float64,
'complex floating': numpy.complex128,
'integral': numpy.int64,
'indexing': numpy.int64}
"""
if device not in ["cpu", None]:
raise ValueError(
'Device not understood. Only "cpu" is allowed, but received:'
f' {device}'
)
return {
"real floating": dtype(float64),
"complex floating": dtype(complex128),
"integral": dtype(intp),
"indexing": dtype(intp),
}
def dtypes(
self,
*,
device: Device | None = None,
kind: str | tuple[str, ...] | None = None,
) -> dict[str, DType]:
"""
The array API data types supported by NumPy.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : str, optional
The device to get the data types for. For NumPy, only ``'cpu'`` is
allowed.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
NumPy data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': numpy.int8,
'int16': numpy.int16,
'int32': numpy.int32,
'int64': numpy.int64}
"""
if device not in ["cpu", None]:
raise ValueError(
'Device not understood. Only "cpu" is allowed, but received:'
f' {device}'
)
if kind is None:
return {
"bool": dtype(bool),
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "bool":
return {"bool": dtype(bool)}
if kind == "signed integer":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
}
if kind == "unsigned integer":
return {
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "integral":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "real floating":
return {
"float32": dtype(float32),
"float64": dtype(float64),
}
if kind == "complex floating":
return {
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "numeric":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if isinstance(kind, tuple):
res: dict[str, DType] = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
def devices(self) -> list[Device]:
"""
The devices supported by NumPy.
For NumPy, this always returns ``['cpu']``.
Returns
-------
devices : list[Device]
The devices supported by NumPy.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.devices()
['cpu']
"""
return ["cpu"]
__all__ = ["__array_namespace_info__"]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import numpy as np
Device: TypeAlias = Literal["cpu"]
if TYPE_CHECKING:
# NumPy 1.x on Python 3.10 fails to parse np.dtype[]
DType: TypeAlias = np.dtype[
np.bool_
| np.integer[Any]
| np.float32
| np.float64
| np.complex64
| np.complex128
]
Array: TypeAlias = np.ndarray[Any, DType]
else:
DType: TypeAlias = np.dtype
Array: TypeAlias = np.ndarray
__all__ = ["Array", "DType", "Device"]
_all_ignore = ["np"]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,35 @@
import numpy as np
from numpy.fft import __all__ as fft_all
from numpy.fft import fft2, ifft2, irfft2, rfft2
from .._internal import get_xp
from ..common import _fft
fft = get_xp(np)(_fft.fft)
ifft = get_xp(np)(_fft.ifft)
fftn = get_xp(np)(_fft.fftn)
ifftn = get_xp(np)(_fft.ifftn)
rfft = get_xp(np)(_fft.rfft)
irfft = get_xp(np)(_fft.irfft)
rfftn = get_xp(np)(_fft.rfftn)
irfftn = get_xp(np)(_fft.irfftn)
hfft = get_xp(np)(_fft.hfft)
ihfft = get_xp(np)(_fft.ihfft)
fftfreq = get_xp(np)(_fft.fftfreq)
rfftfreq = get_xp(np)(_fft.rfftfreq)
fftshift = get_xp(np)(_fft.fftshift)
ifftshift = get_xp(np)(_fft.ifftshift)
__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
__all__ += _fft.__all__
def __dir__() -> list[str]:
return __all__
del get_xp
del np
del fft_all
del _fft

View File

@@ -0,0 +1,143 @@
# pyright: reportAttributeAccessIssue=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
from __future__ import annotations
import numpy as np
# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
from numpy.linalg import (
LinAlgError,
cond,
det,
eig,
eigvals,
eigvalsh,
inv,
lstsq,
matrix_power,
multi_dot,
norm,
tensorinv,
tensorsolve,
)
from .._internal import get_xp
from ..common import _linalg
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
from ._typing import Array
cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.
# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.
# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
def solve(x1: Array, x2: Array, /) -> Array:
try:
from numpy.linalg._linalg import (
_assert_stacked_2d,
_assert_stacked_square,
_commonType,
_makearray,
_raise_linalgerror_singular,
isComplexType,
)
except ImportError:
from numpy.linalg.linalg import (
_assert_stacked_2d,
_assert_stacked_square,
_commonType,
_makearray,
_raise_linalgerror_singular,
isComplexType,
)
from numpy.linalg import _umath_linalg
x1, _ = _makearray(x1)
_assert_stacked_2d(x1)
_assert_stacked_square(x1)
x2, wrap = _makearray(x2)
t, result_t = _commonType(x1, x2)
# This part is different from np.linalg.solve
gufunc: np.ufunc
if x2.ndim == 1:
gufunc = _umath_linalg.solve1
else:
gufunc = _umath_linalg.solve
# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
signature = "DD->D" if isComplexType(t) else "dd->d"
with np.errstate(
call=_raise_linalgerror_singular,
invalid="call",
over="ignore",
divide="ignore",
under="ignore",
):
r: Array = gufunc(x1, x2, signature=signature)
return wrap(r.astype(result_t, copy=False))
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, "vector_norm"):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
__all__ = [
"LinAlgError",
"cond",
"det",
"eig",
"eigvals",
"eigvalsh",
"inv",
"lstsq",
"matrix_power",
"multi_dot",
"norm",
"tensorinv",
"tensorsolve",
]
__all__ += _linalg.__all__
__all__ += ["solve", "vector_norm"]
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,22 @@
from torch import * # noqa: F403
# Several names are not included in the above import *
import torch
for n in dir(torch):
if (n.startswith('_')
or n.endswith('_')
or 'cuda' in n
or 'cpu' in n
or 'backward' in n):
continue
exec(f"{n} = torch.{n}")
del n
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
__array_api_version__ = '2024.12'

View File

@@ -0,0 +1,855 @@
from __future__ import annotations
from functools import reduce as _reduce, wraps as _wraps
from builtins import all as _builtin_all, any as _builtin_any
from typing import Any, List, Optional, Sequence, Tuple, Union, Literal
import torch
from .._internal import get_xp
from ..common import _aliases
from ..common._typing import NestedSequence, SupportsBufferProtocol
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType
_int_dtypes = {
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
}
try:
# torch >=2.3
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
except AttributeError:
pass
_array_api_dtypes = {
torch.bool,
*_int_dtypes,
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
}
_promotion_table = {
# ints
(torch.int8, torch.int16): torch.int16,
(torch.int8, torch.int32): torch.int32,
(torch.int8, torch.int64): torch.int64,
(torch.int16, torch.int32): torch.int32,
(torch.int16, torch.int64): torch.int64,
(torch.int32, torch.int64): torch.int64,
# ints and uints (mixed sign)
(torch.uint8, torch.int8): torch.int16,
(torch.uint8, torch.int16): torch.int16,
(torch.uint8, torch.int32): torch.int32,
(torch.uint8, torch.int64): torch.int64,
# floats
(torch.float32, torch.float64): torch.float64,
# complexes
(torch.complex64, torch.complex128): torch.complex128,
# Mixed float and complex
(torch.float32, torch.complex64): torch.complex64,
(torch.float32, torch.complex128): torch.complex128,
(torch.float64, torch.complex64): torch.complex128,
(torch.float64, torch.complex128): torch.complex128,
}
_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
_promotion_table.update({(a, a): a for a in _array_api_dtypes})
def _two_arg(f):
@_wraps(f)
def _f(x1, x2, /, **kwargs):
x1, x2 = _fix_promotion(x1, x2)
return f(x1, x2, **kwargs)
if _f.__doc__ is None:
_f.__doc__ = f"""\
Array API compatibility wrapper for torch.{f.__name__}.
See the corresponding PyTorch documentation and/or the array API specification
for more details.
"""
return _f
def _fix_promotion(x1, x2, only_scalar=True):
if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor):
return x1, x2
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
return x1, x2
# If an argument is 0-D pytorch downcasts the other argument
if not only_scalar or x1.shape == ():
dtype = result_type(x1, x2)
x2 = x2.to(dtype)
if not only_scalar or x2.shape == ():
dtype = result_type(x1, x2)
x1 = x1.to(dtype)
return x1, x2
_py_scalars = (bool, int, float, complex)
def result_type(
*arrays_and_dtypes: Array | DType | bool | int | float | complex
) -> DType:
num = len(arrays_and_dtypes)
if num == 0:
raise ValueError("At least one array or dtype must be provided")
elif num == 1:
x = arrays_and_dtypes[0]
if isinstance(x, torch.dtype):
return x
return x.dtype
if num == 2:
x, y = arrays_and_dtypes
return _result_type(x, y)
else:
# sort scalars so that they are treated last
scalars, others = [], []
for x in arrays_and_dtypes:
if isinstance(x, _py_scalars):
scalars.append(x)
else:
others.append(x)
if not others:
raise ValueError("At least one array or dtype must be provided")
# combine left-to-right
return _reduce(_result_type, others + scalars)
def _result_type(
x: Array | DType | bool | int | float | complex,
y: Array | DType | bool | int | float | complex,
) -> DType:
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
xdt = x if isinstance(x, torch.dtype) else x.dtype
ydt = y if isinstance(y, torch.dtype) else y.dtype
try:
return _promotion_table[xdt, ydt]
except KeyError:
pass
# This doesn't result_type(dtype, dtype) for non-array API dtypes
# because torch.result_type only accepts tensors. This does however, allow
# cross-kind promotion.
x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
return torch.result_type(x, y)
def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
if not isinstance(from_, torch.dtype):
from_ = from_.dtype
return torch.can_cast(from_, to)
# Basic renames
bitwise_invert = torch.bitwise_not
newaxis = None
# torch.conj sets the conjugation bit, which breaks conversion to other
# libraries. See https://github.com/data-apis/array-api-compat/issues/173
conj = torch.conj_physical
# Two-arg elementwise functions
# These require a wrapper to do the correct type promotion on 0-D tensors
add = _two_arg(torch.add)
atan2 = _two_arg(torch.atan2)
bitwise_and = _two_arg(torch.bitwise_and)
bitwise_left_shift = _two_arg(torch.bitwise_left_shift)
bitwise_or = _two_arg(torch.bitwise_or)
bitwise_right_shift = _two_arg(torch.bitwise_right_shift)
bitwise_xor = _two_arg(torch.bitwise_xor)
copysign = _two_arg(torch.copysign)
divide = _two_arg(torch.divide)
# Also a rename. torch.equal does not broadcast
equal = _two_arg(torch.eq)
floor_divide = _two_arg(torch.floor_divide)
greater = _two_arg(torch.greater)
greater_equal = _two_arg(torch.greater_equal)
hypot = _two_arg(torch.hypot)
less = _two_arg(torch.less)
less_equal = _two_arg(torch.less_equal)
logaddexp = _two_arg(torch.logaddexp)
# logical functions are not included here because they only accept bool in the
# spec, so type promotion is irrelevant.
maximum = _two_arg(torch.maximum)
minimum = _two_arg(torch.minimum)
multiply = _two_arg(torch.multiply)
not_equal = _two_arg(torch.not_equal)
pow = _two_arg(torch.pow)
remainder = _two_arg(torch.remainder)
subtract = _two_arg(torch.subtract)
def asarray(
obj: (
Array
| bool | int | float | complex
| NestedSequence[bool | int | float | complex]
| SupportsBufferProtocol
),
/,
*,
dtype: DType | None = None,
device: Device | None = None,
copy: bool | None = None,
**kwargs: Any,
) -> Array:
# torch.asarray does not respect input->output device propagation
# https://github.com/pytorch/pytorch/issues/150199
if device is None and isinstance(obj, torch.Tensor):
device = obj.device
return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
# of 'axis'.
# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745
def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
return torch.amax(x, axis, keepdims=keepdims)
def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
return torch.amin(x, axis, keepdims=keepdims)
clip = get_xp(torch)(_aliases.clip)
unstack = get_xp(torch)(_aliases.unstack)
cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
finfo = get_xp(torch)(_aliases.finfo)
iinfo = get_xp(torch)(_aliases.iinfo)
# torch.sort also returns a tuple
# https://github.com/pytorch/pytorch/issues/70921
def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array:
return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
def _normalize_axes(axis, ndim):
axes = []
if ndim == 0 and axis:
# Better error message in this case
raise IndexError(f"Dimension out of range: {axis[0]}")
lower, upper = -ndim, ndim - 1
for a in axis:
if a < lower or a > upper:
# Match torch error message (e.g., from sum())
raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
if a < 0:
a = a + ndim
if a in axes:
# Use IndexError instead of RuntimeError, and "axis" instead of "dim"
raise IndexError(f"Axis {a} appears multiple times in the list of axes")
axes.append(a)
return sorted(axes)
def _axis_none_keepdims(x, ndim, keepdims):
# Apply keepdims when axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
# Note that this is only valid for the axis=None case.
if keepdims:
for i in range(ndim):
x = torch.unsqueeze(x, 0)
return x
def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
# Some reductions don't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
axes = _normalize_axes(axis, x.ndim)
for a in reversed(axes):
x = torch.movedim(x, a, -1)
x = torch.flatten(x, -len(axes))
out = f(x, -1, **kwargs)
if keepdims:
for a in axes:
out = torch.unsqueeze(out, a)
return out
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
"""
Implements `sum(..., axis=())` and `prod(..., axis=())`.
Works around https://github.com/pytorch/pytorch/issues/29137
"""
if dtype is not None:
return x.clone() if dtype == x.dtype else x.to(dtype)
# We can't upcast uint8 according to the spec because there is no
# torch.uint64, so at least upcast to int64 which is what prod does
# when axis=None.
if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
return x.to(torch.int64)
return x.clone()
def prod(x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[DType] = None,
keepdims: bool = False,
**kwargs) -> Array:
if axis == ():
return _sum_prod_no_axis(x, dtype)
# torch.prod doesn't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.prod(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
def sum(x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[DType] = None,
keepdims: bool = False,
**kwargs) -> Array:
if axis == ():
return _sum_prod_no_axis(x, dtype)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.sum(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
def any(x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
**kwargs) -> Array:
if axis == ():
return x.to(torch.bool)
# torch.any doesn't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs)
return res.to(torch.bool)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.any(x, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res.to(torch.bool)
# torch.any doesn't return bool for uint8
return torch.any(x, axis, keepdims=keepdims).to(torch.bool)
def all(x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
**kwargs) -> Array:
if axis == ():
return x.to(torch.bool)
# torch.all doesn't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs)
return res.to(torch.bool)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.all(x, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res.to(torch.bool)
# torch.all doesn't return bool for uint8
return torch.all(x, axis, keepdims=keepdims).to(torch.bool)
def mean(x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
**kwargs) -> Array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.mean(x, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.mean(x, axis, keepdims=keepdims, **kwargs)
def std(x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0,
keepdims: bool = False,
**kwargs) -> Array:
# Note, float correction is not supported
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
raise NotImplementedError("float correction in torch std() is not yet supported")
else:
_correction = correction
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
def var(x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0,
keepdims: bool = False,
**kwargs) -> Array:
# Note, float correction is not supported
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
# if isinstance(correction, float):
# correction = int(correction)
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs)
# torch.concat doesn't support dim=None
# https://github.com/pytorch/pytorch/issues/70925
def concat(arrays: Union[Tuple[Array, ...], List[Array]],
/,
*,
axis: Optional[int] = 0,
**kwargs) -> Array:
if axis is None:
arrays = tuple(ar.flatten() for ar in arrays)
axis = 0
return torch.concat(arrays, axis, **kwargs)
# torch.squeeze only accepts int dim and doesn't require it
# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
# added at https://github.com/pytorch/pytorch/pull/89017.
def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
if isinstance(axis, int):
axis = (axis,)
for a in axis:
if x.shape[a] != 1:
raise ValueError("squeezed dimensions must be equal to 1")
axes = _normalize_axes(axis, x.ndim)
# Remove this once pytorch 1.14 is released with the above PR #89017.
sequence = [a - i for i, a in enumerate(axes)]
for a in sequence:
x = torch.squeeze(x, a)
return x
# torch.broadcast_to uses size instead of shape
def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array:
return torch.broadcast_to(x, shape, **kwargs)
# torch.permute uses dims instead of axes
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
return torch.permute(x, axes)
# The axis parameter doesn't work for flip() and roll()
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
# accept axis=None
def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
if axis is None:
axis = tuple(range(x.ndim))
# torch.flip doesn't accept dim as an int but the method does
# https://github.com/pytorch/pytorch/issues/18095
return x.flip(axis, **kwargs)
def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
return torch.roll(x, shift, axis, **kwargs)
def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return torch.nonzero(x, as_tuple=True, **kwargs)
# torch uses `dim` instead of `axis`
def diff(
x: Array,
/,
*,
axis: int = -1,
n: int = 1,
prepend: Optional[Array] = None,
append: Optional[Array] = None,
) -> Array:
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
# torch uses `dim` instead of `axis`, does not have keepdims
def count_nonzero(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
result = torch.count_nonzero(x, dim=axis)
if keepdims:
if isinstance(axis, int):
return result.unsqueeze(axis)
elif isinstance(axis, tuple):
n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
return torch.reshape(result, sh)
return _axis_none_keepdims(result, x.ndim, keepdims)
else:
return result
# "repeat" is torch.repeat_interleave; also the dim argument
def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
return torch.repeat_interleave(x, repeats, axis)
def where(
condition: Array,
x1: Array | bool | int | float | complex,
x2: Array | bool | int | float | complex,
/,
) -> Array:
x1, x2 = _fix_promotion(x1, x2)
return torch.where(condition, x1, x2)
# torch.reshape doesn't have the copy keyword
def reshape(x: Array,
/,
shape: Tuple[int, ...],
*,
copy: Optional[bool] = None,
**kwargs) -> Array:
if copy is not None:
raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
return torch.reshape(x, shape, **kwargs)
# torch.arange doesn't support returning empty arrays
# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
# keyword argument combinations
# (https://github.com/pytorch/pytorch/issues/70914)
def arange(start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
**kwargs) -> Array:
if stop is None:
start, stop = 0, start
if step > 0 and stop <= start or step < 0 and stop >= start:
if dtype is None:
if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
dtype = torch.int64
else:
dtype = torch.float32
return torch.empty(0, dtype=dtype, device=device, **kwargs)
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
# torch.eye does not accept None as a default for the second argument and
# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
def eye(n_rows: int,
n_cols: Optional[int] = None,
/,
*,
k: int = 0,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
**kwargs) -> Array:
if n_cols is None:
n_cols = n_rows
z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs)
if abs(k) <= n_rows + n_cols:
z.diagonal(k).fill_(1)
return z
# torch.linspace doesn't have the endpoint parameter
def linspace(start: Union[int, float],
stop: Union[int, float],
/,
num: int,
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
endpoint: bool = True,
**kwargs) -> Array:
if not endpoint:
return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1]
return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs)
# torch.full does not accept an int size
# https://github.com/pytorch/pytorch/issues/70906
def full(shape: Union[int, Tuple[int, ...]],
fill_value: bool | int | float | complex,
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
**kwargs) -> Array:
if isinstance(shape, int):
shape = (shape,)
return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
# ones, zeros, and empty do not accept shape as a keyword argument
def ones(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
**kwargs) -> Array:
return torch.ones(shape, dtype=dtype, device=device, **kwargs)
def zeros(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
**kwargs) -> Array:
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
def empty(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
**kwargs) -> Array:
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
# tril and triu do not call the keyword argument k
def tril(x: Array, /, *, k: int = 0) -> Array:
return torch.tril(x, k)
def triu(x: Array, /, *, k: int = 0) -> Array:
return torch.triu(x, k)
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
return torch.unsqueeze(x, axis)
def astype(
x: Array,
dtype: DType,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
) -> Array:
if device is not None:
return x.to(device, dtype=dtype, copy=copy)
return x.to(dtype=dtype, copy=copy)
def broadcast_arrays(*arrays: Array) -> List[Array]:
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
return [torch.broadcast_to(a, shape) for a in arrays]
# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
UniqueInverseResult)
# https://github.com/pytorch/pytorch/issues/70920
def unique_all(x: Array) -> UniqueAllResult:
# torch.unique doesn't support returning indices.
# https://github.com/pytorch/pytorch/issues/36748. The workaround
# suggested in that issue doesn't actually function correctly (it relies
# on non-deterministic behavior of scatter()).
raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)")
# values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True)
# # torch.unique incorrectly gives a 0 count for nan values.
# # https://github.com/pytorch/pytorch/issues/94106
# counts[torch.isnan(values)] = 1
# return UniqueAllResult(values, indices, inverse_indices, counts)
def unique_counts(x: Array) -> UniqueCountsResult:
values, counts = torch.unique(x, return_counts=True)
# torch.unique incorrectly gives a 0 count for nan values.
# https://github.com/pytorch/pytorch/issues/94106
counts[torch.isnan(values)] = 1
return UniqueCountsResult(values, counts)
def unique_inverse(x: Array) -> UniqueInverseResult:
values, inverse = torch.unique(x, return_inverse=True)
return UniqueInverseResult(values, inverse)
def unique_values(x: Array) -> Array:
return torch.unique(x)
def matmul(x1: Array, x2: Array, /, **kwargs) -> Array:
# torch.matmul doesn't type promote (but differently from _fix_promotion)
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.matmul(x1, x2, **kwargs)
matrix_transpose = get_xp(torch)(_aliases.matrix_transpose)
_vecdot = get_xp(torch)(_aliases.vecdot)
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return _vecdot(x1, x2, axis=axis)
# torch.tensordot uses dims instead of axes
def tensordot(
x1: Array,
x2: Array,
/,
*,
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
**kwargs,
) -> Array:
# Note: torch.tensordot fails with integer dtypes when there is only 1
# element in the axis (https://github.com/pytorch/pytorch/issues/84530).
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.tensordot(x1, x2, dims=axes, **kwargs)
def isdtype(
dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]],
*, _tuple=True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
Note that outside of this function, this compat library does not yet fully
support complex numbers.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple) and _tuple:
return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype == torch.bool
elif kind == 'signed integer':
return dtype in _int_dtypes and dtype.is_signed
elif kind == 'unsigned integer':
return dtype in _int_dtypes and not dtype.is_signed
elif kind == 'integral':
return dtype in _int_dtypes
elif kind == 'real floating':
return dtype.is_floating_point
elif kind == 'complex floating':
return dtype.is_complex
elif kind == 'numeric':
return isdtype(dtype, ('integral', 'real floating', 'complex floating'))
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
return dtype == kind
def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array:
if axis is None:
if x.ndim != 1:
raise ValueError("axis must be specified when ndim > 1")
axis = 0
return torch.index_select(x, axis, indices, **kwargs)
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
return torch.take_along_dim(x, indices, dim=axis)
def sign(x: Array, /) -> Array:
# torch sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
if x.dtype.is_complex:
out = x/torch.abs(x)
# sign(0) = 0 but the above formula would give nan
out[x == 0+0j] = 0+0j
return out
else:
out = torch.sign(x)
if x.dtype.is_floating_point:
out[torch.isnan(x)] = torch.nan
return out
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]:
# enforce the default of 'xy'
# TODO: is the return type a list or a tuple
return list(torch.meshgrid(*arrays, indexing='xy'))
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
'diff', 'divide',
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid']
_all_ignore = ['torch', 'get_xp']

View File

@@ -0,0 +1,369 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""
import torch
from functools import cache
class __array_namespace_info__:
"""
Get the array API inspection namespace for PyTorch.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for PyTorch.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': numpy.float64,
'complex floating': numpy.complex128,
'integral': numpy.int64,
'indexing': numpy.int64}
"""
__module__ = 'torch'
def capabilities(self):
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing. Always ``True`` for PyTorch.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
PyTorch.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True,
'max dimensions': 64}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
"max dimensions": 64,
}
def default_device(self):
"""
The default device used for new PyTorch arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : Device
The default device used for new PyTorch arrays.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_device()
device(type='cpu')
Notes
-----
This method returns the static default device when PyTorch is initialized.
However, the *current* device used by creation functions (``empty`` etc.)
can be changed at runtime.
See Also
--------
https://github.com/data-apis/array-api/issues/835
"""
return torch.device("cpu")
def default_dtypes(self, *, device=None):
"""
The default data types used for new PyTorch arrays.
Parameters
----------
device : Device, optional
The device to get the default data types for.
Unused for PyTorch, as all devices use the same default dtypes.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new PyTorch
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': torch.float32,
'complex floating': torch.complex64,
'integral': torch.int64,
'indexing': torch.int64}
"""
# Note: if the default is set to float64, the devices like MPS that
# don't support float64 will error. We still return the default_dtype
# value here because this error doesn't represent a different default
# per-device.
default_floating = torch.get_default_dtype()
default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128
default_integral = torch.int64
return {
"real floating": default_floating,
"complex floating": default_complex,
"integral": default_integral,
"indexing": default_integral,
}
def _dtypes(self, kind):
bool = torch.bool
int8 = torch.int8
int16 = torch.int16
int32 = torch.int32
int64 = torch.int64
uint8 = torch.uint8
# uint16, uint32, and uint64 are present in newer versions of pytorch,
# but they aren't generally supported by the array API functions, so
# we omit them from this function.
float32 = torch.float32
float64 = torch.float64
complex64 = torch.complex64
complex128 = torch.complex128
if kind is None:
return {
"bool": bool,
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
"uint8": uint8,
"float32": float32,
"float64": float64,
"complex64": complex64,
"complex128": complex128,
}
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
return {
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
}
if kind == "unsigned integer":
return {
"uint8": uint8,
}
if kind == "integral":
return {
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
"uint8": uint8,
}
if kind == "real floating":
return {
"float32": float32,
"float64": float64,
}
if kind == "complex floating":
return {
"complex64": complex64,
"complex128": complex128,
}
if kind == "numeric":
return {
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
"uint8": uint8,
"float32": float32,
"float64": float64,
"complex64": complex64,
"complex128": complex128,
}
if isinstance(kind, tuple):
res = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")
@cache
def dtypes(self, *, device=None, kind=None):
"""
The array API data types supported by PyTorch.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : Device, optional
The device to get the data types for.
Unused for PyTorch, as all devices use the same dtypes.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
PyTorch data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': numpy.int8,
'int16': numpy.int16,
'int32': numpy.int32,
'int64': numpy.int64}
"""
res = self._dtypes(kind)
for k, v in res.copy().items():
try:
torch.empty((0,), dtype=v, device=device)
except:
del res[k]
return res
@cache
def devices(self):
"""
The devices supported by PyTorch.
Returns
-------
devices : list[Device]
The devices supported by PyTorch.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.devices()
[device(type='cpu'), device(type='mps', index=0), device(type='meta')]
"""
# Torch doesn't have a straightforward way to get the list of all
# currently supported devices. To do this, we first parse the error
# message of torch.device to get the list of all possible types of
# device:
try:
torch.device('notadevice')
raise AssertionError("unreachable") # pragma: nocover
except RuntimeError as e:
# The error message is something like:
# "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice"
devices_names = e.args[0].split('Expected one of ')[1].split(' device type')[0].split(', ')
# Next we need to check for different indices for different devices.
# device(device_name, index=index) doesn't actually check if the
# device name or index is valid. We have to try to create a tensor
# with it (which is why this function is cached).
devices = []
for device_name in devices_names:
i = 0
while True:
try:
a = torch.empty((0,), device=torch.device(device_name, index=i))
if a.device in devices:
break
devices.append(a.device)
except:
break
i += 1
return devices

View File

@@ -0,0 +1,3 @@
__all__ = ["Array", "Device", "DType"]
from torch import device as Device, dtype as DType, Tensor as Array

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from typing import Union, Sequence, Literal
import torch
import torch.fft
from torch.fft import * # noqa: F403
from ._typing import Array
# Several torch fft functions do not map axes to dim
def fftn(
x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> Array:
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
def ifftn(
x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> Array:
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
def rfftn(
x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> Array:
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def irfftn(
x: Array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> Array:
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def fftshift(
x: Array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> Array:
return torch.fft.fftshift(x, dim=axes, **kwargs)
def ifftshift(
x: Array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> Array:
return torch.fft.ifftshift(x, dim=axes, **kwargs)
__all__ = torch.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
"irfftn",
"fftshift",
"ifftshift",
]
_all_ignore = ['torch']

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
import torch
from typing import Optional, Union, Tuple
from torch.linalg import * # noqa: F403
# torch.linalg doesn't define __all__
# from torch.linalg import __all__ as linalg_all
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
# outer is implemented in torch but aren't in the linalg namespace
from torch import outer
from ._aliases import _fix_promotion, sum
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
from ._typing import Array, DType
from ..common._typing import JustInt, JustFloat
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
# torch.cross also does not support broadcasting when it would add new
# dimensions https://github.com/pytorch/pytorch/issues/39656
def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
x1, x2 = torch.broadcast_tensors(x1, x2)
return torch_linalg.cross(x1, x2, dim=axis)
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
x1_ = torch.moveaxis(x1, axis, -1)
x2_ = torch.moveaxis(x2, axis, -1)
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
def solve(x1: Array, x2: Array, /, **kwargs) -> Array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
# TORCH_META_FUNC(_linalg_solve_ex) in
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
# See https://github.com/pytorch/pytorch/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
return torch.linalg.solve(x1, x2, **kwargs)
# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
# JustFloat stands for inf | -inf, which are not valid for Literal
ord: JustInt | JustFloat = 2,
**kwargs,
) -> Array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
if out is None:
dtype = None
if x.dtype == torch.complex64:
dtype = torch.float32
elif x.dtype == torch.complex128:
dtype = torch.float64
out = torch.zeros_like(x, dtype=dtype)
# The norm of a single scalar works out to abs(x) in every case except
# for ord=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
out[:] = torch.abs(x)
return out
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
_all_ignore = ['torch_linalg', 'sum']
del linalg_all
def __dir__() -> list[str]:
return __all__

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Consortium for Python Data API Standards
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1 @@
Update this directory using maint_tools/vendor_array_api_extra.sh

View File

@@ -0,0 +1,42 @@
"""Extra array functions built on top of the array API standard."""
from ._delegation import isclose, nan_to_num, one_hot, pad
from ._lib._at import at
from ._lib._funcs import (
apply_where,
atleast_nd,
broadcast_shapes,
cov,
create_diagonal,
default_dtype,
expand_dims,
kron,
nunique,
setdiff1d,
sinc,
)
from ._lib._lazy import lazy_apply
__version__ = "0.8.2"
# pylint: disable=duplicate-code
__all__ = [
"__version__",
"apply_where",
"at",
"atleast_nd",
"broadcast_shapes",
"cov",
"create_diagonal",
"default_dtype",
"expand_dims",
"isclose",
"kron",
"lazy_apply",
"nan_to_num",
"nunique",
"one_hot",
"pad",
"setdiff1d",
"sinc",
]

View File

@@ -0,0 +1,328 @@
"""Delegation to existing implementations for Public API Functions."""
from collections.abc import Sequence
from types import ModuleType
from typing import Literal
from ._lib import _funcs
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array, DType
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
def isclose(
a: Array | complex,
b: Array | complex,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Return a boolean array where two arrays are element-wise equal within a tolerance.
The tolerance values are positive, typically very small numbers. The relative
difference ``(rtol * abs(b))`` and the absolute difference `atol` are added together
to compare against the absolute difference between `a` and `b`.
NaNs are treated as equal if they are in the same place and if ``equal_nan=True``.
Infs are treated as equal if they are in the same place and of the same sign in both
arrays.
Parameters
----------
a, b : Array | int | float | complex | bool
Input objects to compare. At least one must be an array.
rtol : array_like, optional
The relative tolerance parameter (see Notes).
atol : array_like, optional
The absolute tolerance parameter (see Notes).
equal_nan : bool, optional
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
equal to NaN's in `b` in the output array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.
Returns
-------
Array
A boolean array of shape broadcasted from `a` and `b`, containing ``True`` where
`a` is close to `b`, and ``False`` otherwise.
Warnings
--------
The default `atol` is not appropriate for comparing numbers with magnitudes much
smaller than one (see notes).
See Also
--------
math.isclose : Similar function in stdlib for Python scalars.
Notes
-----
For finite values, `isclose` uses the following equation to test whether two
floating point values are equivalent::
absolute(a - b) <= (atol + rtol * absolute(b))
Unlike the built-in `math.isclose`,
the above equation is not symmetric in `a` and `b`,
so that ``isclose(a, b)`` might be different from ``isclose(b, a)`` in some rare
cases.
The default value of `atol` is not appropriate when the reference value `b` has
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is ``True``
with default settings. Be sure to select `atol` for the use case at hand, especially
for defining the threshold below which a non-zero value in `a` will be considered
"close" to a very small or zero value in `b`.
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
`b` need not have the same shape in order for ``isclose(a, b)`` to evaluate to
``True``.
`isclose` is not defined for non-numeric data types.
``bool`` is considered a numeric data-type for this purpose.
"""
xp = array_namespace(a, b) if xp is None else xp
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
):
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
if is_torch_namespace(xp):
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
def nan_to_num(
x: Array | float | complex,
/,
*,
fill_value: int | float = 0.0,
xp: ModuleType | None = None,
) -> Array:
"""
Replace NaN with zero and infinity with large finite numbers (default behaviour).
If `x` is inexact, NaN is replaced by zero or by the user defined value in the
`fill_value` keyword, infinity is replaced by the largest finite floating
point value representable by ``x.dtype``, and -infinity is replaced by the
most negative finite floating point value representable by ``x.dtype``.
For complex dtypes, the above is applied to each of the real and
imaginary components of `x` separately.
Parameters
----------
x : array | float | complex
Input data.
fill_value : int | float, optional
Value to be used to fill NaN values. If no value is passed
then NaN values will be replaced with 0.0.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
`x`, with the non-finite values replaced.
See Also
--------
array_api.isnan : Shows which elements are Not a Number (NaN).
Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.nan_to_num(xp.inf)
1.7976931348623157e+308
>>> xpx.nan_to_num(-xp.inf)
-1.7976931348623157e+308
>>> xpx.nan_to_num(xp.nan)
0.0
>>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
>>> xpx.nan_to_num(x)
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
>>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
>>> xpx.nan_to_num(y)
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
0.00000000e+000 +0.00000000e+000j,
0.00000000e+000 +1.79769313e+308j])
"""
if isinstance(fill_value, complex):
msg = "Complex fill values are not supported."
raise TypeError(msg)
xp = array_namespace(x) if xp is None else xp
# for scalars we want to output an array
y = xp.asarray(x)
if (
is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_numpy_namespace(xp)
or is_torch_namespace(xp)
):
return xp.nan_to_num(y, nan=fill_value)
return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)
def one_hot(
x: Array,
/,
num_classes: int,
*,
dtype: DType | None = None,
axis: int = -1,
xp: ModuleType | None = None,
) -> Array:
"""
One-hot encode the given indices.
Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
with the element at the given index set to one.
Parameters
----------
x : array
An array with integral dtype whose values are between `0` and `num_classes - 1`.
num_classes : int
Number of classes in the one-hot dimension.
dtype : DType, optional
The dtype of the return value. Defaults to the default float dtype (usually
float64).
axis : int, optional
Position in the expanded axes where the new axis is placed. Default: -1.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
An array having the same shape as `x` except for a new axis at the position
given by `axis` having size `num_classes`. If `axis` is unspecified, it
defaults to -1, which appends a new axis.
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
an exception, or may even cause a bad state. `x` is not checked.
Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
Array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]], dtype=array_api_strict.float64)
"""
# Validate inputs.
if xp is None:
xp = array_namespace(x)
if not xp.isdtype(x.dtype, "integral"):
msg = "x must have an integral dtype."
raise TypeError(msg)
if dtype is None:
dtype = _funcs.default_dtype(xp, device=get_device(x))
# Delegate where possible.
if is_jax_namespace(xp):
from jax.nn import one_hot as jax_one_hot
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
if is_torch_namespace(xp):
from torch.nn.functional import one_hot as torch_one_hot
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
try:
out = torch_one_hot(x, num_classes)
except RuntimeError as e:
raise IndexError from e
else:
out = _funcs.one_hot(x, num_classes, xp=xp)
out = xp.astype(out, dtype, copy=False)
if axis != -1:
out = xp.moveaxis(out, -1, axis)
return out
def pad(
x: Array,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
mode: Literal["constant"] = "constant",
*,
constant_values: complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.
Parameters
----------
x : array
Input array.
pad_width : int or tuple of ints or sequence of pairs of ints
Pad the input array with this many elements from each side.
If a sequence of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
xp = array_namespace(x) if xp is None else xp
if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_pydata_sparse_namespace(xp)
):
return xp.pad(x, pad_width, mode, constant_values=constant_values)
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)

View File

@@ -0,0 +1 @@
"""Internals of array-api-extra."""

Some files were not shown because too many files have changed in this diff Show More