Videre
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
""" Numba's POWER ON SELF TEST script. Used by CI to check:
|
||||
0. That Numba imports ok!
|
||||
1. That Numba can find an appropriate number of its own tests to run.
|
||||
2. That Numba can manage to correctly compile and execute at least one thing.
|
||||
"""
|
||||
from numba.tests import test_runtests
|
||||
from numba import njit
|
||||
|
||||
|
||||
def _check_runtests():
|
||||
test_inst = test_runtests.TestCase()
|
||||
test_inst.test_default() # will raise an exception if there is a problem
|
||||
|
||||
|
||||
def _check_cpu_compilation():
|
||||
@njit
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
result = foo(1)
|
||||
|
||||
if result != 2:
|
||||
msg = ("Unexpected result from trial compilation. "
|
||||
f"Expected: 2, Got: {result}.")
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def check():
|
||||
_check_runtests()
|
||||
_check_cpu_compilation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,551 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2005-2010 ActiveState Software Inc.
|
||||
# Copyright (c) 2013 Eddy Petrișor
|
||||
|
||||
"""Utilities for determining application-specific dirs.
|
||||
|
||||
See <http://github.com/ActiveState/appdirs> for details and usage.
|
||||
"""
|
||||
# Dev Notes:
|
||||
# - MSDN on where to store app data files:
|
||||
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
|
||||
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
|
||||
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
|
||||
|
||||
__version_info__ = (1, 4, 1)
|
||||
__version__ = '.'.join(map(str, __version_info__))
|
||||
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
unicode = str
|
||||
|
||||
if sys.platform.startswith('java'):
|
||||
import platform
|
||||
os_name = platform.java_ver()[3][0]
|
||||
if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc.
|
||||
system = 'win32'
|
||||
elif os_name.startswith('Mac'): # "Mac OS X", etc.
|
||||
system = 'darwin'
|
||||
else: # "Linux", "SunOS", "FreeBSD", etc.
|
||||
# Setting this to "linux2" is not ideal, but only Windows or Mac
|
||||
# are actually checked for and the rest of the module expects
|
||||
# *sys.platform* style strings.
|
||||
system = 'linux2'
|
||||
else:
|
||||
system = sys.platform
|
||||
|
||||
|
||||
|
||||
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user data directories are:
|
||||
Mac OS X: ~/Library/Application Support/<AppName>
|
||||
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
|
||||
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
|
||||
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
|
||||
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
|
||||
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
|
||||
|
||||
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
|
||||
That means, by default "~/.local/share/<AppName>".
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
|
||||
path = os.path.normpath(_get_win_folder(const))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('~/Library/Application Support/')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"multipath" is an optional parameter only applicable to *nix
|
||||
which indicates that the entire list of data dirs should be
|
||||
returned. By default, the first item from XDG_DATA_DIRS is
|
||||
returned, or '/usr/local/share/<AppName>',
|
||||
if XDG_DATA_DIRS is not set
|
||||
|
||||
Typical user data directories are:
|
||||
Mac OS X: /Library/Application Support/<AppName>
|
||||
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
|
||||
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
|
||||
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
||||
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
|
||||
|
||||
For Unix, this is using the $XDG_DATA_DIRS[0] default.
|
||||
|
||||
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('/Library/Application Support')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
# XDG default for $XDG_DATA_DIRS
|
||||
# only first, if multipath is False
|
||||
path = os.getenv('XDG_DATA_DIRS',
|
||||
os.pathsep.join(['/usr/local/share', '/usr/share']))
|
||||
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
|
||||
if appname:
|
||||
if version:
|
||||
appname = os.path.join(appname, version)
|
||||
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
||||
|
||||
if multipath:
|
||||
path = os.pathsep.join(pathlist)
|
||||
else:
|
||||
path = pathlist[0]
|
||||
return path
|
||||
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific config dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user data directories are:
|
||||
Mac OS X: same as user_data_dir
|
||||
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
|
||||
Win *: same as user_data_dir
|
||||
|
||||
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
|
||||
That means, by default "~/.config/<AppName>".
|
||||
"""
|
||||
if system in ["win32", "darwin"]:
|
||||
path = user_data_dir(appname, appauthor, None, roaming)
|
||||
else:
|
||||
path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"multipath" is an optional parameter only applicable to *nix
|
||||
which indicates that the entire list of config dirs should be
|
||||
returned. By default, the first item from XDG_CONFIG_DIRS is
|
||||
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
|
||||
|
||||
Typical user data directories are:
|
||||
Mac OS X: same as site_data_dir
|
||||
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
|
||||
$XDG_CONFIG_DIRS
|
||||
Win *: same as site_data_dir
|
||||
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
||||
|
||||
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
|
||||
|
||||
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
||||
"""
|
||||
if system in ["win32", "darwin"]:
|
||||
path = site_data_dir(appname, appauthor)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
else:
|
||||
# XDG default for $XDG_CONFIG_DIRS
|
||||
# only first, if multipath is False
|
||||
path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg')
|
||||
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
|
||||
if appname:
|
||||
if version:
|
||||
appname = os.path.join(appname, version)
|
||||
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
||||
|
||||
if multipath:
|
||||
path = os.pathsep.join(pathlist)
|
||||
else:
|
||||
path = pathlist[0]
|
||||
return path
|
||||
|
||||
|
||||
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
|
||||
r"""Return full path to the user-specific cache dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"opinion" (boolean) can be False to disable the appending of
|
||||
"Cache" to the base app data dir for Windows. See
|
||||
discussion below.
|
||||
|
||||
Typical user cache directories are:
|
||||
Mac OS X: ~/Library/Caches/<AppName>
|
||||
Unix: ~/.cache/<AppName> (XDG default)
|
||||
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
|
||||
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
|
||||
|
||||
On Windows the only suggestion in the MSDN docs is that local settings go in
|
||||
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
|
||||
app data dir (the default returned by `user_data_dir` above). Apps typically
|
||||
put cache data somewhere *under* the given dir here. Some examples:
|
||||
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
|
||||
...\Acme\SuperApp\Cache\1.0
|
||||
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
|
||||
This can be disabled with the `opinion=False` option.
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
if opinion:
|
||||
path = os.path.join(path, "Cache")
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('~/Library/Caches')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
|
||||
r"""Return full path to the user-specific log dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"opinion" (boolean) can be False to disable the appending of
|
||||
"Logs" to the base app data dir for Windows, and "log" to the
|
||||
base cache dir for Unix. See discussion below.
|
||||
|
||||
Typical user cache directories are:
|
||||
Mac OS X: ~/Library/Logs/<AppName>
|
||||
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
|
||||
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
|
||||
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
|
||||
|
||||
On Windows the only suggestion in the MSDN docs is that local settings
|
||||
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
|
||||
examples of what some windows apps use for a logs dir.)
|
||||
|
||||
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
|
||||
value for Windows and appends "log" to the user cache dir for Unix.
|
||||
This can be disabled with the `opinion=False` option.
|
||||
"""
|
||||
if system == "darwin":
|
||||
path = os.path.join(
|
||||
os.path.expanduser('~/Library/Logs'),
|
||||
appname)
|
||||
elif system == "win32":
|
||||
path = user_data_dir(appname, appauthor, version)
|
||||
version = False
|
||||
if opinion:
|
||||
path = os.path.join(path, "Logs")
|
||||
else:
|
||||
path = user_cache_dir(appname, appauthor, version)
|
||||
version = False
|
||||
if opinion:
|
||||
path = os.path.join(path, "log")
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
class AppDirs(object):
|
||||
"""Convenience wrapper for getting application dirs."""
|
||||
def __init__(self, appname, appauthor=None, version=None, roaming=False,
|
||||
multipath=False):
|
||||
self.appname = appname
|
||||
self.appauthor = appauthor
|
||||
self.version = version
|
||||
self.roaming = roaming
|
||||
self.multipath = multipath
|
||||
|
||||
@property
|
||||
def user_data_dir(self):
|
||||
return user_data_dir(self.appname, self.appauthor,
|
||||
version=self.version, roaming=self.roaming)
|
||||
|
||||
@property
|
||||
def site_data_dir(self):
|
||||
return site_data_dir(self.appname, self.appauthor,
|
||||
version=self.version, multipath=self.multipath)
|
||||
|
||||
@property
|
||||
def user_config_dir(self):
|
||||
return user_config_dir(self.appname, self.appauthor,
|
||||
version=self.version, roaming=self.roaming)
|
||||
|
||||
@property
|
||||
def site_config_dir(self):
|
||||
return site_config_dir(self.appname, self.appauthor,
|
||||
version=self.version, multipath=self.multipath)
|
||||
|
||||
@property
|
||||
def user_cache_dir(self):
|
||||
return user_cache_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
@property
|
||||
def user_log_dir(self):
|
||||
return user_log_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
|
||||
#---- internal support stuff
|
||||
|
||||
def _get_win_folder_from_registry(csidl_name):
|
||||
"""This is a fallback technique at best. I'm not sure if using the
|
||||
registry for this guarantees us the correct answer for all CSIDL_*
|
||||
names.
|
||||
"""
|
||||
import winreg as _winreg
|
||||
|
||||
shell_folder_name = {
|
||||
"CSIDL_APPDATA": "AppData",
|
||||
"CSIDL_COMMON_APPDATA": "Common AppData",
|
||||
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
||||
}[csidl_name]
|
||||
|
||||
key = _winreg.OpenKey(
|
||||
_winreg.HKEY_CURRENT_USER,
|
||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
)
|
||||
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
||||
return dir
|
||||
|
||||
|
||||
def _get_win_folder_with_pywin32(csidl_name):
|
||||
from win32com.shell import shellcon, shell
|
||||
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
|
||||
# Try to make this a unicode path because SHGetFolderPath does
|
||||
# not return unicode strings when there is unicode data in the
|
||||
# path.
|
||||
try:
|
||||
dir = unicode(dir)
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in dir:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
try:
|
||||
import win32api
|
||||
dir = win32api.GetShortPathName(dir)
|
||||
except ImportError:
|
||||
pass
|
||||
except UnicodeError:
|
||||
pass
|
||||
return dir
|
||||
|
||||
|
||||
def _get_win_folder_with_ctypes(csidl_name):
|
||||
import ctypes
|
||||
|
||||
csidl_const = {
|
||||
"CSIDL_APPDATA": 26,
|
||||
"CSIDL_COMMON_APPDATA": 35,
|
||||
"CSIDL_LOCAL_APPDATA": 28,
|
||||
}[csidl_name]
|
||||
|
||||
buf = ctypes.create_unicode_buffer(1024)
|
||||
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in buf:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
buf2 = ctypes.create_unicode_buffer(1024)
|
||||
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
|
||||
buf = buf2
|
||||
|
||||
return buf.value
|
||||
|
||||
def _get_win_folder_with_jna(csidl_name):
|
||||
import array
|
||||
from com.sun import jna
|
||||
from com.sun.jna.platform import win32
|
||||
|
||||
buf_size = win32.WinDef.MAX_PATH * 2
|
||||
buf = array.zeros('c', buf_size)
|
||||
shell = win32.Shell32.INSTANCE
|
||||
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in dir:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
buf = array.zeros('c', buf_size)
|
||||
kernel = win32.Kernel32.INSTANCE
|
||||
if kernel.GetShortPathName(dir, buf, buf_size):
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
return dir
|
||||
|
||||
if system == "win32":
|
||||
try:
|
||||
import win32com.shell
|
||||
_get_win_folder = _get_win_folder_with_pywin32
|
||||
except ImportError:
|
||||
try:
|
||||
from ctypes import windll
|
||||
_get_win_folder = _get_win_folder_with_ctypes
|
||||
except ImportError:
|
||||
try:
|
||||
import com.sun.jna
|
||||
_get_win_folder = _get_win_folder_with_jna
|
||||
except ImportError:
|
||||
_get_win_folder = _get_win_folder_from_registry
|
||||
|
||||
|
||||
#---- self test code
|
||||
|
||||
if __name__ == "__main__":
|
||||
appname = "MyApp"
|
||||
appauthor = "MyCompany"
|
||||
|
||||
props = ("user_data_dir", "site_data_dir",
|
||||
"user_config_dir", "site_config_dir",
|
||||
"user_cache_dir", "user_log_dir")
|
||||
|
||||
print("-- app dirs %s --" % __version__)
|
||||
|
||||
print("-- app dirs (with optional 'version')")
|
||||
dirs = AppDirs(appname, appauthor, version="1.0")
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (without optional 'version')")
|
||||
dirs = AppDirs(appname, appauthor)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (without optional 'appauthor')")
|
||||
dirs = AppDirs(appname)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (with disabled 'appauthor')")
|
||||
dirs = AppDirs(appname, appauthor=False)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Implementation of some CFFI functions
|
||||
"""
|
||||
|
||||
|
||||
from numba.core.imputils import Registry
|
||||
from numba.core import types
|
||||
from numba.np import arrayobj
|
||||
|
||||
registry = Registry('cffiimpl')
|
||||
|
||||
@registry.lower('ffi.from_buffer', types.Buffer)
|
||||
def from_buffer(context, builder, sig, args):
|
||||
assert len(sig.args) == 1
|
||||
assert len(args) == 1
|
||||
[fromty] = sig.args
|
||||
[val] = args
|
||||
# Type inference should have prevented passing a buffer from an
|
||||
# array to a pointer of the wrong type
|
||||
assert fromty.dtype == sig.return_type.dtype
|
||||
ary = arrayobj.make_array(fromty)(context, builder, val)
|
||||
return ary.data
|
||||
@@ -0,0 +1,5 @@
|
||||
# this file is used with the numba.gdb* functionality
|
||||
break numba_gdb_breakpoint
|
||||
commands
|
||||
return
|
||||
end
|
||||
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Implement code coverage support.
|
||||
|
||||
Currently contains logic to extend ``coverage`` with lines covered by the
|
||||
compiler.
|
||||
"""
|
||||
|
||||
from typing import Optional, Sequence, Callable, no_type_check
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from numba.core import ir, config
|
||||
|
||||
|
||||
try:
|
||||
import coverage
|
||||
except ImportError:
|
||||
coverage_available = False
|
||||
else:
|
||||
coverage_available = True
|
||||
|
||||
|
||||
@no_type_check
|
||||
def get_active_coverage():
|
||||
"""Get active coverage instance or return None if not found.
|
||||
"""
|
||||
cov = None
|
||||
if coverage_available:
|
||||
cov = coverage.Coverage.current()
|
||||
return cov
|
||||
|
||||
|
||||
_the_registry: Callable[[], Optional["NotifyLocBase"]] = []
|
||||
|
||||
|
||||
def get_registered_loc_notify() -> Sequence["NotifyLocBase"]:
|
||||
"""
|
||||
Returns a list of the registered NotifyLocBase instances.
|
||||
"""
|
||||
if not config.JIT_COVERAGE:
|
||||
# Coverage disabled.
|
||||
return []
|
||||
return list(filter(lambda x: x is not None,
|
||||
(factory() for factory in _the_registry)))
|
||||
|
||||
|
||||
class NotifyLocBase(ABC):
|
||||
"""Interface for notifying visiting of a ``numba.core.ir.Loc``."""
|
||||
|
||||
@abstractmethod
|
||||
def notify(self, loc: ir.Loc) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class NotifyCompilerCoverage(NotifyLocBase):
|
||||
"""
|
||||
Use to notify ``coverage`` about compiled lines.
|
||||
|
||||
The compiled lines are under the "numba_compiled" context in the coverage
|
||||
data.
|
||||
"""
|
||||
|
||||
def __init__(self, collector):
|
||||
self._collector = collector
|
||||
|
||||
# see https://github.com/nedbat/coveragepy/blob/e7c05fe91ee36c0c94e144bb88d25db4fc3d02fd/coverage/collector.py#L261 # noqa E501
|
||||
tracer_kwargs = collector.core.tracer_kwargs.copy()
|
||||
tracer_kwargs.update(
|
||||
dict(
|
||||
data=collector.data,
|
||||
lock_data=collector.lock_data,
|
||||
unlock_data=collector.unlock_data,
|
||||
trace_arcs=collector.branch,
|
||||
should_trace=collector.should_trace,
|
||||
should_trace_cache=collector.should_trace_cache,
|
||||
warn=collector.warn,
|
||||
should_start_context=collector.should_start_context,
|
||||
switch_context=collector.switch_context,
|
||||
packed_arcs=collector.core.packed_arcs,
|
||||
)
|
||||
)
|
||||
self._tracer = NumbaTracer(**tracer_kwargs)
|
||||
collector.tracers.append(self._tracer)
|
||||
|
||||
def notify(self, loc: ir.Loc):
|
||||
tracer = self._tracer
|
||||
if loc.filename.endswith(".py"):
|
||||
tracer.switch_context("numba_compiled")
|
||||
tracer.trace(loc)
|
||||
tracer.switch_context(None)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
@_the_registry.append
|
||||
def _register_coverage_notifier():
|
||||
cov = get_active_coverage()
|
||||
if cov is not None:
|
||||
col = cov._collector
|
||||
# Is coverage started?
|
||||
if col.tracers:
|
||||
return NotifyCompilerCoverage(col)
|
||||
|
||||
|
||||
if coverage_available:
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class NumbaTracer(coverage.types.Tracer):
|
||||
"""
|
||||
Not actually a tracer as in the coverage implementation, which will
|
||||
setup a Python trace function. This implementation pretends to trace
|
||||
but instead receives fake trace events for each line the compiler has
|
||||
visited.
|
||||
|
||||
See coverage.PyTracer
|
||||
"""
|
||||
|
||||
data: coverage.types.TTraceData
|
||||
trace_arcs: bool
|
||||
should_trace: coverage.types.TShouldTraceFn
|
||||
should_trace_cache: Mapping[
|
||||
str, coverage.types.TFileDisposition | None
|
||||
]
|
||||
should_start_context: coverage.types.TShouldStartContextFn | None
|
||||
switch_context: Callable[[str | None], None] | None
|
||||
lock_data: Callable[[], None]
|
||||
unlock_data: Callable[[], None]
|
||||
warn: coverage.types.TWarnFn
|
||||
packed_arcs: bool
|
||||
|
||||
_activity: bool = field(default=False)
|
||||
|
||||
def start(self) -> coverage.types.TTraceFn | None:
|
||||
"""Start this tracer, return a trace function if based on
|
||||
sys.settrace."""
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop this tracer."""
|
||||
return None
|
||||
|
||||
def activity(self) -> bool:
|
||||
"""Has there been any activity?"""
|
||||
return self._activity
|
||||
|
||||
def reset_activity(self) -> None:
|
||||
"""Reset the activity() flag."""
|
||||
self._activity = False
|
||||
|
||||
def get_stats(self) -> dict[str, int] | None:
|
||||
"""Return a dictionary of statistics, or None."""
|
||||
return None
|
||||
|
||||
def trace(self, loc: ir.Loc) -> None:
|
||||
"""Insert coverage data given source location.
|
||||
"""
|
||||
# Check whether the file should be traced
|
||||
disp = self.should_trace_cache.get(loc.filename)
|
||||
if disp is None:
|
||||
disp = self.should_trace(loc.filename, None)
|
||||
self.should_trace_cache[loc.filename] = disp
|
||||
if not disp.trace:
|
||||
# Bail if not tracing the file
|
||||
return
|
||||
# Insert trace data
|
||||
tracename = disp.source_filename
|
||||
self.lock_data()
|
||||
cur_file_data = self.data.setdefault(tracename, set())
|
||||
if self.trace_arcs:
|
||||
if self.packed_arcs:
|
||||
cur_file_data.add(_pack_arcs(loc.line, loc.line))
|
||||
else:
|
||||
cur_file_data.add((loc.line, loc.line))
|
||||
else:
|
||||
cur_file_data.add(loc.line)
|
||||
self.unlock_data()
|
||||
# Mark activity for this tracer
|
||||
self._activity = True
|
||||
|
||||
def _pack_arcs(l1: int, l2: int) -> int:
|
||||
"""Pack arcs into a single integer for compatibility with .packed_arcs
|
||||
option.
|
||||
|
||||
See
|
||||
https://github.com/nedbat/coveragepy/blob/e7c05fe91ee36c0c94e144bb88d25db4fc3d02fd/coverage/ctracer/tracer.c#L171
|
||||
"""
|
||||
packed = 0
|
||||
if l1 < 0:
|
||||
packed |= 1 << 40
|
||||
l1 = -l1
|
||||
if l2 < 0:
|
||||
packed |= 1 << 41
|
||||
l2 = -l2
|
||||
packed |= (l2 << 20) + l1
|
||||
return packed
|
||||
@@ -0,0 +1,84 @@
|
||||
try:
|
||||
from pygments.styles.default import DefaultStyle
|
||||
except ImportError:
|
||||
msg = "Please install pygments to see highlighted dumps"
|
||||
raise ImportError(msg)
|
||||
|
||||
import numba.core.config
|
||||
from pygments.styles.manni import ManniStyle
|
||||
from pygments.styles.monokai import MonokaiStyle
|
||||
from pygments.styles.native import NativeStyle
|
||||
|
||||
from pygments.lexer import RegexLexer, include, bygroups, words
|
||||
from pygments.token import Text, Name, String, Punctuation, Keyword, \
|
||||
Operator, Number
|
||||
|
||||
from pygments.style import Style
|
||||
|
||||
|
||||
class NumbaIRLexer(RegexLexer):
|
||||
"""
|
||||
Pygments style lexer for Numba IR (for use with highlighting etc).
|
||||
"""
|
||||
name = 'Numba_IR'
|
||||
aliases = ['numba_ir']
|
||||
filenames = ['*.numba_ir']
|
||||
|
||||
identifier = r'\$[a-zA-Z0-9._]+'
|
||||
fun_or_var = r'([a-zA-Z_]+[a-zA-Z0-9]*)'
|
||||
|
||||
tokens = {
|
||||
'root' : [
|
||||
(r'(label)(\ [0-9]+)(:)$',
|
||||
bygroups(Keyword, Name.Label, Punctuation)),
|
||||
|
||||
(r' = ', Operator),
|
||||
include('whitespace'),
|
||||
include('keyword'),
|
||||
|
||||
(identifier, Name.Variable),
|
||||
(fun_or_var + r'(\()',
|
||||
bygroups(Name.Function, Punctuation)),
|
||||
(fun_or_var + r'(\=)',
|
||||
bygroups(Name.Attribute, Punctuation)),
|
||||
(fun_or_var, Name.Constant),
|
||||
(r'[0-9]+', Number),
|
||||
|
||||
# <built-in function some>
|
||||
(r'<[^>\n]*>', String),
|
||||
|
||||
(r'[=<>{}\[\]()*.,!\':]|x\b', Punctuation)
|
||||
],
|
||||
|
||||
'keyword':[
|
||||
(words((
|
||||
'del', 'jump', 'call', 'branch',
|
||||
), suffix=' '), Keyword),
|
||||
],
|
||||
|
||||
'whitespace': [
|
||||
(r'(\n|\s)', Text),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def by_colorscheme():
|
||||
"""
|
||||
Get appropriate style for highlighting according to
|
||||
NUMBA_COLOR_SCHEME setting
|
||||
"""
|
||||
styles = DefaultStyle.styles.copy()
|
||||
styles.update({
|
||||
Name.Variable: "#888888",
|
||||
})
|
||||
custom_default = type('CustomDefaultStyle', (Style, ), {'styles': styles})
|
||||
|
||||
style_map = {
|
||||
'no_color' : custom_default,
|
||||
'dark_bg' : MonokaiStyle,
|
||||
'light_bg' : ManniStyle,
|
||||
'blue_bg' : NativeStyle,
|
||||
'jupyter_nb' : DefaultStyle,
|
||||
}
|
||||
|
||||
return style_map[numba.core.config.COLOR_SCHEME]
|
||||
@@ -0,0 +1,63 @@
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def get_lib_dirs():
|
||||
"""
|
||||
Anaconda specific
|
||||
"""
|
||||
if sys.platform == 'win32':
|
||||
# on windows, historically `DLLs` has been used for CUDA libraries,
|
||||
# since approximately CUDA 9.2, `Library\bin` has been used.
|
||||
dirnames = ['DLLs', os.path.join('Library', 'bin')]
|
||||
else:
|
||||
dirnames = ['lib', ]
|
||||
libdirs = [os.path.join(sys.prefix, x) for x in dirnames]
|
||||
return libdirs
|
||||
|
||||
|
||||
DLLNAMEMAP = {
|
||||
'linux': r'lib%(name)s\.so\.%(ver)s$',
|
||||
'linux2': r'lib%(name)s\.so\.%(ver)s$',
|
||||
'linux-static': r'lib%(name)s\.a$',
|
||||
'darwin': r'lib%(name)s\.%(ver)s\.dylib$',
|
||||
'win32': r'%(name)s%(ver)s\.dll$',
|
||||
'win32-static': r'%(name)s\.lib$',
|
||||
'bsd': r'lib%(name)s\.so\.%(ver)s$',
|
||||
}
|
||||
|
||||
RE_VER = r'[0-9]*([_\.][0-9]+)*'
|
||||
|
||||
|
||||
def find_lib(libname, libdir=None, platform=None, static=False):
|
||||
platform = platform or sys.platform
|
||||
platform = 'bsd' if 'bsd' in platform else platform
|
||||
if static:
|
||||
platform = f"{platform}-static"
|
||||
if platform not in DLLNAMEMAP:
|
||||
# Return empty list if platform name is undefined.
|
||||
# Not all platforms define their static library paths.
|
||||
return []
|
||||
pat = DLLNAMEMAP[platform] % {"name": libname, "ver": RE_VER}
|
||||
regex = re.compile(pat)
|
||||
return find_file(regex, libdir)
|
||||
|
||||
|
||||
def find_file(pat, libdir=None):
|
||||
if libdir is None:
|
||||
libdirs = get_lib_dirs()
|
||||
elif isinstance(libdir, str):
|
||||
libdirs = [libdir,]
|
||||
else:
|
||||
libdirs = list(libdir)
|
||||
files = []
|
||||
for ldir in libdirs:
|
||||
try:
|
||||
entries = os.listdir(ldir)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
candidates = [os.path.join(ldir, ent)
|
||||
for ent in entries if pat.match(ent)]
|
||||
files.extend([c for c in candidates if os.path.isfile(c)])
|
||||
return files
|
||||
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
This module provides helper functions to find the first line of a function
|
||||
body.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
|
||||
class FindDefFirstLine(ast.NodeVisitor):
|
||||
"""
|
||||
Attributes
|
||||
----------
|
||||
first_stmt_line : int or None
|
||||
This stores the first statement line number if the definition is found.
|
||||
Or, ``None`` if the definition is not found.
|
||||
"""
|
||||
|
||||
def __init__(self, name, firstlineno):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
code :
|
||||
The function's code object.
|
||||
"""
|
||||
self._co_name = name
|
||||
self._co_firstlineno = firstlineno
|
||||
self.first_stmt_line = None
|
||||
|
||||
def _visit_children(self, node):
|
||||
for child in ast.iter_child_nodes(node):
|
||||
super().visit(child)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
if node.name == self._co_name:
|
||||
# Name of function matches.
|
||||
|
||||
# The `def` line may match co_firstlineno.
|
||||
possible_start_lines = set([node.lineno])
|
||||
if node.decorator_list:
|
||||
# Has decorators.
|
||||
# The first decorator line may match co_firstlineno.
|
||||
first_decor = node.decorator_list[0]
|
||||
possible_start_lines.add(first_decor.lineno)
|
||||
# Does the first lineno match?
|
||||
if self._co_firstlineno in possible_start_lines:
|
||||
# Yes, we found the function.
|
||||
# So, use the first statement line as the first line.
|
||||
if node.body:
|
||||
first_stmt = node.body[0]
|
||||
if _is_docstring(first_stmt):
|
||||
# Skip docstring
|
||||
first_stmt = node.body[1]
|
||||
self.first_stmt_line = first_stmt.lineno
|
||||
return
|
||||
else:
|
||||
# This is probably unreachable.
|
||||
# Function body cannot be bare. It must at least have
|
||||
# A const string for docstring or a `pass`.
|
||||
pass
|
||||
self._visit_children(node)
|
||||
|
||||
|
||||
def _is_docstring(node):
|
||||
if isinstance(node, ast.Expr):
|
||||
if (isinstance(node.value, ast.Constant)
|
||||
and isinstance(node.value.value, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_func_body_first_lineno(pyfunc):
|
||||
"""
|
||||
Look up the first line of function body using the file in
|
||||
``pyfunc.__code__.co_filename``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
lineno : int; or None
|
||||
The first line number of the function body; or ``None`` if the first
|
||||
line cannot be determined.
|
||||
"""
|
||||
co = pyfunc.__code__
|
||||
try:
|
||||
with open(co.co_filename) as fin:
|
||||
source = fin.read()
|
||||
offset = 0
|
||||
except (FileNotFoundError, OSError):
|
||||
try:
|
||||
lines, offset = inspect.getsourcelines(pyfunc)
|
||||
source = "".join(lines)
|
||||
offset = offset - 1
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
tree = ast.parse(textwrap.dedent(source))
|
||||
finder = FindDefFirstLine(co.co_name, co.co_firstlineno - offset)
|
||||
finder.visit(tree)
|
||||
if finder.first_stmt_line:
|
||||
return finder.first_stmt_line + offset
|
||||
else:
|
||||
# No first line found.
|
||||
return None
|
||||
@@ -0,0 +1,228 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from llvmlite import ir
|
||||
|
||||
from numba.core import types, utils, config, cgutils, errors
|
||||
from numba import gdb, gdb_init, gdb_breakpoint
|
||||
from numba.core.extending import overload, intrinsic
|
||||
|
||||
_path = os.path.dirname(__file__)
|
||||
|
||||
_platform = sys.platform
|
||||
_unix_like = (_platform.startswith('linux') or
|
||||
_platform.startswith('darwin') or
|
||||
('bsd' in _platform))
|
||||
|
||||
|
||||
def _confirm_gdb(need_ptrace_attach=True):
|
||||
"""
|
||||
Set need_ptrace_attach to True/False to indicate whether the ptrace attach
|
||||
permission is needed for this gdb use case. Mode 0 (classic) or 1
|
||||
(restricted ptrace) is required if need_ptrace_attach is True. See:
|
||||
https://www.kernel.org/doc/Documentation/admin-guide/LSM/Yama.rst
|
||||
for details on the modes.
|
||||
"""
|
||||
if not _unix_like:
|
||||
msg = 'gdb support is only available on unix-like systems'
|
||||
raise errors.NumbaRuntimeError(msg)
|
||||
gdbloc = config.GDB_BINARY
|
||||
if not (os.path.exists(gdbloc) and os.path.isfile(gdbloc)):
|
||||
msg = ('Is gdb present? Location specified (%s) does not exist. The gdb'
|
||||
' binary location can be set using Numba configuration, see: '
|
||||
'https://numba.readthedocs.io/en/stable/reference/envvars.html' # noqa: E501
|
||||
)
|
||||
raise RuntimeError(msg % config.GDB_BINARY)
|
||||
# Is Yama being used as a kernel security module and if so is ptrace_scope
|
||||
# limited? In this case ptracing non-child processes requires special
|
||||
# permission so raise an exception.
|
||||
ptrace_scope_file = os.path.join(os.sep, 'proc', 'sys', 'kernel', 'yama',
|
||||
'ptrace_scope')
|
||||
has_ptrace_scope = os.path.exists(ptrace_scope_file)
|
||||
if has_ptrace_scope:
|
||||
with open(ptrace_scope_file, 'rt') as f:
|
||||
value = f.readline().strip()
|
||||
if need_ptrace_attach and value not in ("0", "1"):
|
||||
msg = ("gdb can launch but cannot attach to the executing program"
|
||||
" because ptrace permissions have been restricted at the "
|
||||
"system level by the Linux security module 'Yama'.\n\n"
|
||||
"Documentation for this module and the security "
|
||||
"implications of making changes to its behaviour can be "
|
||||
"found in the Linux Kernel documentation "
|
||||
"https://www.kernel.org/doc/Documentation/admin-guide/LSM/Yama.rst" # noqa: E501
|
||||
"\n\nDocumentation on how to adjust the behaviour of Yama "
|
||||
"on Ubuntu Linux with regards to 'ptrace_scope' can be "
|
||||
"found here "
|
||||
"https://wiki.ubuntu.com/Security/Features#ptrace.")
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
@overload(gdb)
|
||||
def hook_gdb(*args):
|
||||
_confirm_gdb()
|
||||
gdbimpl = gen_gdb_impl(args, True)
|
||||
|
||||
def impl(*args):
|
||||
gdbimpl()
|
||||
return impl
|
||||
|
||||
|
||||
@overload(gdb_init)
|
||||
def hook_gdb_init(*args):
|
||||
_confirm_gdb()
|
||||
gdbimpl = gen_gdb_impl(args, False)
|
||||
|
||||
def impl(*args):
|
||||
gdbimpl()
|
||||
return impl
|
||||
|
||||
|
||||
def init_gdb_codegen(cgctx, builder, signature, args,
|
||||
const_args, do_break=False):
|
||||
|
||||
int8_t = ir.IntType(8)
|
||||
int32_t = ir.IntType(32)
|
||||
intp_t = ir.IntType(utils.MACHINE_BITS)
|
||||
char_ptr = ir.PointerType(ir.IntType(8))
|
||||
zero_i32t = int32_t(0)
|
||||
|
||||
mod = builder.module
|
||||
pid = cgutils.alloca_once(builder, int32_t, size=1)
|
||||
|
||||
# 32bit pid, 11 char max + terminator
|
||||
pidstr = cgutils.alloca_once(builder, int8_t, size=12)
|
||||
|
||||
# str consts
|
||||
intfmt = cgctx.insert_const_string(mod, '%d')
|
||||
gdb_str = cgctx.insert_const_string(mod, config.GDB_BINARY)
|
||||
attach_str = cgctx.insert_const_string(mod, 'attach')
|
||||
|
||||
new_args = []
|
||||
# add break point command to known location
|
||||
# this command file thing is due to commands attached to a breakpoint
|
||||
# requiring an interactive prompt
|
||||
# https://sourceware.org/bugzilla/show_bug.cgi?id=10079
|
||||
new_args.extend(['-x', os.path.join(_path, 'cmdlang.gdb')])
|
||||
# issue command to continue execution from sleep function
|
||||
new_args.extend(['-ex', 'c'])
|
||||
# then run the user defined args if any
|
||||
if any([not isinstance(x, types.StringLiteral) for x in const_args]):
|
||||
raise errors.RequireLiteralValue(const_args)
|
||||
new_args.extend([x.literal_value for x in const_args])
|
||||
cmdlang = [cgctx.insert_const_string(mod, x) for x in new_args]
|
||||
|
||||
# insert getpid, getpid is always successful, call without concern!
|
||||
fnty = ir.FunctionType(int32_t, tuple())
|
||||
getpid = cgutils.get_or_insert_function(mod, fnty, "getpid")
|
||||
|
||||
# insert snprintf
|
||||
# int snprintf(char *str, size_t size, const char *format, ...);
|
||||
fnty = ir.FunctionType(
|
||||
int32_t, (char_ptr, intp_t, char_ptr), var_arg=True)
|
||||
snprintf = cgutils.get_or_insert_function(mod, fnty, "snprintf")
|
||||
|
||||
# insert fork
|
||||
fnty = ir.FunctionType(int32_t, tuple())
|
||||
fork = cgutils.get_or_insert_function(mod, fnty, "fork")
|
||||
|
||||
# insert execl
|
||||
fnty = ir.FunctionType(int32_t, (char_ptr, char_ptr), var_arg=True)
|
||||
execl = cgutils.get_or_insert_function(mod, fnty, "execl")
|
||||
|
||||
# insert sleep
|
||||
fnty = ir.FunctionType(int32_t, (int32_t,))
|
||||
sleep = cgutils.get_or_insert_function(mod, fnty, "sleep")
|
||||
|
||||
# insert break point
|
||||
fnty = ir.FunctionType(ir.VoidType(), tuple())
|
||||
breakpoint = cgutils.get_or_insert_function(mod, fnty,
|
||||
"numba_gdb_breakpoint")
|
||||
|
||||
# do the work
|
||||
parent_pid = builder.call(getpid, tuple())
|
||||
builder.store(parent_pid, pid)
|
||||
pidstr_ptr = builder.gep(pidstr, [zero_i32t], inbounds=True)
|
||||
pid_val = builder.load(pid)
|
||||
|
||||
# call snprintf to write the pid into a char *
|
||||
stat = builder.call(
|
||||
snprintf, (pidstr_ptr, intp_t(12), intfmt, pid_val))
|
||||
invalid_write = builder.icmp_signed('>', stat, int32_t(12))
|
||||
with builder.if_then(invalid_write, likely=False):
|
||||
msg = "Internal error: `snprintf` buffer would have overflowed."
|
||||
cgctx.call_conv.return_user_exc(builder, RuntimeError, (msg,))
|
||||
|
||||
# fork, check pids etc
|
||||
child_pid = builder.call(fork, tuple())
|
||||
fork_failed = builder.icmp_signed('==', child_pid, int32_t(-1))
|
||||
with builder.if_then(fork_failed, likely=False):
|
||||
msg = "Internal error: `fork` failed."
|
||||
cgctx.call_conv.return_user_exc(builder, RuntimeError, (msg,))
|
||||
|
||||
is_child = builder.icmp_signed('==', child_pid, zero_i32t)
|
||||
with builder.if_else(is_child) as (then, orelse):
|
||||
with then:
|
||||
# is child
|
||||
nullptr = ir.Constant(char_ptr, None)
|
||||
gdb_str_ptr = builder.gep(
|
||||
gdb_str, [zero_i32t], inbounds=True)
|
||||
attach_str_ptr = builder.gep(
|
||||
attach_str, [zero_i32t], inbounds=True)
|
||||
cgutils.printf(
|
||||
builder, "Attaching to PID: %s\n", pidstr)
|
||||
buf = (
|
||||
gdb_str_ptr,
|
||||
gdb_str_ptr,
|
||||
attach_str_ptr,
|
||||
pidstr_ptr)
|
||||
buf = buf + tuple(cmdlang) + (nullptr,)
|
||||
builder.call(execl, buf)
|
||||
with orelse:
|
||||
# is parent
|
||||
builder.call(sleep, (int32_t(10),))
|
||||
# if breaking is desired, break now
|
||||
if do_break is True:
|
||||
builder.call(breakpoint, tuple())
|
||||
|
||||
|
||||
def gen_gdb_impl(const_args, do_break):
|
||||
@intrinsic
|
||||
def gdb_internal(tyctx):
|
||||
function_sig = types.void()
|
||||
|
||||
def codegen(cgctx, builder, signature, args):
|
||||
init_gdb_codegen(cgctx, builder, signature, args, const_args,
|
||||
do_break=do_break)
|
||||
return cgctx.get_constant(types.none, None)
|
||||
return function_sig, codegen
|
||||
return gdb_internal
|
||||
|
||||
|
||||
@overload(gdb_breakpoint)
|
||||
def hook_gdb_breakpoint():
|
||||
"""
|
||||
Adds the Numba break point into the source
|
||||
"""
|
||||
if not sys.platform.startswith('linux'):
|
||||
raise RuntimeError('gdb is only available on linux')
|
||||
bp_impl = gen_bp_impl()
|
||||
|
||||
def impl():
|
||||
bp_impl()
|
||||
return impl
|
||||
|
||||
|
||||
def gen_bp_impl():
|
||||
@intrinsic
|
||||
def bp_internal(tyctx):
|
||||
function_sig = types.void()
|
||||
|
||||
def codegen(cgctx, builder, signature, args):
|
||||
mod = builder.module
|
||||
fnty = ir.FunctionType(ir.VoidType(), tuple())
|
||||
breakpoint = cgutils.get_or_insert_function(mod, fnty,
|
||||
"numba_gdb_breakpoint")
|
||||
builder.call(breakpoint, tuple())
|
||||
return cgctx.get_constant(types.none, None)
|
||||
return function_sig, codegen
|
||||
return bp_internal
|
||||
@@ -0,0 +1,204 @@
|
||||
"""gdb printing extension for Numba types.
|
||||
"""
|
||||
import re
|
||||
|
||||
try:
|
||||
import gdb.printing
|
||||
import gdb
|
||||
except ImportError:
|
||||
raise ImportError("GDB python support is not available.")
|
||||
|
||||
|
||||
class NumbaArrayPrinter:
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def to_string(self):
|
||||
try:
|
||||
import numpy as np
|
||||
HAVE_NUMPY = True
|
||||
except ImportError:
|
||||
HAVE_NUMPY = False
|
||||
|
||||
try:
|
||||
NULL = 0x0
|
||||
|
||||
# Raw data references, these need unpacking/interpreting.
|
||||
|
||||
# Member "data" is...
|
||||
# DW_TAG_member of DIDerivedType, tag of DW_TAG_pointer_type
|
||||
# encoding e.g. DW_ATE_float
|
||||
data = self.val["data"]
|
||||
|
||||
# Member "itemsize" is...
|
||||
# DW_TAG_member of DIBasicType encoding DW_ATE_signed
|
||||
itemsize = self.val["itemsize"]
|
||||
|
||||
# Members "shape" and "strides" are...
|
||||
# DW_TAG_member of DIDerivedType, the type is a DICompositeType
|
||||
# (it's a Numba UniTuple) with tag: DW_TAG_array_type, i.e. it's an
|
||||
# array repr, it has a basetype of e.g. DW_ATE_unsigned and also
|
||||
# "elements" which are referenced with a DISubrange(count: <const>)
|
||||
# to say how many elements are in the array.
|
||||
rshp = self.val["shape"]
|
||||
rstrides = self.val["strides"]
|
||||
|
||||
# bool on whether the data is aligned.
|
||||
is_aligned = False
|
||||
|
||||
# type information decode, simple type:
|
||||
ty_str = str(self.val.type)
|
||||
if HAVE_NUMPY and ('aligned' in ty_str or 'Record' in ty_str):
|
||||
ty_str = ty_str.replace('unaligned ','').strip()
|
||||
matcher = re.compile(r"array\((Record.*), (.*), (.*)\)\ \(.*")
|
||||
# NOTE: need to deal with "Alignment" else dtype size is wrong
|
||||
arr_info = [x.strip() for x in matcher.match(ty_str).groups()]
|
||||
dtype_str, ndim_str, order_str = arr_info
|
||||
rstr = r'Record\((.*\[.*\]);([0-9]+);(True|False)'
|
||||
rstr_match = re.match(rstr, dtype_str)
|
||||
# balign is unused, it's the alignment
|
||||
fields, balign, is_aligned_str = rstr_match.groups()
|
||||
is_aligned = is_aligned_str == 'True'
|
||||
field_dts = fields.split(',')
|
||||
struct_entries = []
|
||||
for f in field_dts:
|
||||
splitted = f.split('[')
|
||||
name = splitted[0]
|
||||
dt_part = splitted[1:]
|
||||
if len(dt_part) > 1:
|
||||
raise TypeError('Unsupported sub-type: %s' % f)
|
||||
else:
|
||||
dt_part = dt_part[0]
|
||||
if "nestedarray" in dt_part:
|
||||
raise TypeError('Unsupported sub-type: %s' % f)
|
||||
dt_as_str = dt_part.split(';')[0].split('=')[1]
|
||||
dtype = np.dtype(dt_as_str)
|
||||
struct_entries.append((name, dtype))
|
||||
# The dtype is actually a record of some sort
|
||||
dtype_str = struct_entries
|
||||
else: # simple type
|
||||
matcher = re.compile(r"array\((.*),(.*),(.*)\)\ \(.*")
|
||||
arr_info = [x.strip() for x in matcher.match(ty_str).groups()]
|
||||
dtype_str, ndim_str, order_str = arr_info
|
||||
# fix up unichr dtype
|
||||
if 'unichr x ' in dtype_str:
|
||||
dtype_str = dtype_str[1:-1].replace('unichr x ', '<U')
|
||||
|
||||
def dwarr2inttuple(dwarr):
|
||||
# Converts a gdb handle to a dwarf array to a tuple of ints
|
||||
fields = dwarr.type.fields()
|
||||
lo, hi = fields[0].type.range()
|
||||
return tuple([int(dwarr[x]) for x in range(lo, hi + 1)])
|
||||
|
||||
# shape/strides extraction
|
||||
shape = dwarr2inttuple(rshp)
|
||||
strides = dwarr2inttuple(rstrides)
|
||||
|
||||
# if data is not NULL
|
||||
if data != NULL:
|
||||
if HAVE_NUMPY:
|
||||
# The data extent in bytes is:
|
||||
# sum(shape * strides)
|
||||
# get the data, then wire to as_strided
|
||||
shp_arr = np.array([max(0, x - 1) for x in shape])
|
||||
strd_arr = np.array(strides)
|
||||
extent = np.sum(shp_arr * strd_arr)
|
||||
extent += int(itemsize)
|
||||
dtype_clazz = np.dtype(dtype_str, align=is_aligned)
|
||||
dtype = dtype_clazz
|
||||
this_proc = gdb.selected_inferior()
|
||||
mem = this_proc.read_memory(int(data), extent)
|
||||
arr_data = np.frombuffer(mem, dtype=dtype)
|
||||
new_arr = np.lib.stride_tricks.as_strided(arr_data,
|
||||
shape=shape,
|
||||
strides=strides,)
|
||||
return '\n' + str(new_arr)
|
||||
# Catch all for no NumPy
|
||||
return "array([...], dtype=%s, shape=%s)" % (dtype_str, shape)
|
||||
else:
|
||||
# Not yet initialized or NULLed out data
|
||||
buf = list(["NULL/Uninitialized"])
|
||||
return "array([" + ', '.join(buf) + "]" + ")"
|
||||
except Exception as e:
|
||||
return 'array[Exception: Failed to parse. %s]' % e
|
||||
|
||||
|
||||
class NumbaComplexPrinter:
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def to_string(self):
|
||||
return "%s+%sj" % (self.val['real'], self.val['imag'])
|
||||
|
||||
|
||||
class NumbaTuplePrinter:
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def to_string(self):
|
||||
buf = []
|
||||
fields = self.val.type.fields()
|
||||
for f in fields:
|
||||
buf.append(str(self.val[f.name]))
|
||||
return "(%s)" % ', '.join(buf)
|
||||
|
||||
|
||||
class NumbaUniTuplePrinter:
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def to_string(self):
|
||||
# unituples are arrays
|
||||
fields = self.val.type.fields()
|
||||
lo, hi = fields[0].type.range()
|
||||
buf = []
|
||||
for i in range(lo, hi + 1):
|
||||
buf.append(str(self.val[i]))
|
||||
return "(%s)" % ', '.join(buf)
|
||||
|
||||
|
||||
class NumbaUnicodeTypePrinter:
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def to_string(self):
|
||||
NULL = 0x0
|
||||
data = self.val["data"]
|
||||
nitems = self.val["length"]
|
||||
kind = self.val["kind"]
|
||||
if data != NULL:
|
||||
# This needs sorting out, encoding is wrong
|
||||
this_proc = gdb.selected_inferior()
|
||||
mem = this_proc.read_memory(int(data), nitems * kind)
|
||||
if isinstance(mem, memoryview):
|
||||
buf = bytes(mem).decode()
|
||||
else:
|
||||
buf = mem.decode('utf-8')
|
||||
else:
|
||||
buf = str(data)
|
||||
return "'%s'" % buf
|
||||
|
||||
|
||||
def _create_printers():
|
||||
printer = gdb.printing.RegexpCollectionPrettyPrinter("Numba")
|
||||
printer.add_printer('Numba unaligned array printer', '^unaligned array\\(',
|
||||
NumbaArrayPrinter)
|
||||
printer.add_printer('Numba array printer', '^array\\(', NumbaArrayPrinter)
|
||||
printer.add_printer('Numba complex printer', '^complex[0-9]+\\ ',
|
||||
NumbaComplexPrinter)
|
||||
printer.add_printer('Numba Tuple printer', '^Tuple\\(',
|
||||
NumbaTuplePrinter)
|
||||
printer.add_printer('Numba UniTuple printer', '^UniTuple\\(',
|
||||
NumbaUniTuplePrinter)
|
||||
printer.add_printer('Numba unicode_type printer', '^unicode_type\\s+\\(',
|
||||
NumbaUnicodeTypePrinter)
|
||||
return printer
|
||||
|
||||
|
||||
# register the Numba pretty printers for the current object
|
||||
gdb.printing.register_pretty_printer(gdb.current_objfile(), _create_printers())
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
This file contains `__main__` so that it can be run as a commandline tool.
|
||||
|
||||
This file contains functions to inspect Numba's support for a given Python
|
||||
module or a Python package.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pkgutil
|
||||
import warnings
|
||||
import types as pytypes
|
||||
|
||||
from numba.core import errors
|
||||
from numba._version import get_versions
|
||||
from numba.core.registry import cpu_target
|
||||
from numba.tests.support import captured_stdout
|
||||
|
||||
|
||||
def _get_commit():
|
||||
full = get_versions()['full-revisionid']
|
||||
if not full:
|
||||
warnings.warn(
|
||||
"Cannot find git commit hash. Source links could be inaccurate.",
|
||||
category=errors.NumbaWarning,
|
||||
)
|
||||
return 'main'
|
||||
return full
|
||||
|
||||
|
||||
commit = _get_commit()
|
||||
github_url = 'https://github.com/numba/numba/blob/{commit}/{path}#L{firstline}-L{lastline}' # noqa: E501
|
||||
|
||||
|
||||
def inspect_function(function, target=None):
|
||||
"""Return information about the support of a function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
info : dict
|
||||
Defined keys:
|
||||
- "numba_type": str or None
|
||||
The numba type object of the function if supported.
|
||||
- "explained": str
|
||||
A textual description of the support.
|
||||
- "source_infos": dict
|
||||
A dictionary containing the source location of each definition.
|
||||
"""
|
||||
target = target or cpu_target
|
||||
tyct = target.typing_context
|
||||
# Make sure we have loaded all extensions
|
||||
tyct.refresh()
|
||||
target.target_context.refresh()
|
||||
|
||||
info = {}
|
||||
# Try getting the function type
|
||||
source_infos = {}
|
||||
try:
|
||||
nbty = tyct.resolve_value_type(function)
|
||||
except ValueError:
|
||||
nbty = None
|
||||
explained = 'not supported'
|
||||
else:
|
||||
# Make a longer explanation of the type
|
||||
explained = tyct.explain_function_type(nbty)
|
||||
for temp in nbty.templates:
|
||||
try:
|
||||
source_infos[temp] = temp.get_source_info()
|
||||
except AttributeError:
|
||||
source_infos[temp] = None
|
||||
|
||||
info['numba_type'] = nbty
|
||||
info['explained'] = explained
|
||||
info['source_infos'] = source_infos
|
||||
return info
|
||||
|
||||
|
||||
def inspect_module(module, target=None, alias=None):
|
||||
"""Inspect a module object and yielding results from `inspect_function()`
|
||||
for each function object in the module.
|
||||
"""
|
||||
alias = {} if alias is None else alias
|
||||
# Walk the module
|
||||
for name in dir(module):
|
||||
if name.startswith('_'):
|
||||
# Skip
|
||||
continue
|
||||
obj = getattr(module, name)
|
||||
supported_types = (pytypes.FunctionType, pytypes.BuiltinFunctionType)
|
||||
|
||||
if not isinstance(obj, supported_types):
|
||||
# Skip if it's not a function
|
||||
continue
|
||||
|
||||
info = dict(module=module, name=name, obj=obj)
|
||||
if obj in alias:
|
||||
info['alias'] = alias[obj]
|
||||
else:
|
||||
alias[obj] = "{module}.{name}".format(module=module.__name__,
|
||||
name=name)
|
||||
info.update(inspect_function(obj, target=target))
|
||||
yield info
|
||||
|
||||
|
||||
class _Stat(object):
|
||||
"""For gathering simple statistic of (un)supported functions"""
|
||||
def __init__(self):
|
||||
self.supported = 0
|
||||
self.unsupported = 0
|
||||
|
||||
@property
|
||||
def total(self):
|
||||
total = self.supported + self.unsupported
|
||||
return total
|
||||
|
||||
@property
|
||||
def ratio(self):
|
||||
ratio = self.supported / self.total * 100
|
||||
return ratio
|
||||
|
||||
def describe(self):
|
||||
if self.total == 0:
|
||||
return "empty"
|
||||
return "supported = {supported} / {total} = {ratio:.2f}%".format(
|
||||
supported=self.supported,
|
||||
total=self.total,
|
||||
ratio=self.ratio,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "{clsname}({describe})".format(
|
||||
clsname=self.__class__.__name__,
|
||||
describe=self.describe(),
|
||||
)
|
||||
|
||||
|
||||
def filter_private_module(module_components):
|
||||
return not any(x.startswith('_') for x in module_components)
|
||||
|
||||
|
||||
def filter_tests_module(module_components):
|
||||
return not any(x == 'tests' for x in module_components)
|
||||
|
||||
|
||||
_default_module_filters = (
|
||||
filter_private_module,
|
||||
filter_tests_module,
|
||||
)
|
||||
|
||||
|
||||
def list_modules_in_package(package, module_filters=_default_module_filters):
|
||||
"""Yield all modules in a given package.
|
||||
|
||||
Recursively walks the package tree.
|
||||
"""
|
||||
onerror_ignore = lambda _: None
|
||||
|
||||
prefix = package.__name__ + "."
|
||||
package_walker = pkgutil.walk_packages(
|
||||
package.__path__,
|
||||
prefix,
|
||||
onerror=onerror_ignore,
|
||||
)
|
||||
|
||||
def check_filter(modname):
|
||||
module_components = modname.split('.')
|
||||
return any(not filter_fn(module_components)
|
||||
for filter_fn in module_filters)
|
||||
|
||||
modname = package.__name__
|
||||
if not check_filter(modname):
|
||||
yield package
|
||||
|
||||
for pkginfo in package_walker:
|
||||
modname = pkginfo[1]
|
||||
if check_filter(modname):
|
||||
continue
|
||||
# In case importing of the module print to stdout
|
||||
with captured_stdout():
|
||||
try:
|
||||
# Import the module
|
||||
mod = __import__(modname)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Extract the module
|
||||
for part in modname.split('.')[1:]:
|
||||
try:
|
||||
mod = getattr(mod, part)
|
||||
except AttributeError:
|
||||
# Suppress error in getting the attribute
|
||||
mod = None
|
||||
break
|
||||
|
||||
# Ignore if mod is not a module
|
||||
if not isinstance(mod, pytypes.ModuleType):
|
||||
# Skip non-module
|
||||
continue
|
||||
|
||||
yield mod
|
||||
|
||||
|
||||
class Formatter(object):
|
||||
"""Base class for formatters.
|
||||
"""
|
||||
def __init__(self, fileobj):
|
||||
self._fileobj = fileobj
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
kwargs.setdefault('file', self._fileobj)
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
class HTMLFormatter(Formatter):
|
||||
"""Formatter that outputs HTML
|
||||
"""
|
||||
|
||||
def escape(self, text):
|
||||
import html
|
||||
return html.escape(text)
|
||||
|
||||
def title(self, text):
|
||||
self.print('<h1>', text, '</h2>')
|
||||
|
||||
def begin_module_section(self, modname):
|
||||
self.print('<h2>', modname, '</h2>')
|
||||
self.print('<ul>')
|
||||
|
||||
def end_module_section(self):
|
||||
self.print('</ul>')
|
||||
|
||||
def write_supported_item(self, modname, itemname, typename, explained,
|
||||
sources, alias):
|
||||
self.print('<li>')
|
||||
self.print('{}.<b>{}</b>'.format(
|
||||
modname,
|
||||
itemname,
|
||||
))
|
||||
self.print(': <b>{}</b>'.format(typename))
|
||||
self.print('<div><pre>', explained, '</pre></div>')
|
||||
|
||||
self.print("<ul>")
|
||||
for tcls, source in sources.items():
|
||||
if source:
|
||||
self.print("<li>")
|
||||
impl = source['name']
|
||||
sig = source['sig']
|
||||
filename = source['filename']
|
||||
lines = source['lines']
|
||||
self.print(
|
||||
"<p>defined by <b>{}</b>{} at {}:{}-{}</p>".format(
|
||||
self.escape(impl), self.escape(sig),
|
||||
self.escape(filename), lines[0], lines[1],
|
||||
),
|
||||
)
|
||||
self.print('<p>{}</p>'.format(
|
||||
self.escape(source['docstring'] or '')
|
||||
))
|
||||
else:
|
||||
self.print("<li>{}".format(self.escape(str(tcls))))
|
||||
self.print("</li>")
|
||||
self.print("</ul>")
|
||||
self.print('</li>')
|
||||
|
||||
def write_unsupported_item(self, modname, itemname):
|
||||
self.print('<li>')
|
||||
self.print('{}.<b>{}</b>: UNSUPPORTED'.format(
|
||||
modname,
|
||||
itemname,
|
||||
))
|
||||
self.print('</li>')
|
||||
|
||||
def write_statistic(self, stats):
|
||||
self.print('<p>{}</p>'.format(stats.describe()))
|
||||
|
||||
|
||||
class ReSTFormatter(Formatter):
|
||||
"""Formatter that output ReSTructured text format for Sphinx docs.
|
||||
"""
|
||||
def escape(self, text):
|
||||
return text
|
||||
|
||||
def title(self, text):
|
||||
self.print(text)
|
||||
self.print('=' * len(text))
|
||||
self.print()
|
||||
|
||||
def begin_module_section(self, modname):
|
||||
self.print(modname)
|
||||
self.print('-' * len(modname))
|
||||
self.print()
|
||||
|
||||
def end_module_section(self):
|
||||
self.print()
|
||||
|
||||
def write_supported_item(self, modname, itemname, typename, explained,
|
||||
sources, alias):
|
||||
self.print('.. function:: {}.{}'.format(modname, itemname))
|
||||
self.print(' :noindex:')
|
||||
self.print()
|
||||
|
||||
if alias:
|
||||
self.print(" Alias to: ``{}``".format(alias))
|
||||
self.print()
|
||||
|
||||
for tcls, source in sources.items():
|
||||
if source:
|
||||
impl = source['name']
|
||||
sig = source['sig']
|
||||
filename = source['filename']
|
||||
lines = source['lines']
|
||||
source_link = github_url.format(
|
||||
commit=commit,
|
||||
path=filename,
|
||||
firstline=lines[0],
|
||||
lastline=lines[1],
|
||||
)
|
||||
self.print(
|
||||
" - defined by ``{}{}`` at `{}:{}-{} <{}>`_".format(
|
||||
impl, sig, filename, lines[0], lines[1], source_link,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
self.print(" - defined by ``{}``".format(str(tcls)))
|
||||
self.print()
|
||||
|
||||
def write_unsupported_item(self, modname, itemname):
|
||||
pass
|
||||
|
||||
def write_statistic(self, stat):
|
||||
if stat.supported == 0:
|
||||
self.print("This module is not supported.")
|
||||
else:
|
||||
msg = "Not showing {} unsupported functions."
|
||||
self.print(msg.format(stat.unsupported))
|
||||
self.print()
|
||||
self.print(stat.describe())
|
||||
self.print()
|
||||
|
||||
|
||||
def _format_module_infos(formatter, package_name, mod_sequence, target=None):
|
||||
"""Format modules.
|
||||
"""
|
||||
formatter.title('Listings for {}'.format(package_name))
|
||||
alias_map = {} # remember object seen to track alias
|
||||
for mod in mod_sequence:
|
||||
stat = _Stat()
|
||||
modname = mod.__name__
|
||||
formatter.begin_module_section(formatter.escape(modname))
|
||||
for info in inspect_module(mod, target=target, alias=alias_map):
|
||||
nbtype = info['numba_type']
|
||||
if nbtype is not None:
|
||||
stat.supported += 1
|
||||
formatter.write_supported_item(
|
||||
modname=formatter.escape(info['module'].__name__),
|
||||
itemname=formatter.escape(info['name']),
|
||||
typename=formatter.escape(str(nbtype)),
|
||||
explained=formatter.escape(info['explained']),
|
||||
sources=info['source_infos'],
|
||||
alias=info.get('alias'),
|
||||
)
|
||||
|
||||
else:
|
||||
stat.unsupported += 1
|
||||
formatter.write_unsupported_item(
|
||||
modname=formatter.escape(info['module'].__name__),
|
||||
itemname=formatter.escape(info['name']),
|
||||
)
|
||||
|
||||
formatter.write_statistic(stat)
|
||||
formatter.end_module_section()
|
||||
|
||||
|
||||
def write_listings(package_name, filename, output_format):
|
||||
"""Write listing information into a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
package_name : str
|
||||
Name of the package to inspect.
|
||||
filename : str
|
||||
Output filename. Always overwrite.
|
||||
output_format : str
|
||||
Support formats are "html" and "rst".
|
||||
"""
|
||||
package = __import__(package_name)
|
||||
if hasattr(package, '__path__'):
|
||||
mods = list_modules_in_package(package)
|
||||
else:
|
||||
mods = [package]
|
||||
|
||||
if output_format == 'html':
|
||||
with open(filename + '.html', 'w') as fout:
|
||||
fmtr = HTMLFormatter(fileobj=fout)
|
||||
_format_module_infos(fmtr, package_name, mods)
|
||||
elif output_format == 'rst':
|
||||
with open(filename + '.rst', 'w') as fout:
|
||||
fmtr = ReSTFormatter(fileobj=fout)
|
||||
_format_module_infos(fmtr, package_name, mods)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Output format '{}' is not supported".format(output_format))
|
||||
|
||||
|
||||
program_description = """
|
||||
Inspect Numba support for a given top-level package.
|
||||
""".strip()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=program_description)
|
||||
parser.add_argument(
|
||||
'package', metavar='package', type=str,
|
||||
help='Package to inspect',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--format', dest='format', default='html',
|
||||
help='Output format; i.e. "html", "rst"',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--file', dest='file', default='inspector_output',
|
||||
help='Output filename. Defaults to "inspector_output.<format>"',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
package_name = args.package
|
||||
output_format = args.format
|
||||
filename = args.file
|
||||
write_listings(package_name, filename, output_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Collection of miscellaneous initialization utilities."""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
version_info = namedtuple('version_info',
|
||||
('major minor patch short full '
|
||||
'string tuple git_revision'))
|
||||
|
||||
|
||||
def generate_version_info(version):
|
||||
"""Process a version string into a structured version_info object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version: str
|
||||
a string describing the current version
|
||||
|
||||
Returns
|
||||
-------
|
||||
version_info: tuple
|
||||
structured version information
|
||||
|
||||
See also
|
||||
--------
|
||||
Look at the definition of 'version_info' in this module for details.
|
||||
|
||||
"""
|
||||
parts = version.split('.')
|
||||
|
||||
def try_int(x):
|
||||
try:
|
||||
return int(x)
|
||||
except ValueError:
|
||||
return None
|
||||
major = try_int(parts[0]) if len(parts) >= 1 else None
|
||||
minor = try_int(parts[1]) if len(parts) >= 2 else None
|
||||
patch = try_int(parts[2]) if len(parts) >= 3 else None
|
||||
short = (major, minor)
|
||||
full = (major, minor, patch)
|
||||
string = version
|
||||
tup = tuple(parts)
|
||||
git_revision = tup[3] if len(tup) >= 4 else None
|
||||
return version_info(major, minor, patch, short, full, string, tup,
|
||||
git_revision)
|
||||
@@ -0,0 +1,103 @@
|
||||
"""Miscellaneous inspection tools
|
||||
"""
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from numba.core.errors import NumbaWarning
|
||||
|
||||
|
||||
def disassemble_elf_to_cfg(elf, mangled_symbol):
|
||||
"""
|
||||
Gets the CFG of the disassembly of an ELF object, elf, at mangled name,
|
||||
mangled_symbol, and renders it appropriately depending on the execution
|
||||
environment (terminal/notebook).
|
||||
"""
|
||||
try:
|
||||
import r2pipe
|
||||
except ImportError:
|
||||
raise RuntimeError("r2pipe package needed for disasm CFG")
|
||||
|
||||
def get_rendering(cmd=None):
|
||||
from numba.pycc.platform import Toolchain # import local, circular ref
|
||||
if cmd is None:
|
||||
raise ValueError("No command given")
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Write ELF as a temporary file in the temporary dir, do not delete!
|
||||
with NamedTemporaryFile(delete=False, dir=tmpdir) as f:
|
||||
f.write(elf)
|
||||
f.flush() # force write, radare2 needs a binary blob on disk
|
||||
|
||||
# Now try and link the ELF, this helps radare2 _a lot_
|
||||
linked = False
|
||||
try:
|
||||
raw_dso_name = f'{os.path.basename(f.name)}.so'
|
||||
linked_dso = os.path.join(tmpdir, raw_dso_name)
|
||||
tc = Toolchain()
|
||||
tc.link_shared(linked_dso, (f.name,))
|
||||
obj_to_analyse = linked_dso
|
||||
linked = True
|
||||
except Exception as e:
|
||||
# link failed, mention it to user, radare2 will still be able to
|
||||
# analyse the object, but things like dwarf won't appear in the
|
||||
# asm as comments.
|
||||
msg = ('Linking the ELF object with the distutils toolchain '
|
||||
f'failed with: {e}. Disassembly will still work but '
|
||||
'might be less accurate and will not use DWARF '
|
||||
'information.')
|
||||
warnings.warn(NumbaWarning(msg))
|
||||
obj_to_analyse = f.name
|
||||
|
||||
# catch if r2pipe can actually talk to radare2
|
||||
try:
|
||||
flags = ['-2', # close stderr to hide warnings
|
||||
'-e io.cache=true', # fix relocations in disassembly
|
||||
'-e scr.color=1', # 16 bit ANSI colour terminal
|
||||
'-e asm.dwarf=true', # DWARF decode
|
||||
'-e scr.utf8=true', # UTF8 output looks better
|
||||
]
|
||||
r = r2pipe.open(obj_to_analyse, flags=flags)
|
||||
r.cmd('aaaaaa') # analyse as much as possible
|
||||
# If the elf is linked then it's necessary to seek as the
|
||||
# DSO ctor/dtor is at the default position
|
||||
if linked:
|
||||
# r2 only matches up to 61 chars?! found this by experiment!
|
||||
mangled_symbol_61char = mangled_symbol[:61]
|
||||
# switch off demangle, the seek is on a mangled symbol
|
||||
r.cmd('e bin.demangle=false')
|
||||
# seek to the mangled symbol address
|
||||
r.cmd(f's `is~ {mangled_symbol_61char}[1]`')
|
||||
# switch demangling back on for output purposes
|
||||
r.cmd('e bin.demangle=true')
|
||||
data = r.cmd('%s' % cmd) # print graph
|
||||
r.quit()
|
||||
except Exception as e:
|
||||
if "radare2 in PATH" in str(e):
|
||||
msg = ("This feature requires 'radare2' to be "
|
||||
"installed and available on the system see: "
|
||||
"https://github.com/radareorg/radare2. "
|
||||
"Cannot find 'radare2' in $PATH.")
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
raise e
|
||||
return data
|
||||
|
||||
class DisasmCFG(object):
|
||||
|
||||
def _repr_svg_(self):
|
||||
try:
|
||||
import graphviz
|
||||
except ImportError:
|
||||
raise RuntimeError("graphviz package needed for disasm CFG")
|
||||
jupyter_rendering = get_rendering(cmd='agfd')
|
||||
# this just makes it read slightly better in jupyter notebooks
|
||||
jupyter_rendering.replace('fontname="Courier",',
|
||||
'fontname="Courier",fontsize=6,')
|
||||
src = graphviz.Source(jupyter_rendering)
|
||||
return src.pipe('svg').decode('UTF-8')
|
||||
|
||||
def __repr__(self):
|
||||
return get_rendering(cmd='agf')
|
||||
|
||||
return DisasmCFG()
|
||||
@@ -0,0 +1,24 @@
|
||||
from numba.core.extending import overload
|
||||
from numba.core import types
|
||||
from numba.misc.special import literally, literal_unroll
|
||||
from numba.core.errors import TypingError
|
||||
|
||||
|
||||
@overload(literally)
|
||||
def _ov_literally(obj):
|
||||
if isinstance(obj, (types.Literal, types.InitialValue)):
|
||||
return lambda obj: obj
|
||||
else:
|
||||
m = "Invalid use of non-Literal type in literally({})".format(obj)
|
||||
raise TypingError(m)
|
||||
|
||||
|
||||
@overload(literal_unroll)
|
||||
def literal_unroll_impl(container):
|
||||
if isinstance(container, types.Poison):
|
||||
m = f"Invalid use of non-Literal type in literal_unroll({container})"
|
||||
raise TypingError(m)
|
||||
|
||||
def impl(container):
|
||||
return container
|
||||
return impl
|
||||
@@ -0,0 +1,471 @@
|
||||
import re
|
||||
import operator
|
||||
import heapq
|
||||
from collections import namedtuple
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from functools import cached_property
|
||||
|
||||
from numba.core import config
|
||||
|
||||
import llvmlite.binding as llvm
|
||||
|
||||
|
||||
class RecordLLVMPassTimingsLegacy:
|
||||
"""A helper context manager to track LLVM pass timings.
|
||||
"""
|
||||
|
||||
__slots__ = ["_data"]
|
||||
|
||||
def __enter__(self):
|
||||
"""Enables the pass timing in LLVM.
|
||||
"""
|
||||
llvm.set_time_passes(True)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_val, exc_type, exc_tb):
|
||||
"""Reset timings and save report internally.
|
||||
"""
|
||||
self._data = llvm.report_and_reset_timings()
|
||||
llvm.set_time_passes(False)
|
||||
return
|
||||
|
||||
def get(self):
|
||||
"""Retrieve timing data for processing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
timings: ProcessedPassTimings
|
||||
"""
|
||||
return ProcessedPassTimings(self._data)
|
||||
|
||||
|
||||
class RecordLLVMPassTimings:
|
||||
"""A helper context manager to track LLVM pass timings.
|
||||
"""
|
||||
|
||||
__slots__ = ["_data", "_pb"]
|
||||
|
||||
def __init__(self, pb):
|
||||
self._pb = pb
|
||||
self._data = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Enables the pass timing in LLVM.
|
||||
"""
|
||||
self._pb.start_pass_timing()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_val, exc_type, exc_tb):
|
||||
"""Reset timings and save report internally.
|
||||
"""
|
||||
self._data = self._pb.finish_pass_timing()
|
||||
return
|
||||
|
||||
def get(self):
|
||||
"""Retrieve timing data for processing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
timings: ProcessedPassTimings
|
||||
"""
|
||||
return ProcessedPassTimings(self._data)
|
||||
|
||||
|
||||
PassTimingRecord = namedtuple(
|
||||
"PassTimingRecord",
|
||||
[
|
||||
"user_time",
|
||||
"user_percent",
|
||||
"system_time",
|
||||
"system_percent",
|
||||
"user_system_time",
|
||||
"user_system_percent",
|
||||
"wall_time",
|
||||
"wall_percent",
|
||||
"pass_name",
|
||||
"instruction",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _adjust_timings(records):
|
||||
"""Adjust timing records because of truncated information.
|
||||
|
||||
Details: The percent information can be used to improve the timing
|
||||
information.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: List[PassTimingRecord]
|
||||
"""
|
||||
total_rec = records[-1]
|
||||
assert total_rec.pass_name == "Total" # guard for implementation error
|
||||
|
||||
def make_adjuster(attr):
|
||||
time_attr = f"{attr}_time"
|
||||
percent_attr = f"{attr}_percent"
|
||||
time_getter = operator.attrgetter(time_attr)
|
||||
|
||||
def adjust(d):
|
||||
"""Compute percent x total_time = adjusted"""
|
||||
total = time_getter(total_rec)
|
||||
adjusted = total * d[percent_attr] * 0.01
|
||||
d[time_attr] = adjusted
|
||||
return d
|
||||
|
||||
return adjust
|
||||
|
||||
# Make adjustment functions for each field
|
||||
adj_fns = [
|
||||
make_adjuster(x) for x in ["user", "system", "user_system", "wall"]
|
||||
]
|
||||
|
||||
# Extract dictionaries from the namedtuples
|
||||
dicts = map(lambda x: x._asdict(), records)
|
||||
|
||||
def chained(d):
|
||||
# Chain the adjustment functions
|
||||
for fn in adj_fns:
|
||||
d = fn(d)
|
||||
# Reconstruct the namedtuple
|
||||
return PassTimingRecord(**d)
|
||||
|
||||
return list(map(chained, dicts))
|
||||
|
||||
|
||||
class ProcessedPassTimings:
|
||||
"""A class for processing raw timing report from LLVM.
|
||||
|
||||
The processing is done lazily so we don't waste time processing unused
|
||||
timing information.
|
||||
"""
|
||||
|
||||
def __init__(self, raw_data):
|
||||
self._raw_data = raw_data
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self._raw_data)
|
||||
|
||||
def get_raw_data(self):
|
||||
"""Returns the raw string data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: str
|
||||
"""
|
||||
return self._raw_data
|
||||
|
||||
def get_total_time(self):
|
||||
"""Compute the total time spend in all passes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: float
|
||||
"""
|
||||
return self.list_records()[-1].wall_time
|
||||
|
||||
def list_records(self):
|
||||
"""Get the processed data for the timing report.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: List[PassTimingRecord]
|
||||
"""
|
||||
return self._processed
|
||||
|
||||
def list_top(self, n):
|
||||
"""Returns the top(n) most time-consuming (by wall-time) passes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n: int
|
||||
This limits the maximum number of items to show.
|
||||
This function will show the ``n`` most time-consuming passes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: List[PassTimingRecord]
|
||||
Returns the top(n) most time-consuming passes in descending order.
|
||||
"""
|
||||
records = self.list_records()
|
||||
key = operator.attrgetter("wall_time")
|
||||
return heapq.nlargest(n, records[:-1], key)
|
||||
|
||||
def summary(self, topn=5, indent=0):
|
||||
"""Return a string summarizing the timing information.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
topn: int; optional
|
||||
This limits the maximum number of items to show.
|
||||
This function will show the ``topn`` most time-consuming passes.
|
||||
indent: int; optional
|
||||
Set the indentation level. Defaults to 0 for no indentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: str
|
||||
"""
|
||||
buf = []
|
||||
prefix = " " * indent
|
||||
|
||||
def ap(arg):
|
||||
buf.append(f"{prefix}{arg}")
|
||||
|
||||
ap(f"Total {self.get_total_time():.4f}s")
|
||||
ap("Top timings:")
|
||||
for p in self.list_top(topn):
|
||||
ap(f" {p.wall_time:.4f}s ({p.wall_percent:5}%) {p.pass_name}")
|
||||
return "\n".join(buf)
|
||||
|
||||
@cached_property
|
||||
def _processed(self):
|
||||
"""A cached property for lazily processing the data and returning it.
|
||||
|
||||
See ``_process()`` for details.
|
||||
"""
|
||||
return self._process()
|
||||
|
||||
def _process(self):
|
||||
"""Parses the raw string data from LLVM timing report and attempts
|
||||
to improve the data by recomputing the times
|
||||
(See `_adjust_timings()``).
|
||||
"""
|
||||
|
||||
def parse(raw_data):
|
||||
"""A generator that parses the raw_data line-by-line to extract
|
||||
timing information for each pass.
|
||||
"""
|
||||
lines = raw_data.splitlines()
|
||||
colheader = r"[a-zA-Z+ ]+"
|
||||
# Take at least one column header.
|
||||
multicolheaders = fr"(?:\s*-+{colheader}-+)+"
|
||||
|
||||
line_iter = iter(lines)
|
||||
# find column headers
|
||||
header_map = {
|
||||
"User Time": "user",
|
||||
"System Time": "system",
|
||||
"User+System": "user_system",
|
||||
"Wall Time": "wall",
|
||||
"Instr": "instruction",
|
||||
"Name": "pass_name",
|
||||
}
|
||||
for ln in line_iter:
|
||||
m = re.match(multicolheaders, ln)
|
||||
if m:
|
||||
# Get all the column headers
|
||||
raw_headers = re.findall(r"[a-zA-Z][a-zA-Z+ ]+", ln)
|
||||
headers = [header_map[k.strip()] for k in raw_headers]
|
||||
break
|
||||
|
||||
assert headers[-1] == 'pass_name'
|
||||
# compute the list of available attributes from the column headers
|
||||
attrs = []
|
||||
n = r"\s*((?:[0-9]+\.)?[0-9]+)"
|
||||
pat = ""
|
||||
for k in headers[:-1]:
|
||||
if k == "instruction":
|
||||
pat += n
|
||||
else:
|
||||
attrs.append(f"{k}_time")
|
||||
attrs.append(f"{k}_percent")
|
||||
pat += rf"\s+(?:{n}\s*\({n}%\)|-+)"
|
||||
|
||||
# put default value 0.0 to all missing attributes
|
||||
missing = {}
|
||||
for k in PassTimingRecord._fields:
|
||||
if k not in attrs and k != 'pass_name':
|
||||
missing[k] = 0.0
|
||||
# parse timings
|
||||
pat += r"\s*(.*)"
|
||||
for ln in line_iter:
|
||||
m = re.match(pat, ln)
|
||||
if m is not None:
|
||||
raw_data = list(m.groups())
|
||||
data = {k: float(v) if v is not None else 0.0
|
||||
for k, v in zip(attrs, raw_data)}
|
||||
data.update(missing)
|
||||
pass_name = raw_data[-1]
|
||||
rec = PassTimingRecord(
|
||||
pass_name=pass_name, **data,
|
||||
)
|
||||
yield rec
|
||||
if rec.pass_name == "Total":
|
||||
# "Total" means the report has ended
|
||||
break
|
||||
# Check that we have reach the end of the report
|
||||
remaining = '\n'.join(line_iter)
|
||||
|
||||
# FIXME: Need to handle parsing of Analysis execution timing report
|
||||
if "Analysis execution timing report" in remaining:
|
||||
return
|
||||
|
||||
if remaining:
|
||||
raise ValueError(
|
||||
f"unexpected text after parser finished:\n{remaining}"
|
||||
)
|
||||
|
||||
# Parse raw data
|
||||
records = list(parse(self._raw_data))
|
||||
return _adjust_timings(records)
|
||||
|
||||
|
||||
NamedTimings = namedtuple("NamedTimings", ["name", "timings"])
|
||||
|
||||
|
||||
class PassTimingsCollection(Sequence):
|
||||
"""A collection of pass timings.
|
||||
|
||||
This class implements the ``Sequence`` protocol for accessing the
|
||||
individual timing records.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
self._records = []
|
||||
|
||||
@contextmanager
|
||||
def record_legacy(self, name):
|
||||
"""Record new timings and append to this collection.
|
||||
|
||||
Note: this is mainly for internal use inside the compiler pipeline.
|
||||
|
||||
See also ``RecordLLVMPassTimingsLegacy``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Name for the records.
|
||||
"""
|
||||
if config.LLVM_PASS_TIMINGS:
|
||||
# Recording of pass timings is enabled
|
||||
with RecordLLVMPassTimingsLegacy() as timings:
|
||||
yield
|
||||
rec = timings.get()
|
||||
# Only keep non-empty records
|
||||
if rec:
|
||||
self._append(name, rec)
|
||||
else:
|
||||
# Do nothing. Recording of pass timings is disabled.
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def record(self, name, pb):
|
||||
"""Record new timings and append to this collection.
|
||||
|
||||
Note: this is mainly for internal use inside the compiler pipeline.
|
||||
|
||||
See also ``RecordLLVMPassTimings``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Name for the records.
|
||||
"""
|
||||
if config.LLVM_PASS_TIMINGS:
|
||||
# Recording of pass timings is enabled
|
||||
with RecordLLVMPassTimings(pb) as timings:
|
||||
yield
|
||||
rec = timings.get()
|
||||
# Only keep non-empty records
|
||||
if rec:
|
||||
self._append(name, rec)
|
||||
else:
|
||||
# Do nothing. Recording of pass timings is disabled.
|
||||
yield
|
||||
|
||||
def _append(self, name, timings):
|
||||
"""Append timing records
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
Name for the records.
|
||||
timings: ProcessedPassTimings
|
||||
the timing records.
|
||||
"""
|
||||
self._records.append(NamedTimings(name, timings))
|
||||
|
||||
def get_total_time(self):
|
||||
"""Computes the sum of the total time across all contained timings.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: float or None
|
||||
Returns the total number of seconds or None if no timings were
|
||||
recorded
|
||||
"""
|
||||
if self._records:
|
||||
return sum(r.timings.get_total_time() for r in self._records)
|
||||
else:
|
||||
return None
|
||||
|
||||
def list_longest_first(self):
|
||||
"""Returns the timings in descending order of total time duration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: List[ProcessedPassTimings]
|
||||
"""
|
||||
return sorted(self._records,
|
||||
key=lambda x: x.timings.get_total_time(),
|
||||
reverse=True)
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
"""
|
||||
"""
|
||||
return not self._records
|
||||
|
||||
def summary(self, topn=5):
|
||||
"""Return a string representing the summary of the timings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
topn: int; optional, default=5.
|
||||
This limits the maximum number of items to show.
|
||||
This function will show the ``topn`` most time-consuming passes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: str
|
||||
|
||||
See also ``ProcessedPassTimings.summary()``
|
||||
"""
|
||||
if self.is_empty:
|
||||
return "No pass timings were recorded"
|
||||
else:
|
||||
buf = []
|
||||
ap = buf.append
|
||||
ap(f"Printing pass timings for {self._name}")
|
||||
overall_time = self.get_total_time()
|
||||
ap(f"Total time: {overall_time:.4f}")
|
||||
for i, r in enumerate(self._records):
|
||||
ap(f"== #{i} {r.name}")
|
||||
percent = r.timings.get_total_time() / overall_time * 100
|
||||
ap(f" Percent: {percent:.1f}%")
|
||||
ap(r.timings.summary(topn=topn, indent=1))
|
||||
return "\n".join(buf)
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Get the i-th timing record.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res: (name, timings)
|
||||
A named tuple with two fields:
|
||||
|
||||
- name: str
|
||||
- timings: ProcessedPassTimings
|
||||
"""
|
||||
return self._records[i]
|
||||
|
||||
def __len__(self):
|
||||
"""Length of this collection.
|
||||
"""
|
||||
return len(self._records)
|
||||
|
||||
def __str__(self):
|
||||
return self.summary()
|
||||
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Memory monitoring utilities for measuring memory usage.
|
||||
|
||||
Example usage:
|
||||
tracker = MemoryTracker("my_function")
|
||||
with tracker.monitor():
|
||||
my_function()
|
||||
# Access data: tracker.rss_delta, tracker.duration, etc.
|
||||
# Get formatted string: tracker.get_summary()
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import contextlib
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
try:
|
||||
import psutil
|
||||
_HAS_PSUTIL = True
|
||||
except ImportError:
|
||||
_HAS_PSUTIL = False
|
||||
|
||||
|
||||
IS_SUPPORTED = _HAS_PSUTIL
|
||||
|
||||
|
||||
def get_available_memory() -> Optional[int]:
|
||||
"""
|
||||
Get current available system memory in bytes.
|
||||
|
||||
Used for memory threshold checking in parallel test execution.
|
||||
|
||||
Returns:
|
||||
int or None: Available memory in bytes, or None if unavailable
|
||||
"""
|
||||
if _HAS_PSUTIL:
|
||||
try:
|
||||
sys_mem = psutil.virtual_memory()
|
||||
return sys_mem.available
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_memory_usage() -> Dict[str, Optional[int]]:
|
||||
"""
|
||||
Get memory usage information needed for monitoring.
|
||||
|
||||
Returns only RSS and available memory which are the fields
|
||||
actually used by the MemoryTracker.
|
||||
|
||||
Returns:
|
||||
dict: Memory usage information including:
|
||||
- rss: Current process RSS (physical memory currently used)
|
||||
- available: Available system memory
|
||||
"""
|
||||
memory_info = {}
|
||||
|
||||
if _HAS_PSUTIL:
|
||||
try:
|
||||
# Get current process RSS
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
memory_info["rss"] = mem_info.rss
|
||||
|
||||
# Get system available memory
|
||||
sys_mem = psutil.virtual_memory()
|
||||
memory_info["available"] = sys_mem.available
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
# Set defaults if unavailable
|
||||
if "rss" not in memory_info:
|
||||
memory_info["rss"] = None
|
||||
if "available" not in memory_info:
|
||||
memory_info["available"] = None
|
||||
|
||||
return memory_info
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
"""
|
||||
A simple memory monitor that tracks RSS delta and timing.
|
||||
|
||||
Stores monitoring data in instance attributes for later access.
|
||||
Each instance is typically used for monitoring a single operation.
|
||||
"""
|
||||
pid: int
|
||||
name: str
|
||||
start_time: float | None
|
||||
end_time: float | None
|
||||
start_memory: Dict[str, int | None] | None
|
||||
end_memory: Dict[str, int | None] | None
|
||||
duration: float | None
|
||||
rss_delta: int | None
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""Initialize a MemoryTracker with empty monitoring data."""
|
||||
self.pid = os.getpid()
|
||||
self.name = name
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.start_memory = None
|
||||
self.end_memory = None
|
||||
self.duration = None
|
||||
self.rss_delta = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def monitor(self):
|
||||
"""
|
||||
Context manager to monitor memory usage during function execution.
|
||||
|
||||
Records start/end memory usage and timing, calculates RSS delta,
|
||||
and stores all data in instance attributes.
|
||||
|
||||
Args:
|
||||
name (str): Name/identifier for the function or operation being
|
||||
monitored
|
||||
|
||||
Yields:
|
||||
self: The MemoryTracker instance for accessing stored data
|
||||
"""
|
||||
# Store data in self and record start time and memory usage
|
||||
self.start_time = time.time()
|
||||
self.start_memory = get_memory_usage()
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# Record end time and memory usage
|
||||
self.end_time = time.time()
|
||||
self.end_memory = get_memory_usage()
|
||||
self.duration = self.end_time - self.start_time
|
||||
|
||||
# Calculate RSS delta
|
||||
start_rss = self.start_memory.get("rss", 0)
|
||||
end_rss = self.end_memory.get("rss", 0)
|
||||
self.rss_delta = ((end_rss - start_rss)
|
||||
if start_rss and end_rss else 0)
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""
|
||||
Return a formatted summary of the memory monitoring data.
|
||||
|
||||
Formats the stored monitoring data into a human-readable string
|
||||
containing name, PID, RSS delta, available memory, duration,
|
||||
and start time.
|
||||
|
||||
Returns:
|
||||
str: Formatted summary string with monitoring results
|
||||
|
||||
Note:
|
||||
Should be called after monitor() context has completed
|
||||
to ensure all data is available.
|
||||
"""
|
||||
if self.start_memory is None or self.end_memory is None:
|
||||
raise ValueError("Memory monitoring data not available")
|
||||
|
||||
current_available = self.end_memory.get("available")
|
||||
|
||||
def format_bytes(bytes_val, show_sign=False):
|
||||
"""Convert bytes to human readable format"""
|
||||
if bytes_val is None:
|
||||
return "N/A"
|
||||
if bytes_val == 0:
|
||||
return "0 B"
|
||||
|
||||
sign = ""
|
||||
if show_sign:
|
||||
sign = "-" if bytes_val < 0 else "+"
|
||||
bytes_val = abs(bytes_val)
|
||||
|
||||
for unit in ["B", "KB", "MB", "GB"]:
|
||||
if bytes_val < 1024.0:
|
||||
return f"{sign}{bytes_val:.2f} {unit}"
|
||||
bytes_val /= 1024.0
|
||||
return f"{sign}{bytes_val:.2f} TB"
|
||||
|
||||
start_ts = time.strftime("%H:%M:%S", time.localtime(self.start_time))
|
||||
start_rss = self.start_memory.get("rss", 0)
|
||||
end_rss = self.end_memory.get("rss", 0)
|
||||
|
||||
buf = [
|
||||
f"Name: {self.name}",
|
||||
f"PID: {self.pid}",
|
||||
f"Start: {start_ts}",
|
||||
f"Duration: {self.duration:.3f}s",
|
||||
f"Start RSS: {format_bytes(start_rss)}",
|
||||
f"End RSS: {format_bytes(end_rss)}",
|
||||
f"RSS delta: {format_bytes(self.rss_delta, show_sign=True)}",
|
||||
f"Avail memory: {format_bytes(current_available)}",
|
||||
]
|
||||
return ' | '.join(buf)
|
||||
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
The same algorithm as translated from numpy.
|
||||
See numpy/core/src/npysort/mergesort.c.src.
|
||||
The high-level numba code is adding a little overhead comparing to
|
||||
the pure-C implementation in numpy.
|
||||
"""
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
|
||||
# Array size smaller than this will be sorted by insertion sort
|
||||
SMALL_MERGESORT = 20
|
||||
|
||||
|
||||
MergesortImplementation = namedtuple('MergesortImplementation', [
|
||||
'run_mergesort',
|
||||
])
|
||||
|
||||
|
||||
def make_mergesort_impl(wrap, lt=None, is_argsort=False):
|
||||
kwargs_lite = dict(no_cpython_wrapper=True, _nrt=False)
|
||||
|
||||
# The less than
|
||||
if lt is None:
|
||||
@wrap(**kwargs_lite)
|
||||
def lt(a, b):
|
||||
return a < b
|
||||
else:
|
||||
lt = wrap(**kwargs_lite)(lt)
|
||||
|
||||
if is_argsort:
|
||||
@wrap(**kwargs_lite)
|
||||
def lessthan(a, b, vals):
|
||||
return lt(vals[a], vals[b])
|
||||
else:
|
||||
@wrap(**kwargs_lite)
|
||||
def lessthan(a, b, vals):
|
||||
return lt(a, b)
|
||||
|
||||
@wrap(**kwargs_lite)
|
||||
def argmergesort_inner(arr, vals, ws):
|
||||
"""The actual mergesort function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arr : array [read+write]
|
||||
The values being sorted inplace. For argsort, this is the
|
||||
indices.
|
||||
vals : array [readonly]
|
||||
``None`` for normal sort. In argsort, this is the actual array values.
|
||||
ws : array [write]
|
||||
The workspace. Must be of size ``arr.size // 2``
|
||||
"""
|
||||
if arr.size > SMALL_MERGESORT:
|
||||
# Merge sort
|
||||
mid = arr.size // 2
|
||||
|
||||
argmergesort_inner(arr[:mid], vals, ws)
|
||||
argmergesort_inner(arr[mid:], vals, ws)
|
||||
|
||||
# Copy left half into workspace so we don't overwrite it
|
||||
for i in range(mid):
|
||||
ws[i] = arr[i]
|
||||
|
||||
# Merge
|
||||
left = ws[:mid]
|
||||
right = arr[mid:]
|
||||
out = arr
|
||||
|
||||
i = j = k = 0
|
||||
while i < left.size and j < right.size:
|
||||
if not lessthan(right[j], left[i], vals):
|
||||
out[k] = left[i]
|
||||
i += 1
|
||||
else:
|
||||
out[k] = right[j]
|
||||
j += 1
|
||||
k += 1
|
||||
|
||||
# Leftovers
|
||||
while i < left.size:
|
||||
out[k] = left[i]
|
||||
i += 1
|
||||
k += 1
|
||||
|
||||
while j < right.size:
|
||||
out[k] = right[j]
|
||||
j += 1
|
||||
k += 1
|
||||
else:
|
||||
# Insertion sort
|
||||
i = 1
|
||||
while i < arr.size:
|
||||
j = i
|
||||
while j > 0 and lessthan(arr[j], arr[j - 1], vals):
|
||||
arr[j - 1], arr[j] = arr[j], arr[j - 1]
|
||||
j -= 1
|
||||
i += 1
|
||||
|
||||
# The top-level entry points
|
||||
|
||||
@wrap(no_cpython_wrapper=True)
|
||||
def mergesort(arr):
|
||||
"Inplace"
|
||||
ws = np.empty(arr.size // 2, dtype=arr.dtype)
|
||||
argmergesort_inner(arr, None, ws)
|
||||
return arr
|
||||
|
||||
|
||||
@wrap(no_cpython_wrapper=True)
|
||||
def argmergesort(arr):
|
||||
"Out-of-place"
|
||||
idxs = np.arange(arr.size)
|
||||
ws = np.empty(arr.size // 2, dtype=idxs.dtype)
|
||||
argmergesort_inner(idxs, arr, ws)
|
||||
return idxs
|
||||
|
||||
return MergesortImplementation(
|
||||
run_mergesort=(argmergesort if is_argsort else mergesort)
|
||||
)
|
||||
|
||||
|
||||
def make_jit_mergesort(*args, **kwargs):
|
||||
from numba import njit
|
||||
# NOTE: wrap with njit to allow recursion
|
||||
# because @register_jitable => @overload doesn't support recursion
|
||||
return make_mergesort_impl(njit, *args, **kwargs)
|
||||
@@ -0,0 +1,72 @@
|
||||
import sys
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
from .numba_sysinfo import display_sysinfo, get_sysinfo
|
||||
from .numba_gdbinfo import display_gdbinfo
|
||||
|
||||
|
||||
def make_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--annotate', help='Annotate source',
|
||||
action='store_true')
|
||||
parser.add_argument('--dump-llvm', action="store_true",
|
||||
help='Print generated llvm assembly')
|
||||
parser.add_argument('--dump-optimized', action='store_true',
|
||||
help='Dump the optimized llvm assembly')
|
||||
parser.add_argument('--dump-assembly', action='store_true',
|
||||
help='Dump the LLVM generated assembly')
|
||||
parser.add_argument('--annotate-html', nargs=1,
|
||||
help='Output source annotation as html')
|
||||
parser.add_argument('-s', '--sysinfo', action="store_true",
|
||||
help='Output system information for bug reporting')
|
||||
parser.add_argument('-g', '--gdbinfo', action="store_true",
|
||||
help='Output system information about gdb')
|
||||
parser.add_argument('--sys-json', nargs=1,
|
||||
help='Saves the system info dict as a json file')
|
||||
parser.add_argument('filename', nargs='?', help='Python source filename')
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = make_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sysinfo:
|
||||
print("System info:")
|
||||
display_sysinfo()
|
||||
|
||||
if args.gdbinfo:
|
||||
print("GDB info:")
|
||||
display_gdbinfo()
|
||||
|
||||
if args.sysinfo or args.gdbinfo:
|
||||
sys.exit(0)
|
||||
|
||||
if args.sys_json:
|
||||
info = get_sysinfo()
|
||||
info.update({'Start': info['Start'].isoformat()})
|
||||
info.update({'Start UTC': info['Start UTC'].isoformat()})
|
||||
with open(args.sys_json[0], 'w') as f:
|
||||
json.dump(info, f, indent=4)
|
||||
sys.exit(0)
|
||||
|
||||
os.environ['NUMBA_DUMP_ANNOTATION'] = str(int(args.annotate))
|
||||
if args.annotate_html is not None:
|
||||
try:
|
||||
from jinja2 import Template
|
||||
except ImportError:
|
||||
raise ImportError("Please install the 'jinja2' package")
|
||||
os.environ['NUMBA_DUMP_HTML'] = str(args.annotate_html[0])
|
||||
os.environ['NUMBA_DUMP_LLVM'] = str(int(args.dump_llvm))
|
||||
os.environ['NUMBA_DUMP_OPTIMIZED'] = str(int(args.dump_optimized))
|
||||
os.environ['NUMBA_DUMP_ASSEMBLY'] = str(int(args.dump_assembly))
|
||||
|
||||
if args.filename:
|
||||
cmd = [sys.executable, args.filename]
|
||||
subprocess.call(cmd)
|
||||
else:
|
||||
print("numba: error: the following arguments are required: filename")
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Module for displaying information about Numba's gdb set up"""
|
||||
from collections import namedtuple
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from textwrap import dedent
|
||||
from numba import config
|
||||
|
||||
# Container for the output of the gdb info data collection
|
||||
_fields = ('binary_loc, extension_loc, py_ver, np_ver, supported')
|
||||
_gdb_info = namedtuple('_gdb_info', _fields)
|
||||
|
||||
|
||||
class _GDBTestWrapper():
|
||||
"""Wraps the gdb binary and has methods for checking what the gdb binary
|
||||
has support for (Python and NumPy)."""
|
||||
|
||||
def __init__(self,):
|
||||
gdb_binary = config.GDB_BINARY
|
||||
if gdb_binary is None:
|
||||
msg = ("No valid binary could be found for gdb named: "
|
||||
f"{config.GDB_BINARY}")
|
||||
raise ValueError(msg)
|
||||
self._gdb_binary = gdb_binary
|
||||
|
||||
def _run_cmd(self, cmd=()):
|
||||
gdb_call = [self.gdb_binary, '-q',]
|
||||
for x in cmd:
|
||||
gdb_call.append('-ex')
|
||||
gdb_call.append(x)
|
||||
gdb_call.extend(['-ex', 'q'])
|
||||
return subprocess.run(gdb_call, capture_output=True, timeout=10,
|
||||
text=True)
|
||||
|
||||
@property
|
||||
def gdb_binary(self):
|
||||
return self._gdb_binary
|
||||
|
||||
@classmethod
|
||||
def success(cls, status):
|
||||
return status.returncode == 0
|
||||
|
||||
def check_launch(self):
|
||||
"""Checks that gdb will launch ok"""
|
||||
return self._run_cmd()
|
||||
|
||||
def check_python(self):
|
||||
cmd = ("python from __future__ import print_function; "
|
||||
"import sys; print(sys.version_info[:2])")
|
||||
return self._run_cmd((cmd,))
|
||||
|
||||
def check_numpy(self):
|
||||
cmd = ("python from __future__ import print_function; "
|
||||
"import types; import numpy; "
|
||||
"print(isinstance(numpy, types.ModuleType))")
|
||||
return self._run_cmd((cmd,))
|
||||
|
||||
def check_numpy_version(self):
|
||||
cmd = ("python from __future__ import print_function; "
|
||||
"import types; import numpy;"
|
||||
"print(numpy.__version__)")
|
||||
return self._run_cmd((cmd,))
|
||||
|
||||
|
||||
def collect_gdbinfo():
|
||||
"""Prints information to stdout about the gdb setup that Numba has found"""
|
||||
|
||||
# State flags:
|
||||
gdb_state = None
|
||||
gdb_has_python = False
|
||||
gdb_has_numpy = False
|
||||
gdb_python_version = 'No Python support'
|
||||
gdb_python_numpy_version = "No NumPy support"
|
||||
|
||||
# There are so many ways for gdb to not be working as expected. Surround
|
||||
# the "is it working" tests with try/except and if there's an exception
|
||||
# store it for processing later.
|
||||
try:
|
||||
# Check gdb exists
|
||||
gdb_wrapper = _GDBTestWrapper()
|
||||
|
||||
# Check gdb works
|
||||
status = gdb_wrapper.check_launch()
|
||||
if not gdb_wrapper.success(status):
|
||||
msg = (f"gdb at '{gdb_wrapper.gdb_binary}' does not appear to work."
|
||||
f"\nstdout: {status.stdout}\nstderr: {status.stderr}")
|
||||
raise ValueError(msg)
|
||||
gdb_state = gdb_wrapper.gdb_binary
|
||||
except Exception as e:
|
||||
gdb_state = f"Testing gdb binary failed. Reported Error: {e}"
|
||||
else:
|
||||
# Got this far, so gdb works, start checking what it supports
|
||||
status = gdb_wrapper.check_python()
|
||||
if gdb_wrapper.success(status):
|
||||
version_match = re.match(r'\((\d+),\s+(\d+)\)',
|
||||
status.stdout.strip())
|
||||
if version_match is not None:
|
||||
pymajor, pyminor = version_match.groups()
|
||||
gdb_python_version = f"{pymajor}.{pyminor}"
|
||||
gdb_has_python = True
|
||||
|
||||
status = gdb_wrapper.check_numpy()
|
||||
if gdb_wrapper.success(status):
|
||||
if "Traceback" not in status.stderr.strip():
|
||||
if status.stdout.strip() == 'True':
|
||||
gdb_has_numpy = True
|
||||
gdb_python_numpy_version = "Unknown"
|
||||
# NumPy is present find the version
|
||||
status = gdb_wrapper.check_numpy_version()
|
||||
if gdb_wrapper.success(status):
|
||||
if "Traceback" not in status.stderr.strip():
|
||||
gdb_python_numpy_version = \
|
||||
status.stdout.strip()
|
||||
|
||||
# Work out what level of print-extension support is present in this gdb
|
||||
if gdb_has_python:
|
||||
if gdb_has_numpy:
|
||||
print_ext_supported = "Full (Python and NumPy supported)"
|
||||
else:
|
||||
print_ext_supported = "Partial (Python only, no NumPy support)"
|
||||
else:
|
||||
print_ext_supported = "None"
|
||||
|
||||
# Work out print ext location
|
||||
print_ext_file = "gdb_print_extension.py"
|
||||
print_ext_path = os.path.join(os.path.dirname(__file__), print_ext_file)
|
||||
|
||||
# return!
|
||||
return _gdb_info(gdb_state, print_ext_path, gdb_python_version,
|
||||
gdb_python_numpy_version, print_ext_supported)
|
||||
|
||||
|
||||
def display_gdbinfo(sep_pos=45):
|
||||
"""Displays the information collected by collect_gdbinfo.
|
||||
"""
|
||||
gdb_info = collect_gdbinfo()
|
||||
print('-' * 80)
|
||||
fmt = f'%-{sep_pos}s : %-s'
|
||||
# Display the information
|
||||
print(fmt % ("Binary location", gdb_info.binary_loc))
|
||||
print(fmt % ("Print extension location", gdb_info.extension_loc))
|
||||
print(fmt % ("Python version", gdb_info.py_ver))
|
||||
print(fmt % ("NumPy version", gdb_info.np_ver))
|
||||
print(fmt % ("Numba printing extension support", gdb_info.supported))
|
||||
|
||||
print("")
|
||||
print("To load the Numba gdb printing extension, execute the following "
|
||||
"from the gdb prompt:")
|
||||
print(f"\nsource {gdb_info.extension_loc}\n")
|
||||
print('-' * 80)
|
||||
warn = """
|
||||
=============================================================
|
||||
IMPORTANT: Before sharing you should remove any information
|
||||
in the above that you wish to keep private e.g. paths.
|
||||
=============================================================
|
||||
"""
|
||||
print(dedent(warn))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
display_gdbinfo()
|
||||
@@ -0,0 +1,706 @@
|
||||
import json
|
||||
import locale
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import textwrap
|
||||
import sys
|
||||
from contextlib import redirect_stdout
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from subprocess import check_output, PIPE, CalledProcessError
|
||||
import numpy as np
|
||||
import llvmlite.binding as llvmbind
|
||||
from llvmlite import __version__ as llvmlite_version
|
||||
from numba import cuda as cu, __version__ as version_number
|
||||
from numba.cuda import cudadrv
|
||||
from numba.cuda.cudadrv.driver import driver as cudriver
|
||||
from numba.cuda.cudadrv.runtime import runtime as curuntime
|
||||
from numba.core import config
|
||||
|
||||
_psutil_import = False
|
||||
try:
|
||||
import psutil
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
_psutil_import = True
|
||||
|
||||
__all__ = ['get_sysinfo', 'display_sysinfo']
|
||||
|
||||
# Keys of a `sysinfo` dictionary
|
||||
|
||||
# Time info
|
||||
_start, _start_utc, _runtime = 'Start', 'Start UTC', 'Runtime'
|
||||
_numba_version = 'Numba Version'
|
||||
# Hardware info
|
||||
_machine = 'Machine'
|
||||
_cpu_name, _cpu_count = 'CPU Name', 'CPU Count'
|
||||
_cpus_allowed, _cpus_list = 'CPUs Allowed', 'List CPUs Allowed'
|
||||
_cpu_features = 'CPU Features'
|
||||
_cfs_quota, _cfs_period = 'CFS Quota', 'CFS Period',
|
||||
_cfs_restrict = 'CFS Restriction'
|
||||
_mem_total, _mem_available = 'Mem Total', 'Mem Available'
|
||||
# OS info
|
||||
_platform_name, _platform_release = 'Platform Name', 'Platform Release'
|
||||
_os_name, _os_version = 'OS Name', 'OS Version'
|
||||
_os_spec_version = 'OS Specific Version'
|
||||
_libc_version = 'Libc Version'
|
||||
# Python info
|
||||
_python_comp = 'Python Compiler'
|
||||
_python_impl = 'Python Implementation'
|
||||
_python_version = 'Python Version'
|
||||
_python_locale = 'Python Locale'
|
||||
# LLVM info
|
||||
_llvmlite_version = 'llvmlite Version'
|
||||
_llvm_version = 'LLVM Version'
|
||||
# CUDA info
|
||||
_cu_target_impl = 'CUDA Target Impl'
|
||||
_cu_dev_init = 'CUDA Device Init'
|
||||
_cu_drv_ver = 'CUDA Driver Version'
|
||||
_cu_rt_ver = 'CUDA Runtime Version'
|
||||
_cu_nvidia_bindings = 'NVIDIA CUDA Bindings'
|
||||
_cu_nvidia_bindings_used = 'NVIDIA CUDA Bindings In Use'
|
||||
_cu_detect_out, _cu_lib_test = 'CUDA Detect Output', 'CUDA Lib Test'
|
||||
_cu_mvc_available = 'NVIDIA CUDA Minor Version Compatibility Available'
|
||||
_cu_mvc_needed = 'NVIDIA CUDA Minor Version Compatibility Needed'
|
||||
_cu_mvc_in_use = 'NVIDIA CUDA Minor Version Compatibility In Use'
|
||||
# NumPy info
|
||||
_numpy_version = 'NumPy Version'
|
||||
_numpy_supported_simd_features = 'NumPy Supported SIMD features'
|
||||
_numpy_supported_simd_dispatch = 'NumPy Supported SIMD dispatch'
|
||||
_numpy_supported_simd_baseline = 'NumPy Supported SIMD baseline'
|
||||
_numpy_AVX512_SKX_detected = 'NumPy AVX512_SKX detected'
|
||||
# SVML info
|
||||
_svml_state, _svml_loaded = 'SVML State', 'SVML Lib Loaded'
|
||||
_llvm_svml_patched = 'LLVM SVML Patched'
|
||||
_svml_operational = 'SVML Operational'
|
||||
# Threading layer info
|
||||
_tbb_thread, _tbb_error = 'TBB Threading', 'TBB Threading Error'
|
||||
_openmp_thread, _openmp_error = 'OpenMP Threading', 'OpenMP Threading Error'
|
||||
_openmp_vendor = 'OpenMP vendor'
|
||||
_wkq_thread, _wkq_error = 'Workqueue Threading', 'Workqueue Threading Error'
|
||||
# Numba info
|
||||
_numba_env_vars = 'Numba Env Vars'
|
||||
# Conda info
|
||||
_conda_build_ver, _conda_env_ver = 'Conda Build', 'Conda Env'
|
||||
_conda_platform, _conda_python_ver = 'Conda Platform', 'Conda Python Version'
|
||||
_conda_root_writable = 'Conda Root Writable'
|
||||
# Packages info
|
||||
_inst_pkg = 'Installed Packages'
|
||||
# Psutil info
|
||||
_psutil = 'Psutil Available'
|
||||
# Errors and warnings
|
||||
_errors = 'Errors'
|
||||
_warnings = 'Warnings'
|
||||
|
||||
# Error and warning log
|
||||
_error_log = []
|
||||
_warning_log = []
|
||||
|
||||
|
||||
def get_os_spec_info(os_name):
|
||||
# Linux man page for `/proc`:
|
||||
# http://man7.org/linux/man-pages/man5/proc.5.html
|
||||
|
||||
# WMIC is deprecated in windows-2025, using powershell instead
|
||||
# https://techcommunity.microsoft.com/t5/windows-it-pro-blog/wmi-command-line-wmic-utility-deprecation-next-steps/ba-p/4039242
|
||||
# Windows documentation for `powershell`:
|
||||
# https://learn.microsoft.com/en-us/powershell/
|
||||
|
||||
# MacOS man page for `sysctl`:
|
||||
# https://www.unix.com/man-page/osx/3/sysctl/
|
||||
# MacOS man page for `vm_stat`:
|
||||
# https://www.unix.com/man-page/osx/1/vm_stat/
|
||||
|
||||
class CmdBufferOut(tuple):
|
||||
buffer_output_flag = True
|
||||
|
||||
class CmdReadFile(tuple):
|
||||
read_file_flag = True
|
||||
|
||||
shell_params = {
|
||||
'Linux': {
|
||||
'cmd': (
|
||||
CmdReadFile(('/sys/fs/cgroup/cpuacct/cpu.cfs_quota_us',)),
|
||||
CmdReadFile(('/sys/fs/cgroup/cpuacct/cpu.cfs_period_us',)),
|
||||
),
|
||||
'cmd_optional': (
|
||||
CmdReadFile(('/proc/meminfo',)),
|
||||
CmdReadFile(('/proc/self/status',)),
|
||||
),
|
||||
'kwds': {
|
||||
# output string fragment -> result dict key
|
||||
'MemTotal:': _mem_total,
|
||||
'MemAvailable:': _mem_available,
|
||||
'Cpus_allowed:': _cpus_allowed,
|
||||
'Cpus_allowed_list:': _cpus_list,
|
||||
'/sys/fs/cgroup/cpuacct/cpu.cfs_quota_us': _cfs_quota,
|
||||
'/sys/fs/cgroup/cpuacct/cpu.cfs_period_us': _cfs_period,
|
||||
},
|
||||
},
|
||||
'Windows': {
|
||||
'cmd': (),
|
||||
'cmd_optional': (
|
||||
CmdBufferOut(('powershell', '-NoProfile', '-Command',
|
||||
"'TotalVirtualMemorySize ' + "
|
||||
"(Get-CimInstance -ClassName "
|
||||
"Win32_OperatingSystem).TotalVirtualMemorySize")),
|
||||
CmdBufferOut(('powershell', '-NoProfile', '-Command',
|
||||
"'FreeVirtualMemory ' + "
|
||||
"(Get-CimInstance -ClassName "
|
||||
"Win32_OperatingSystem).FreeVirtualMemory")),
|
||||
),
|
||||
'kwds': {
|
||||
# output string fragment -> result dict key
|
||||
'TotalVirtualMemorySize': _mem_total,
|
||||
'FreeVirtualMemory': _mem_available,
|
||||
},
|
||||
},
|
||||
'Darwin': {
|
||||
'cmd': (),
|
||||
'cmd_optional': (
|
||||
('sysctl', 'hw.memsize'),
|
||||
('vm_stat'),
|
||||
),
|
||||
'kwds': {
|
||||
# output string fragment -> result dict key
|
||||
'hw.memsize:': _mem_total,
|
||||
'free:': _mem_available,
|
||||
},
|
||||
'units': {
|
||||
_mem_total: 1, # Size is given in bytes.
|
||||
_mem_available: 4096, # Size is given in 4kB pages.
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
os_spec_info = {}
|
||||
params = shell_params.get(os_name, {})
|
||||
cmd_selected = params.get('cmd', ())
|
||||
|
||||
if _psutil_import:
|
||||
vm = psutil.virtual_memory()
|
||||
os_spec_info.update({
|
||||
_mem_total: vm.total,
|
||||
_mem_available: vm.available,
|
||||
})
|
||||
p = psutil.Process()
|
||||
cpus_allowed = p.cpu_affinity() if hasattr(p, 'cpu_affinity') else []
|
||||
if cpus_allowed:
|
||||
os_spec_info[_cpus_allowed] = len(cpus_allowed)
|
||||
os_spec_info[_cpus_list] = ' '.join(str(n) for n in cpus_allowed)
|
||||
|
||||
else:
|
||||
_warning_log.append(
|
||||
"Warning (psutil): psutil cannot be imported. "
|
||||
"For more accuracy, consider installing it.")
|
||||
# Fallback to internal heuristics
|
||||
cmd_selected += params.get('cmd_optional', ())
|
||||
|
||||
# Assuming the shell cmd returns a unique (k, v) pair per line
|
||||
# or a unique (k, v) pair spread over several lines:
|
||||
# Gather output in a list of strings containing a keyword and some value.
|
||||
output = []
|
||||
for cmd in cmd_selected:
|
||||
if hasattr(cmd, 'read_file_flag'):
|
||||
# Open file within Python
|
||||
if os.path.exists(cmd[0]):
|
||||
try:
|
||||
with open(cmd[0], 'r') as f:
|
||||
out = f.readlines()
|
||||
if out:
|
||||
out[0] = ' '.join((cmd[0], out[0]))
|
||||
output.extend(out)
|
||||
except OSError as e:
|
||||
_error_log.append(f'Error (file read): {e}')
|
||||
continue
|
||||
else:
|
||||
_warning_log.append('Warning (no file): {}'.format(cmd[0]))
|
||||
continue
|
||||
else:
|
||||
# Spawn a subprocess
|
||||
try:
|
||||
out = check_output(cmd, stderr=PIPE)
|
||||
except (OSError, CalledProcessError) as e:
|
||||
_error_log.append(f'Error (subprocess): {e}')
|
||||
continue
|
||||
if hasattr(cmd, 'buffer_output_flag'):
|
||||
out = b' '.join(line for line in out.splitlines()) + b'\n'
|
||||
output.extend(out.decode().splitlines())
|
||||
|
||||
# Extract (k, output) pairs by searching for keywords in output
|
||||
kwds = params.get('kwds', {})
|
||||
for line in output:
|
||||
match = kwds.keys() & line.split()
|
||||
if match and len(match) == 1:
|
||||
k = kwds[match.pop()]
|
||||
os_spec_info[k] = line
|
||||
elif len(match) > 1:
|
||||
print(f'Ambiguous output: {line}')
|
||||
|
||||
# Try to extract something meaningful from output string
|
||||
def format():
|
||||
# CFS restrictions
|
||||
split = os_spec_info.get(_cfs_quota, '').split()
|
||||
if split:
|
||||
os_spec_info[_cfs_quota] = float(split[-1])
|
||||
split = os_spec_info.get(_cfs_period, '').split()
|
||||
if split:
|
||||
os_spec_info[_cfs_period] = float(split[-1])
|
||||
if os_spec_info.get(_cfs_quota, -1) != -1:
|
||||
cfs_quota = os_spec_info.get(_cfs_quota, '')
|
||||
cfs_period = os_spec_info.get(_cfs_period, '')
|
||||
runtime_amount = cfs_quota / cfs_period
|
||||
os_spec_info[_cfs_restrict] = runtime_amount
|
||||
|
||||
def format_optional():
|
||||
# Memory
|
||||
units = {_mem_total: 1024, _mem_available: 1024}
|
||||
units.update(params.get('units', {}))
|
||||
for k in (_mem_total, _mem_available):
|
||||
digits = ''.join(d for d in os_spec_info.get(k, '') if d.isdigit())
|
||||
os_spec_info[k] = int(digits or 0) * units[k]
|
||||
# Accessible CPUs
|
||||
split = os_spec_info.get(_cpus_allowed, '').split()
|
||||
if split:
|
||||
n = split[-1]
|
||||
n = n.split(',')[-1]
|
||||
os_spec_info[_cpus_allowed] = str(bin(int(n or 0, 16))).count('1')
|
||||
split = os_spec_info.get(_cpus_list, '').split()
|
||||
if split:
|
||||
os_spec_info[_cpus_list] = split[-1]
|
||||
|
||||
try:
|
||||
format()
|
||||
if not _psutil_import:
|
||||
format_optional()
|
||||
except Exception as e:
|
||||
_error_log.append(f'Error (format shell output): {e}')
|
||||
|
||||
# Call OS specific functions
|
||||
os_specific_funcs = {
|
||||
'Linux': {
|
||||
_libc_version: lambda: ' '.join(platform.libc_ver())
|
||||
},
|
||||
'Windows': {
|
||||
_os_spec_version: lambda: ' '.join(
|
||||
s for s in platform.win32_ver()),
|
||||
},
|
||||
'Darwin': {
|
||||
_os_spec_version: lambda: ''.join(
|
||||
i or ' ' for s in tuple(platform.mac_ver()) for i in s),
|
||||
},
|
||||
}
|
||||
key_func = os_specific_funcs.get(os_name, {})
|
||||
os_spec_info.update({k: f() for k, f in key_func.items()})
|
||||
return os_spec_info
|
||||
|
||||
|
||||
def get_sysinfo():
|
||||
|
||||
# Gather the information that shouldn't raise exceptions
|
||||
sys_info = {
|
||||
_start: datetime.now(),
|
||||
_start_utc: datetime.utcnow(),
|
||||
_machine: platform.machine(),
|
||||
_cpu_name: llvmbind.get_host_cpu_name(),
|
||||
_cpu_count: multiprocessing.cpu_count(),
|
||||
_platform_name: platform.platform(aliased=True),
|
||||
_platform_release: platform.release(),
|
||||
_os_name: platform.system(),
|
||||
_os_version: platform.version(),
|
||||
_python_comp: platform.python_compiler(),
|
||||
_python_impl: platform.python_implementation(),
|
||||
_python_version: platform.python_version(),
|
||||
_numba_env_vars: {k: v for (k, v) in os.environ.items()
|
||||
if k.startswith('NUMBA_')},
|
||||
_numba_version: version_number,
|
||||
_llvm_version: '.'.join(str(i) for i in llvmbind.llvm_version_info),
|
||||
_llvmlite_version: llvmlite_version,
|
||||
_psutil: _psutil_import,
|
||||
}
|
||||
|
||||
# CPU features
|
||||
try:
|
||||
feature_map = llvmbind.get_host_cpu_features()
|
||||
except RuntimeError as e:
|
||||
_error_log.append(f'Error (CPU features): {e}')
|
||||
else:
|
||||
features = sorted([key for key, value in feature_map.items() if value])
|
||||
sys_info[_cpu_features] = ' '.join(features)
|
||||
|
||||
# Python locale
|
||||
# On MacOSX, getdefaultlocale can raise. Check again if Py > 3.7.5
|
||||
try:
|
||||
# If $LANG is unset, getdefaultlocale() can return (None, None), make
|
||||
# sure we can encode this as strings by casting explicitly.
|
||||
sys_info[_python_locale] = '.'.join([str(i) for i in
|
||||
locale.getdefaultlocale()])
|
||||
except Exception as e:
|
||||
_error_log.append(f'Error (locale): {e}')
|
||||
|
||||
# CUDA information
|
||||
try:
|
||||
sys_info[_cu_target_impl] = cu.implementation
|
||||
except AttributeError:
|
||||
# On the offchance an out-of-tree target did not set the
|
||||
# implementation, we can try to continue
|
||||
pass
|
||||
|
||||
try:
|
||||
cu.list_devices()[0] # will a device initialise?
|
||||
except Exception as e:
|
||||
sys_info[_cu_dev_init] = False
|
||||
msg_not_found = "CUDA driver library cannot be found"
|
||||
msg_disabled_by_user = "CUDA is disabled"
|
||||
msg_end = " or no CUDA enabled devices are present."
|
||||
msg_generic_problem = "CUDA device initialisation problem."
|
||||
msg = getattr(e, 'msg', None)
|
||||
if msg is not None:
|
||||
if msg_not_found in msg:
|
||||
err_msg = msg_not_found + msg_end
|
||||
elif msg_disabled_by_user in msg:
|
||||
err_msg = msg_disabled_by_user + msg_end
|
||||
else:
|
||||
err_msg = msg_generic_problem + " Message:" + msg
|
||||
else:
|
||||
err_msg = msg_generic_problem + " " + str(e)
|
||||
# Best effort error report
|
||||
_warning_log.append("Warning (cuda): %s\nException class: %s" %
|
||||
(err_msg, str(type(e))))
|
||||
else:
|
||||
try:
|
||||
sys_info[_cu_dev_init] = True
|
||||
|
||||
output = StringIO()
|
||||
with redirect_stdout(output):
|
||||
cu.detect()
|
||||
sys_info[_cu_detect_out] = output.getvalue()
|
||||
output.close()
|
||||
|
||||
cu_drv_ver = cudriver.get_version()
|
||||
cu_rt_ver = curuntime.get_version()
|
||||
sys_info[_cu_drv_ver] = '%s.%s' % cu_drv_ver
|
||||
sys_info[_cu_rt_ver] = '%s.%s' % cu_rt_ver
|
||||
|
||||
output = StringIO()
|
||||
with redirect_stdout(output):
|
||||
cudadrv.libs.test()
|
||||
sys_info[_cu_lib_test] = output.getvalue()
|
||||
output.close()
|
||||
|
||||
try:
|
||||
from cuda import cuda # noqa: F401
|
||||
nvidia_bindings_available = True
|
||||
except ImportError:
|
||||
nvidia_bindings_available = False
|
||||
sys_info[_cu_nvidia_bindings] = nvidia_bindings_available
|
||||
|
||||
nv_binding_used = bool(cudadrv.driver.USE_NV_BINDING)
|
||||
sys_info[_cu_nvidia_bindings_used] = nv_binding_used
|
||||
|
||||
try:
|
||||
from ptxcompiler import compile_ptx # noqa: F401
|
||||
from cubinlinker import CubinLinker # noqa: F401
|
||||
sys_info[_cu_mvc_available] = True
|
||||
except ImportError:
|
||||
sys_info[_cu_mvc_available] = False
|
||||
|
||||
sys_info[_cu_mvc_needed] = cu_rt_ver > cu_drv_ver
|
||||
sys_info[_cu_mvc_in_use] = bool(
|
||||
config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY)
|
||||
except Exception as e:
|
||||
_warning_log.append(
|
||||
"Warning (cuda): Probing CUDA failed "
|
||||
"(device and driver present, runtime problem?)\n"
|
||||
f"(cuda) {type(e)}: {e}")
|
||||
|
||||
# NumPy information
|
||||
sys_info[_numpy_version] = np.version.full_version
|
||||
try:
|
||||
# NOTE: These consts were added in NumPy 1.20
|
||||
from numpy.core._multiarray_umath import (__cpu_features__,
|
||||
__cpu_dispatch__,
|
||||
__cpu_baseline__,)
|
||||
except ImportError:
|
||||
sys_info[_numpy_AVX512_SKX_detected] = False
|
||||
else:
|
||||
feat_filtered = [k for k, v in __cpu_features__.items() if v]
|
||||
sys_info[_numpy_supported_simd_features] = feat_filtered
|
||||
sys_info[_numpy_supported_simd_dispatch] = __cpu_dispatch__
|
||||
sys_info[_numpy_supported_simd_baseline] = __cpu_baseline__
|
||||
sys_info[_numpy_AVX512_SKX_detected] = \
|
||||
__cpu_features__.get("AVX512_SKX", False)
|
||||
|
||||
# SVML information
|
||||
# Replicate some SVML detection logic from numba.__init__ here.
|
||||
# If SVML load fails in numba.__init__ the splitting of the logic
|
||||
# here will help diagnosing the underlying issue.
|
||||
svml_lib_loaded = True
|
||||
try:
|
||||
if sys.platform.startswith('linux'):
|
||||
llvmbind.load_library_permanently("libsvml.so")
|
||||
elif sys.platform.startswith('darwin'):
|
||||
llvmbind.load_library_permanently("libsvml.dylib")
|
||||
elif sys.platform.startswith('win'):
|
||||
llvmbind.load_library_permanently("svml_dispmd")
|
||||
else:
|
||||
svml_lib_loaded = False
|
||||
except Exception:
|
||||
svml_lib_loaded = False
|
||||
func = getattr(llvmbind.targets, "has_svml", None)
|
||||
sys_info[_llvm_svml_patched] = func() if func else False
|
||||
sys_info[_svml_state] = config.USING_SVML
|
||||
sys_info[_svml_loaded] = svml_lib_loaded
|
||||
sys_info[_svml_operational] = all((
|
||||
sys_info[_svml_state],
|
||||
sys_info[_svml_loaded],
|
||||
sys_info[_llvm_svml_patched],
|
||||
))
|
||||
|
||||
# Check which threading backends are available.
|
||||
def parse_error(e, backend):
|
||||
# parses a linux based error message, this is to provide feedback
|
||||
# and hide user paths etc
|
||||
try:
|
||||
path, problem, symbol = [x.strip() for x in e.msg.split(':')]
|
||||
extn_dso = os.path.split(path)[1]
|
||||
if backend in extn_dso:
|
||||
return "%s: %s" % (problem, symbol)
|
||||
except Exception:
|
||||
pass
|
||||
return "Unknown import problem."
|
||||
|
||||
try:
|
||||
# check import is ok, this means the DSO linkage is working
|
||||
from numba.np.ufunc import tbbpool # NOQA
|
||||
# check that the version is compatible, this is a check performed at
|
||||
# runtime (well, compile time), it will also ImportError if there's
|
||||
# a problem.
|
||||
from numba.np.ufunc.parallel import _check_tbb_version_compatible
|
||||
_check_tbb_version_compatible()
|
||||
sys_info[_tbb_thread] = True
|
||||
except ImportError as e:
|
||||
# might be a missing symbol due to e.g. tbb libraries missing
|
||||
sys_info[_tbb_thread] = False
|
||||
sys_info[_tbb_error] = parse_error(e, 'tbbpool')
|
||||
|
||||
try:
|
||||
from numba.np.ufunc import omppool
|
||||
sys_info[_openmp_thread] = True
|
||||
sys_info[_openmp_vendor] = omppool.openmp_vendor
|
||||
except ImportError as e:
|
||||
sys_info[_openmp_thread] = False
|
||||
sys_info[_openmp_error] = parse_error(e, 'omppool')
|
||||
|
||||
try:
|
||||
from numba.np.ufunc import workqueue # NOQA
|
||||
sys_info[_wkq_thread] = True
|
||||
except ImportError as e:
|
||||
sys_info[_wkq_thread] = True
|
||||
sys_info[_wkq_error] = parse_error(e, 'workqueue')
|
||||
|
||||
# Look for conda and installed packages information
|
||||
cmd = ('conda', 'info', '--json')
|
||||
try:
|
||||
conda_out = check_output(cmd)
|
||||
except Exception as e:
|
||||
_warning_log.append(f'Warning: Conda not available.\n Error was {e}\n')
|
||||
# Conda is not available, try pip list to list installed packages
|
||||
cmd = (sys.executable, '-m', 'pip', 'list')
|
||||
try:
|
||||
reqs = check_output(cmd)
|
||||
except Exception as e:
|
||||
_error_log.append(f'Error (pip): {e}')
|
||||
else:
|
||||
sys_info[_inst_pkg] = reqs.decode().splitlines()
|
||||
|
||||
else:
|
||||
jsond = json.loads(conda_out.decode())
|
||||
keys = {
|
||||
'conda_build_version': _conda_build_ver,
|
||||
'conda_env_version': _conda_env_ver,
|
||||
'platform': _conda_platform,
|
||||
'python_version': _conda_python_ver,
|
||||
'root_writable': _conda_root_writable,
|
||||
}
|
||||
for conda_k, sysinfo_k in keys.items():
|
||||
sys_info[sysinfo_k] = jsond.get(conda_k, 'N/A')
|
||||
|
||||
# Get info about packages in current environment
|
||||
cmd = ('conda', 'list')
|
||||
try:
|
||||
conda_out = check_output(cmd)
|
||||
except CalledProcessError as e:
|
||||
_error_log.append(f'Error (conda): {e}')
|
||||
else:
|
||||
data = conda_out.decode().splitlines()
|
||||
sys_info[_inst_pkg] = [l for l in data if not l.startswith('#')]
|
||||
|
||||
sys_info.update(get_os_spec_info(sys_info[_os_name]))
|
||||
sys_info[_errors] = _error_log
|
||||
sys_info[_warnings] = _warning_log
|
||||
sys_info[_runtime] = (datetime.now() - sys_info[_start]).total_seconds()
|
||||
return sys_info
|
||||
|
||||
|
||||
def display_sysinfo(info=None, sep_pos=45):
|
||||
class DisplayMap(dict):
|
||||
display_map_flag = True
|
||||
|
||||
class DisplaySeq(tuple):
|
||||
display_seq_flag = True
|
||||
|
||||
class DisplaySeqMaps(tuple):
|
||||
display_seqmaps_flag = True
|
||||
|
||||
if info is None:
|
||||
info = get_sysinfo()
|
||||
|
||||
fmt = f'%-{sep_pos}s : %-s'
|
||||
MB = 1024**2
|
||||
template = (
|
||||
("-" * 80,),
|
||||
("__Time Stamp__",),
|
||||
("Report started (local time)", info.get(_start, '?')),
|
||||
("UTC start time", info.get(_start_utc, '?')),
|
||||
("Running time (s)", info.get(_runtime, '?')),
|
||||
("",),
|
||||
("__Hardware Information__",),
|
||||
("Machine", info.get(_machine, '?')),
|
||||
("CPU Name", info.get(_cpu_name, '?')),
|
||||
("CPU Count", info.get(_cpu_count, '?')),
|
||||
("Number of accessible CPUs", info.get(_cpus_allowed, '?')),
|
||||
("List of accessible CPUs cores", info.get(_cpus_list, '?')),
|
||||
("CFS Restrictions (CPUs worth of runtime)",
|
||||
info.get(_cfs_restrict, 'None')),
|
||||
("",),
|
||||
("CPU Features", '\n'.join(
|
||||
' ' * (sep_pos + 3) + l if i else l
|
||||
for i, l in enumerate(
|
||||
textwrap.wrap(
|
||||
info.get(_cpu_features, '?'),
|
||||
width=79 - sep_pos
|
||||
)
|
||||
)
|
||||
)),
|
||||
("",),
|
||||
("Memory Total (MB)", info.get(_mem_total, 0) // MB or '?'),
|
||||
("Memory Available (MB)"
|
||||
if info.get(_os_name, '') != 'Darwin' or info.get(_psutil, False)
|
||||
else "Free Memory (MB)", info.get(_mem_available, 0) // MB or '?'),
|
||||
("",),
|
||||
("__OS Information__",),
|
||||
("Platform Name", info.get(_platform_name, '?')),
|
||||
("Platform Release", info.get(_platform_release, '?')),
|
||||
("OS Name", info.get(_os_name, '?')),
|
||||
("OS Version", info.get(_os_version, '?')),
|
||||
("OS Specific Version", info.get(_os_spec_version, '?')),
|
||||
("Libc Version", info.get(_libc_version, '?')),
|
||||
("",),
|
||||
("__Python Information__",),
|
||||
DisplayMap({k: v for k, v in info.items() if k.startswith('Python')}),
|
||||
("",),
|
||||
("__Numba Toolchain Versions__",),
|
||||
("Numba Version", info.get(_numba_version, '?')),
|
||||
("llvmlite Version", info.get(_llvmlite_version, '?')),
|
||||
("",),
|
||||
("__LLVM Information__",),
|
||||
("LLVM Version", info.get(_llvm_version, '?')),
|
||||
("",),
|
||||
("__CUDA Information__",),
|
||||
("CUDA Target Implementation", info.get(_cu_target_impl, '?')),
|
||||
("CUDA Device Initialized", info.get(_cu_dev_init, '?')),
|
||||
("CUDA Driver Version", info.get(_cu_drv_ver, '?')),
|
||||
("CUDA Runtime Version", info.get(_cu_rt_ver, '?')),
|
||||
("CUDA NVIDIA Bindings Available", info.get(_cu_nvidia_bindings, '?')),
|
||||
("CUDA NVIDIA Bindings In Use",
|
||||
info.get(_cu_nvidia_bindings_used, '?')),
|
||||
("CUDA Minor Version Compatibility Available",
|
||||
info.get(_cu_mvc_available, '?')),
|
||||
("CUDA Minor Version Compatibility Needed",
|
||||
info.get(_cu_mvc_needed, '?')),
|
||||
("CUDA Minor Version Compatibility In Use",
|
||||
info.get(_cu_mvc_in_use, '?')),
|
||||
("CUDA Detect Output:",),
|
||||
(info.get(_cu_detect_out, "None"),),
|
||||
("CUDA Libraries Test Output:",),
|
||||
(info.get(_cu_lib_test, "None"),),
|
||||
("",),
|
||||
("__NumPy Information__",),
|
||||
("NumPy Version", info.get(_numpy_version, '?')),
|
||||
("NumPy Supported SIMD features",
|
||||
DisplaySeq(info.get(_numpy_supported_simd_features, [])
|
||||
or ('None found.',))),
|
||||
("NumPy Supported SIMD dispatch",
|
||||
DisplaySeq(info.get(_numpy_supported_simd_dispatch, [])
|
||||
or ('None found.',))),
|
||||
("NumPy Supported SIMD baseline",
|
||||
DisplaySeq(info.get(_numpy_supported_simd_baseline, [])
|
||||
or ('None found.',))),
|
||||
("NumPy AVX512_SKX support detected",
|
||||
info.get(_numpy_AVX512_SKX_detected, '?')),
|
||||
("",),
|
||||
("__SVML Information__",),
|
||||
("SVML State, config.USING_SVML", info.get(_svml_state, '?')),
|
||||
("SVML Library Loaded", info.get(_svml_loaded, '?')),
|
||||
("llvmlite Using SVML Patched LLVM", info.get(_llvm_svml_patched, '?')),
|
||||
("SVML Operational", info.get(_svml_operational, '?')),
|
||||
("",),
|
||||
("__Threading Layer Information__",),
|
||||
("TBB Threading Layer Available", info.get(_tbb_thread, '?')),
|
||||
("+-->TBB imported successfully." if info.get(_tbb_thread, '?')
|
||||
else f"+--> Disabled due to {info.get(_tbb_error, '?')}",),
|
||||
("OpenMP Threading Layer Available", info.get(_openmp_thread, '?')),
|
||||
(f"+-->Vendor: {info.get(_openmp_vendor, '?')}"
|
||||
if info.get(_openmp_thread, False)
|
||||
else f"+--> Disabled due to {info.get(_openmp_error, '?')}",),
|
||||
("Workqueue Threading Layer Available", info.get(_wkq_thread, '?')),
|
||||
("+-->Workqueue imported successfully." if info.get(_wkq_thread, False)
|
||||
else f"+--> Disabled due to {info.get(_wkq_error, '?')}",),
|
||||
("",),
|
||||
("__Numba Environment Variable Information__",),
|
||||
(DisplayMap(info.get(_numba_env_vars, {})) or ('None found.',)),
|
||||
("",),
|
||||
("__Conda Information__",),
|
||||
(DisplayMap({k: v for k, v in info.items()
|
||||
if k.startswith('Conda')}) or ("Conda not available.",)),
|
||||
("",),
|
||||
("__Installed Packages__",),
|
||||
DisplaySeq(info.get(_inst_pkg, ("Couldn't retrieve packages info.",))),
|
||||
("",),
|
||||
("__Error log__" if info.get(_errors, [])
|
||||
else "No errors reported.",),
|
||||
DisplaySeq(info.get(_errors, [])),
|
||||
("",),
|
||||
("__Warning log__" if info.get(_warnings, [])
|
||||
else "No warnings reported.",),
|
||||
DisplaySeq(info.get(_warnings, [])),
|
||||
("-" * 80,),
|
||||
("If requested, please copy and paste the information between\n"
|
||||
"the dashed (----) lines, or from a given specific section as\n"
|
||||
"appropriate.\n\n"
|
||||
"=============================================================\n"
|
||||
"IMPORTANT: Please ensure that you are happy with sharing the\n"
|
||||
"contents of the information present, any information that you\n"
|
||||
"wish to keep private you should remove before sharing.\n"
|
||||
"=============================================================\n",),
|
||||
)
|
||||
for t in template:
|
||||
if hasattr(t, 'display_seq_flag'):
|
||||
print(*t, sep='\n')
|
||||
elif hasattr(t, 'display_map_flag'):
|
||||
print(*tuple(fmt % (k, v) for (k, v) in t.items()), sep='\n')
|
||||
elif hasattr(t, 'display_seqmaps_flag'):
|
||||
for d in t:
|
||||
print(*tuple(fmt % ('\t' + k, v) for (k, v) in d.items()),
|
||||
sep='\n', end='\n')
|
||||
elif len(t) == 2:
|
||||
print(fmt % t)
|
||||
else:
|
||||
print(*t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
display_sysinfo()
|
||||
@@ -0,0 +1,258 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from numba.core import types, config
|
||||
|
||||
|
||||
QuicksortImplementation = collections.namedtuple(
|
||||
'QuicksortImplementation',
|
||||
(# The compile function itself
|
||||
'compile',
|
||||
# All subroutines exercised by test_sort
|
||||
'partition', 'partition3', 'insertion_sort',
|
||||
# The top-level function
|
||||
'run_quicksort',
|
||||
))
|
||||
|
||||
|
||||
Partition = collections.namedtuple('Partition', ('start', 'stop'))
|
||||
|
||||
# Under this size, switch to a simple insertion sort
|
||||
SMALL_QUICKSORT = 15
|
||||
|
||||
MAX_STACK = 100
|
||||
|
||||
|
||||
def make_quicksort_impl(wrap, lt=None, is_argsort=False, is_list=False, is_np_array=False):
|
||||
|
||||
intp = types.intp
|
||||
zero = intp(0)
|
||||
|
||||
# Two subroutines to make the core algorithm generic wrt. argsort
|
||||
# or normal sorting. Note the genericity may make basic sort()
|
||||
# slightly slower (~5%)
|
||||
if is_argsort:
|
||||
if is_list:
|
||||
@wrap
|
||||
def make_res(A):
|
||||
return [x for x in range(len(A))]
|
||||
else:
|
||||
@wrap
|
||||
def make_res(A):
|
||||
return np.arange(A.size)
|
||||
|
||||
@wrap
|
||||
def GET(A, idx_or_val):
|
||||
return A[idx_or_val]
|
||||
|
||||
else:
|
||||
@wrap
|
||||
def make_res(A):
|
||||
return A
|
||||
|
||||
@wrap
|
||||
def GET(A, idx_or_val):
|
||||
return idx_or_val
|
||||
|
||||
def default_lt(a, b):
|
||||
"""
|
||||
Trivial comparison function between two keys.
|
||||
"""
|
||||
return a < b
|
||||
|
||||
LT = wrap(lt if lt is not None else default_lt)
|
||||
|
||||
@wrap
|
||||
def insertion_sort(A, R, low, high):
|
||||
"""
|
||||
Insertion sort A[low:high + 1]. Note the inclusive bounds.
|
||||
"""
|
||||
assert low >= 0
|
||||
if high <= low:
|
||||
return
|
||||
|
||||
for i in range(low + 1, high + 1):
|
||||
k = R[i]
|
||||
v = GET(A, k)
|
||||
# Insert v into A[low:i]
|
||||
j = i
|
||||
while j > low and LT(v, GET(A, R[j - 1])):
|
||||
# Make place for moving A[i] downwards
|
||||
R[j] = R[j - 1]
|
||||
j -= 1
|
||||
R[j] = k
|
||||
|
||||
@wrap
|
||||
def partition(A, R, low, high):
|
||||
"""
|
||||
Partition A[low:high + 1] around a chosen pivot. The pivot's index
|
||||
is returned.
|
||||
"""
|
||||
assert low >= 0
|
||||
assert high > low
|
||||
|
||||
mid = (low + high) >> 1
|
||||
# NOTE: the pattern of swaps below for the pivot choice and the
|
||||
# partitioning gives good results (i.e. regular O(n log n))
|
||||
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
|
||||
# risk breaking this property.
|
||||
|
||||
# median of three {low, middle, high}
|
||||
if LT(GET(A, R[mid]), GET(A, R[low])):
|
||||
R[low], R[mid] = R[mid], R[low]
|
||||
if LT(GET(A, R[high]), GET(A, R[mid])):
|
||||
R[high], R[mid] = R[mid], R[high]
|
||||
if LT(GET(A, R[mid]), GET(A, R[low])):
|
||||
R[low], R[mid] = R[mid], R[low]
|
||||
pivot = GET(A, R[mid])
|
||||
|
||||
# Temporarily stash the pivot at the end
|
||||
R[high], R[mid] = R[mid], R[high]
|
||||
i = low
|
||||
j = high - 1
|
||||
while True:
|
||||
while i < high and LT(GET(A, R[i]), pivot):
|
||||
i += 1
|
||||
while j >= low and LT(pivot, GET(A, R[j])):
|
||||
j -= 1
|
||||
if i >= j:
|
||||
break
|
||||
R[i], R[j] = R[j], R[i]
|
||||
i += 1
|
||||
j -= 1
|
||||
# Put the pivot back in its final place (all items before `i`
|
||||
# are smaller than the pivot, all items at/after `i` are larger)
|
||||
R[i], R[high] = R[high], R[i]
|
||||
return i
|
||||
|
||||
@wrap
|
||||
def partition3(A, low, high):
|
||||
"""
|
||||
Three-way partition [low, high) around a chosen pivot.
|
||||
A tuple (lt, gt) is returned such that:
|
||||
- all elements in [low, lt) are < pivot
|
||||
- all elements in [lt, gt] are == pivot
|
||||
- all elements in (gt, high] are > pivot
|
||||
"""
|
||||
mid = (low + high) >> 1
|
||||
# median of three {low, middle, high}
|
||||
if LT(A[mid], A[low]):
|
||||
A[low], A[mid] = A[mid], A[low]
|
||||
if LT(A[high], A[mid]):
|
||||
A[high], A[mid] = A[mid], A[high]
|
||||
if LT(A[mid], A[low]):
|
||||
A[low], A[mid] = A[mid], A[low]
|
||||
pivot = A[mid]
|
||||
|
||||
A[low], A[mid] = A[mid], A[low]
|
||||
lt = low
|
||||
gt = high
|
||||
i = low + 1
|
||||
while i <= gt:
|
||||
if LT(A[i], pivot):
|
||||
A[lt], A[i] = A[i], A[lt]
|
||||
lt += 1
|
||||
i += 1
|
||||
elif LT(pivot, A[i]):
|
||||
A[gt], A[i] = A[i], A[gt]
|
||||
gt -= 1
|
||||
else:
|
||||
i += 1
|
||||
return lt, gt
|
||||
|
||||
@wrap
|
||||
def run_quicksort1(A):
|
||||
R = make_res(A)
|
||||
|
||||
if len(A) < 2:
|
||||
return R
|
||||
|
||||
stack = [Partition(zero, zero)] * MAX_STACK
|
||||
stack[0] = Partition(zero, len(A) - 1)
|
||||
n = 1
|
||||
|
||||
while n > 0:
|
||||
n -= 1
|
||||
low, high = stack[n]
|
||||
# Partition until it becomes more efficient to do an insertion sort
|
||||
while high - low >= SMALL_QUICKSORT:
|
||||
assert n < MAX_STACK
|
||||
i = partition(A, R, low, high)
|
||||
# Push largest partition on the stack
|
||||
if high - i > i - low:
|
||||
# Right is larger
|
||||
if high > i:
|
||||
stack[n] = Partition(i + 1, high)
|
||||
n += 1
|
||||
high = i - 1
|
||||
else:
|
||||
if i > low:
|
||||
stack[n] = Partition(low, i - 1)
|
||||
n += 1
|
||||
low = i + 1
|
||||
|
||||
insertion_sort(A, R, low, high)
|
||||
|
||||
return R
|
||||
|
||||
if is_np_array:
|
||||
@wrap
|
||||
def run_quicksort(A):
|
||||
if A.ndim == 1:
|
||||
return run_quicksort1(A)
|
||||
else:
|
||||
for idx in np.ndindex(A.shape[:-1]):
|
||||
run_quicksort1(A[idx])
|
||||
return A
|
||||
else:
|
||||
@wrap
|
||||
def run_quicksort(A):
|
||||
return run_quicksort1(A)
|
||||
|
||||
# Unused quicksort implementation based on 3-way partitioning; the
|
||||
# partitioning scheme turns out exhibiting bad behaviour on sorted arrays.
|
||||
@wrap
|
||||
def _run_quicksort(A):
|
||||
stack = [Partition(zero, zero)] * 100
|
||||
stack[0] = Partition(zero, len(A) - 1)
|
||||
n = 1
|
||||
|
||||
while n > 0:
|
||||
n -= 1
|
||||
low, high = stack[n]
|
||||
# Partition until it becomes more efficient to do an insertion sort
|
||||
while high - low >= SMALL_QUICKSORT:
|
||||
assert n < MAX_STACK
|
||||
l, r = partition3(A, low, high)
|
||||
# One trivial (empty) partition => iterate on the other
|
||||
if r == high:
|
||||
high = l - 1
|
||||
elif l == low:
|
||||
low = r + 1
|
||||
# Push largest partition on the stack
|
||||
elif high - r > l - low:
|
||||
# Right is larger
|
||||
stack[n] = Partition(r + 1, high)
|
||||
n += 1
|
||||
high = l - 1
|
||||
else:
|
||||
stack[n] = Partition(low, l - 1)
|
||||
n += 1
|
||||
low = r + 1
|
||||
|
||||
insertion_sort(A, low, high)
|
||||
|
||||
|
||||
return QuicksortImplementation(wrap,
|
||||
partition, partition3, insertion_sort,
|
||||
run_quicksort)
|
||||
|
||||
|
||||
def make_py_quicksort(*args, **kwargs):
|
||||
return make_quicksort_impl((lambda f: f), *args, **kwargs)
|
||||
|
||||
def make_jit_quicksort(*args, **kwargs):
|
||||
from numba.core.extending import register_jitable
|
||||
return make_quicksort_impl((lambda f: register_jitable(f)),
|
||||
*args, **kwargs)
|
||||
@@ -0,0 +1,104 @@
|
||||
import numpy as np
|
||||
|
||||
from numba.core.typing.typeof import typeof
|
||||
from numba.core.typing.asnumbatype import as_numba_type
|
||||
|
||||
|
||||
def pndindex(*args):
|
||||
""" Provides an n-dimensional parallel iterator that generates index tuples
|
||||
for each iteration point. Sequentially, pndindex is identical to np.ndindex.
|
||||
"""
|
||||
return np.ndindex(*args)
|
||||
|
||||
|
||||
class prange(object):
|
||||
""" Provides a 1D parallel iterator that generates a sequence of integers.
|
||||
In non-parallel contexts, prange is identical to range.
|
||||
"""
|
||||
def __new__(cls, *args):
|
||||
return range(*args)
|
||||
|
||||
|
||||
def _gdb_python_call_gen(func_name, *args):
|
||||
# generates a call to a function containing a compiled in gdb command,
|
||||
# this is to make `numba.gdb*` work in the interpreter.
|
||||
import numba
|
||||
fn = getattr(numba, func_name)
|
||||
argstr = ','.join(['"%s"' for _ in args]) % args
|
||||
defn = """def _gdb_func_injection():\n\t%s(%s)\n
|
||||
""" % (func_name, argstr)
|
||||
l = {}
|
||||
exec(defn, {func_name: fn}, l)
|
||||
return numba.njit(l['_gdb_func_injection'])
|
||||
|
||||
|
||||
def gdb(*args):
|
||||
"""
|
||||
Calling this function will invoke gdb and attach it to the current process
|
||||
at the call site. Arguments are strings in the gdb command language syntax
|
||||
which will be executed by gdb once initialisation has occurred.
|
||||
"""
|
||||
_gdb_python_call_gen('gdb', *args)()
|
||||
|
||||
|
||||
def gdb_breakpoint():
|
||||
"""
|
||||
Calling this function will inject a breakpoint at the call site that is
|
||||
recognised by both `gdb` and `gdb_init`, this is to allow breaking at
|
||||
multiple points. gdb will stop in the user defined code just after the frame
|
||||
employed by the breakpoint returns.
|
||||
"""
|
||||
_gdb_python_call_gen('gdb_breakpoint')()
|
||||
|
||||
|
||||
def gdb_init(*args):
|
||||
"""
|
||||
Calling this function will invoke gdb and attach it to the current process
|
||||
at the call site, then continue executing the process under gdb's control.
|
||||
Arguments are strings in the gdb command language syntax which will be
|
||||
executed by gdb once initialisation has occurred.
|
||||
"""
|
||||
_gdb_python_call_gen('gdb_init', *args)()
|
||||
|
||||
|
||||
def literally(obj):
|
||||
"""Forces Numba to interpret *obj* as an Literal value.
|
||||
|
||||
*obj* must be either a literal or an argument of the caller function, where
|
||||
the argument must be bound to a literal. The literal requirement
|
||||
propagates up the call stack.
|
||||
|
||||
This function is intercepted by the compiler to alter the compilation
|
||||
behavior to wrap the corresponding function parameters as ``Literal``.
|
||||
It has **no effect** outside of nopython-mode (interpreter, and objectmode).
|
||||
|
||||
The current implementation detects literal arguments in two ways:
|
||||
|
||||
1. Scans for uses of ``literally`` via a compiler pass.
|
||||
2. ``literally`` is overloaded to raise ``numba.errors.ForceLiteralArg``
|
||||
to signal the dispatcher to treat the corresponding parameter
|
||||
differently. This mode is to support indirect use (via a function call).
|
||||
|
||||
The execution semantic of this function is equivalent to an identity
|
||||
function.
|
||||
|
||||
See :ghfile:`numba/tests/test_literal_dispatch.py` for examples.
|
||||
"""
|
||||
return obj
|
||||
|
||||
|
||||
def literal_unroll(container):
|
||||
return container
|
||||
|
||||
|
||||
__all__ = [
|
||||
'typeof',
|
||||
'as_numba_type',
|
||||
'prange',
|
||||
'pndindex',
|
||||
'gdb',
|
||||
'gdb_breakpoint',
|
||||
'gdb_init',
|
||||
'literally',
|
||||
'literal_unroll',
|
||||
]
|
||||
@@ -0,0 +1,943 @@
|
||||
"""
|
||||
Timsort implementation. Mostly adapted from CPython's listobject.c.
|
||||
|
||||
For more information, see listsort.txt in CPython's source tree.
|
||||
"""
|
||||
|
||||
|
||||
import collections
|
||||
|
||||
from numba.core import types
|
||||
|
||||
|
||||
TimsortImplementation = collections.namedtuple(
|
||||
'TimsortImplementation',
|
||||
(# The compile function itself
|
||||
'compile',
|
||||
# All subroutines exercised by test_sort
|
||||
'count_run', 'binarysort', 'gallop_left', 'gallop_right',
|
||||
'merge_init', 'merge_append', 'merge_pop',
|
||||
'merge_compute_minrun', 'merge_lo', 'merge_hi', 'merge_at',
|
||||
'merge_force_collapse', 'merge_collapse',
|
||||
# The top-level functions
|
||||
'run_timsort', 'run_timsort_with_values'
|
||||
))
|
||||
|
||||
|
||||
# The maximum number of entries in a MergeState's pending-runs stack.
|
||||
# This is enough to sort arrays of size up to about
|
||||
# 32 * phi ** MAX_MERGE_PENDING
|
||||
# where phi ~= 1.618. 85 is ridiculously large enough, good for an array
|
||||
# with 2**64 elements.
|
||||
# NOTE this implementation doesn't depend on it (the stack is dynamically
|
||||
# allocated), but it's still good to check as an invariant.
|
||||
MAX_MERGE_PENDING = 85
|
||||
|
||||
# When we get into galloping mode, we stay there until both runs win less
|
||||
# often than MIN_GALLOP consecutive times. See listsort.txt for more info.
|
||||
MIN_GALLOP = 7
|
||||
|
||||
# Start size for temp arrays.
|
||||
MERGESTATE_TEMP_SIZE = 256
|
||||
|
||||
# A mergestate is a named tuple with the following members:
|
||||
# - *min_gallop* is an integer controlling when we get into galloping mode
|
||||
# - *keys* is a temp list for merging keys
|
||||
# - *values* is a temp list for merging values, if needed
|
||||
# - *pending* is a stack of pending runs to be merged
|
||||
# - *n* is the current stack length of *pending*
|
||||
|
||||
MergeState = collections.namedtuple(
|
||||
'MergeState', ('min_gallop', 'keys', 'values', 'pending', 'n'))
|
||||
|
||||
|
||||
MergeRun = collections.namedtuple('MergeRun', ('start', 'size'))
|
||||
|
||||
|
||||
def make_timsort_impl(wrap, make_temp_area):
|
||||
|
||||
make_temp_area = wrap(make_temp_area)
|
||||
intp = types.intp
|
||||
zero = intp(0)
|
||||
|
||||
@wrap
|
||||
def has_values(keys, values):
|
||||
return values is not keys
|
||||
|
||||
@wrap
|
||||
def merge_init(keys):
|
||||
"""
|
||||
Initialize a MergeState for a non-keyed sort.
|
||||
"""
|
||||
temp_size = min(len(keys) // 2 + 1, MERGESTATE_TEMP_SIZE)
|
||||
temp_keys = make_temp_area(keys, temp_size)
|
||||
temp_values = temp_keys
|
||||
pending = [MergeRun(zero, zero)] * MAX_MERGE_PENDING
|
||||
return MergeState(intp(MIN_GALLOP), temp_keys, temp_values, pending, zero)
|
||||
|
||||
@wrap
|
||||
def merge_init_with_values(keys, values):
|
||||
"""
|
||||
Initialize a MergeState for a keyed sort.
|
||||
"""
|
||||
temp_size = min(len(keys) // 2 + 1, MERGESTATE_TEMP_SIZE)
|
||||
temp_keys = make_temp_area(keys, temp_size)
|
||||
temp_values = make_temp_area(values, temp_size)
|
||||
pending = [MergeRun(zero, zero)] * MAX_MERGE_PENDING
|
||||
return MergeState(intp(MIN_GALLOP), temp_keys, temp_values, pending, zero)
|
||||
|
||||
@wrap
|
||||
def merge_append(ms, run):
|
||||
"""
|
||||
Append a run on the merge stack.
|
||||
"""
|
||||
n = ms.n
|
||||
assert n < MAX_MERGE_PENDING
|
||||
ms.pending[n] = run
|
||||
return MergeState(ms.min_gallop, ms.keys, ms.values, ms.pending, n + 1)
|
||||
|
||||
@wrap
|
||||
def merge_pop(ms):
|
||||
"""
|
||||
Pop the top run from the merge stack.
|
||||
"""
|
||||
return MergeState(ms.min_gallop, ms.keys, ms.values, ms.pending, ms.n - 1)
|
||||
|
||||
@wrap
|
||||
def merge_getmem(ms, need):
|
||||
"""
|
||||
Ensure enough temp memory for 'need' items is available.
|
||||
"""
|
||||
alloced = len(ms.keys)
|
||||
if need <= alloced:
|
||||
return ms
|
||||
# Over-allocate
|
||||
while alloced < need:
|
||||
alloced = alloced << 1
|
||||
# Don't realloc! That can cost cycles to copy the old data, but
|
||||
# we don't care what's in the block.
|
||||
temp_keys = make_temp_area(ms.keys, alloced)
|
||||
if has_values(ms.keys, ms.values):
|
||||
temp_values = make_temp_area(ms.values, alloced)
|
||||
else:
|
||||
temp_values = temp_keys
|
||||
return MergeState(ms.min_gallop, temp_keys, temp_values, ms.pending, ms.n)
|
||||
|
||||
@wrap
|
||||
def merge_adjust_gallop(ms, new_gallop):
|
||||
"""
|
||||
Modify the MergeState's min_gallop.
|
||||
"""
|
||||
return MergeState(intp(new_gallop), ms.keys, ms.values, ms.pending, ms.n)
|
||||
|
||||
|
||||
@wrap
|
||||
def LT(a, b):
|
||||
"""
|
||||
Trivial comparison function between two keys. This is factored out to
|
||||
make it clear where comparisons occur.
|
||||
"""
|
||||
return a < b
|
||||
|
||||
@wrap
|
||||
def binarysort(keys, values, lo, hi, start):
|
||||
"""
|
||||
binarysort is the best method for sorting small arrays: it does
|
||||
few compares, but can do data movement quadratic in the number of
|
||||
elements.
|
||||
[lo, hi) is a contiguous slice of a list, and is sorted via
|
||||
binary insertion. This sort is stable.
|
||||
On entry, must have lo <= start <= hi, and that [lo, start) is already
|
||||
sorted (pass start == lo if you don't know!).
|
||||
"""
|
||||
assert lo <= start and start <= hi
|
||||
_has_values = has_values(keys, values)
|
||||
if lo == start:
|
||||
start += 1
|
||||
while start < hi:
|
||||
pivot = keys[start]
|
||||
# Bisect to find where to insert `pivot`
|
||||
# NOTE: bisection only wins over linear search if the comparison
|
||||
# function is much more expensive than simply moving data.
|
||||
l = lo
|
||||
r = start
|
||||
# Invariants:
|
||||
# pivot >= all in [lo, l).
|
||||
# pivot < all in [r, start).
|
||||
# The second is vacuously true at the start.
|
||||
while l < r:
|
||||
p = l + ((r - l) >> 1)
|
||||
if LT(pivot, keys[p]):
|
||||
r = p
|
||||
else:
|
||||
l = p+1
|
||||
|
||||
# The invariants still hold, so pivot >= all in [lo, l) and
|
||||
# pivot < all in [l, start), so pivot belongs at l. Note
|
||||
# that if there are elements equal to pivot, l points to the
|
||||
# first slot after them -- that's why this sort is stable.
|
||||
# Slide over to make room (aka memmove()).
|
||||
for p in range(start, l, -1):
|
||||
keys[p] = keys[p - 1]
|
||||
keys[l] = pivot
|
||||
if _has_values:
|
||||
pivot_val = values[start]
|
||||
for p in range(start, l, -1):
|
||||
values[p] = values[p - 1]
|
||||
values[l] = pivot_val
|
||||
|
||||
start += 1
|
||||
|
||||
|
||||
@wrap
|
||||
def count_run(keys, lo, hi):
|
||||
"""
|
||||
Return the length of the run beginning at lo, in the slice [lo, hi).
|
||||
lo < hi is required on entry. "A run" is the longest ascending sequence, with
|
||||
|
||||
lo[0] <= lo[1] <= lo[2] <= ...
|
||||
|
||||
or the longest descending sequence, with
|
||||
|
||||
lo[0] > lo[1] > lo[2] > ...
|
||||
|
||||
A tuple (length, descending) is returned, where boolean *descending*
|
||||
is set to 0 in the former case, or to 1 in the latter.
|
||||
For its intended use in a stable mergesort, the strictness of the defn of
|
||||
"descending" is needed so that the caller can safely reverse a descending
|
||||
sequence without violating stability (strict > ensures there are no equal
|
||||
elements to get out of order).
|
||||
"""
|
||||
assert lo < hi
|
||||
if lo + 1 == hi:
|
||||
# Trivial 1-long run
|
||||
return 1, False
|
||||
if LT(keys[lo + 1], keys[lo]):
|
||||
# Descending run
|
||||
for k in range(lo + 2, hi):
|
||||
if not LT(keys[k], keys[k - 1]):
|
||||
return k - lo, True
|
||||
return hi - lo, True
|
||||
else:
|
||||
# Ascending run
|
||||
for k in range(lo + 2, hi):
|
||||
if LT(keys[k], keys[k - 1]):
|
||||
return k - lo, False
|
||||
return hi - lo, False
|
||||
|
||||
|
||||
@wrap
|
||||
def gallop_left(key, a, start, stop, hint):
|
||||
"""
|
||||
Locate the proper position of key in a sorted vector; if the vector contains
|
||||
an element equal to key, return the position immediately to the left of
|
||||
the leftmost equal element. [gallop_right() does the same except returns
|
||||
the position to the right of the rightmost equal element (if any).]
|
||||
|
||||
"a" is a sorted vector with stop elements, starting at a[start].
|
||||
stop must be > start.
|
||||
|
||||
"hint" is an index at which to begin the search, start <= hint < stop.
|
||||
The closer hint is to the final result, the faster this runs.
|
||||
|
||||
The return value is the int k in start..stop such that
|
||||
|
||||
a[k-1] < key <= a[k]
|
||||
|
||||
pretending that a[start-1] is minus infinity and a[stop] is plus infinity.
|
||||
IOW, key belongs at index k; or, IOW, the first k elements of a should
|
||||
precede key, and the last stop-start-k should follow key.
|
||||
|
||||
See listsort.txt for info on the method.
|
||||
"""
|
||||
assert stop > start
|
||||
assert hint >= start and hint < stop
|
||||
n = stop - start
|
||||
|
||||
# First, gallop from the hint to find a "good" subinterval for bisecting
|
||||
lastofs = 0
|
||||
ofs = 1
|
||||
if LT(a[hint], key):
|
||||
# a[hint] < key => gallop right, until
|
||||
# a[hint + lastofs] < key <= a[hint + ofs]
|
||||
maxofs = stop - hint
|
||||
while ofs < maxofs:
|
||||
if LT(a[hint + ofs], key):
|
||||
lastofs = ofs
|
||||
ofs = (ofs << 1) + 1
|
||||
if ofs <= 0:
|
||||
# Int overflow
|
||||
ofs = maxofs
|
||||
else:
|
||||
# key <= a[hint + ofs]
|
||||
break
|
||||
if ofs > maxofs:
|
||||
ofs = maxofs
|
||||
# Translate back to offsets relative to a[0]
|
||||
lastofs += hint
|
||||
ofs += hint
|
||||
else:
|
||||
# key <= a[hint] => gallop left, until
|
||||
# a[hint - ofs] < key <= a[hint - lastofs]
|
||||
maxofs = hint - start + 1
|
||||
while ofs < maxofs:
|
||||
if LT(a[hint - ofs], key):
|
||||
break
|
||||
else:
|
||||
# key <= a[hint - ofs]
|
||||
lastofs = ofs
|
||||
ofs = (ofs << 1) + 1
|
||||
if ofs <= 0:
|
||||
# Int overflow
|
||||
ofs = maxofs
|
||||
if ofs > maxofs:
|
||||
ofs = maxofs
|
||||
# Translate back to positive offsets relative to a[0]
|
||||
lastofs, ofs = hint - ofs, hint - lastofs
|
||||
|
||||
assert start - 1 <= lastofs and lastofs < ofs and ofs <= stop
|
||||
# Now a[lastofs] < key <= a[ofs], so key belongs somewhere to the
|
||||
# right of lastofs but no farther right than ofs. Do a binary
|
||||
# search, with invariant a[lastofs-1] < key <= a[ofs].
|
||||
lastofs += 1
|
||||
while lastofs < ofs:
|
||||
m = lastofs + ((ofs - lastofs) >> 1)
|
||||
if LT(a[m], key):
|
||||
# a[m] < key
|
||||
lastofs = m + 1
|
||||
else:
|
||||
# key <= a[m]
|
||||
ofs = m
|
||||
# Now lastofs == ofs, so a[ofs - 1] < key <= a[ofs]
|
||||
return ofs
|
||||
|
||||
|
||||
@wrap
|
||||
def gallop_right(key, a, start, stop, hint):
|
||||
"""
|
||||
Exactly like gallop_left(), except that if key already exists in a[start:stop],
|
||||
finds the position immediately to the right of the rightmost equal value.
|
||||
|
||||
The return value is the int k in start..stop such that
|
||||
|
||||
a[k-1] <= key < a[k]
|
||||
|
||||
The code duplication is massive, but this is enough different given that
|
||||
we're sticking to "<" comparisons that it's much harder to follow if
|
||||
written as one routine with yet another "left or right?" flag.
|
||||
"""
|
||||
assert stop > start
|
||||
assert hint >= start and hint < stop
|
||||
n = stop - start
|
||||
|
||||
# First, gallop from the hint to find a "good" subinterval for bisecting
|
||||
lastofs = 0
|
||||
ofs = 1
|
||||
if LT(key, a[hint]):
|
||||
# key < a[hint] => gallop left, until
|
||||
# a[hint - ofs] <= key < a[hint - lastofs]
|
||||
maxofs = hint - start + 1
|
||||
while ofs < maxofs:
|
||||
if LT(key, a[hint - ofs]):
|
||||
lastofs = ofs
|
||||
ofs = (ofs << 1) + 1
|
||||
if ofs <= 0:
|
||||
# Int overflow
|
||||
ofs = maxofs
|
||||
else:
|
||||
# a[hint - ofs] <= key
|
||||
break
|
||||
if ofs > maxofs:
|
||||
ofs = maxofs
|
||||
# Translate back to positive offsets relative to a[0]
|
||||
lastofs, ofs = hint - ofs, hint - lastofs
|
||||
else:
|
||||
# a[hint] <= key -- gallop right, until
|
||||
# a[hint + lastofs] <= key < a[hint + ofs]
|
||||
maxofs = stop - hint
|
||||
while ofs < maxofs:
|
||||
if LT(key, a[hint + ofs]):
|
||||
break
|
||||
else:
|
||||
# a[hint + ofs] <= key
|
||||
lastofs = ofs
|
||||
ofs = (ofs << 1) + 1
|
||||
if ofs <= 0:
|
||||
# Int overflow
|
||||
ofs = maxofs
|
||||
if ofs > maxofs:
|
||||
ofs = maxofs
|
||||
# Translate back to offsets relative to a[0]
|
||||
lastofs += hint
|
||||
ofs += hint
|
||||
|
||||
assert start - 1 <= lastofs and lastofs < ofs and ofs <= stop
|
||||
# Now a[lastofs] <= key < a[ofs], so key belongs somewhere to the
|
||||
# right of lastofs but no farther right than ofs. Do a binary
|
||||
# search, with invariant a[lastofs-1] <= key < a[ofs].
|
||||
lastofs += 1
|
||||
while lastofs < ofs:
|
||||
m = lastofs + ((ofs - lastofs) >> 1)
|
||||
if LT(key, a[m]):
|
||||
# key < a[m]
|
||||
ofs = m
|
||||
else:
|
||||
# a[m] <= key
|
||||
lastofs = m + 1
|
||||
# Now lastofs == ofs, so a[ofs - 1] <= key < a[ofs]
|
||||
return ofs
|
||||
|
||||
|
||||
@wrap
|
||||
def merge_compute_minrun(n):
|
||||
"""
|
||||
Compute a good value for the minimum run length; natural runs shorter
|
||||
than this are boosted artificially via binary insertion.
|
||||
|
||||
If n < 64, return n (it's too small to bother with fancy stuff).
|
||||
Else if n is an exact power of 2, return 32.
|
||||
Else return an int k, 32 <= k <= 64, such that n/k is close to, but
|
||||
strictly less than, an exact power of 2.
|
||||
|
||||
See listsort.txt for more info.
|
||||
"""
|
||||
r = 0
|
||||
assert n >= 0
|
||||
while n >= 64:
|
||||
r |= n & 1
|
||||
n >>= 1
|
||||
return n + r
|
||||
|
||||
|
||||
@wrap
|
||||
def sortslice_copy(dest_keys, dest_values, dest_start,
|
||||
src_keys, src_values, src_start,
|
||||
nitems):
|
||||
"""
|
||||
Upwards memcpy().
|
||||
"""
|
||||
assert src_start >= 0
|
||||
assert dest_start >= 0
|
||||
for i in range(nitems):
|
||||
dest_keys[dest_start + i] = src_keys[src_start + i]
|
||||
if has_values(src_keys, src_values):
|
||||
for i in range(nitems):
|
||||
dest_values[dest_start + i] = src_values[src_start + i]
|
||||
|
||||
@wrap
|
||||
def sortslice_copy_down(dest_keys, dest_values, dest_start,
|
||||
src_keys, src_values, src_start,
|
||||
nitems):
|
||||
"""
|
||||
Downwards memcpy().
|
||||
"""
|
||||
assert src_start >= 0
|
||||
assert dest_start >= 0
|
||||
for i in range(nitems):
|
||||
dest_keys[dest_start - i] = src_keys[src_start - i]
|
||||
if has_values(src_keys, src_values):
|
||||
for i in range(nitems):
|
||||
dest_values[dest_start - i] = src_values[src_start - i]
|
||||
|
||||
|
||||
# Disable this for debug or perf comparison
|
||||
DO_GALLOP = 1
|
||||
|
||||
@wrap
|
||||
def merge_lo(ms, keys, values, ssa, na, ssb, nb):
|
||||
"""
|
||||
Merge the na elements starting at ssa with the nb elements starting at
|
||||
ssb = ssa + na in a stable way, in-place. na and nb must be > 0,
|
||||
and should have na <= nb. See listsort.txt for more info.
|
||||
|
||||
An updated MergeState is returned (with possibly a different min_gallop
|
||||
or larger temp arrays).
|
||||
|
||||
NOTE: compared to CPython's timsort, the requirement that
|
||||
"Must also have that keys[ssa + na - 1] belongs at the end of the merge"
|
||||
|
||||
is removed. This makes the code a bit simpler and easier to reason about.
|
||||
"""
|
||||
assert na > 0 and nb > 0 and na <= nb
|
||||
assert ssb == ssa + na
|
||||
# First copy [ssa, ssa + na) into the temp space
|
||||
ms = merge_getmem(ms, na)
|
||||
sortslice_copy(ms.keys, ms.values, 0,
|
||||
keys, values, ssa,
|
||||
na)
|
||||
a_keys = ms.keys
|
||||
a_values = ms.values
|
||||
b_keys = keys
|
||||
b_values = values
|
||||
dest = ssa
|
||||
ssa = 0
|
||||
|
||||
_has_values = has_values(a_keys, a_values)
|
||||
min_gallop = ms.min_gallop
|
||||
|
||||
# Now start merging into the space left from [ssa, ...)
|
||||
|
||||
while nb > 0 and na > 0:
|
||||
# Do the straightforward thing until (if ever) one run
|
||||
# appears to win consistently.
|
||||
acount = 0
|
||||
bcount = 0
|
||||
|
||||
while True:
|
||||
if LT(b_keys[ssb], a_keys[ssa]):
|
||||
keys[dest] = b_keys[ssb]
|
||||
if _has_values:
|
||||
values[dest] = b_values[ssb]
|
||||
dest += 1
|
||||
ssb += 1
|
||||
nb -= 1
|
||||
if nb == 0:
|
||||
break
|
||||
# It's a B run
|
||||
bcount += 1
|
||||
acount = 0
|
||||
if bcount >= min_gallop:
|
||||
break
|
||||
else:
|
||||
keys[dest] = a_keys[ssa]
|
||||
if _has_values:
|
||||
values[dest] = a_values[ssa]
|
||||
dest += 1
|
||||
ssa += 1
|
||||
na -= 1
|
||||
if na == 0:
|
||||
break
|
||||
# It's a A run
|
||||
acount += 1
|
||||
bcount = 0
|
||||
if acount >= min_gallop:
|
||||
break
|
||||
|
||||
# One run is winning so consistently that galloping may
|
||||
# be a huge win. So try that, and continue galloping until
|
||||
# (if ever) neither run appears to be winning consistently
|
||||
# anymore.
|
||||
if DO_GALLOP and na > 0 and nb > 0:
|
||||
min_gallop += 1
|
||||
|
||||
while acount >= MIN_GALLOP or bcount >= MIN_GALLOP:
|
||||
# As long as we gallop without leaving this loop, make
|
||||
# the heuristic more likely
|
||||
min_gallop -= min_gallop > 1
|
||||
|
||||
# Gallop in A to find where keys[ssb] should end up
|
||||
k = gallop_right(b_keys[ssb], a_keys, ssa, ssa + na, ssa)
|
||||
# k is an index, make it a size
|
||||
k -= ssa
|
||||
acount = k
|
||||
if k > 0:
|
||||
# Copy everything from A before k
|
||||
sortslice_copy(keys, values, dest,
|
||||
a_keys, a_values, ssa,
|
||||
k)
|
||||
dest += k
|
||||
ssa += k
|
||||
na -= k
|
||||
if na == 0:
|
||||
# Finished merging
|
||||
break
|
||||
# Copy keys[ssb]
|
||||
keys[dest] = b_keys[ssb]
|
||||
if _has_values:
|
||||
values[dest] = b_values[ssb]
|
||||
dest += 1
|
||||
ssb += 1
|
||||
nb -= 1
|
||||
if nb == 0:
|
||||
# Finished merging
|
||||
break
|
||||
|
||||
# Gallop in B to find where keys[ssa] should end up
|
||||
k = gallop_left(a_keys[ssa], b_keys, ssb, ssb + nb, ssb)
|
||||
# k is an index, make it a size
|
||||
k -= ssb
|
||||
bcount = k
|
||||
if k > 0:
|
||||
# Copy everything from B before k
|
||||
# NOTE: source and dest are the same buffer, but the
|
||||
# destination index is below the source index
|
||||
sortslice_copy(keys, values, dest,
|
||||
b_keys, b_values, ssb,
|
||||
k)
|
||||
dest += k
|
||||
ssb += k
|
||||
nb -= k
|
||||
if nb == 0:
|
||||
# Finished merging
|
||||
break
|
||||
# Copy keys[ssa]
|
||||
keys[dest] = a_keys[ssa]
|
||||
if _has_values:
|
||||
values[dest] = a_values[ssa]
|
||||
dest += 1
|
||||
ssa += 1
|
||||
na -= 1
|
||||
if na == 0:
|
||||
# Finished merging
|
||||
break
|
||||
|
||||
# Penalize it for leaving galloping mode
|
||||
min_gallop += 1
|
||||
|
||||
# Merge finished, now handle the remaining areas
|
||||
if nb == 0:
|
||||
# Only A remaining to copy at the end of the destination area
|
||||
sortslice_copy(keys, values, dest,
|
||||
a_keys, a_values, ssa,
|
||||
na)
|
||||
else:
|
||||
assert na == 0
|
||||
assert dest == ssb
|
||||
# B's tail is already at the right place, do nothing
|
||||
|
||||
return merge_adjust_gallop(ms, min_gallop)
|
||||
|
||||
|
||||
@wrap
|
||||
def merge_hi(ms, keys, values, ssa, na, ssb, nb):
|
||||
"""
|
||||
Merge the na elements starting at ssa with the nb elements starting at
|
||||
ssb = ssa + na in a stable way, in-place. na and nb must be > 0,
|
||||
and should have na >= nb. See listsort.txt for more info.
|
||||
|
||||
An updated MergeState is returned (with possibly a different min_gallop
|
||||
or larger temp arrays).
|
||||
|
||||
NOTE: compared to CPython's timsort, the requirement that
|
||||
"Must also have that keys[ssa + na - 1] belongs at the end of the merge"
|
||||
|
||||
is removed. This makes the code a bit simpler and easier to reason about.
|
||||
"""
|
||||
assert na > 0 and nb > 0 and na >= nb
|
||||
assert ssb == ssa + na
|
||||
# First copy [ssb, ssb + nb) into the temp space
|
||||
ms = merge_getmem(ms, nb)
|
||||
sortslice_copy(ms.keys, ms.values, 0,
|
||||
keys, values, ssb,
|
||||
nb)
|
||||
a_keys = keys
|
||||
a_values = values
|
||||
b_keys = ms.keys
|
||||
b_values = ms.values
|
||||
|
||||
# Now start merging *in descending order* into the space left
|
||||
# from [..., ssb + nb).
|
||||
dest = ssb + nb - 1
|
||||
ssb = nb - 1
|
||||
ssa = ssa + na - 1
|
||||
|
||||
_has_values = has_values(b_keys, b_values)
|
||||
min_gallop = ms.min_gallop
|
||||
|
||||
while nb > 0 and na > 0:
|
||||
# Do the straightforward thing until (if ever) one run
|
||||
# appears to win consistently.
|
||||
acount = 0
|
||||
bcount = 0
|
||||
|
||||
while True:
|
||||
if LT(b_keys[ssb], a_keys[ssa]):
|
||||
# We merge in descending order, so copy the larger value
|
||||
keys[dest] = a_keys[ssa]
|
||||
if _has_values:
|
||||
values[dest] = a_values[ssa]
|
||||
dest -= 1
|
||||
ssa -= 1
|
||||
na -= 1
|
||||
if na == 0:
|
||||
break
|
||||
# It's a A run
|
||||
acount += 1
|
||||
bcount = 0
|
||||
if acount >= min_gallop:
|
||||
break
|
||||
else:
|
||||
keys[dest] = b_keys[ssb]
|
||||
if _has_values:
|
||||
values[dest] = b_values[ssb]
|
||||
dest -= 1
|
||||
ssb -= 1
|
||||
nb -= 1
|
||||
if nb == 0:
|
||||
break
|
||||
# It's a B run
|
||||
bcount += 1
|
||||
acount = 0
|
||||
if bcount >= min_gallop:
|
||||
break
|
||||
|
||||
# One run is winning so consistently that galloping may
|
||||
# be a huge win. So try that, and continue galloping until
|
||||
# (if ever) neither run appears to be winning consistently
|
||||
# anymore.
|
||||
if DO_GALLOP and na > 0 and nb > 0:
|
||||
min_gallop += 1
|
||||
|
||||
while acount >= MIN_GALLOP or bcount >= MIN_GALLOP:
|
||||
# As long as we gallop without leaving this loop, make
|
||||
# the heuristic more likely
|
||||
min_gallop -= min_gallop > 1
|
||||
|
||||
# Gallop in A to find where keys[ssb] should end up
|
||||
k = gallop_right(b_keys[ssb], a_keys, ssa - na + 1, ssa + 1, ssa)
|
||||
# k is an index, make it a size from the end
|
||||
k = ssa + 1 - k
|
||||
acount = k
|
||||
if k > 0:
|
||||
# Copy everything from A after k.
|
||||
# Destination and source are the same buffer, and destination
|
||||
# index is greater, so copy from the end to the start.
|
||||
sortslice_copy_down(keys, values, dest,
|
||||
a_keys, a_values, ssa,
|
||||
k)
|
||||
dest -= k
|
||||
ssa -= k
|
||||
na -= k
|
||||
if na == 0:
|
||||
# Finished merging
|
||||
break
|
||||
# Copy keys[ssb]
|
||||
keys[dest] = b_keys[ssb]
|
||||
if _has_values:
|
||||
values[dest] = b_values[ssb]
|
||||
dest -= 1
|
||||
ssb -= 1
|
||||
nb -= 1
|
||||
if nb == 0:
|
||||
# Finished merging
|
||||
break
|
||||
|
||||
# Gallop in B to find where keys[ssa] should end up
|
||||
k = gallop_left(a_keys[ssa], b_keys, ssb - nb + 1, ssb + 1, ssb)
|
||||
# k is an index, make it a size from the end
|
||||
k = ssb + 1 - k
|
||||
bcount = k
|
||||
if k > 0:
|
||||
# Copy everything from B before k
|
||||
sortslice_copy_down(keys, values, dest,
|
||||
b_keys, b_values, ssb,
|
||||
k)
|
||||
dest -= k
|
||||
ssb -= k
|
||||
nb -= k
|
||||
if nb == 0:
|
||||
# Finished merging
|
||||
break
|
||||
# Copy keys[ssa]
|
||||
keys[dest] = a_keys[ssa]
|
||||
if _has_values:
|
||||
values[dest] = a_values[ssa]
|
||||
dest -= 1
|
||||
ssa -= 1
|
||||
na -= 1
|
||||
if na == 0:
|
||||
# Finished merging
|
||||
break
|
||||
|
||||
# Penalize it for leaving galloping mode
|
||||
min_gallop += 1
|
||||
|
||||
# Merge finished, now handle the remaining areas
|
||||
if na == 0:
|
||||
# Only B remaining to copy at the front of the destination area
|
||||
sortslice_copy(keys, values, dest - nb + 1,
|
||||
b_keys, b_values, ssb - nb + 1,
|
||||
nb)
|
||||
else:
|
||||
assert nb == 0
|
||||
assert dest == ssa
|
||||
# A's front is already at the right place, do nothing
|
||||
|
||||
return merge_adjust_gallop(ms, min_gallop)
|
||||
|
||||
|
||||
@wrap
|
||||
def merge_at(ms, keys, values, i):
|
||||
"""
|
||||
Merge the two runs at stack indices i and i+1.
|
||||
|
||||
An updated MergeState is returned.
|
||||
"""
|
||||
n = ms.n
|
||||
assert n >= 2
|
||||
assert i >= 0
|
||||
assert i == n - 2 or i == n - 3
|
||||
|
||||
ssa, na = ms.pending[i]
|
||||
ssb, nb = ms.pending[i + 1]
|
||||
assert na > 0 and nb > 0
|
||||
assert ssa + na == ssb
|
||||
|
||||
# Record the length of the combined runs; if i is the 3rd-last
|
||||
# run now, also slide over the last run (which isn't involved
|
||||
# in this merge). The current run i+1 goes away in any case.
|
||||
ms.pending[i] = MergeRun(ssa, na + nb)
|
||||
if i == n - 3:
|
||||
ms.pending[i + 1] = ms.pending[i + 2]
|
||||
ms = merge_pop(ms)
|
||||
|
||||
# Where does b start in a? Elements in a before that can be
|
||||
# ignored (already in place).
|
||||
k = gallop_right(keys[ssb], keys, ssa, ssa + na, ssa)
|
||||
# [k, ssa + na) remains to be merged
|
||||
na -= k - ssa
|
||||
ssa = k
|
||||
if na == 0:
|
||||
return ms
|
||||
|
||||
# Where does a end in b? Elements in b after that can be
|
||||
# ignored (already in place).
|
||||
k = gallop_left(keys[ssa + na - 1], keys, ssb, ssb + nb, ssb + nb - 1)
|
||||
# [ssb, k) remains to be merged
|
||||
nb = k - ssb
|
||||
|
||||
# Merge what remains of the runs, using a temp array with
|
||||
# min(na, nb) elements.
|
||||
if na <= nb:
|
||||
return merge_lo(ms, keys, values, ssa, na, ssb, nb)
|
||||
else:
|
||||
return merge_hi(ms, keys, values, ssa, na, ssb, nb)
|
||||
|
||||
|
||||
@wrap
|
||||
def merge_collapse(ms, keys, values):
|
||||
"""
|
||||
Examine the stack of runs waiting to be merged, merging adjacent runs
|
||||
until the stack invariants are re-established:
|
||||
|
||||
1. len[-3] > len[-2] + len[-1]
|
||||
2. len[-2] > len[-1]
|
||||
|
||||
An updated MergeState is returned.
|
||||
|
||||
See listsort.txt for more info.
|
||||
"""
|
||||
while ms.n > 1:
|
||||
pending = ms.pending
|
||||
n = ms.n - 2
|
||||
if ((n > 0 and pending[n-1].size <= pending[n].size + pending[n+1].size) or
|
||||
(n > 1 and pending[n-2].size <= pending[n-1].size + pending[n].size)):
|
||||
if pending[n - 1].size < pending[n + 1].size:
|
||||
# Merge smaller one first
|
||||
n -= 1
|
||||
ms = merge_at(ms, keys, values, n)
|
||||
elif pending[n].size < pending[n + 1].size:
|
||||
ms = merge_at(ms, keys, values, n)
|
||||
else:
|
||||
break
|
||||
return ms
|
||||
|
||||
@wrap
|
||||
def merge_force_collapse(ms, keys, values):
|
||||
"""
|
||||
Regardless of invariants, merge all runs on the stack until only one
|
||||
remains. This is used at the end of the mergesort.
|
||||
|
||||
An updated MergeState is returned.
|
||||
"""
|
||||
while ms.n > 1:
|
||||
pending = ms.pending
|
||||
n = ms.n - 2
|
||||
if n > 0:
|
||||
if pending[n - 1].size < pending[n + 1].size:
|
||||
# Merge the smaller one first
|
||||
n -= 1
|
||||
ms = merge_at(ms, keys, values, n)
|
||||
return ms
|
||||
|
||||
|
||||
@wrap
|
||||
def reverse_slice(keys, values, start, stop):
|
||||
"""
|
||||
Reverse a slice, in-place.
|
||||
"""
|
||||
i = start
|
||||
j = stop - 1
|
||||
while i < j:
|
||||
keys[i], keys[j] = keys[j], keys[i]
|
||||
i += 1
|
||||
j -= 1
|
||||
if has_values(keys, values):
|
||||
i = start
|
||||
j = stop - 1
|
||||
while i < j:
|
||||
values[i], values[j] = values[j], values[i]
|
||||
i += 1
|
||||
j -= 1
|
||||
|
||||
|
||||
@wrap
|
||||
def run_timsort_with_mergestate(ms, keys, values):
|
||||
"""
|
||||
Run timsort with the mergestate.
|
||||
"""
|
||||
nremaining = len(keys)
|
||||
if nremaining < 2:
|
||||
return
|
||||
|
||||
# March over the array once, left to right, finding natural runs,
|
||||
# and extending short natural runs to minrun elements.
|
||||
minrun = merge_compute_minrun(nremaining)
|
||||
|
||||
lo = zero
|
||||
while nremaining > 0:
|
||||
n, desc = count_run(keys, lo, lo + nremaining)
|
||||
if desc:
|
||||
# Descending run => reverse
|
||||
reverse_slice(keys, values, lo, lo + n)
|
||||
# If short, extend to min(minrun, nremaining)
|
||||
if n < minrun:
|
||||
force = min(minrun, nremaining)
|
||||
binarysort(keys, values, lo, lo + force, lo + n)
|
||||
n = force
|
||||
# Push run onto stack, and maybe merge.
|
||||
ms = merge_append(ms, MergeRun(lo, n))
|
||||
ms = merge_collapse(ms, keys, values)
|
||||
# Advance to find next run.
|
||||
lo += n
|
||||
nremaining -= n
|
||||
|
||||
# All initial runs have been discovered, now finish merging.
|
||||
ms = merge_force_collapse(ms, keys, values)
|
||||
assert ms.n == 1
|
||||
assert ms.pending[0] == (0, len(keys))
|
||||
|
||||
|
||||
@wrap
|
||||
def run_timsort(keys):
|
||||
"""
|
||||
Run timsort over the given keys.
|
||||
"""
|
||||
values = keys
|
||||
run_timsort_with_mergestate(merge_init(keys), keys, values)
|
||||
|
||||
|
||||
@wrap
|
||||
def run_timsort_with_values(keys, values):
|
||||
"""
|
||||
Run timsort over the given keys and values.
|
||||
"""
|
||||
run_timsort_with_mergestate(merge_init_with_values(keys, values),
|
||||
keys, values)
|
||||
|
||||
return TimsortImplementation(
|
||||
wrap,
|
||||
count_run, binarysort, gallop_left, gallop_right,
|
||||
merge_init, merge_append, merge_pop,
|
||||
merge_compute_minrun, merge_lo, merge_hi, merge_at,
|
||||
merge_force_collapse, merge_collapse,
|
||||
run_timsort, run_timsort_with_values)
|
||||
|
||||
|
||||
def make_py_timsort(*args):
|
||||
return make_timsort_impl((lambda f: f), *args)
|
||||
|
||||
def make_jit_timsort(*args):
|
||||
from numba import jit
|
||||
return make_timsort_impl((lambda f: jit(nopython=True)(f)),
|
||||
*args)
|
||||
Reference in New Issue
Block a user