This commit is contained in:
2026-04-10 15:06:59 +02:00
parent 3031b7153b
commit e5a4711004
7806 changed files with 1918528 additions and 335 deletions

View File

@@ -0,0 +1,34 @@
""" Numba's POWER ON SELF TEST script. Used by CI to check:
0. That Numba imports ok!
1. That Numba can find an appropriate number of its own tests to run.
2. That Numba can manage to correctly compile and execute at least one thing.
"""
from numba.tests import test_runtests
from numba import njit
def _check_runtests():
test_inst = test_runtests.TestCase()
test_inst.test_default() # will raise an exception if there is a problem
def _check_cpu_compilation():
@njit
def foo(x):
return x + 1
result = foo(1)
if result != 2:
msg = ("Unexpected result from trial compilation. "
f"Expected: 2, Got: {result}.")
raise AssertionError(msg)
def check():
_check_runtests()
_check_cpu_compilation()
if __name__ == "__main__":
check()

View File

@@ -0,0 +1,551 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2005-2010 ActiveState Software Inc.
# Copyright (c) 2013 Eddy Petrișor
"""Utilities for determining application-specific dirs.
See <http://github.com/ActiveState/appdirs> for details and usage.
"""
# Dev Notes:
# - MSDN on where to store app data files:
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
__version_info__ = (1, 4, 1)
__version__ = '.'.join(map(str, __version_info__))
import sys
import os
unicode = str
if sys.platform.startswith('java'):
import platform
os_name = platform.java_ver()[3][0]
if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc.
system = 'win32'
elif os_name.startswith('Mac'): # "Mac OS X", etc.
system = 'darwin'
else: # "Linux", "SunOS", "FreeBSD", etc.
# Setting this to "linux2" is not ideal, but only Windows or Mac
# are actually checked for and the rest of the module expects
# *sys.platform* style strings.
system = 'linux2'
else:
system = sys.platform
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user data directories are:
Mac OS X: ~/Library/Application Support/<AppName>
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
That means, by default "~/.local/share/<AppName>".
"""
if system == "win32":
if appauthor is None:
appauthor = appname
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
path = os.path.normpath(_get_win_folder(const))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == 'darwin':
path = os.path.expanduser('~/Library/Application Support/')
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of data dirs should be
returned. By default, the first item from XDG_DATA_DIRS is
returned, or '/usr/local/share/<AppName>',
if XDG_DATA_DIRS is not set
Typical user data directories are:
Mac OS X: /Library/Application Support/<AppName>
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
For Unix, this is using the $XDG_DATA_DIRS[0] default.
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == 'darwin':
path = os.path.expanduser('/Library/Application Support')
if appname:
path = os.path.join(path, appname)
else:
# XDG default for $XDG_DATA_DIRS
# only first, if multipath is False
path = os.getenv('XDG_DATA_DIRS',
os.pathsep.join(['/usr/local/share', '/usr/share']))
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
if appname and version:
path = os.path.join(path, version)
return path
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific config dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user data directories are:
Mac OS X: same as user_data_dir
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
Win *: same as user_data_dir
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
That means, by default "~/.config/<AppName>".
"""
if system in ["win32", "darwin"]:
path = user_data_dir(appname, appauthor, None, roaming)
else:
path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of config dirs should be
returned. By default, the first item from XDG_CONFIG_DIRS is
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
Typical user data directories are:
Mac OS X: same as site_data_dir
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
$XDG_CONFIG_DIRS
Win *: same as site_data_dir
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system in ["win32", "darwin"]:
path = site_data_dir(appname, appauthor)
if appname and version:
path = os.path.join(path, version)
else:
# XDG default for $XDG_CONFIG_DIRS
# only first, if multipath is False
path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg')
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific cache dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Cache" to the base app data dir for Windows. See
discussion below.
Typical user cache directories are:
Mac OS X: ~/Library/Caches/<AppName>
Unix: ~/.cache/<AppName> (XDG default)
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
On Windows the only suggestion in the MSDN docs is that local settings go in
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
app data dir (the default returned by `user_data_dir` above). Apps typically
put cache data somewhere *under* the given dir here. Some examples:
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
...\Acme\SuperApp\Cache\1.0
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
This can be disabled with the `opinion=False` option.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
if opinion:
path = os.path.join(path, "Cache")
elif system == 'darwin':
path = os.path.expanduser('~/Library/Caches')
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific log dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Logs" to the base app data dir for Windows, and "log" to the
base cache dir for Unix. See discussion below.
Typical user cache directories are:
Mac OS X: ~/Library/Logs/<AppName>
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
On Windows the only suggestion in the MSDN docs is that local settings
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
examples of what some windows apps use for a logs dir.)
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
value for Windows and appends "log" to the user cache dir for Unix.
This can be disabled with the `opinion=False` option.
"""
if system == "darwin":
path = os.path.join(
os.path.expanduser('~/Library/Logs'),
appname)
elif system == "win32":
path = user_data_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "Logs")
else:
path = user_cache_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "log")
if appname and version:
path = os.path.join(path, version)
return path
class AppDirs(object):
"""Convenience wrapper for getting application dirs."""
def __init__(self, appname, appauthor=None, version=None, roaming=False,
multipath=False):
self.appname = appname
self.appauthor = appauthor
self.version = version
self.roaming = roaming
self.multipath = multipath
@property
def user_data_dir(self):
return user_data_dir(self.appname, self.appauthor,
version=self.version, roaming=self.roaming)
@property
def site_data_dir(self):
return site_data_dir(self.appname, self.appauthor,
version=self.version, multipath=self.multipath)
@property
def user_config_dir(self):
return user_config_dir(self.appname, self.appauthor,
version=self.version, roaming=self.roaming)
@property
def site_config_dir(self):
return site_config_dir(self.appname, self.appauthor,
version=self.version, multipath=self.multipath)
@property
def user_cache_dir(self):
return user_cache_dir(self.appname, self.appauthor,
version=self.version)
@property
def user_log_dir(self):
return user_log_dir(self.appname, self.appauthor,
version=self.version)
#---- internal support stuff
def _get_win_folder_from_registry(csidl_name):
"""This is a fallback technique at best. I'm not sure if using the
registry for this guarantees us the correct answer for all CSIDL_*
names.
"""
import winreg as _winreg
shell_folder_name = {
"CSIDL_APPDATA": "AppData",
"CSIDL_COMMON_APPDATA": "Common AppData",
"CSIDL_LOCAL_APPDATA": "Local AppData",
}[csidl_name]
key = _winreg.OpenKey(
_winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
return dir
def _get_win_folder_with_pywin32(csidl_name):
from win32com.shell import shellcon, shell
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
# Try to make this a unicode path because SHGetFolderPath does
# not return unicode strings when there is unicode data in the
# path.
try:
dir = unicode(dir)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
try:
import win32api
dir = win32api.GetShortPathName(dir)
except ImportError:
pass
except UnicodeError:
pass
return dir
def _get_win_folder_with_ctypes(csidl_name):
import ctypes
csidl_const = {
"CSIDL_APPDATA": 26,
"CSIDL_COMMON_APPDATA": 35,
"CSIDL_LOCAL_APPDATA": 28,
}[csidl_name]
buf = ctypes.create_unicode_buffer(1024)
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in buf:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf2 = ctypes.create_unicode_buffer(1024)
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
buf = buf2
return buf.value
def _get_win_folder_with_jna(csidl_name):
import array
from com.sun import jna
from com.sun.jna.platform import win32
buf_size = win32.WinDef.MAX_PATH * 2
buf = array.zeros('c', buf_size)
shell = win32.Shell32.INSTANCE
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf = array.zeros('c', buf_size)
kernel = win32.Kernel32.INSTANCE
if kernel.GetShortPathName(dir, buf, buf_size):
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
return dir
if system == "win32":
try:
import win32com.shell
_get_win_folder = _get_win_folder_with_pywin32
except ImportError:
try:
from ctypes import windll
_get_win_folder = _get_win_folder_with_ctypes
except ImportError:
try:
import com.sun.jna
_get_win_folder = _get_win_folder_with_jna
except ImportError:
_get_win_folder = _get_win_folder_from_registry
#---- self test code
if __name__ == "__main__":
appname = "MyApp"
appauthor = "MyCompany"
props = ("user_data_dir", "site_data_dir",
"user_config_dir", "site_config_dir",
"user_cache_dir", "user_log_dir")
print("-- app dirs %s --" % __version__)
print("-- app dirs (with optional 'version')")
dirs = AppDirs(appname, appauthor, version="1.0")
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (without optional 'version')")
dirs = AppDirs(appname, appauthor)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (without optional 'appauthor')")
dirs = AppDirs(appname)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (with disabled 'appauthor')")
dirs = AppDirs(appname, appauthor=False)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))

View File

@@ -0,0 +1,22 @@
"""
Implementation of some CFFI functions
"""
from numba.core.imputils import Registry
from numba.core import types
from numba.np import arrayobj
registry = Registry('cffiimpl')
@registry.lower('ffi.from_buffer', types.Buffer)
def from_buffer(context, builder, sig, args):
assert len(sig.args) == 1
assert len(args) == 1
[fromty] = sig.args
[val] = args
# Type inference should have prevented passing a buffer from an
# array to a pointer of the wrong type
assert fromty.dtype == sig.return_type.dtype
ary = arrayobj.make_array(fromty)(context, builder, val)
return ary.data

View File

@@ -0,0 +1,5 @@
# this file is used with the numba.gdb* functionality
break numba_gdb_breakpoint
commands
return
end

View File

@@ -0,0 +1,201 @@
"""
Implement code coverage support.
Currently contains logic to extend ``coverage`` with lines covered by the
compiler.
"""
from typing import Optional, Sequence, Callable, no_type_check
from collections.abc import Mapping
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from numba.core import ir, config
try:
import coverage
except ImportError:
coverage_available = False
else:
coverage_available = True
@no_type_check
def get_active_coverage():
"""Get active coverage instance or return None if not found.
"""
cov = None
if coverage_available:
cov = coverage.Coverage.current()
return cov
_the_registry: Callable[[], Optional["NotifyLocBase"]] = []
def get_registered_loc_notify() -> Sequence["NotifyLocBase"]:
"""
Returns a list of the registered NotifyLocBase instances.
"""
if not config.JIT_COVERAGE:
# Coverage disabled.
return []
return list(filter(lambda x: x is not None,
(factory() for factory in _the_registry)))
class NotifyLocBase(ABC):
"""Interface for notifying visiting of a ``numba.core.ir.Loc``."""
@abstractmethod
def notify(self, loc: ir.Loc) -> None:
pass
@abstractmethod
def close(self) -> None:
pass
class NotifyCompilerCoverage(NotifyLocBase):
"""
Use to notify ``coverage`` about compiled lines.
The compiled lines are under the "numba_compiled" context in the coverage
data.
"""
def __init__(self, collector):
self._collector = collector
# see https://github.com/nedbat/coveragepy/blob/e7c05fe91ee36c0c94e144bb88d25db4fc3d02fd/coverage/collector.py#L261 # noqa E501
tracer_kwargs = collector.core.tracer_kwargs.copy()
tracer_kwargs.update(
dict(
data=collector.data,
lock_data=collector.lock_data,
unlock_data=collector.unlock_data,
trace_arcs=collector.branch,
should_trace=collector.should_trace,
should_trace_cache=collector.should_trace_cache,
warn=collector.warn,
should_start_context=collector.should_start_context,
switch_context=collector.switch_context,
packed_arcs=collector.core.packed_arcs,
)
)
self._tracer = NumbaTracer(**tracer_kwargs)
collector.tracers.append(self._tracer)
def notify(self, loc: ir.Loc):
tracer = self._tracer
if loc.filename.endswith(".py"):
tracer.switch_context("numba_compiled")
tracer.trace(loc)
tracer.switch_context(None)
def close(self):
pass
@_the_registry.append
def _register_coverage_notifier():
cov = get_active_coverage()
if cov is not None:
col = cov._collector
# Is coverage started?
if col.tracers:
return NotifyCompilerCoverage(col)
if coverage_available:
@dataclass(kw_only=True)
class NumbaTracer(coverage.types.Tracer):
"""
Not actually a tracer as in the coverage implementation, which will
setup a Python trace function. This implementation pretends to trace
but instead receives fake trace events for each line the compiler has
visited.
See coverage.PyTracer
"""
data: coverage.types.TTraceData
trace_arcs: bool
should_trace: coverage.types.TShouldTraceFn
should_trace_cache: Mapping[
str, coverage.types.TFileDisposition | None
]
should_start_context: coverage.types.TShouldStartContextFn | None
switch_context: Callable[[str | None], None] | None
lock_data: Callable[[], None]
unlock_data: Callable[[], None]
warn: coverage.types.TWarnFn
packed_arcs: bool
_activity: bool = field(default=False)
def start(self) -> coverage.types.TTraceFn | None:
"""Start this tracer, return a trace function if based on
sys.settrace."""
return None
def stop(self) -> None:
"""Stop this tracer."""
return None
def activity(self) -> bool:
"""Has there been any activity?"""
return self._activity
def reset_activity(self) -> None:
"""Reset the activity() flag."""
self._activity = False
def get_stats(self) -> dict[str, int] | None:
"""Return a dictionary of statistics, or None."""
return None
def trace(self, loc: ir.Loc) -> None:
"""Insert coverage data given source location.
"""
# Check whether the file should be traced
disp = self.should_trace_cache.get(loc.filename)
if disp is None:
disp = self.should_trace(loc.filename, None)
self.should_trace_cache[loc.filename] = disp
if not disp.trace:
# Bail if not tracing the file
return
# Insert trace data
tracename = disp.source_filename
self.lock_data()
cur_file_data = self.data.setdefault(tracename, set())
if self.trace_arcs:
if self.packed_arcs:
cur_file_data.add(_pack_arcs(loc.line, loc.line))
else:
cur_file_data.add((loc.line, loc.line))
else:
cur_file_data.add(loc.line)
self.unlock_data()
# Mark activity for this tracer
self._activity = True
def _pack_arcs(l1: int, l2: int) -> int:
"""Pack arcs into a single integer for compatibility with .packed_arcs
option.
See
https://github.com/nedbat/coveragepy/blob/e7c05fe91ee36c0c94e144bb88d25db4fc3d02fd/coverage/ctracer/tracer.c#L171
"""
packed = 0
if l1 < 0:
packed |= 1 << 40
l1 = -l1
if l2 < 0:
packed |= 1 << 41
l2 = -l2
packed |= (l2 << 20) + l1
return packed

View File

@@ -0,0 +1,84 @@
try:
from pygments.styles.default import DefaultStyle
except ImportError:
msg = "Please install pygments to see highlighted dumps"
raise ImportError(msg)
import numba.core.config
from pygments.styles.manni import ManniStyle
from pygments.styles.monokai import MonokaiStyle
from pygments.styles.native import NativeStyle
from pygments.lexer import RegexLexer, include, bygroups, words
from pygments.token import Text, Name, String, Punctuation, Keyword, \
Operator, Number
from pygments.style import Style
class NumbaIRLexer(RegexLexer):
"""
Pygments style lexer for Numba IR (for use with highlighting etc).
"""
name = 'Numba_IR'
aliases = ['numba_ir']
filenames = ['*.numba_ir']
identifier = r'\$[a-zA-Z0-9._]+'
fun_or_var = r'([a-zA-Z_]+[a-zA-Z0-9]*)'
tokens = {
'root' : [
(r'(label)(\ [0-9]+)(:)$',
bygroups(Keyword, Name.Label, Punctuation)),
(r' = ', Operator),
include('whitespace'),
include('keyword'),
(identifier, Name.Variable),
(fun_or_var + r'(\()',
bygroups(Name.Function, Punctuation)),
(fun_or_var + r'(\=)',
bygroups(Name.Attribute, Punctuation)),
(fun_or_var, Name.Constant),
(r'[0-9]+', Number),
# <built-in function some>
(r'<[^>\n]*>', String),
(r'[=<>{}\[\]()*.,!\':]|x\b', Punctuation)
],
'keyword':[
(words((
'del', 'jump', 'call', 'branch',
), suffix=' '), Keyword),
],
'whitespace': [
(r'(\n|\s)', Text),
],
}
def by_colorscheme():
"""
Get appropriate style for highlighting according to
NUMBA_COLOR_SCHEME setting
"""
styles = DefaultStyle.styles.copy()
styles.update({
Name.Variable: "#888888",
})
custom_default = type('CustomDefaultStyle', (Style, ), {'styles': styles})
style_map = {
'no_color' : custom_default,
'dark_bg' : MonokaiStyle,
'light_bg' : ManniStyle,
'blue_bg' : NativeStyle,
'jupyter_nb' : DefaultStyle,
}
return style_map[numba.core.config.COLOR_SCHEME]

View File

@@ -0,0 +1,63 @@
import sys
import os
import re
def get_lib_dirs():
"""
Anaconda specific
"""
if sys.platform == 'win32':
# on windows, historically `DLLs` has been used for CUDA libraries,
# since approximately CUDA 9.2, `Library\bin` has been used.
dirnames = ['DLLs', os.path.join('Library', 'bin')]
else:
dirnames = ['lib', ]
libdirs = [os.path.join(sys.prefix, x) for x in dirnames]
return libdirs
DLLNAMEMAP = {
'linux': r'lib%(name)s\.so\.%(ver)s$',
'linux2': r'lib%(name)s\.so\.%(ver)s$',
'linux-static': r'lib%(name)s\.a$',
'darwin': r'lib%(name)s\.%(ver)s\.dylib$',
'win32': r'%(name)s%(ver)s\.dll$',
'win32-static': r'%(name)s\.lib$',
'bsd': r'lib%(name)s\.so\.%(ver)s$',
}
RE_VER = r'[0-9]*([_\.][0-9]+)*'
def find_lib(libname, libdir=None, platform=None, static=False):
platform = platform or sys.platform
platform = 'bsd' if 'bsd' in platform else platform
if static:
platform = f"{platform}-static"
if platform not in DLLNAMEMAP:
# Return empty list if platform name is undefined.
# Not all platforms define their static library paths.
return []
pat = DLLNAMEMAP[platform] % {"name": libname, "ver": RE_VER}
regex = re.compile(pat)
return find_file(regex, libdir)
def find_file(pat, libdir=None):
if libdir is None:
libdirs = get_lib_dirs()
elif isinstance(libdir, str):
libdirs = [libdir,]
else:
libdirs = list(libdir)
files = []
for ldir in libdirs:
try:
entries = os.listdir(ldir)
except FileNotFoundError:
continue
candidates = [os.path.join(ldir, ent)
for ent in entries if pat.match(ent)]
files.extend([c for c in candidates if os.path.isfile(c)])
return files

View File

@@ -0,0 +1,104 @@
"""
This module provides helper functions to find the first line of a function
body.
"""
import ast
import inspect
import textwrap
class FindDefFirstLine(ast.NodeVisitor):
"""
Attributes
----------
first_stmt_line : int or None
This stores the first statement line number if the definition is found.
Or, ``None`` if the definition is not found.
"""
def __init__(self, name, firstlineno):
"""
Parameters
----------
code :
The function's code object.
"""
self._co_name = name
self._co_firstlineno = firstlineno
self.first_stmt_line = None
def _visit_children(self, node):
for child in ast.iter_child_nodes(node):
super().visit(child)
def visit_FunctionDef(self, node: ast.FunctionDef):
if node.name == self._co_name:
# Name of function matches.
# The `def` line may match co_firstlineno.
possible_start_lines = set([node.lineno])
if node.decorator_list:
# Has decorators.
# The first decorator line may match co_firstlineno.
first_decor = node.decorator_list[0]
possible_start_lines.add(first_decor.lineno)
# Does the first lineno match?
if self._co_firstlineno in possible_start_lines:
# Yes, we found the function.
# So, use the first statement line as the first line.
if node.body:
first_stmt = node.body[0]
if _is_docstring(first_stmt):
# Skip docstring
first_stmt = node.body[1]
self.first_stmt_line = first_stmt.lineno
return
else:
# This is probably unreachable.
# Function body cannot be bare. It must at least have
# A const string for docstring or a `pass`.
pass
self._visit_children(node)
def _is_docstring(node):
if isinstance(node, ast.Expr):
if (isinstance(node.value, ast.Constant)
and isinstance(node.value.value, str)):
return True
return False
def get_func_body_first_lineno(pyfunc):
"""
Look up the first line of function body using the file in
``pyfunc.__code__.co_filename``.
Returns
-------
lineno : int; or None
The first line number of the function body; or ``None`` if the first
line cannot be determined.
"""
co = pyfunc.__code__
try:
with open(co.co_filename) as fin:
source = fin.read()
offset = 0
except (FileNotFoundError, OSError):
try:
lines, offset = inspect.getsourcelines(pyfunc)
source = "".join(lines)
offset = offset - 1
except (OSError, TypeError):
return None
tree = ast.parse(textwrap.dedent(source))
finder = FindDefFirstLine(co.co_name, co.co_firstlineno - offset)
finder.visit(tree)
if finder.first_stmt_line:
return finder.first_stmt_line + offset
else:
# No first line found.
return None

View File

@@ -0,0 +1,228 @@
import os
import sys
from llvmlite import ir
from numba.core import types, utils, config, cgutils, errors
from numba import gdb, gdb_init, gdb_breakpoint
from numba.core.extending import overload, intrinsic
_path = os.path.dirname(__file__)
_platform = sys.platform
_unix_like = (_platform.startswith('linux') or
_platform.startswith('darwin') or
('bsd' in _platform))
def _confirm_gdb(need_ptrace_attach=True):
"""
Set need_ptrace_attach to True/False to indicate whether the ptrace attach
permission is needed for this gdb use case. Mode 0 (classic) or 1
(restricted ptrace) is required if need_ptrace_attach is True. See:
https://www.kernel.org/doc/Documentation/admin-guide/LSM/Yama.rst
for details on the modes.
"""
if not _unix_like:
msg = 'gdb support is only available on unix-like systems'
raise errors.NumbaRuntimeError(msg)
gdbloc = config.GDB_BINARY
if not (os.path.exists(gdbloc) and os.path.isfile(gdbloc)):
msg = ('Is gdb present? Location specified (%s) does not exist. The gdb'
' binary location can be set using Numba configuration, see: '
'https://numba.readthedocs.io/en/stable/reference/envvars.html' # noqa: E501
)
raise RuntimeError(msg % config.GDB_BINARY)
# Is Yama being used as a kernel security module and if so is ptrace_scope
# limited? In this case ptracing non-child processes requires special
# permission so raise an exception.
ptrace_scope_file = os.path.join(os.sep, 'proc', 'sys', 'kernel', 'yama',
'ptrace_scope')
has_ptrace_scope = os.path.exists(ptrace_scope_file)
if has_ptrace_scope:
with open(ptrace_scope_file, 'rt') as f:
value = f.readline().strip()
if need_ptrace_attach and value not in ("0", "1"):
msg = ("gdb can launch but cannot attach to the executing program"
" because ptrace permissions have been restricted at the "
"system level by the Linux security module 'Yama'.\n\n"
"Documentation for this module and the security "
"implications of making changes to its behaviour can be "
"found in the Linux Kernel documentation "
"https://www.kernel.org/doc/Documentation/admin-guide/LSM/Yama.rst" # noqa: E501
"\n\nDocumentation on how to adjust the behaviour of Yama "
"on Ubuntu Linux with regards to 'ptrace_scope' can be "
"found here "
"https://wiki.ubuntu.com/Security/Features#ptrace.")
raise RuntimeError(msg)
@overload(gdb)
def hook_gdb(*args):
_confirm_gdb()
gdbimpl = gen_gdb_impl(args, True)
def impl(*args):
gdbimpl()
return impl
@overload(gdb_init)
def hook_gdb_init(*args):
_confirm_gdb()
gdbimpl = gen_gdb_impl(args, False)
def impl(*args):
gdbimpl()
return impl
def init_gdb_codegen(cgctx, builder, signature, args,
const_args, do_break=False):
int8_t = ir.IntType(8)
int32_t = ir.IntType(32)
intp_t = ir.IntType(utils.MACHINE_BITS)
char_ptr = ir.PointerType(ir.IntType(8))
zero_i32t = int32_t(0)
mod = builder.module
pid = cgutils.alloca_once(builder, int32_t, size=1)
# 32bit pid, 11 char max + terminator
pidstr = cgutils.alloca_once(builder, int8_t, size=12)
# str consts
intfmt = cgctx.insert_const_string(mod, '%d')
gdb_str = cgctx.insert_const_string(mod, config.GDB_BINARY)
attach_str = cgctx.insert_const_string(mod, 'attach')
new_args = []
# add break point command to known location
# this command file thing is due to commands attached to a breakpoint
# requiring an interactive prompt
# https://sourceware.org/bugzilla/show_bug.cgi?id=10079
new_args.extend(['-x', os.path.join(_path, 'cmdlang.gdb')])
# issue command to continue execution from sleep function
new_args.extend(['-ex', 'c'])
# then run the user defined args if any
if any([not isinstance(x, types.StringLiteral) for x in const_args]):
raise errors.RequireLiteralValue(const_args)
new_args.extend([x.literal_value for x in const_args])
cmdlang = [cgctx.insert_const_string(mod, x) for x in new_args]
# insert getpid, getpid is always successful, call without concern!
fnty = ir.FunctionType(int32_t, tuple())
getpid = cgutils.get_or_insert_function(mod, fnty, "getpid")
# insert snprintf
# int snprintf(char *str, size_t size, const char *format, ...);
fnty = ir.FunctionType(
int32_t, (char_ptr, intp_t, char_ptr), var_arg=True)
snprintf = cgutils.get_or_insert_function(mod, fnty, "snprintf")
# insert fork
fnty = ir.FunctionType(int32_t, tuple())
fork = cgutils.get_or_insert_function(mod, fnty, "fork")
# insert execl
fnty = ir.FunctionType(int32_t, (char_ptr, char_ptr), var_arg=True)
execl = cgutils.get_or_insert_function(mod, fnty, "execl")
# insert sleep
fnty = ir.FunctionType(int32_t, (int32_t,))
sleep = cgutils.get_or_insert_function(mod, fnty, "sleep")
# insert break point
fnty = ir.FunctionType(ir.VoidType(), tuple())
breakpoint = cgutils.get_or_insert_function(mod, fnty,
"numba_gdb_breakpoint")
# do the work
parent_pid = builder.call(getpid, tuple())
builder.store(parent_pid, pid)
pidstr_ptr = builder.gep(pidstr, [zero_i32t], inbounds=True)
pid_val = builder.load(pid)
# call snprintf to write the pid into a char *
stat = builder.call(
snprintf, (pidstr_ptr, intp_t(12), intfmt, pid_val))
invalid_write = builder.icmp_signed('>', stat, int32_t(12))
with builder.if_then(invalid_write, likely=False):
msg = "Internal error: `snprintf` buffer would have overflowed."
cgctx.call_conv.return_user_exc(builder, RuntimeError, (msg,))
# fork, check pids etc
child_pid = builder.call(fork, tuple())
fork_failed = builder.icmp_signed('==', child_pid, int32_t(-1))
with builder.if_then(fork_failed, likely=False):
msg = "Internal error: `fork` failed."
cgctx.call_conv.return_user_exc(builder, RuntimeError, (msg,))
is_child = builder.icmp_signed('==', child_pid, zero_i32t)
with builder.if_else(is_child) as (then, orelse):
with then:
# is child
nullptr = ir.Constant(char_ptr, None)
gdb_str_ptr = builder.gep(
gdb_str, [zero_i32t], inbounds=True)
attach_str_ptr = builder.gep(
attach_str, [zero_i32t], inbounds=True)
cgutils.printf(
builder, "Attaching to PID: %s\n", pidstr)
buf = (
gdb_str_ptr,
gdb_str_ptr,
attach_str_ptr,
pidstr_ptr)
buf = buf + tuple(cmdlang) + (nullptr,)
builder.call(execl, buf)
with orelse:
# is parent
builder.call(sleep, (int32_t(10),))
# if breaking is desired, break now
if do_break is True:
builder.call(breakpoint, tuple())
def gen_gdb_impl(const_args, do_break):
@intrinsic
def gdb_internal(tyctx):
function_sig = types.void()
def codegen(cgctx, builder, signature, args):
init_gdb_codegen(cgctx, builder, signature, args, const_args,
do_break=do_break)
return cgctx.get_constant(types.none, None)
return function_sig, codegen
return gdb_internal
@overload(gdb_breakpoint)
def hook_gdb_breakpoint():
"""
Adds the Numba break point into the source
"""
if not sys.platform.startswith('linux'):
raise RuntimeError('gdb is only available on linux')
bp_impl = gen_bp_impl()
def impl():
bp_impl()
return impl
def gen_bp_impl():
@intrinsic
def bp_internal(tyctx):
function_sig = types.void()
def codegen(cgctx, builder, signature, args):
mod = builder.module
fnty = ir.FunctionType(ir.VoidType(), tuple())
breakpoint = cgutils.get_or_insert_function(mod, fnty,
"numba_gdb_breakpoint")
builder.call(breakpoint, tuple())
return cgctx.get_constant(types.none, None)
return function_sig, codegen
return bp_internal

View File

@@ -0,0 +1,204 @@
"""gdb printing extension for Numba types.
"""
import re
try:
import gdb.printing
import gdb
except ImportError:
raise ImportError("GDB python support is not available.")
class NumbaArrayPrinter:
def __init__(self, val):
self.val = val
def to_string(self):
try:
import numpy as np
HAVE_NUMPY = True
except ImportError:
HAVE_NUMPY = False
try:
NULL = 0x0
# Raw data references, these need unpacking/interpreting.
# Member "data" is...
# DW_TAG_member of DIDerivedType, tag of DW_TAG_pointer_type
# encoding e.g. DW_ATE_float
data = self.val["data"]
# Member "itemsize" is...
# DW_TAG_member of DIBasicType encoding DW_ATE_signed
itemsize = self.val["itemsize"]
# Members "shape" and "strides" are...
# DW_TAG_member of DIDerivedType, the type is a DICompositeType
# (it's a Numba UniTuple) with tag: DW_TAG_array_type, i.e. it's an
# array repr, it has a basetype of e.g. DW_ATE_unsigned and also
# "elements" which are referenced with a DISubrange(count: <const>)
# to say how many elements are in the array.
rshp = self.val["shape"]
rstrides = self.val["strides"]
# bool on whether the data is aligned.
is_aligned = False
# type information decode, simple type:
ty_str = str(self.val.type)
if HAVE_NUMPY and ('aligned' in ty_str or 'Record' in ty_str):
ty_str = ty_str.replace('unaligned ','').strip()
matcher = re.compile(r"array\((Record.*), (.*), (.*)\)\ \(.*")
# NOTE: need to deal with "Alignment" else dtype size is wrong
arr_info = [x.strip() for x in matcher.match(ty_str).groups()]
dtype_str, ndim_str, order_str = arr_info
rstr = r'Record\((.*\[.*\]);([0-9]+);(True|False)'
rstr_match = re.match(rstr, dtype_str)
# balign is unused, it's the alignment
fields, balign, is_aligned_str = rstr_match.groups()
is_aligned = is_aligned_str == 'True'
field_dts = fields.split(',')
struct_entries = []
for f in field_dts:
splitted = f.split('[')
name = splitted[0]
dt_part = splitted[1:]
if len(dt_part) > 1:
raise TypeError('Unsupported sub-type: %s' % f)
else:
dt_part = dt_part[0]
if "nestedarray" in dt_part:
raise TypeError('Unsupported sub-type: %s' % f)
dt_as_str = dt_part.split(';')[0].split('=')[1]
dtype = np.dtype(dt_as_str)
struct_entries.append((name, dtype))
# The dtype is actually a record of some sort
dtype_str = struct_entries
else: # simple type
matcher = re.compile(r"array\((.*),(.*),(.*)\)\ \(.*")
arr_info = [x.strip() for x in matcher.match(ty_str).groups()]
dtype_str, ndim_str, order_str = arr_info
# fix up unichr dtype
if 'unichr x ' in dtype_str:
dtype_str = dtype_str[1:-1].replace('unichr x ', '<U')
def dwarr2inttuple(dwarr):
# Converts a gdb handle to a dwarf array to a tuple of ints
fields = dwarr.type.fields()
lo, hi = fields[0].type.range()
return tuple([int(dwarr[x]) for x in range(lo, hi + 1)])
# shape/strides extraction
shape = dwarr2inttuple(rshp)
strides = dwarr2inttuple(rstrides)
# if data is not NULL
if data != NULL:
if HAVE_NUMPY:
# The data extent in bytes is:
# sum(shape * strides)
# get the data, then wire to as_strided
shp_arr = np.array([max(0, x - 1) for x in shape])
strd_arr = np.array(strides)
extent = np.sum(shp_arr * strd_arr)
extent += int(itemsize)
dtype_clazz = np.dtype(dtype_str, align=is_aligned)
dtype = dtype_clazz
this_proc = gdb.selected_inferior()
mem = this_proc.read_memory(int(data), extent)
arr_data = np.frombuffer(mem, dtype=dtype)
new_arr = np.lib.stride_tricks.as_strided(arr_data,
shape=shape,
strides=strides,)
return '\n' + str(new_arr)
# Catch all for no NumPy
return "array([...], dtype=%s, shape=%s)" % (dtype_str, shape)
else:
# Not yet initialized or NULLed out data
buf = list(["NULL/Uninitialized"])
return "array([" + ', '.join(buf) + "]" + ")"
except Exception as e:
return 'array[Exception: Failed to parse. %s]' % e
class NumbaComplexPrinter:
def __init__(self, val):
self.val = val
def to_string(self):
return "%s+%sj" % (self.val['real'], self.val['imag'])
class NumbaTuplePrinter:
def __init__(self, val):
self.val = val
def to_string(self):
buf = []
fields = self.val.type.fields()
for f in fields:
buf.append(str(self.val[f.name]))
return "(%s)" % ', '.join(buf)
class NumbaUniTuplePrinter:
def __init__(self, val):
self.val = val
def to_string(self):
# unituples are arrays
fields = self.val.type.fields()
lo, hi = fields[0].type.range()
buf = []
for i in range(lo, hi + 1):
buf.append(str(self.val[i]))
return "(%s)" % ', '.join(buf)
class NumbaUnicodeTypePrinter:
def __init__(self, val):
self.val = val
def to_string(self):
NULL = 0x0
data = self.val["data"]
nitems = self.val["length"]
kind = self.val["kind"]
if data != NULL:
# This needs sorting out, encoding is wrong
this_proc = gdb.selected_inferior()
mem = this_proc.read_memory(int(data), nitems * kind)
if isinstance(mem, memoryview):
buf = bytes(mem).decode()
else:
buf = mem.decode('utf-8')
else:
buf = str(data)
return "'%s'" % buf
def _create_printers():
printer = gdb.printing.RegexpCollectionPrettyPrinter("Numba")
printer.add_printer('Numba unaligned array printer', '^unaligned array\\(',
NumbaArrayPrinter)
printer.add_printer('Numba array printer', '^array\\(', NumbaArrayPrinter)
printer.add_printer('Numba complex printer', '^complex[0-9]+\\ ',
NumbaComplexPrinter)
printer.add_printer('Numba Tuple printer', '^Tuple\\(',
NumbaTuplePrinter)
printer.add_printer('Numba UniTuple printer', '^UniTuple\\(',
NumbaUniTuplePrinter)
printer.add_printer('Numba unicode_type printer', '^unicode_type\\s+\\(',
NumbaUnicodeTypePrinter)
return printer
# register the Numba pretty printers for the current object
gdb.printing.register_pretty_printer(gdb.current_objfile(), _create_printers())

View File

@@ -0,0 +1,433 @@
"""
This file contains `__main__` so that it can be run as a commandline tool.
This file contains functions to inspect Numba's support for a given Python
module or a Python package.
"""
import argparse
import pkgutil
import warnings
import types as pytypes
from numba.core import errors
from numba._version import get_versions
from numba.core.registry import cpu_target
from numba.tests.support import captured_stdout
def _get_commit():
full = get_versions()['full-revisionid']
if not full:
warnings.warn(
"Cannot find git commit hash. Source links could be inaccurate.",
category=errors.NumbaWarning,
)
return 'main'
return full
commit = _get_commit()
github_url = 'https://github.com/numba/numba/blob/{commit}/{path}#L{firstline}-L{lastline}' # noqa: E501
def inspect_function(function, target=None):
"""Return information about the support of a function.
Returns
-------
info : dict
Defined keys:
- "numba_type": str or None
The numba type object of the function if supported.
- "explained": str
A textual description of the support.
- "source_infos": dict
A dictionary containing the source location of each definition.
"""
target = target or cpu_target
tyct = target.typing_context
# Make sure we have loaded all extensions
tyct.refresh()
target.target_context.refresh()
info = {}
# Try getting the function type
source_infos = {}
try:
nbty = tyct.resolve_value_type(function)
except ValueError:
nbty = None
explained = 'not supported'
else:
# Make a longer explanation of the type
explained = tyct.explain_function_type(nbty)
for temp in nbty.templates:
try:
source_infos[temp] = temp.get_source_info()
except AttributeError:
source_infos[temp] = None
info['numba_type'] = nbty
info['explained'] = explained
info['source_infos'] = source_infos
return info
def inspect_module(module, target=None, alias=None):
"""Inspect a module object and yielding results from `inspect_function()`
for each function object in the module.
"""
alias = {} if alias is None else alias
# Walk the module
for name in dir(module):
if name.startswith('_'):
# Skip
continue
obj = getattr(module, name)
supported_types = (pytypes.FunctionType, pytypes.BuiltinFunctionType)
if not isinstance(obj, supported_types):
# Skip if it's not a function
continue
info = dict(module=module, name=name, obj=obj)
if obj in alias:
info['alias'] = alias[obj]
else:
alias[obj] = "{module}.{name}".format(module=module.__name__,
name=name)
info.update(inspect_function(obj, target=target))
yield info
class _Stat(object):
"""For gathering simple statistic of (un)supported functions"""
def __init__(self):
self.supported = 0
self.unsupported = 0
@property
def total(self):
total = self.supported + self.unsupported
return total
@property
def ratio(self):
ratio = self.supported / self.total * 100
return ratio
def describe(self):
if self.total == 0:
return "empty"
return "supported = {supported} / {total} = {ratio:.2f}%".format(
supported=self.supported,
total=self.total,
ratio=self.ratio,
)
def __repr__(self):
return "{clsname}({describe})".format(
clsname=self.__class__.__name__,
describe=self.describe(),
)
def filter_private_module(module_components):
return not any(x.startswith('_') for x in module_components)
def filter_tests_module(module_components):
return not any(x == 'tests' for x in module_components)
_default_module_filters = (
filter_private_module,
filter_tests_module,
)
def list_modules_in_package(package, module_filters=_default_module_filters):
"""Yield all modules in a given package.
Recursively walks the package tree.
"""
onerror_ignore = lambda _: None
prefix = package.__name__ + "."
package_walker = pkgutil.walk_packages(
package.__path__,
prefix,
onerror=onerror_ignore,
)
def check_filter(modname):
module_components = modname.split('.')
return any(not filter_fn(module_components)
for filter_fn in module_filters)
modname = package.__name__
if not check_filter(modname):
yield package
for pkginfo in package_walker:
modname = pkginfo[1]
if check_filter(modname):
continue
# In case importing of the module print to stdout
with captured_stdout():
try:
# Import the module
mod = __import__(modname)
except Exception:
continue
# Extract the module
for part in modname.split('.')[1:]:
try:
mod = getattr(mod, part)
except AttributeError:
# Suppress error in getting the attribute
mod = None
break
# Ignore if mod is not a module
if not isinstance(mod, pytypes.ModuleType):
# Skip non-module
continue
yield mod
class Formatter(object):
"""Base class for formatters.
"""
def __init__(self, fileobj):
self._fileobj = fileobj
def print(self, *args, **kwargs):
kwargs.setdefault('file', self._fileobj)
print(*args, **kwargs)
class HTMLFormatter(Formatter):
"""Formatter that outputs HTML
"""
def escape(self, text):
import html
return html.escape(text)
def title(self, text):
self.print('<h1>', text, '</h2>')
def begin_module_section(self, modname):
self.print('<h2>', modname, '</h2>')
self.print('<ul>')
def end_module_section(self):
self.print('</ul>')
def write_supported_item(self, modname, itemname, typename, explained,
sources, alias):
self.print('<li>')
self.print('{}.<b>{}</b>'.format(
modname,
itemname,
))
self.print(': <b>{}</b>'.format(typename))
self.print('<div><pre>', explained, '</pre></div>')
self.print("<ul>")
for tcls, source in sources.items():
if source:
self.print("<li>")
impl = source['name']
sig = source['sig']
filename = source['filename']
lines = source['lines']
self.print(
"<p>defined by <b>{}</b>{} at {}:{}-{}</p>".format(
self.escape(impl), self.escape(sig),
self.escape(filename), lines[0], lines[1],
),
)
self.print('<p>{}</p>'.format(
self.escape(source['docstring'] or '')
))
else:
self.print("<li>{}".format(self.escape(str(tcls))))
self.print("</li>")
self.print("</ul>")
self.print('</li>')
def write_unsupported_item(self, modname, itemname):
self.print('<li>')
self.print('{}.<b>{}</b>: UNSUPPORTED'.format(
modname,
itemname,
))
self.print('</li>')
def write_statistic(self, stats):
self.print('<p>{}</p>'.format(stats.describe()))
class ReSTFormatter(Formatter):
"""Formatter that output ReSTructured text format for Sphinx docs.
"""
def escape(self, text):
return text
def title(self, text):
self.print(text)
self.print('=' * len(text))
self.print()
def begin_module_section(self, modname):
self.print(modname)
self.print('-' * len(modname))
self.print()
def end_module_section(self):
self.print()
def write_supported_item(self, modname, itemname, typename, explained,
sources, alias):
self.print('.. function:: {}.{}'.format(modname, itemname))
self.print(' :noindex:')
self.print()
if alias:
self.print(" Alias to: ``{}``".format(alias))
self.print()
for tcls, source in sources.items():
if source:
impl = source['name']
sig = source['sig']
filename = source['filename']
lines = source['lines']
source_link = github_url.format(
commit=commit,
path=filename,
firstline=lines[0],
lastline=lines[1],
)
self.print(
" - defined by ``{}{}`` at `{}:{}-{} <{}>`_".format(
impl, sig, filename, lines[0], lines[1], source_link,
),
)
else:
self.print(" - defined by ``{}``".format(str(tcls)))
self.print()
def write_unsupported_item(self, modname, itemname):
pass
def write_statistic(self, stat):
if stat.supported == 0:
self.print("This module is not supported.")
else:
msg = "Not showing {} unsupported functions."
self.print(msg.format(stat.unsupported))
self.print()
self.print(stat.describe())
self.print()
def _format_module_infos(formatter, package_name, mod_sequence, target=None):
"""Format modules.
"""
formatter.title('Listings for {}'.format(package_name))
alias_map = {} # remember object seen to track alias
for mod in mod_sequence:
stat = _Stat()
modname = mod.__name__
formatter.begin_module_section(formatter.escape(modname))
for info in inspect_module(mod, target=target, alias=alias_map):
nbtype = info['numba_type']
if nbtype is not None:
stat.supported += 1
formatter.write_supported_item(
modname=formatter.escape(info['module'].__name__),
itemname=formatter.escape(info['name']),
typename=formatter.escape(str(nbtype)),
explained=formatter.escape(info['explained']),
sources=info['source_infos'],
alias=info.get('alias'),
)
else:
stat.unsupported += 1
formatter.write_unsupported_item(
modname=formatter.escape(info['module'].__name__),
itemname=formatter.escape(info['name']),
)
formatter.write_statistic(stat)
formatter.end_module_section()
def write_listings(package_name, filename, output_format):
"""Write listing information into a file.
Parameters
----------
package_name : str
Name of the package to inspect.
filename : str
Output filename. Always overwrite.
output_format : str
Support formats are "html" and "rst".
"""
package = __import__(package_name)
if hasattr(package, '__path__'):
mods = list_modules_in_package(package)
else:
mods = [package]
if output_format == 'html':
with open(filename + '.html', 'w') as fout:
fmtr = HTMLFormatter(fileobj=fout)
_format_module_infos(fmtr, package_name, mods)
elif output_format == 'rst':
with open(filename + '.rst', 'w') as fout:
fmtr = ReSTFormatter(fileobj=fout)
_format_module_infos(fmtr, package_name, mods)
else:
raise ValueError(
"Output format '{}' is not supported".format(output_format))
program_description = """
Inspect Numba support for a given top-level package.
""".strip()
def main():
parser = argparse.ArgumentParser(description=program_description)
parser.add_argument(
'package', metavar='package', type=str,
help='Package to inspect',
)
parser.add_argument(
'--format', dest='format', default='html',
help='Output format; i.e. "html", "rst"',
)
parser.add_argument(
'--file', dest='file', default='inspector_output',
help='Output filename. Defaults to "inspector_output.<format>"',
)
args = parser.parse_args()
package_name = args.package
output_format = args.format
filename = args.file
write_listings(package_name, filename, output_format)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,44 @@
"""Collection of miscellaneous initialization utilities."""
from collections import namedtuple
version_info = namedtuple('version_info',
('major minor patch short full '
'string tuple git_revision'))
def generate_version_info(version):
"""Process a version string into a structured version_info object.
Parameters
----------
version: str
a string describing the current version
Returns
-------
version_info: tuple
structured version information
See also
--------
Look at the definition of 'version_info' in this module for details.
"""
parts = version.split('.')
def try_int(x):
try:
return int(x)
except ValueError:
return None
major = try_int(parts[0]) if len(parts) >= 1 else None
minor = try_int(parts[1]) if len(parts) >= 2 else None
patch = try_int(parts[2]) if len(parts) >= 3 else None
short = (major, minor)
full = (major, minor, patch)
string = version
tup = tuple(parts)
git_revision = tup[3] if len(tup) >= 4 else None
return version_info(major, minor, patch, short, full, string, tup,
git_revision)

View File

@@ -0,0 +1,103 @@
"""Miscellaneous inspection tools
"""
from tempfile import NamedTemporaryFile, TemporaryDirectory
import os
import warnings
from numba.core.errors import NumbaWarning
def disassemble_elf_to_cfg(elf, mangled_symbol):
"""
Gets the CFG of the disassembly of an ELF object, elf, at mangled name,
mangled_symbol, and renders it appropriately depending on the execution
environment (terminal/notebook).
"""
try:
import r2pipe
except ImportError:
raise RuntimeError("r2pipe package needed for disasm CFG")
def get_rendering(cmd=None):
from numba.pycc.platform import Toolchain # import local, circular ref
if cmd is None:
raise ValueError("No command given")
with TemporaryDirectory() as tmpdir:
# Write ELF as a temporary file in the temporary dir, do not delete!
with NamedTemporaryFile(delete=False, dir=tmpdir) as f:
f.write(elf)
f.flush() # force write, radare2 needs a binary blob on disk
# Now try and link the ELF, this helps radare2 _a lot_
linked = False
try:
raw_dso_name = f'{os.path.basename(f.name)}.so'
linked_dso = os.path.join(tmpdir, raw_dso_name)
tc = Toolchain()
tc.link_shared(linked_dso, (f.name,))
obj_to_analyse = linked_dso
linked = True
except Exception as e:
# link failed, mention it to user, radare2 will still be able to
# analyse the object, but things like dwarf won't appear in the
# asm as comments.
msg = ('Linking the ELF object with the distutils toolchain '
f'failed with: {e}. Disassembly will still work but '
'might be less accurate and will not use DWARF '
'information.')
warnings.warn(NumbaWarning(msg))
obj_to_analyse = f.name
# catch if r2pipe can actually talk to radare2
try:
flags = ['-2', # close stderr to hide warnings
'-e io.cache=true', # fix relocations in disassembly
'-e scr.color=1', # 16 bit ANSI colour terminal
'-e asm.dwarf=true', # DWARF decode
'-e scr.utf8=true', # UTF8 output looks better
]
r = r2pipe.open(obj_to_analyse, flags=flags)
r.cmd('aaaaaa') # analyse as much as possible
# If the elf is linked then it's necessary to seek as the
# DSO ctor/dtor is at the default position
if linked:
# r2 only matches up to 61 chars?! found this by experiment!
mangled_symbol_61char = mangled_symbol[:61]
# switch off demangle, the seek is on a mangled symbol
r.cmd('e bin.demangle=false')
# seek to the mangled symbol address
r.cmd(f's `is~ {mangled_symbol_61char}[1]`')
# switch demangling back on for output purposes
r.cmd('e bin.demangle=true')
data = r.cmd('%s' % cmd) # print graph
r.quit()
except Exception as e:
if "radare2 in PATH" in str(e):
msg = ("This feature requires 'radare2' to be "
"installed and available on the system see: "
"https://github.com/radareorg/radare2. "
"Cannot find 'radare2' in $PATH.")
raise RuntimeError(msg)
else:
raise e
return data
class DisasmCFG(object):
def _repr_svg_(self):
try:
import graphviz
except ImportError:
raise RuntimeError("graphviz package needed for disasm CFG")
jupyter_rendering = get_rendering(cmd='agfd')
# this just makes it read slightly better in jupyter notebooks
jupyter_rendering.replace('fontname="Courier",',
'fontname="Courier",fontsize=6,')
src = graphviz.Source(jupyter_rendering)
return src.pipe('svg').decode('UTF-8')
def __repr__(self):
return get_rendering(cmd='agf')
return DisasmCFG()

View File

@@ -0,0 +1,24 @@
from numba.core.extending import overload
from numba.core import types
from numba.misc.special import literally, literal_unroll
from numba.core.errors import TypingError
@overload(literally)
def _ov_literally(obj):
if isinstance(obj, (types.Literal, types.InitialValue)):
return lambda obj: obj
else:
m = "Invalid use of non-Literal type in literally({})".format(obj)
raise TypingError(m)
@overload(literal_unroll)
def literal_unroll_impl(container):
if isinstance(container, types.Poison):
m = f"Invalid use of non-Literal type in literal_unroll({container})"
raise TypingError(m)
def impl(container):
return container
return impl

View File

@@ -0,0 +1,471 @@
import re
import operator
import heapq
from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager
from functools import cached_property
from numba.core import config
import llvmlite.binding as llvm
class RecordLLVMPassTimingsLegacy:
"""A helper context manager to track LLVM pass timings.
"""
__slots__ = ["_data"]
def __enter__(self):
"""Enables the pass timing in LLVM.
"""
llvm.set_time_passes(True)
return self
def __exit__(self, exc_val, exc_type, exc_tb):
"""Reset timings and save report internally.
"""
self._data = llvm.report_and_reset_timings()
llvm.set_time_passes(False)
return
def get(self):
"""Retrieve timing data for processing.
Returns
-------
timings: ProcessedPassTimings
"""
return ProcessedPassTimings(self._data)
class RecordLLVMPassTimings:
"""A helper context manager to track LLVM pass timings.
"""
__slots__ = ["_data", "_pb"]
def __init__(self, pb):
self._pb = pb
self._data = None
def __enter__(self):
"""Enables the pass timing in LLVM.
"""
self._pb.start_pass_timing()
return self
def __exit__(self, exc_val, exc_type, exc_tb):
"""Reset timings and save report internally.
"""
self._data = self._pb.finish_pass_timing()
return
def get(self):
"""Retrieve timing data for processing.
Returns
-------
timings: ProcessedPassTimings
"""
return ProcessedPassTimings(self._data)
PassTimingRecord = namedtuple(
"PassTimingRecord",
[
"user_time",
"user_percent",
"system_time",
"system_percent",
"user_system_time",
"user_system_percent",
"wall_time",
"wall_percent",
"pass_name",
"instruction",
],
)
def _adjust_timings(records):
"""Adjust timing records because of truncated information.
Details: The percent information can be used to improve the timing
information.
Returns
-------
res: List[PassTimingRecord]
"""
total_rec = records[-1]
assert total_rec.pass_name == "Total" # guard for implementation error
def make_adjuster(attr):
time_attr = f"{attr}_time"
percent_attr = f"{attr}_percent"
time_getter = operator.attrgetter(time_attr)
def adjust(d):
"""Compute percent x total_time = adjusted"""
total = time_getter(total_rec)
adjusted = total * d[percent_attr] * 0.01
d[time_attr] = adjusted
return d
return adjust
# Make adjustment functions for each field
adj_fns = [
make_adjuster(x) for x in ["user", "system", "user_system", "wall"]
]
# Extract dictionaries from the namedtuples
dicts = map(lambda x: x._asdict(), records)
def chained(d):
# Chain the adjustment functions
for fn in adj_fns:
d = fn(d)
# Reconstruct the namedtuple
return PassTimingRecord(**d)
return list(map(chained, dicts))
class ProcessedPassTimings:
"""A class for processing raw timing report from LLVM.
The processing is done lazily so we don't waste time processing unused
timing information.
"""
def __init__(self, raw_data):
self._raw_data = raw_data
def __bool__(self):
return bool(self._raw_data)
def get_raw_data(self):
"""Returns the raw string data.
Returns
-------
res: str
"""
return self._raw_data
def get_total_time(self):
"""Compute the total time spend in all passes.
Returns
-------
res: float
"""
return self.list_records()[-1].wall_time
def list_records(self):
"""Get the processed data for the timing report.
Returns
-------
res: List[PassTimingRecord]
"""
return self._processed
def list_top(self, n):
"""Returns the top(n) most time-consuming (by wall-time) passes.
Parameters
----------
n: int
This limits the maximum number of items to show.
This function will show the ``n`` most time-consuming passes.
Returns
-------
res: List[PassTimingRecord]
Returns the top(n) most time-consuming passes in descending order.
"""
records = self.list_records()
key = operator.attrgetter("wall_time")
return heapq.nlargest(n, records[:-1], key)
def summary(self, topn=5, indent=0):
"""Return a string summarizing the timing information.
Parameters
----------
topn: int; optional
This limits the maximum number of items to show.
This function will show the ``topn`` most time-consuming passes.
indent: int; optional
Set the indentation level. Defaults to 0 for no indentation.
Returns
-------
res: str
"""
buf = []
prefix = " " * indent
def ap(arg):
buf.append(f"{prefix}{arg}")
ap(f"Total {self.get_total_time():.4f}s")
ap("Top timings:")
for p in self.list_top(topn):
ap(f" {p.wall_time:.4f}s ({p.wall_percent:5}%) {p.pass_name}")
return "\n".join(buf)
@cached_property
def _processed(self):
"""A cached property for lazily processing the data and returning it.
See ``_process()`` for details.
"""
return self._process()
def _process(self):
"""Parses the raw string data from LLVM timing report and attempts
to improve the data by recomputing the times
(See `_adjust_timings()``).
"""
def parse(raw_data):
"""A generator that parses the raw_data line-by-line to extract
timing information for each pass.
"""
lines = raw_data.splitlines()
colheader = r"[a-zA-Z+ ]+"
# Take at least one column header.
multicolheaders = fr"(?:\s*-+{colheader}-+)+"
line_iter = iter(lines)
# find column headers
header_map = {
"User Time": "user",
"System Time": "system",
"User+System": "user_system",
"Wall Time": "wall",
"Instr": "instruction",
"Name": "pass_name",
}
for ln in line_iter:
m = re.match(multicolheaders, ln)
if m:
# Get all the column headers
raw_headers = re.findall(r"[a-zA-Z][a-zA-Z+ ]+", ln)
headers = [header_map[k.strip()] for k in raw_headers]
break
assert headers[-1] == 'pass_name'
# compute the list of available attributes from the column headers
attrs = []
n = r"\s*((?:[0-9]+\.)?[0-9]+)"
pat = ""
for k in headers[:-1]:
if k == "instruction":
pat += n
else:
attrs.append(f"{k}_time")
attrs.append(f"{k}_percent")
pat += rf"\s+(?:{n}\s*\({n}%\)|-+)"
# put default value 0.0 to all missing attributes
missing = {}
for k in PassTimingRecord._fields:
if k not in attrs and k != 'pass_name':
missing[k] = 0.0
# parse timings
pat += r"\s*(.*)"
for ln in line_iter:
m = re.match(pat, ln)
if m is not None:
raw_data = list(m.groups())
data = {k: float(v) if v is not None else 0.0
for k, v in zip(attrs, raw_data)}
data.update(missing)
pass_name = raw_data[-1]
rec = PassTimingRecord(
pass_name=pass_name, **data,
)
yield rec
if rec.pass_name == "Total":
# "Total" means the report has ended
break
# Check that we have reach the end of the report
remaining = '\n'.join(line_iter)
# FIXME: Need to handle parsing of Analysis execution timing report
if "Analysis execution timing report" in remaining:
return
if remaining:
raise ValueError(
f"unexpected text after parser finished:\n{remaining}"
)
# Parse raw data
records = list(parse(self._raw_data))
return _adjust_timings(records)
NamedTimings = namedtuple("NamedTimings", ["name", "timings"])
class PassTimingsCollection(Sequence):
"""A collection of pass timings.
This class implements the ``Sequence`` protocol for accessing the
individual timing records.
"""
def __init__(self, name):
self._name = name
self._records = []
@contextmanager
def record_legacy(self, name):
"""Record new timings and append to this collection.
Note: this is mainly for internal use inside the compiler pipeline.
See also ``RecordLLVMPassTimingsLegacy``
Parameters
----------
name: str
Name for the records.
"""
if config.LLVM_PASS_TIMINGS:
# Recording of pass timings is enabled
with RecordLLVMPassTimingsLegacy() as timings:
yield
rec = timings.get()
# Only keep non-empty records
if rec:
self._append(name, rec)
else:
# Do nothing. Recording of pass timings is disabled.
yield
@contextmanager
def record(self, name, pb):
"""Record new timings and append to this collection.
Note: this is mainly for internal use inside the compiler pipeline.
See also ``RecordLLVMPassTimings``
Parameters
----------
name: str
Name for the records.
"""
if config.LLVM_PASS_TIMINGS:
# Recording of pass timings is enabled
with RecordLLVMPassTimings(pb) as timings:
yield
rec = timings.get()
# Only keep non-empty records
if rec:
self._append(name, rec)
else:
# Do nothing. Recording of pass timings is disabled.
yield
def _append(self, name, timings):
"""Append timing records
Parameters
----------
name: str
Name for the records.
timings: ProcessedPassTimings
the timing records.
"""
self._records.append(NamedTimings(name, timings))
def get_total_time(self):
"""Computes the sum of the total time across all contained timings.
Returns
-------
res: float or None
Returns the total number of seconds or None if no timings were
recorded
"""
if self._records:
return sum(r.timings.get_total_time() for r in self._records)
else:
return None
def list_longest_first(self):
"""Returns the timings in descending order of total time duration.
Returns
-------
res: List[ProcessedPassTimings]
"""
return sorted(self._records,
key=lambda x: x.timings.get_total_time(),
reverse=True)
@property
def is_empty(self):
"""
"""
return not self._records
def summary(self, topn=5):
"""Return a string representing the summary of the timings.
Parameters
----------
topn: int; optional, default=5.
This limits the maximum number of items to show.
This function will show the ``topn`` most time-consuming passes.
Returns
-------
res: str
See also ``ProcessedPassTimings.summary()``
"""
if self.is_empty:
return "No pass timings were recorded"
else:
buf = []
ap = buf.append
ap(f"Printing pass timings for {self._name}")
overall_time = self.get_total_time()
ap(f"Total time: {overall_time:.4f}")
for i, r in enumerate(self._records):
ap(f"== #{i} {r.name}")
percent = r.timings.get_total_time() / overall_time * 100
ap(f" Percent: {percent:.1f}%")
ap(r.timings.summary(topn=topn, indent=1))
return "\n".join(buf)
def __getitem__(self, i):
"""Get the i-th timing record.
Returns
-------
res: (name, timings)
A named tuple with two fields:
- name: str
- timings: ProcessedPassTimings
"""
return self._records[i]
def __len__(self):
"""Length of this collection.
"""
return len(self._records)
def __str__(self):
return self.summary()

View File

@@ -0,0 +1,196 @@
"""
Memory monitoring utilities for measuring memory usage.
Example usage:
tracker = MemoryTracker("my_function")
with tracker.monitor():
my_function()
# Access data: tracker.rss_delta, tracker.duration, etc.
# Get formatted string: tracker.get_summary()
"""
from __future__ import annotations
import os
import contextlib
import time
from typing import Dict, Optional
try:
import psutil
_HAS_PSUTIL = True
except ImportError:
_HAS_PSUTIL = False
IS_SUPPORTED = _HAS_PSUTIL
def get_available_memory() -> Optional[int]:
"""
Get current available system memory in bytes.
Used for memory threshold checking in parallel test execution.
Returns:
int or None: Available memory in bytes, or None if unavailable
"""
if _HAS_PSUTIL:
try:
sys_mem = psutil.virtual_memory()
return sys_mem.available
except Exception:
pass
return None
def get_memory_usage() -> Dict[str, Optional[int]]:
"""
Get memory usage information needed for monitoring.
Returns only RSS and available memory which are the fields
actually used by the MemoryTracker.
Returns:
dict: Memory usage information including:
- rss: Current process RSS (physical memory currently used)
- available: Available system memory
"""
memory_info = {}
if _HAS_PSUTIL:
try:
# Get current process RSS
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
memory_info["rss"] = mem_info.rss
# Get system available memory
sys_mem = psutil.virtual_memory()
memory_info["available"] = sys_mem.available
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
# Set defaults if unavailable
if "rss" not in memory_info:
memory_info["rss"] = None
if "available" not in memory_info:
memory_info["available"] = None
return memory_info
class MemoryTracker:
"""
A simple memory monitor that tracks RSS delta and timing.
Stores monitoring data in instance attributes for later access.
Each instance is typically used for monitoring a single operation.
"""
pid: int
name: str
start_time: float | None
end_time: float | None
start_memory: Dict[str, int | None] | None
end_memory: Dict[str, int | None] | None
duration: float | None
rss_delta: int | None
def __init__(self, name: str):
"""Initialize a MemoryTracker with empty monitoring data."""
self.pid = os.getpid()
self.name = name
self.start_time = None
self.end_time = None
self.start_memory = None
self.end_memory = None
self.duration = None
self.rss_delta = None
@contextlib.contextmanager
def monitor(self):
"""
Context manager to monitor memory usage during function execution.
Records start/end memory usage and timing, calculates RSS delta,
and stores all data in instance attributes.
Args:
name (str): Name/identifier for the function or operation being
monitored
Yields:
self: The MemoryTracker instance for accessing stored data
"""
# Store data in self and record start time and memory usage
self.start_time = time.time()
self.start_memory = get_memory_usage()
try:
yield self
finally:
# Record end time and memory usage
self.end_time = time.time()
self.end_memory = get_memory_usage()
self.duration = self.end_time - self.start_time
# Calculate RSS delta
start_rss = self.start_memory.get("rss", 0)
end_rss = self.end_memory.get("rss", 0)
self.rss_delta = ((end_rss - start_rss)
if start_rss and end_rss else 0)
def get_summary(self) -> str:
"""
Return a formatted summary of the memory monitoring data.
Formats the stored monitoring data into a human-readable string
containing name, PID, RSS delta, available memory, duration,
and start time.
Returns:
str: Formatted summary string with monitoring results
Note:
Should be called after monitor() context has completed
to ensure all data is available.
"""
if self.start_memory is None or self.end_memory is None:
raise ValueError("Memory monitoring data not available")
current_available = self.end_memory.get("available")
def format_bytes(bytes_val, show_sign=False):
"""Convert bytes to human readable format"""
if bytes_val is None:
return "N/A"
if bytes_val == 0:
return "0 B"
sign = ""
if show_sign:
sign = "-" if bytes_val < 0 else "+"
bytes_val = abs(bytes_val)
for unit in ["B", "KB", "MB", "GB"]:
if bytes_val < 1024.0:
return f"{sign}{bytes_val:.2f} {unit}"
bytes_val /= 1024.0
return f"{sign}{bytes_val:.2f} TB"
start_ts = time.strftime("%H:%M:%S", time.localtime(self.start_time))
start_rss = self.start_memory.get("rss", 0)
end_rss = self.end_memory.get("rss", 0)
buf = [
f"Name: {self.name}",
f"PID: {self.pid}",
f"Start: {start_ts}",
f"Duration: {self.duration:.3f}s",
f"Start RSS: {format_bytes(start_rss)}",
f"End RSS: {format_bytes(end_rss)}",
f"RSS delta: {format_bytes(self.rss_delta, show_sign=True)}",
f"Avail memory: {format_bytes(current_available)}",
]
return ' | '.join(buf)

View File

@@ -0,0 +1,126 @@
"""
The same algorithm as translated from numpy.
See numpy/core/src/npysort/mergesort.c.src.
The high-level numba code is adding a little overhead comparing to
the pure-C implementation in numpy.
"""
import numpy as np
from collections import namedtuple
# Array size smaller than this will be sorted by insertion sort
SMALL_MERGESORT = 20
MergesortImplementation = namedtuple('MergesortImplementation', [
'run_mergesort',
])
def make_mergesort_impl(wrap, lt=None, is_argsort=False):
kwargs_lite = dict(no_cpython_wrapper=True, _nrt=False)
# The less than
if lt is None:
@wrap(**kwargs_lite)
def lt(a, b):
return a < b
else:
lt = wrap(**kwargs_lite)(lt)
if is_argsort:
@wrap(**kwargs_lite)
def lessthan(a, b, vals):
return lt(vals[a], vals[b])
else:
@wrap(**kwargs_lite)
def lessthan(a, b, vals):
return lt(a, b)
@wrap(**kwargs_lite)
def argmergesort_inner(arr, vals, ws):
"""The actual mergesort function
Parameters
----------
arr : array [read+write]
The values being sorted inplace. For argsort, this is the
indices.
vals : array [readonly]
``None`` for normal sort. In argsort, this is the actual array values.
ws : array [write]
The workspace. Must be of size ``arr.size // 2``
"""
if arr.size > SMALL_MERGESORT:
# Merge sort
mid = arr.size // 2
argmergesort_inner(arr[:mid], vals, ws)
argmergesort_inner(arr[mid:], vals, ws)
# Copy left half into workspace so we don't overwrite it
for i in range(mid):
ws[i] = arr[i]
# Merge
left = ws[:mid]
right = arr[mid:]
out = arr
i = j = k = 0
while i < left.size and j < right.size:
if not lessthan(right[j], left[i], vals):
out[k] = left[i]
i += 1
else:
out[k] = right[j]
j += 1
k += 1
# Leftovers
while i < left.size:
out[k] = left[i]
i += 1
k += 1
while j < right.size:
out[k] = right[j]
j += 1
k += 1
else:
# Insertion sort
i = 1
while i < arr.size:
j = i
while j > 0 and lessthan(arr[j], arr[j - 1], vals):
arr[j - 1], arr[j] = arr[j], arr[j - 1]
j -= 1
i += 1
# The top-level entry points
@wrap(no_cpython_wrapper=True)
def mergesort(arr):
"Inplace"
ws = np.empty(arr.size // 2, dtype=arr.dtype)
argmergesort_inner(arr, None, ws)
return arr
@wrap(no_cpython_wrapper=True)
def argmergesort(arr):
"Out-of-place"
idxs = np.arange(arr.size)
ws = np.empty(arr.size // 2, dtype=idxs.dtype)
argmergesort_inner(idxs, arr, ws)
return idxs
return MergesortImplementation(
run_mergesort=(argmergesort if is_argsort else mergesort)
)
def make_jit_mergesort(*args, **kwargs):
from numba import njit
# NOTE: wrap with njit to allow recursion
# because @register_jitable => @overload doesn't support recursion
return make_mergesort_impl(njit, *args, **kwargs)

View File

@@ -0,0 +1,72 @@
import sys
import argparse
import os
import subprocess
import json
from .numba_sysinfo import display_sysinfo, get_sysinfo
from .numba_gdbinfo import display_gdbinfo
def make_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--annotate', help='Annotate source',
action='store_true')
parser.add_argument('--dump-llvm', action="store_true",
help='Print generated llvm assembly')
parser.add_argument('--dump-optimized', action='store_true',
help='Dump the optimized llvm assembly')
parser.add_argument('--dump-assembly', action='store_true',
help='Dump the LLVM generated assembly')
parser.add_argument('--annotate-html', nargs=1,
help='Output source annotation as html')
parser.add_argument('-s', '--sysinfo', action="store_true",
help='Output system information for bug reporting')
parser.add_argument('-g', '--gdbinfo', action="store_true",
help='Output system information about gdb')
parser.add_argument('--sys-json', nargs=1,
help='Saves the system info dict as a json file')
parser.add_argument('filename', nargs='?', help='Python source filename')
return parser
def main():
parser = make_parser()
args = parser.parse_args()
if args.sysinfo:
print("System info:")
display_sysinfo()
if args.gdbinfo:
print("GDB info:")
display_gdbinfo()
if args.sysinfo or args.gdbinfo:
sys.exit(0)
if args.sys_json:
info = get_sysinfo()
info.update({'Start': info['Start'].isoformat()})
info.update({'Start UTC': info['Start UTC'].isoformat()})
with open(args.sys_json[0], 'w') as f:
json.dump(info, f, indent=4)
sys.exit(0)
os.environ['NUMBA_DUMP_ANNOTATION'] = str(int(args.annotate))
if args.annotate_html is not None:
try:
from jinja2 import Template
except ImportError:
raise ImportError("Please install the 'jinja2' package")
os.environ['NUMBA_DUMP_HTML'] = str(args.annotate_html[0])
os.environ['NUMBA_DUMP_LLVM'] = str(int(args.dump_llvm))
os.environ['NUMBA_DUMP_OPTIMIZED'] = str(int(args.dump_optimized))
os.environ['NUMBA_DUMP_ASSEMBLY'] = str(int(args.dump_assembly))
if args.filename:
cmd = [sys.executable, args.filename]
subprocess.call(cmd)
else:
print("numba: error: the following arguments are required: filename")
sys.exit(1)

View File

@@ -0,0 +1,161 @@
"""Module for displaying information about Numba's gdb set up"""
from collections import namedtuple
import os
import re
import subprocess
from textwrap import dedent
from numba import config
# Container for the output of the gdb info data collection
_fields = ('binary_loc, extension_loc, py_ver, np_ver, supported')
_gdb_info = namedtuple('_gdb_info', _fields)
class _GDBTestWrapper():
"""Wraps the gdb binary and has methods for checking what the gdb binary
has support for (Python and NumPy)."""
def __init__(self,):
gdb_binary = config.GDB_BINARY
if gdb_binary is None:
msg = ("No valid binary could be found for gdb named: "
f"{config.GDB_BINARY}")
raise ValueError(msg)
self._gdb_binary = gdb_binary
def _run_cmd(self, cmd=()):
gdb_call = [self.gdb_binary, '-q',]
for x in cmd:
gdb_call.append('-ex')
gdb_call.append(x)
gdb_call.extend(['-ex', 'q'])
return subprocess.run(gdb_call, capture_output=True, timeout=10,
text=True)
@property
def gdb_binary(self):
return self._gdb_binary
@classmethod
def success(cls, status):
return status.returncode == 0
def check_launch(self):
"""Checks that gdb will launch ok"""
return self._run_cmd()
def check_python(self):
cmd = ("python from __future__ import print_function; "
"import sys; print(sys.version_info[:2])")
return self._run_cmd((cmd,))
def check_numpy(self):
cmd = ("python from __future__ import print_function; "
"import types; import numpy; "
"print(isinstance(numpy, types.ModuleType))")
return self._run_cmd((cmd,))
def check_numpy_version(self):
cmd = ("python from __future__ import print_function; "
"import types; import numpy;"
"print(numpy.__version__)")
return self._run_cmd((cmd,))
def collect_gdbinfo():
"""Prints information to stdout about the gdb setup that Numba has found"""
# State flags:
gdb_state = None
gdb_has_python = False
gdb_has_numpy = False
gdb_python_version = 'No Python support'
gdb_python_numpy_version = "No NumPy support"
# There are so many ways for gdb to not be working as expected. Surround
# the "is it working" tests with try/except and if there's an exception
# store it for processing later.
try:
# Check gdb exists
gdb_wrapper = _GDBTestWrapper()
# Check gdb works
status = gdb_wrapper.check_launch()
if not gdb_wrapper.success(status):
msg = (f"gdb at '{gdb_wrapper.gdb_binary}' does not appear to work."
f"\nstdout: {status.stdout}\nstderr: {status.stderr}")
raise ValueError(msg)
gdb_state = gdb_wrapper.gdb_binary
except Exception as e:
gdb_state = f"Testing gdb binary failed. Reported Error: {e}"
else:
# Got this far, so gdb works, start checking what it supports
status = gdb_wrapper.check_python()
if gdb_wrapper.success(status):
version_match = re.match(r'\((\d+),\s+(\d+)\)',
status.stdout.strip())
if version_match is not None:
pymajor, pyminor = version_match.groups()
gdb_python_version = f"{pymajor}.{pyminor}"
gdb_has_python = True
status = gdb_wrapper.check_numpy()
if gdb_wrapper.success(status):
if "Traceback" not in status.stderr.strip():
if status.stdout.strip() == 'True':
gdb_has_numpy = True
gdb_python_numpy_version = "Unknown"
# NumPy is present find the version
status = gdb_wrapper.check_numpy_version()
if gdb_wrapper.success(status):
if "Traceback" not in status.stderr.strip():
gdb_python_numpy_version = \
status.stdout.strip()
# Work out what level of print-extension support is present in this gdb
if gdb_has_python:
if gdb_has_numpy:
print_ext_supported = "Full (Python and NumPy supported)"
else:
print_ext_supported = "Partial (Python only, no NumPy support)"
else:
print_ext_supported = "None"
# Work out print ext location
print_ext_file = "gdb_print_extension.py"
print_ext_path = os.path.join(os.path.dirname(__file__), print_ext_file)
# return!
return _gdb_info(gdb_state, print_ext_path, gdb_python_version,
gdb_python_numpy_version, print_ext_supported)
def display_gdbinfo(sep_pos=45):
"""Displays the information collected by collect_gdbinfo.
"""
gdb_info = collect_gdbinfo()
print('-' * 80)
fmt = f'%-{sep_pos}s : %-s'
# Display the information
print(fmt % ("Binary location", gdb_info.binary_loc))
print(fmt % ("Print extension location", gdb_info.extension_loc))
print(fmt % ("Python version", gdb_info.py_ver))
print(fmt % ("NumPy version", gdb_info.np_ver))
print(fmt % ("Numba printing extension support", gdb_info.supported))
print("")
print("To load the Numba gdb printing extension, execute the following "
"from the gdb prompt:")
print(f"\nsource {gdb_info.extension_loc}\n")
print('-' * 80)
warn = """
=============================================================
IMPORTANT: Before sharing you should remove any information
in the above that you wish to keep private e.g. paths.
=============================================================
"""
print(dedent(warn))
if __name__ == '__main__':
display_gdbinfo()

View File

@@ -0,0 +1,706 @@
import json
import locale
import multiprocessing
import os
import platform
import textwrap
import sys
from contextlib import redirect_stdout
from datetime import datetime
from io import StringIO
from subprocess import check_output, PIPE, CalledProcessError
import numpy as np
import llvmlite.binding as llvmbind
from llvmlite import __version__ as llvmlite_version
from numba import cuda as cu, __version__ as version_number
from numba.cuda import cudadrv
from numba.cuda.cudadrv.driver import driver as cudriver
from numba.cuda.cudadrv.runtime import runtime as curuntime
from numba.core import config
_psutil_import = False
try:
import psutil
except ImportError:
pass
else:
_psutil_import = True
__all__ = ['get_sysinfo', 'display_sysinfo']
# Keys of a `sysinfo` dictionary
# Time info
_start, _start_utc, _runtime = 'Start', 'Start UTC', 'Runtime'
_numba_version = 'Numba Version'
# Hardware info
_machine = 'Machine'
_cpu_name, _cpu_count = 'CPU Name', 'CPU Count'
_cpus_allowed, _cpus_list = 'CPUs Allowed', 'List CPUs Allowed'
_cpu_features = 'CPU Features'
_cfs_quota, _cfs_period = 'CFS Quota', 'CFS Period',
_cfs_restrict = 'CFS Restriction'
_mem_total, _mem_available = 'Mem Total', 'Mem Available'
# OS info
_platform_name, _platform_release = 'Platform Name', 'Platform Release'
_os_name, _os_version = 'OS Name', 'OS Version'
_os_spec_version = 'OS Specific Version'
_libc_version = 'Libc Version'
# Python info
_python_comp = 'Python Compiler'
_python_impl = 'Python Implementation'
_python_version = 'Python Version'
_python_locale = 'Python Locale'
# LLVM info
_llvmlite_version = 'llvmlite Version'
_llvm_version = 'LLVM Version'
# CUDA info
_cu_target_impl = 'CUDA Target Impl'
_cu_dev_init = 'CUDA Device Init'
_cu_drv_ver = 'CUDA Driver Version'
_cu_rt_ver = 'CUDA Runtime Version'
_cu_nvidia_bindings = 'NVIDIA CUDA Bindings'
_cu_nvidia_bindings_used = 'NVIDIA CUDA Bindings In Use'
_cu_detect_out, _cu_lib_test = 'CUDA Detect Output', 'CUDA Lib Test'
_cu_mvc_available = 'NVIDIA CUDA Minor Version Compatibility Available'
_cu_mvc_needed = 'NVIDIA CUDA Minor Version Compatibility Needed'
_cu_mvc_in_use = 'NVIDIA CUDA Minor Version Compatibility In Use'
# NumPy info
_numpy_version = 'NumPy Version'
_numpy_supported_simd_features = 'NumPy Supported SIMD features'
_numpy_supported_simd_dispatch = 'NumPy Supported SIMD dispatch'
_numpy_supported_simd_baseline = 'NumPy Supported SIMD baseline'
_numpy_AVX512_SKX_detected = 'NumPy AVX512_SKX detected'
# SVML info
_svml_state, _svml_loaded = 'SVML State', 'SVML Lib Loaded'
_llvm_svml_patched = 'LLVM SVML Patched'
_svml_operational = 'SVML Operational'
# Threading layer info
_tbb_thread, _tbb_error = 'TBB Threading', 'TBB Threading Error'
_openmp_thread, _openmp_error = 'OpenMP Threading', 'OpenMP Threading Error'
_openmp_vendor = 'OpenMP vendor'
_wkq_thread, _wkq_error = 'Workqueue Threading', 'Workqueue Threading Error'
# Numba info
_numba_env_vars = 'Numba Env Vars'
# Conda info
_conda_build_ver, _conda_env_ver = 'Conda Build', 'Conda Env'
_conda_platform, _conda_python_ver = 'Conda Platform', 'Conda Python Version'
_conda_root_writable = 'Conda Root Writable'
# Packages info
_inst_pkg = 'Installed Packages'
# Psutil info
_psutil = 'Psutil Available'
# Errors and warnings
_errors = 'Errors'
_warnings = 'Warnings'
# Error and warning log
_error_log = []
_warning_log = []
def get_os_spec_info(os_name):
# Linux man page for `/proc`:
# http://man7.org/linux/man-pages/man5/proc.5.html
# WMIC is deprecated in windows-2025, using powershell instead
# https://techcommunity.microsoft.com/t5/windows-it-pro-blog/wmi-command-line-wmic-utility-deprecation-next-steps/ba-p/4039242
# Windows documentation for `powershell`:
# https://learn.microsoft.com/en-us/powershell/
# MacOS man page for `sysctl`:
# https://www.unix.com/man-page/osx/3/sysctl/
# MacOS man page for `vm_stat`:
# https://www.unix.com/man-page/osx/1/vm_stat/
class CmdBufferOut(tuple):
buffer_output_flag = True
class CmdReadFile(tuple):
read_file_flag = True
shell_params = {
'Linux': {
'cmd': (
CmdReadFile(('/sys/fs/cgroup/cpuacct/cpu.cfs_quota_us',)),
CmdReadFile(('/sys/fs/cgroup/cpuacct/cpu.cfs_period_us',)),
),
'cmd_optional': (
CmdReadFile(('/proc/meminfo',)),
CmdReadFile(('/proc/self/status',)),
),
'kwds': {
# output string fragment -> result dict key
'MemTotal:': _mem_total,
'MemAvailable:': _mem_available,
'Cpus_allowed:': _cpus_allowed,
'Cpus_allowed_list:': _cpus_list,
'/sys/fs/cgroup/cpuacct/cpu.cfs_quota_us': _cfs_quota,
'/sys/fs/cgroup/cpuacct/cpu.cfs_period_us': _cfs_period,
},
},
'Windows': {
'cmd': (),
'cmd_optional': (
CmdBufferOut(('powershell', '-NoProfile', '-Command',
"'TotalVirtualMemorySize ' + "
"(Get-CimInstance -ClassName "
"Win32_OperatingSystem).TotalVirtualMemorySize")),
CmdBufferOut(('powershell', '-NoProfile', '-Command',
"'FreeVirtualMemory ' + "
"(Get-CimInstance -ClassName "
"Win32_OperatingSystem).FreeVirtualMemory")),
),
'kwds': {
# output string fragment -> result dict key
'TotalVirtualMemorySize': _mem_total,
'FreeVirtualMemory': _mem_available,
},
},
'Darwin': {
'cmd': (),
'cmd_optional': (
('sysctl', 'hw.memsize'),
('vm_stat'),
),
'kwds': {
# output string fragment -> result dict key
'hw.memsize:': _mem_total,
'free:': _mem_available,
},
'units': {
_mem_total: 1, # Size is given in bytes.
_mem_available: 4096, # Size is given in 4kB pages.
},
},
}
os_spec_info = {}
params = shell_params.get(os_name, {})
cmd_selected = params.get('cmd', ())
if _psutil_import:
vm = psutil.virtual_memory()
os_spec_info.update({
_mem_total: vm.total,
_mem_available: vm.available,
})
p = psutil.Process()
cpus_allowed = p.cpu_affinity() if hasattr(p, 'cpu_affinity') else []
if cpus_allowed:
os_spec_info[_cpus_allowed] = len(cpus_allowed)
os_spec_info[_cpus_list] = ' '.join(str(n) for n in cpus_allowed)
else:
_warning_log.append(
"Warning (psutil): psutil cannot be imported. "
"For more accuracy, consider installing it.")
# Fallback to internal heuristics
cmd_selected += params.get('cmd_optional', ())
# Assuming the shell cmd returns a unique (k, v) pair per line
# or a unique (k, v) pair spread over several lines:
# Gather output in a list of strings containing a keyword and some value.
output = []
for cmd in cmd_selected:
if hasattr(cmd, 'read_file_flag'):
# Open file within Python
if os.path.exists(cmd[0]):
try:
with open(cmd[0], 'r') as f:
out = f.readlines()
if out:
out[0] = ' '.join((cmd[0], out[0]))
output.extend(out)
except OSError as e:
_error_log.append(f'Error (file read): {e}')
continue
else:
_warning_log.append('Warning (no file): {}'.format(cmd[0]))
continue
else:
# Spawn a subprocess
try:
out = check_output(cmd, stderr=PIPE)
except (OSError, CalledProcessError) as e:
_error_log.append(f'Error (subprocess): {e}')
continue
if hasattr(cmd, 'buffer_output_flag'):
out = b' '.join(line for line in out.splitlines()) + b'\n'
output.extend(out.decode().splitlines())
# Extract (k, output) pairs by searching for keywords in output
kwds = params.get('kwds', {})
for line in output:
match = kwds.keys() & line.split()
if match and len(match) == 1:
k = kwds[match.pop()]
os_spec_info[k] = line
elif len(match) > 1:
print(f'Ambiguous output: {line}')
# Try to extract something meaningful from output string
def format():
# CFS restrictions
split = os_spec_info.get(_cfs_quota, '').split()
if split:
os_spec_info[_cfs_quota] = float(split[-1])
split = os_spec_info.get(_cfs_period, '').split()
if split:
os_spec_info[_cfs_period] = float(split[-1])
if os_spec_info.get(_cfs_quota, -1) != -1:
cfs_quota = os_spec_info.get(_cfs_quota, '')
cfs_period = os_spec_info.get(_cfs_period, '')
runtime_amount = cfs_quota / cfs_period
os_spec_info[_cfs_restrict] = runtime_amount
def format_optional():
# Memory
units = {_mem_total: 1024, _mem_available: 1024}
units.update(params.get('units', {}))
for k in (_mem_total, _mem_available):
digits = ''.join(d for d in os_spec_info.get(k, '') if d.isdigit())
os_spec_info[k] = int(digits or 0) * units[k]
# Accessible CPUs
split = os_spec_info.get(_cpus_allowed, '').split()
if split:
n = split[-1]
n = n.split(',')[-1]
os_spec_info[_cpus_allowed] = str(bin(int(n or 0, 16))).count('1')
split = os_spec_info.get(_cpus_list, '').split()
if split:
os_spec_info[_cpus_list] = split[-1]
try:
format()
if not _psutil_import:
format_optional()
except Exception as e:
_error_log.append(f'Error (format shell output): {e}')
# Call OS specific functions
os_specific_funcs = {
'Linux': {
_libc_version: lambda: ' '.join(platform.libc_ver())
},
'Windows': {
_os_spec_version: lambda: ' '.join(
s for s in platform.win32_ver()),
},
'Darwin': {
_os_spec_version: lambda: ''.join(
i or ' ' for s in tuple(platform.mac_ver()) for i in s),
},
}
key_func = os_specific_funcs.get(os_name, {})
os_spec_info.update({k: f() for k, f in key_func.items()})
return os_spec_info
def get_sysinfo():
# Gather the information that shouldn't raise exceptions
sys_info = {
_start: datetime.now(),
_start_utc: datetime.utcnow(),
_machine: platform.machine(),
_cpu_name: llvmbind.get_host_cpu_name(),
_cpu_count: multiprocessing.cpu_count(),
_platform_name: platform.platform(aliased=True),
_platform_release: platform.release(),
_os_name: platform.system(),
_os_version: platform.version(),
_python_comp: platform.python_compiler(),
_python_impl: platform.python_implementation(),
_python_version: platform.python_version(),
_numba_env_vars: {k: v for (k, v) in os.environ.items()
if k.startswith('NUMBA_')},
_numba_version: version_number,
_llvm_version: '.'.join(str(i) for i in llvmbind.llvm_version_info),
_llvmlite_version: llvmlite_version,
_psutil: _psutil_import,
}
# CPU features
try:
feature_map = llvmbind.get_host_cpu_features()
except RuntimeError as e:
_error_log.append(f'Error (CPU features): {e}')
else:
features = sorted([key for key, value in feature_map.items() if value])
sys_info[_cpu_features] = ' '.join(features)
# Python locale
# On MacOSX, getdefaultlocale can raise. Check again if Py > 3.7.5
try:
# If $LANG is unset, getdefaultlocale() can return (None, None), make
# sure we can encode this as strings by casting explicitly.
sys_info[_python_locale] = '.'.join([str(i) for i in
locale.getdefaultlocale()])
except Exception as e:
_error_log.append(f'Error (locale): {e}')
# CUDA information
try:
sys_info[_cu_target_impl] = cu.implementation
except AttributeError:
# On the offchance an out-of-tree target did not set the
# implementation, we can try to continue
pass
try:
cu.list_devices()[0] # will a device initialise?
except Exception as e:
sys_info[_cu_dev_init] = False
msg_not_found = "CUDA driver library cannot be found"
msg_disabled_by_user = "CUDA is disabled"
msg_end = " or no CUDA enabled devices are present."
msg_generic_problem = "CUDA device initialisation problem."
msg = getattr(e, 'msg', None)
if msg is not None:
if msg_not_found in msg:
err_msg = msg_not_found + msg_end
elif msg_disabled_by_user in msg:
err_msg = msg_disabled_by_user + msg_end
else:
err_msg = msg_generic_problem + " Message:" + msg
else:
err_msg = msg_generic_problem + " " + str(e)
# Best effort error report
_warning_log.append("Warning (cuda): %s\nException class: %s" %
(err_msg, str(type(e))))
else:
try:
sys_info[_cu_dev_init] = True
output = StringIO()
with redirect_stdout(output):
cu.detect()
sys_info[_cu_detect_out] = output.getvalue()
output.close()
cu_drv_ver = cudriver.get_version()
cu_rt_ver = curuntime.get_version()
sys_info[_cu_drv_ver] = '%s.%s' % cu_drv_ver
sys_info[_cu_rt_ver] = '%s.%s' % cu_rt_ver
output = StringIO()
with redirect_stdout(output):
cudadrv.libs.test()
sys_info[_cu_lib_test] = output.getvalue()
output.close()
try:
from cuda import cuda # noqa: F401
nvidia_bindings_available = True
except ImportError:
nvidia_bindings_available = False
sys_info[_cu_nvidia_bindings] = nvidia_bindings_available
nv_binding_used = bool(cudadrv.driver.USE_NV_BINDING)
sys_info[_cu_nvidia_bindings_used] = nv_binding_used
try:
from ptxcompiler import compile_ptx # noqa: F401
from cubinlinker import CubinLinker # noqa: F401
sys_info[_cu_mvc_available] = True
except ImportError:
sys_info[_cu_mvc_available] = False
sys_info[_cu_mvc_needed] = cu_rt_ver > cu_drv_ver
sys_info[_cu_mvc_in_use] = bool(
config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY)
except Exception as e:
_warning_log.append(
"Warning (cuda): Probing CUDA failed "
"(device and driver present, runtime problem?)\n"
f"(cuda) {type(e)}: {e}")
# NumPy information
sys_info[_numpy_version] = np.version.full_version
try:
# NOTE: These consts were added in NumPy 1.20
from numpy.core._multiarray_umath import (__cpu_features__,
__cpu_dispatch__,
__cpu_baseline__,)
except ImportError:
sys_info[_numpy_AVX512_SKX_detected] = False
else:
feat_filtered = [k for k, v in __cpu_features__.items() if v]
sys_info[_numpy_supported_simd_features] = feat_filtered
sys_info[_numpy_supported_simd_dispatch] = __cpu_dispatch__
sys_info[_numpy_supported_simd_baseline] = __cpu_baseline__
sys_info[_numpy_AVX512_SKX_detected] = \
__cpu_features__.get("AVX512_SKX", False)
# SVML information
# Replicate some SVML detection logic from numba.__init__ here.
# If SVML load fails in numba.__init__ the splitting of the logic
# here will help diagnosing the underlying issue.
svml_lib_loaded = True
try:
if sys.platform.startswith('linux'):
llvmbind.load_library_permanently("libsvml.so")
elif sys.platform.startswith('darwin'):
llvmbind.load_library_permanently("libsvml.dylib")
elif sys.platform.startswith('win'):
llvmbind.load_library_permanently("svml_dispmd")
else:
svml_lib_loaded = False
except Exception:
svml_lib_loaded = False
func = getattr(llvmbind.targets, "has_svml", None)
sys_info[_llvm_svml_patched] = func() if func else False
sys_info[_svml_state] = config.USING_SVML
sys_info[_svml_loaded] = svml_lib_loaded
sys_info[_svml_operational] = all((
sys_info[_svml_state],
sys_info[_svml_loaded],
sys_info[_llvm_svml_patched],
))
# Check which threading backends are available.
def parse_error(e, backend):
# parses a linux based error message, this is to provide feedback
# and hide user paths etc
try:
path, problem, symbol = [x.strip() for x in e.msg.split(':')]
extn_dso = os.path.split(path)[1]
if backend in extn_dso:
return "%s: %s" % (problem, symbol)
except Exception:
pass
return "Unknown import problem."
try:
# check import is ok, this means the DSO linkage is working
from numba.np.ufunc import tbbpool # NOQA
# check that the version is compatible, this is a check performed at
# runtime (well, compile time), it will also ImportError if there's
# a problem.
from numba.np.ufunc.parallel import _check_tbb_version_compatible
_check_tbb_version_compatible()
sys_info[_tbb_thread] = True
except ImportError as e:
# might be a missing symbol due to e.g. tbb libraries missing
sys_info[_tbb_thread] = False
sys_info[_tbb_error] = parse_error(e, 'tbbpool')
try:
from numba.np.ufunc import omppool
sys_info[_openmp_thread] = True
sys_info[_openmp_vendor] = omppool.openmp_vendor
except ImportError as e:
sys_info[_openmp_thread] = False
sys_info[_openmp_error] = parse_error(e, 'omppool')
try:
from numba.np.ufunc import workqueue # NOQA
sys_info[_wkq_thread] = True
except ImportError as e:
sys_info[_wkq_thread] = True
sys_info[_wkq_error] = parse_error(e, 'workqueue')
# Look for conda and installed packages information
cmd = ('conda', 'info', '--json')
try:
conda_out = check_output(cmd)
except Exception as e:
_warning_log.append(f'Warning: Conda not available.\n Error was {e}\n')
# Conda is not available, try pip list to list installed packages
cmd = (sys.executable, '-m', 'pip', 'list')
try:
reqs = check_output(cmd)
except Exception as e:
_error_log.append(f'Error (pip): {e}')
else:
sys_info[_inst_pkg] = reqs.decode().splitlines()
else:
jsond = json.loads(conda_out.decode())
keys = {
'conda_build_version': _conda_build_ver,
'conda_env_version': _conda_env_ver,
'platform': _conda_platform,
'python_version': _conda_python_ver,
'root_writable': _conda_root_writable,
}
for conda_k, sysinfo_k in keys.items():
sys_info[sysinfo_k] = jsond.get(conda_k, 'N/A')
# Get info about packages in current environment
cmd = ('conda', 'list')
try:
conda_out = check_output(cmd)
except CalledProcessError as e:
_error_log.append(f'Error (conda): {e}')
else:
data = conda_out.decode().splitlines()
sys_info[_inst_pkg] = [l for l in data if not l.startswith('#')]
sys_info.update(get_os_spec_info(sys_info[_os_name]))
sys_info[_errors] = _error_log
sys_info[_warnings] = _warning_log
sys_info[_runtime] = (datetime.now() - sys_info[_start]).total_seconds()
return sys_info
def display_sysinfo(info=None, sep_pos=45):
class DisplayMap(dict):
display_map_flag = True
class DisplaySeq(tuple):
display_seq_flag = True
class DisplaySeqMaps(tuple):
display_seqmaps_flag = True
if info is None:
info = get_sysinfo()
fmt = f'%-{sep_pos}s : %-s'
MB = 1024**2
template = (
("-" * 80,),
("__Time Stamp__",),
("Report started (local time)", info.get(_start, '?')),
("UTC start time", info.get(_start_utc, '?')),
("Running time (s)", info.get(_runtime, '?')),
("",),
("__Hardware Information__",),
("Machine", info.get(_machine, '?')),
("CPU Name", info.get(_cpu_name, '?')),
("CPU Count", info.get(_cpu_count, '?')),
("Number of accessible CPUs", info.get(_cpus_allowed, '?')),
("List of accessible CPUs cores", info.get(_cpus_list, '?')),
("CFS Restrictions (CPUs worth of runtime)",
info.get(_cfs_restrict, 'None')),
("",),
("CPU Features", '\n'.join(
' ' * (sep_pos + 3) + l if i else l
for i, l in enumerate(
textwrap.wrap(
info.get(_cpu_features, '?'),
width=79 - sep_pos
)
)
)),
("",),
("Memory Total (MB)", info.get(_mem_total, 0) // MB or '?'),
("Memory Available (MB)"
if info.get(_os_name, '') != 'Darwin' or info.get(_psutil, False)
else "Free Memory (MB)", info.get(_mem_available, 0) // MB or '?'),
("",),
("__OS Information__",),
("Platform Name", info.get(_platform_name, '?')),
("Platform Release", info.get(_platform_release, '?')),
("OS Name", info.get(_os_name, '?')),
("OS Version", info.get(_os_version, '?')),
("OS Specific Version", info.get(_os_spec_version, '?')),
("Libc Version", info.get(_libc_version, '?')),
("",),
("__Python Information__",),
DisplayMap({k: v for k, v in info.items() if k.startswith('Python')}),
("",),
("__Numba Toolchain Versions__",),
("Numba Version", info.get(_numba_version, '?')),
("llvmlite Version", info.get(_llvmlite_version, '?')),
("",),
("__LLVM Information__",),
("LLVM Version", info.get(_llvm_version, '?')),
("",),
("__CUDA Information__",),
("CUDA Target Implementation", info.get(_cu_target_impl, '?')),
("CUDA Device Initialized", info.get(_cu_dev_init, '?')),
("CUDA Driver Version", info.get(_cu_drv_ver, '?')),
("CUDA Runtime Version", info.get(_cu_rt_ver, '?')),
("CUDA NVIDIA Bindings Available", info.get(_cu_nvidia_bindings, '?')),
("CUDA NVIDIA Bindings In Use",
info.get(_cu_nvidia_bindings_used, '?')),
("CUDA Minor Version Compatibility Available",
info.get(_cu_mvc_available, '?')),
("CUDA Minor Version Compatibility Needed",
info.get(_cu_mvc_needed, '?')),
("CUDA Minor Version Compatibility In Use",
info.get(_cu_mvc_in_use, '?')),
("CUDA Detect Output:",),
(info.get(_cu_detect_out, "None"),),
("CUDA Libraries Test Output:",),
(info.get(_cu_lib_test, "None"),),
("",),
("__NumPy Information__",),
("NumPy Version", info.get(_numpy_version, '?')),
("NumPy Supported SIMD features",
DisplaySeq(info.get(_numpy_supported_simd_features, [])
or ('None found.',))),
("NumPy Supported SIMD dispatch",
DisplaySeq(info.get(_numpy_supported_simd_dispatch, [])
or ('None found.',))),
("NumPy Supported SIMD baseline",
DisplaySeq(info.get(_numpy_supported_simd_baseline, [])
or ('None found.',))),
("NumPy AVX512_SKX support detected",
info.get(_numpy_AVX512_SKX_detected, '?')),
("",),
("__SVML Information__",),
("SVML State, config.USING_SVML", info.get(_svml_state, '?')),
("SVML Library Loaded", info.get(_svml_loaded, '?')),
("llvmlite Using SVML Patched LLVM", info.get(_llvm_svml_patched, '?')),
("SVML Operational", info.get(_svml_operational, '?')),
("",),
("__Threading Layer Information__",),
("TBB Threading Layer Available", info.get(_tbb_thread, '?')),
("+-->TBB imported successfully." if info.get(_tbb_thread, '?')
else f"+--> Disabled due to {info.get(_tbb_error, '?')}",),
("OpenMP Threading Layer Available", info.get(_openmp_thread, '?')),
(f"+-->Vendor: {info.get(_openmp_vendor, '?')}"
if info.get(_openmp_thread, False)
else f"+--> Disabled due to {info.get(_openmp_error, '?')}",),
("Workqueue Threading Layer Available", info.get(_wkq_thread, '?')),
("+-->Workqueue imported successfully." if info.get(_wkq_thread, False)
else f"+--> Disabled due to {info.get(_wkq_error, '?')}",),
("",),
("__Numba Environment Variable Information__",),
(DisplayMap(info.get(_numba_env_vars, {})) or ('None found.',)),
("",),
("__Conda Information__",),
(DisplayMap({k: v for k, v in info.items()
if k.startswith('Conda')}) or ("Conda not available.",)),
("",),
("__Installed Packages__",),
DisplaySeq(info.get(_inst_pkg, ("Couldn't retrieve packages info.",))),
("",),
("__Error log__" if info.get(_errors, [])
else "No errors reported.",),
DisplaySeq(info.get(_errors, [])),
("",),
("__Warning log__" if info.get(_warnings, [])
else "No warnings reported.",),
DisplaySeq(info.get(_warnings, [])),
("-" * 80,),
("If requested, please copy and paste the information between\n"
"the dashed (----) lines, or from a given specific section as\n"
"appropriate.\n\n"
"=============================================================\n"
"IMPORTANT: Please ensure that you are happy with sharing the\n"
"contents of the information present, any information that you\n"
"wish to keep private you should remove before sharing.\n"
"=============================================================\n",),
)
for t in template:
if hasattr(t, 'display_seq_flag'):
print(*t, sep='\n')
elif hasattr(t, 'display_map_flag'):
print(*tuple(fmt % (k, v) for (k, v) in t.items()), sep='\n')
elif hasattr(t, 'display_seqmaps_flag'):
for d in t:
print(*tuple(fmt % ('\t' + k, v) for (k, v) in d.items()),
sep='\n', end='\n')
elif len(t) == 2:
print(fmt % t)
else:
print(*t)
if __name__ == '__main__':
display_sysinfo()

View File

@@ -0,0 +1,258 @@
import collections
import numpy as np
from numba.core import types, config
QuicksortImplementation = collections.namedtuple(
'QuicksortImplementation',
(# The compile function itself
'compile',
# All subroutines exercised by test_sort
'partition', 'partition3', 'insertion_sort',
# The top-level function
'run_quicksort',
))
Partition = collections.namedtuple('Partition', ('start', 'stop'))
# Under this size, switch to a simple insertion sort
SMALL_QUICKSORT = 15
MAX_STACK = 100
def make_quicksort_impl(wrap, lt=None, is_argsort=False, is_list=False, is_np_array=False):
intp = types.intp
zero = intp(0)
# Two subroutines to make the core algorithm generic wrt. argsort
# or normal sorting. Note the genericity may make basic sort()
# slightly slower (~5%)
if is_argsort:
if is_list:
@wrap
def make_res(A):
return [x for x in range(len(A))]
else:
@wrap
def make_res(A):
return np.arange(A.size)
@wrap
def GET(A, idx_or_val):
return A[idx_or_val]
else:
@wrap
def make_res(A):
return A
@wrap
def GET(A, idx_or_val):
return idx_or_val
def default_lt(a, b):
"""
Trivial comparison function between two keys.
"""
return a < b
LT = wrap(lt if lt is not None else default_lt)
@wrap
def insertion_sort(A, R, low, high):
"""
Insertion sort A[low:high + 1]. Note the inclusive bounds.
"""
assert low >= 0
if high <= low:
return
for i in range(low + 1, high + 1):
k = R[i]
v = GET(A, k)
# Insert v into A[low:i]
j = i
while j > low and LT(v, GET(A, R[j - 1])):
# Make place for moving A[i] downwards
R[j] = R[j - 1]
j -= 1
R[j] = k
@wrap
def partition(A, R, low, high):
"""
Partition A[low:high + 1] around a chosen pivot. The pivot's index
is returned.
"""
assert low >= 0
assert high > low
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.
# median of three {low, middle, high}
if LT(GET(A, R[mid]), GET(A, R[low])):
R[low], R[mid] = R[mid], R[low]
if LT(GET(A, R[high]), GET(A, R[mid])):
R[high], R[mid] = R[mid], R[high]
if LT(GET(A, R[mid]), GET(A, R[low])):
R[low], R[mid] = R[mid], R[low]
pivot = GET(A, R[mid])
# Temporarily stash the pivot at the end
R[high], R[mid] = R[mid], R[high]
i = low
j = high - 1
while True:
while i < high and LT(GET(A, R[i]), pivot):
i += 1
while j >= low and LT(pivot, GET(A, R[j])):
j -= 1
if i >= j:
break
R[i], R[j] = R[j], R[i]
i += 1
j -= 1
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
R[i], R[high] = R[high], R[i]
return i
@wrap
def partition3(A, low, high):
"""
Three-way partition [low, high) around a chosen pivot.
A tuple (lt, gt) is returned such that:
- all elements in [low, lt) are < pivot
- all elements in [lt, gt] are == pivot
- all elements in (gt, high] are > pivot
"""
mid = (low + high) >> 1
# median of three {low, middle, high}
if LT(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
if LT(A[high], A[mid]):
A[high], A[mid] = A[mid], A[high]
if LT(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
A[low], A[mid] = A[mid], A[low]
lt = low
gt = high
i = low + 1
while i <= gt:
if LT(A[i], pivot):
A[lt], A[i] = A[i], A[lt]
lt += 1
i += 1
elif LT(pivot, A[i]):
A[gt], A[i] = A[i], A[gt]
gt -= 1
else:
i += 1
return lt, gt
@wrap
def run_quicksort1(A):
R = make_res(A)
if len(A) < 2:
return R
stack = [Partition(zero, zero)] * MAX_STACK
stack[0] = Partition(zero, len(A) - 1)
n = 1
while n > 0:
n -= 1
low, high = stack[n]
# Partition until it becomes more efficient to do an insertion sort
while high - low >= SMALL_QUICKSORT:
assert n < MAX_STACK
i = partition(A, R, low, high)
# Push largest partition on the stack
if high - i > i - low:
# Right is larger
if high > i:
stack[n] = Partition(i + 1, high)
n += 1
high = i - 1
else:
if i > low:
stack[n] = Partition(low, i - 1)
n += 1
low = i + 1
insertion_sort(A, R, low, high)
return R
if is_np_array:
@wrap
def run_quicksort(A):
if A.ndim == 1:
return run_quicksort1(A)
else:
for idx in np.ndindex(A.shape[:-1]):
run_quicksort1(A[idx])
return A
else:
@wrap
def run_quicksort(A):
return run_quicksort1(A)
# Unused quicksort implementation based on 3-way partitioning; the
# partitioning scheme turns out exhibiting bad behaviour on sorted arrays.
@wrap
def _run_quicksort(A):
stack = [Partition(zero, zero)] * 100
stack[0] = Partition(zero, len(A) - 1)
n = 1
while n > 0:
n -= 1
low, high = stack[n]
# Partition until it becomes more efficient to do an insertion sort
while high - low >= SMALL_QUICKSORT:
assert n < MAX_STACK
l, r = partition3(A, low, high)
# One trivial (empty) partition => iterate on the other
if r == high:
high = l - 1
elif l == low:
low = r + 1
# Push largest partition on the stack
elif high - r > l - low:
# Right is larger
stack[n] = Partition(r + 1, high)
n += 1
high = l - 1
else:
stack[n] = Partition(low, l - 1)
n += 1
low = r + 1
insertion_sort(A, low, high)
return QuicksortImplementation(wrap,
partition, partition3, insertion_sort,
run_quicksort)
def make_py_quicksort(*args, **kwargs):
return make_quicksort_impl((lambda f: f), *args, **kwargs)
def make_jit_quicksort(*args, **kwargs):
from numba.core.extending import register_jitable
return make_quicksort_impl((lambda f: register_jitable(f)),
*args, **kwargs)

View File

@@ -0,0 +1,104 @@
import numpy as np
from numba.core.typing.typeof import typeof
from numba.core.typing.asnumbatype import as_numba_type
def pndindex(*args):
""" Provides an n-dimensional parallel iterator that generates index tuples
for each iteration point. Sequentially, pndindex is identical to np.ndindex.
"""
return np.ndindex(*args)
class prange(object):
""" Provides a 1D parallel iterator that generates a sequence of integers.
In non-parallel contexts, prange is identical to range.
"""
def __new__(cls, *args):
return range(*args)
def _gdb_python_call_gen(func_name, *args):
# generates a call to a function containing a compiled in gdb command,
# this is to make `numba.gdb*` work in the interpreter.
import numba
fn = getattr(numba, func_name)
argstr = ','.join(['"%s"' for _ in args]) % args
defn = """def _gdb_func_injection():\n\t%s(%s)\n
""" % (func_name, argstr)
l = {}
exec(defn, {func_name: fn}, l)
return numba.njit(l['_gdb_func_injection'])
def gdb(*args):
"""
Calling this function will invoke gdb and attach it to the current process
at the call site. Arguments are strings in the gdb command language syntax
which will be executed by gdb once initialisation has occurred.
"""
_gdb_python_call_gen('gdb', *args)()
def gdb_breakpoint():
"""
Calling this function will inject a breakpoint at the call site that is
recognised by both `gdb` and `gdb_init`, this is to allow breaking at
multiple points. gdb will stop in the user defined code just after the frame
employed by the breakpoint returns.
"""
_gdb_python_call_gen('gdb_breakpoint')()
def gdb_init(*args):
"""
Calling this function will invoke gdb and attach it to the current process
at the call site, then continue executing the process under gdb's control.
Arguments are strings in the gdb command language syntax which will be
executed by gdb once initialisation has occurred.
"""
_gdb_python_call_gen('gdb_init', *args)()
def literally(obj):
"""Forces Numba to interpret *obj* as an Literal value.
*obj* must be either a literal or an argument of the caller function, where
the argument must be bound to a literal. The literal requirement
propagates up the call stack.
This function is intercepted by the compiler to alter the compilation
behavior to wrap the corresponding function parameters as ``Literal``.
It has **no effect** outside of nopython-mode (interpreter, and objectmode).
The current implementation detects literal arguments in two ways:
1. Scans for uses of ``literally`` via a compiler pass.
2. ``literally`` is overloaded to raise ``numba.errors.ForceLiteralArg``
to signal the dispatcher to treat the corresponding parameter
differently. This mode is to support indirect use (via a function call).
The execution semantic of this function is equivalent to an identity
function.
See :ghfile:`numba/tests/test_literal_dispatch.py` for examples.
"""
return obj
def literal_unroll(container):
return container
__all__ = [
'typeof',
'as_numba_type',
'prange',
'pndindex',
'gdb',
'gdb_breakpoint',
'gdb_init',
'literally',
'literal_unroll',
]

View File

@@ -0,0 +1,943 @@
"""
Timsort implementation. Mostly adapted from CPython's listobject.c.
For more information, see listsort.txt in CPython's source tree.
"""
import collections
from numba.core import types
TimsortImplementation = collections.namedtuple(
'TimsortImplementation',
(# The compile function itself
'compile',
# All subroutines exercised by test_sort
'count_run', 'binarysort', 'gallop_left', 'gallop_right',
'merge_init', 'merge_append', 'merge_pop',
'merge_compute_minrun', 'merge_lo', 'merge_hi', 'merge_at',
'merge_force_collapse', 'merge_collapse',
# The top-level functions
'run_timsort', 'run_timsort_with_values'
))
# The maximum number of entries in a MergeState's pending-runs stack.
# This is enough to sort arrays of size up to about
# 32 * phi ** MAX_MERGE_PENDING
# where phi ~= 1.618. 85 is ridiculously large enough, good for an array
# with 2**64 elements.
# NOTE this implementation doesn't depend on it (the stack is dynamically
# allocated), but it's still good to check as an invariant.
MAX_MERGE_PENDING = 85
# When we get into galloping mode, we stay there until both runs win less
# often than MIN_GALLOP consecutive times. See listsort.txt for more info.
MIN_GALLOP = 7
# Start size for temp arrays.
MERGESTATE_TEMP_SIZE = 256
# A mergestate is a named tuple with the following members:
# - *min_gallop* is an integer controlling when we get into galloping mode
# - *keys* is a temp list for merging keys
# - *values* is a temp list for merging values, if needed
# - *pending* is a stack of pending runs to be merged
# - *n* is the current stack length of *pending*
MergeState = collections.namedtuple(
'MergeState', ('min_gallop', 'keys', 'values', 'pending', 'n'))
MergeRun = collections.namedtuple('MergeRun', ('start', 'size'))
def make_timsort_impl(wrap, make_temp_area):
make_temp_area = wrap(make_temp_area)
intp = types.intp
zero = intp(0)
@wrap
def has_values(keys, values):
return values is not keys
@wrap
def merge_init(keys):
"""
Initialize a MergeState for a non-keyed sort.
"""
temp_size = min(len(keys) // 2 + 1, MERGESTATE_TEMP_SIZE)
temp_keys = make_temp_area(keys, temp_size)
temp_values = temp_keys
pending = [MergeRun(zero, zero)] * MAX_MERGE_PENDING
return MergeState(intp(MIN_GALLOP), temp_keys, temp_values, pending, zero)
@wrap
def merge_init_with_values(keys, values):
"""
Initialize a MergeState for a keyed sort.
"""
temp_size = min(len(keys) // 2 + 1, MERGESTATE_TEMP_SIZE)
temp_keys = make_temp_area(keys, temp_size)
temp_values = make_temp_area(values, temp_size)
pending = [MergeRun(zero, zero)] * MAX_MERGE_PENDING
return MergeState(intp(MIN_GALLOP), temp_keys, temp_values, pending, zero)
@wrap
def merge_append(ms, run):
"""
Append a run on the merge stack.
"""
n = ms.n
assert n < MAX_MERGE_PENDING
ms.pending[n] = run
return MergeState(ms.min_gallop, ms.keys, ms.values, ms.pending, n + 1)
@wrap
def merge_pop(ms):
"""
Pop the top run from the merge stack.
"""
return MergeState(ms.min_gallop, ms.keys, ms.values, ms.pending, ms.n - 1)
@wrap
def merge_getmem(ms, need):
"""
Ensure enough temp memory for 'need' items is available.
"""
alloced = len(ms.keys)
if need <= alloced:
return ms
# Over-allocate
while alloced < need:
alloced = alloced << 1
# Don't realloc! That can cost cycles to copy the old data, but
# we don't care what's in the block.
temp_keys = make_temp_area(ms.keys, alloced)
if has_values(ms.keys, ms.values):
temp_values = make_temp_area(ms.values, alloced)
else:
temp_values = temp_keys
return MergeState(ms.min_gallop, temp_keys, temp_values, ms.pending, ms.n)
@wrap
def merge_adjust_gallop(ms, new_gallop):
"""
Modify the MergeState's min_gallop.
"""
return MergeState(intp(new_gallop), ms.keys, ms.values, ms.pending, ms.n)
@wrap
def LT(a, b):
"""
Trivial comparison function between two keys. This is factored out to
make it clear where comparisons occur.
"""
return a < b
@wrap
def binarysort(keys, values, lo, hi, start):
"""
binarysort is the best method for sorting small arrays: it does
few compares, but can do data movement quadratic in the number of
elements.
[lo, hi) is a contiguous slice of a list, and is sorted via
binary insertion. This sort is stable.
On entry, must have lo <= start <= hi, and that [lo, start) is already
sorted (pass start == lo if you don't know!).
"""
assert lo <= start and start <= hi
_has_values = has_values(keys, values)
if lo == start:
start += 1
while start < hi:
pivot = keys[start]
# Bisect to find where to insert `pivot`
# NOTE: bisection only wins over linear search if the comparison
# function is much more expensive than simply moving data.
l = lo
r = start
# Invariants:
# pivot >= all in [lo, l).
# pivot < all in [r, start).
# The second is vacuously true at the start.
while l < r:
p = l + ((r - l) >> 1)
if LT(pivot, keys[p]):
r = p
else:
l = p+1
# The invariants still hold, so pivot >= all in [lo, l) and
# pivot < all in [l, start), so pivot belongs at l. Note
# that if there are elements equal to pivot, l points to the
# first slot after them -- that's why this sort is stable.
# Slide over to make room (aka memmove()).
for p in range(start, l, -1):
keys[p] = keys[p - 1]
keys[l] = pivot
if _has_values:
pivot_val = values[start]
for p in range(start, l, -1):
values[p] = values[p - 1]
values[l] = pivot_val
start += 1
@wrap
def count_run(keys, lo, hi):
"""
Return the length of the run beginning at lo, in the slice [lo, hi).
lo < hi is required on entry. "A run" is the longest ascending sequence, with
lo[0] <= lo[1] <= lo[2] <= ...
or the longest descending sequence, with
lo[0] > lo[1] > lo[2] > ...
A tuple (length, descending) is returned, where boolean *descending*
is set to 0 in the former case, or to 1 in the latter.
For its intended use in a stable mergesort, the strictness of the defn of
"descending" is needed so that the caller can safely reverse a descending
sequence without violating stability (strict > ensures there are no equal
elements to get out of order).
"""
assert lo < hi
if lo + 1 == hi:
# Trivial 1-long run
return 1, False
if LT(keys[lo + 1], keys[lo]):
# Descending run
for k in range(lo + 2, hi):
if not LT(keys[k], keys[k - 1]):
return k - lo, True
return hi - lo, True
else:
# Ascending run
for k in range(lo + 2, hi):
if LT(keys[k], keys[k - 1]):
return k - lo, False
return hi - lo, False
@wrap
def gallop_left(key, a, start, stop, hint):
"""
Locate the proper position of key in a sorted vector; if the vector contains
an element equal to key, return the position immediately to the left of
the leftmost equal element. [gallop_right() does the same except returns
the position to the right of the rightmost equal element (if any).]
"a" is a sorted vector with stop elements, starting at a[start].
stop must be > start.
"hint" is an index at which to begin the search, start <= hint < stop.
The closer hint is to the final result, the faster this runs.
The return value is the int k in start..stop such that
a[k-1] < key <= a[k]
pretending that a[start-1] is minus infinity and a[stop] is plus infinity.
IOW, key belongs at index k; or, IOW, the first k elements of a should
precede key, and the last stop-start-k should follow key.
See listsort.txt for info on the method.
"""
assert stop > start
assert hint >= start and hint < stop
n = stop - start
# First, gallop from the hint to find a "good" subinterval for bisecting
lastofs = 0
ofs = 1
if LT(a[hint], key):
# a[hint] < key => gallop right, until
# a[hint + lastofs] < key <= a[hint + ofs]
maxofs = stop - hint
while ofs < maxofs:
if LT(a[hint + ofs], key):
lastofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
# Int overflow
ofs = maxofs
else:
# key <= a[hint + ofs]
break
if ofs > maxofs:
ofs = maxofs
# Translate back to offsets relative to a[0]
lastofs += hint
ofs += hint
else:
# key <= a[hint] => gallop left, until
# a[hint - ofs] < key <= a[hint - lastofs]
maxofs = hint - start + 1
while ofs < maxofs:
if LT(a[hint - ofs], key):
break
else:
# key <= a[hint - ofs]
lastofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
# Int overflow
ofs = maxofs
if ofs > maxofs:
ofs = maxofs
# Translate back to positive offsets relative to a[0]
lastofs, ofs = hint - ofs, hint - lastofs
assert start - 1 <= lastofs and lastofs < ofs and ofs <= stop
# Now a[lastofs] < key <= a[ofs], so key belongs somewhere to the
# right of lastofs but no farther right than ofs. Do a binary
# search, with invariant a[lastofs-1] < key <= a[ofs].
lastofs += 1
while lastofs < ofs:
m = lastofs + ((ofs - lastofs) >> 1)
if LT(a[m], key):
# a[m] < key
lastofs = m + 1
else:
# key <= a[m]
ofs = m
# Now lastofs == ofs, so a[ofs - 1] < key <= a[ofs]
return ofs
@wrap
def gallop_right(key, a, start, stop, hint):
"""
Exactly like gallop_left(), except that if key already exists in a[start:stop],
finds the position immediately to the right of the rightmost equal value.
The return value is the int k in start..stop such that
a[k-1] <= key < a[k]
The code duplication is massive, but this is enough different given that
we're sticking to "<" comparisons that it's much harder to follow if
written as one routine with yet another "left or right?" flag.
"""
assert stop > start
assert hint >= start and hint < stop
n = stop - start
# First, gallop from the hint to find a "good" subinterval for bisecting
lastofs = 0
ofs = 1
if LT(key, a[hint]):
# key < a[hint] => gallop left, until
# a[hint - ofs] <= key < a[hint - lastofs]
maxofs = hint - start + 1
while ofs < maxofs:
if LT(key, a[hint - ofs]):
lastofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
# Int overflow
ofs = maxofs
else:
# a[hint - ofs] <= key
break
if ofs > maxofs:
ofs = maxofs
# Translate back to positive offsets relative to a[0]
lastofs, ofs = hint - ofs, hint - lastofs
else:
# a[hint] <= key -- gallop right, until
# a[hint + lastofs] <= key < a[hint + ofs]
maxofs = stop - hint
while ofs < maxofs:
if LT(key, a[hint + ofs]):
break
else:
# a[hint + ofs] <= key
lastofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
# Int overflow
ofs = maxofs
if ofs > maxofs:
ofs = maxofs
# Translate back to offsets relative to a[0]
lastofs += hint
ofs += hint
assert start - 1 <= lastofs and lastofs < ofs and ofs <= stop
# Now a[lastofs] <= key < a[ofs], so key belongs somewhere to the
# right of lastofs but no farther right than ofs. Do a binary
# search, with invariant a[lastofs-1] <= key < a[ofs].
lastofs += 1
while lastofs < ofs:
m = lastofs + ((ofs - lastofs) >> 1)
if LT(key, a[m]):
# key < a[m]
ofs = m
else:
# a[m] <= key
lastofs = m + 1
# Now lastofs == ofs, so a[ofs - 1] <= key < a[ofs]
return ofs
@wrap
def merge_compute_minrun(n):
"""
Compute a good value for the minimum run length; natural runs shorter
than this are boosted artificially via binary insertion.
If n < 64, return n (it's too small to bother with fancy stuff).
Else if n is an exact power of 2, return 32.
Else return an int k, 32 <= k <= 64, such that n/k is close to, but
strictly less than, an exact power of 2.
See listsort.txt for more info.
"""
r = 0
assert n >= 0
while n >= 64:
r |= n & 1
n >>= 1
return n + r
@wrap
def sortslice_copy(dest_keys, dest_values, dest_start,
src_keys, src_values, src_start,
nitems):
"""
Upwards memcpy().
"""
assert src_start >= 0
assert dest_start >= 0
for i in range(nitems):
dest_keys[dest_start + i] = src_keys[src_start + i]
if has_values(src_keys, src_values):
for i in range(nitems):
dest_values[dest_start + i] = src_values[src_start + i]
@wrap
def sortslice_copy_down(dest_keys, dest_values, dest_start,
src_keys, src_values, src_start,
nitems):
"""
Downwards memcpy().
"""
assert src_start >= 0
assert dest_start >= 0
for i in range(nitems):
dest_keys[dest_start - i] = src_keys[src_start - i]
if has_values(src_keys, src_values):
for i in range(nitems):
dest_values[dest_start - i] = src_values[src_start - i]
# Disable this for debug or perf comparison
DO_GALLOP = 1
@wrap
def merge_lo(ms, keys, values, ssa, na, ssb, nb):
"""
Merge the na elements starting at ssa with the nb elements starting at
ssb = ssa + na in a stable way, in-place. na and nb must be > 0,
and should have na <= nb. See listsort.txt for more info.
An updated MergeState is returned (with possibly a different min_gallop
or larger temp arrays).
NOTE: compared to CPython's timsort, the requirement that
"Must also have that keys[ssa + na - 1] belongs at the end of the merge"
is removed. This makes the code a bit simpler and easier to reason about.
"""
assert na > 0 and nb > 0 and na <= nb
assert ssb == ssa + na
# First copy [ssa, ssa + na) into the temp space
ms = merge_getmem(ms, na)
sortslice_copy(ms.keys, ms.values, 0,
keys, values, ssa,
na)
a_keys = ms.keys
a_values = ms.values
b_keys = keys
b_values = values
dest = ssa
ssa = 0
_has_values = has_values(a_keys, a_values)
min_gallop = ms.min_gallop
# Now start merging into the space left from [ssa, ...)
while nb > 0 and na > 0:
# Do the straightforward thing until (if ever) one run
# appears to win consistently.
acount = 0
bcount = 0
while True:
if LT(b_keys[ssb], a_keys[ssa]):
keys[dest] = b_keys[ssb]
if _has_values:
values[dest] = b_values[ssb]
dest += 1
ssb += 1
nb -= 1
if nb == 0:
break
# It's a B run
bcount += 1
acount = 0
if bcount >= min_gallop:
break
else:
keys[dest] = a_keys[ssa]
if _has_values:
values[dest] = a_values[ssa]
dest += 1
ssa += 1
na -= 1
if na == 0:
break
# It's a A run
acount += 1
bcount = 0
if acount >= min_gallop:
break
# One run is winning so consistently that galloping may
# be a huge win. So try that, and continue galloping until
# (if ever) neither run appears to be winning consistently
# anymore.
if DO_GALLOP and na > 0 and nb > 0:
min_gallop += 1
while acount >= MIN_GALLOP or bcount >= MIN_GALLOP:
# As long as we gallop without leaving this loop, make
# the heuristic more likely
min_gallop -= min_gallop > 1
# Gallop in A to find where keys[ssb] should end up
k = gallop_right(b_keys[ssb], a_keys, ssa, ssa + na, ssa)
# k is an index, make it a size
k -= ssa
acount = k
if k > 0:
# Copy everything from A before k
sortslice_copy(keys, values, dest,
a_keys, a_values, ssa,
k)
dest += k
ssa += k
na -= k
if na == 0:
# Finished merging
break
# Copy keys[ssb]
keys[dest] = b_keys[ssb]
if _has_values:
values[dest] = b_values[ssb]
dest += 1
ssb += 1
nb -= 1
if nb == 0:
# Finished merging
break
# Gallop in B to find where keys[ssa] should end up
k = gallop_left(a_keys[ssa], b_keys, ssb, ssb + nb, ssb)
# k is an index, make it a size
k -= ssb
bcount = k
if k > 0:
# Copy everything from B before k
# NOTE: source and dest are the same buffer, but the
# destination index is below the source index
sortslice_copy(keys, values, dest,
b_keys, b_values, ssb,
k)
dest += k
ssb += k
nb -= k
if nb == 0:
# Finished merging
break
# Copy keys[ssa]
keys[dest] = a_keys[ssa]
if _has_values:
values[dest] = a_values[ssa]
dest += 1
ssa += 1
na -= 1
if na == 0:
# Finished merging
break
# Penalize it for leaving galloping mode
min_gallop += 1
# Merge finished, now handle the remaining areas
if nb == 0:
# Only A remaining to copy at the end of the destination area
sortslice_copy(keys, values, dest,
a_keys, a_values, ssa,
na)
else:
assert na == 0
assert dest == ssb
# B's tail is already at the right place, do nothing
return merge_adjust_gallop(ms, min_gallop)
@wrap
def merge_hi(ms, keys, values, ssa, na, ssb, nb):
"""
Merge the na elements starting at ssa with the nb elements starting at
ssb = ssa + na in a stable way, in-place. na and nb must be > 0,
and should have na >= nb. See listsort.txt for more info.
An updated MergeState is returned (with possibly a different min_gallop
or larger temp arrays).
NOTE: compared to CPython's timsort, the requirement that
"Must also have that keys[ssa + na - 1] belongs at the end of the merge"
is removed. This makes the code a bit simpler and easier to reason about.
"""
assert na > 0 and nb > 0 and na >= nb
assert ssb == ssa + na
# First copy [ssb, ssb + nb) into the temp space
ms = merge_getmem(ms, nb)
sortslice_copy(ms.keys, ms.values, 0,
keys, values, ssb,
nb)
a_keys = keys
a_values = values
b_keys = ms.keys
b_values = ms.values
# Now start merging *in descending order* into the space left
# from [..., ssb + nb).
dest = ssb + nb - 1
ssb = nb - 1
ssa = ssa + na - 1
_has_values = has_values(b_keys, b_values)
min_gallop = ms.min_gallop
while nb > 0 and na > 0:
# Do the straightforward thing until (if ever) one run
# appears to win consistently.
acount = 0
bcount = 0
while True:
if LT(b_keys[ssb], a_keys[ssa]):
# We merge in descending order, so copy the larger value
keys[dest] = a_keys[ssa]
if _has_values:
values[dest] = a_values[ssa]
dest -= 1
ssa -= 1
na -= 1
if na == 0:
break
# It's a A run
acount += 1
bcount = 0
if acount >= min_gallop:
break
else:
keys[dest] = b_keys[ssb]
if _has_values:
values[dest] = b_values[ssb]
dest -= 1
ssb -= 1
nb -= 1
if nb == 0:
break
# It's a B run
bcount += 1
acount = 0
if bcount >= min_gallop:
break
# One run is winning so consistently that galloping may
# be a huge win. So try that, and continue galloping until
# (if ever) neither run appears to be winning consistently
# anymore.
if DO_GALLOP and na > 0 and nb > 0:
min_gallop += 1
while acount >= MIN_GALLOP or bcount >= MIN_GALLOP:
# As long as we gallop without leaving this loop, make
# the heuristic more likely
min_gallop -= min_gallop > 1
# Gallop in A to find where keys[ssb] should end up
k = gallop_right(b_keys[ssb], a_keys, ssa - na + 1, ssa + 1, ssa)
# k is an index, make it a size from the end
k = ssa + 1 - k
acount = k
if k > 0:
# Copy everything from A after k.
# Destination and source are the same buffer, and destination
# index is greater, so copy from the end to the start.
sortslice_copy_down(keys, values, dest,
a_keys, a_values, ssa,
k)
dest -= k
ssa -= k
na -= k
if na == 0:
# Finished merging
break
# Copy keys[ssb]
keys[dest] = b_keys[ssb]
if _has_values:
values[dest] = b_values[ssb]
dest -= 1
ssb -= 1
nb -= 1
if nb == 0:
# Finished merging
break
# Gallop in B to find where keys[ssa] should end up
k = gallop_left(a_keys[ssa], b_keys, ssb - nb + 1, ssb + 1, ssb)
# k is an index, make it a size from the end
k = ssb + 1 - k
bcount = k
if k > 0:
# Copy everything from B before k
sortslice_copy_down(keys, values, dest,
b_keys, b_values, ssb,
k)
dest -= k
ssb -= k
nb -= k
if nb == 0:
# Finished merging
break
# Copy keys[ssa]
keys[dest] = a_keys[ssa]
if _has_values:
values[dest] = a_values[ssa]
dest -= 1
ssa -= 1
na -= 1
if na == 0:
# Finished merging
break
# Penalize it for leaving galloping mode
min_gallop += 1
# Merge finished, now handle the remaining areas
if na == 0:
# Only B remaining to copy at the front of the destination area
sortslice_copy(keys, values, dest - nb + 1,
b_keys, b_values, ssb - nb + 1,
nb)
else:
assert nb == 0
assert dest == ssa
# A's front is already at the right place, do nothing
return merge_adjust_gallop(ms, min_gallop)
@wrap
def merge_at(ms, keys, values, i):
"""
Merge the two runs at stack indices i and i+1.
An updated MergeState is returned.
"""
n = ms.n
assert n >= 2
assert i >= 0
assert i == n - 2 or i == n - 3
ssa, na = ms.pending[i]
ssb, nb = ms.pending[i + 1]
assert na > 0 and nb > 0
assert ssa + na == ssb
# Record the length of the combined runs; if i is the 3rd-last
# run now, also slide over the last run (which isn't involved
# in this merge). The current run i+1 goes away in any case.
ms.pending[i] = MergeRun(ssa, na + nb)
if i == n - 3:
ms.pending[i + 1] = ms.pending[i + 2]
ms = merge_pop(ms)
# Where does b start in a? Elements in a before that can be
# ignored (already in place).
k = gallop_right(keys[ssb], keys, ssa, ssa + na, ssa)
# [k, ssa + na) remains to be merged
na -= k - ssa
ssa = k
if na == 0:
return ms
# Where does a end in b? Elements in b after that can be
# ignored (already in place).
k = gallop_left(keys[ssa + na - 1], keys, ssb, ssb + nb, ssb + nb - 1)
# [ssb, k) remains to be merged
nb = k - ssb
# Merge what remains of the runs, using a temp array with
# min(na, nb) elements.
if na <= nb:
return merge_lo(ms, keys, values, ssa, na, ssb, nb)
else:
return merge_hi(ms, keys, values, ssa, na, ssb, nb)
@wrap
def merge_collapse(ms, keys, values):
"""
Examine the stack of runs waiting to be merged, merging adjacent runs
until the stack invariants are re-established:
1. len[-3] > len[-2] + len[-1]
2. len[-2] > len[-1]
An updated MergeState is returned.
See listsort.txt for more info.
"""
while ms.n > 1:
pending = ms.pending
n = ms.n - 2
if ((n > 0 and pending[n-1].size <= pending[n].size + pending[n+1].size) or
(n > 1 and pending[n-2].size <= pending[n-1].size + pending[n].size)):
if pending[n - 1].size < pending[n + 1].size:
# Merge smaller one first
n -= 1
ms = merge_at(ms, keys, values, n)
elif pending[n].size < pending[n + 1].size:
ms = merge_at(ms, keys, values, n)
else:
break
return ms
@wrap
def merge_force_collapse(ms, keys, values):
"""
Regardless of invariants, merge all runs on the stack until only one
remains. This is used at the end of the mergesort.
An updated MergeState is returned.
"""
while ms.n > 1:
pending = ms.pending
n = ms.n - 2
if n > 0:
if pending[n - 1].size < pending[n + 1].size:
# Merge the smaller one first
n -= 1
ms = merge_at(ms, keys, values, n)
return ms
@wrap
def reverse_slice(keys, values, start, stop):
"""
Reverse a slice, in-place.
"""
i = start
j = stop - 1
while i < j:
keys[i], keys[j] = keys[j], keys[i]
i += 1
j -= 1
if has_values(keys, values):
i = start
j = stop - 1
while i < j:
values[i], values[j] = values[j], values[i]
i += 1
j -= 1
@wrap
def run_timsort_with_mergestate(ms, keys, values):
"""
Run timsort with the mergestate.
"""
nremaining = len(keys)
if nremaining < 2:
return
# March over the array once, left to right, finding natural runs,
# and extending short natural runs to minrun elements.
minrun = merge_compute_minrun(nremaining)
lo = zero
while nremaining > 0:
n, desc = count_run(keys, lo, lo + nremaining)
if desc:
# Descending run => reverse
reverse_slice(keys, values, lo, lo + n)
# If short, extend to min(minrun, nremaining)
if n < minrun:
force = min(minrun, nremaining)
binarysort(keys, values, lo, lo + force, lo + n)
n = force
# Push run onto stack, and maybe merge.
ms = merge_append(ms, MergeRun(lo, n))
ms = merge_collapse(ms, keys, values)
# Advance to find next run.
lo += n
nremaining -= n
# All initial runs have been discovered, now finish merging.
ms = merge_force_collapse(ms, keys, values)
assert ms.n == 1
assert ms.pending[0] == (0, len(keys))
@wrap
def run_timsort(keys):
"""
Run timsort over the given keys.
"""
values = keys
run_timsort_with_mergestate(merge_init(keys), keys, values)
@wrap
def run_timsort_with_values(keys, values):
"""
Run timsort over the given keys and values.
"""
run_timsort_with_mergestate(merge_init_with_values(keys, values),
keys, values)
return TimsortImplementation(
wrap,
count_run, binarysort, gallop_left, gallop_right,
merge_init, merge_append, merge_pop,
merge_compute_minrun, merge_lo, merge_hi, merge_at,
merge_force_collapse, merge_collapse,
run_timsort, run_timsort_with_values)
def make_py_timsort(*args):
return make_timsort_impl((lambda f: f), *args)
def make_jit_timsort(*args):
from numba import jit
return make_timsort_impl((lambda f: jit(nopython=True)(f)),
*args)