Videre
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import sys
|
||||
import functools
|
||||
import unittest
|
||||
import traceback
|
||||
from fnmatch import fnmatch
|
||||
from os.path import join, isfile, relpath, normpath, splitext
|
||||
|
||||
from .main import NumbaTestProgram, SerialSuite, make_tag_decorator
|
||||
from numba.core import config
|
||||
|
||||
|
||||
def load_testsuite(loader, dir):
|
||||
"""Find tests in 'dir'."""
|
||||
try:
|
||||
suite = unittest.TestSuite()
|
||||
files = []
|
||||
for f in os.listdir(dir):
|
||||
path = join(dir, f)
|
||||
if isfile(path) and fnmatch(f, 'test_*.py'):
|
||||
files.append(f)
|
||||
elif isfile(join(path, '__init__.py')):
|
||||
suite.addTests(loader.discover(path))
|
||||
for f in files:
|
||||
# turn 'f' into a filename relative to the toplevel dir...
|
||||
f = relpath(join(dir, f), loader._top_level_dir)
|
||||
# ...and translate it to a module name.
|
||||
f = splitext(normpath(f.replace(os.path.sep, '.')))[0]
|
||||
suite.addTests(loader.loadTestsFromName(f))
|
||||
return suite
|
||||
except Exception:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def run_tests(argv=None, defaultTest=None, topleveldir=None,
|
||||
xmloutput=None, verbosity=1, nomultiproc=False):
|
||||
"""
|
||||
args
|
||||
----
|
||||
- xmloutput [str or None]
|
||||
Path of XML output directory (optional)
|
||||
- verbosity [int]
|
||||
Verbosity level of tests output
|
||||
|
||||
Returns the TestResult object after running the test *suite*.
|
||||
"""
|
||||
|
||||
if xmloutput is not None:
|
||||
import xmlrunner
|
||||
runner = xmlrunner.XMLTestRunner(output=xmloutput)
|
||||
else:
|
||||
runner = None
|
||||
prog = NumbaTestProgram(argv=argv,
|
||||
module=None,
|
||||
defaultTest=defaultTest,
|
||||
topleveldir=topleveldir,
|
||||
testRunner=runner, exit=False,
|
||||
verbosity=verbosity,
|
||||
nomultiproc=nomultiproc)
|
||||
return prog.result
|
||||
@@ -0,0 +1,4 @@
|
||||
import sys
|
||||
from numba.testing import run_tests
|
||||
|
||||
sys.exit(0 if run_tests(sys.argv).wasSuccessful() else 1)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
|
||||
|
||||
def _main(argv, **kwds):
|
||||
from numba.testing import run_tests
|
||||
# This helper function assumes the first element of argv
|
||||
# is the name of the calling program.
|
||||
# The 'main' API function is invoked in-process, and thus
|
||||
# will synthesize that name.
|
||||
|
||||
if '--log' in argv:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
argv.remove('--log')
|
||||
|
||||
if '--failed-first' in argv:
|
||||
# Failed first
|
||||
argv.remove('--failed-first')
|
||||
return _FailedFirstRunner().main(argv, kwds)
|
||||
elif '--last-failed' in argv:
|
||||
argv.remove('--last-failed')
|
||||
return _FailedFirstRunner(last_failed=True).main(argv, kwds)
|
||||
else:
|
||||
return run_tests(argv, defaultTest='numba.tests',
|
||||
**kwds).wasSuccessful()
|
||||
|
||||
|
||||
def main(*argv, **kwds):
|
||||
"""keyword arguments are accepted for backward compatibility only.
|
||||
See `numba.testing.run_tests()` documentation for details."""
|
||||
return _main(['<main>'] + list(argv), **kwds)
|
||||
|
||||
|
||||
class _FailedFirstRunner(object):
|
||||
"""
|
||||
Test Runner to handle the failed-first (--failed-first) option.
|
||||
"""
|
||||
cache_filename = '.runtests_lastfailed'
|
||||
|
||||
def __init__(self, last_failed=False):
|
||||
self.last_failed = last_failed
|
||||
|
||||
def main(self, argv, kwds):
|
||||
from numba.testing import run_tests
|
||||
prog = argv[0]
|
||||
argv = argv[1:]
|
||||
flags = [a for a in argv if a.startswith('-')]
|
||||
|
||||
all_tests, failed_tests = self.find_last_failed(argv)
|
||||
# Prepare tests to run
|
||||
if failed_tests:
|
||||
ft = "There were {} previously failed tests"
|
||||
print(ft.format(len(failed_tests)))
|
||||
remaing_tests = [t for t in all_tests
|
||||
if t not in failed_tests]
|
||||
if self.last_failed:
|
||||
tests = list(failed_tests)
|
||||
else:
|
||||
tests = failed_tests + remaing_tests
|
||||
else:
|
||||
if self.last_failed:
|
||||
tests = []
|
||||
else:
|
||||
tests = list(all_tests)
|
||||
|
||||
if not tests:
|
||||
print("No tests to run")
|
||||
return True
|
||||
# Run the testsuite
|
||||
print("Running {} tests".format(len(tests)))
|
||||
print('Flags', flags)
|
||||
result = run_tests([prog] + flags + tests, **kwds)
|
||||
# Update failed tests records only if we have run the all the tests
|
||||
# last failed.
|
||||
if len(tests) == result.testsRun:
|
||||
self.save_failed_tests(result, all_tests)
|
||||
return result.wasSuccessful()
|
||||
|
||||
def save_failed_tests(self, result, all_tests):
|
||||
print("Saving failed tests to {}".format(self.cache_filename))
|
||||
cache = []
|
||||
# Find failed tests
|
||||
failed = set()
|
||||
for case in result.errors + result.failures:
|
||||
failed.add(case[0].id())
|
||||
# Build cache
|
||||
for t in all_tests:
|
||||
if t in failed:
|
||||
cache.append(t)
|
||||
# Write cache
|
||||
with open(self.cache_filename, 'w') as fout:
|
||||
json.dump(cache, fout)
|
||||
|
||||
def find_last_failed(self, argv):
|
||||
from numba.tests.support import captured_output
|
||||
|
||||
# Find all tests
|
||||
listargv = ['-l'] + [a for a in argv if not a.startswith('-')]
|
||||
with captured_output("stdout") as stream:
|
||||
main(*listargv)
|
||||
|
||||
pat = re.compile(r"^(\w+\.)+\w+$")
|
||||
lines = stream.getvalue().splitlines()
|
||||
all_tests = [x for x in lines if pat.match(x) is not None]
|
||||
|
||||
try:
|
||||
fobj = open(self.cache_filename)
|
||||
except OSError:
|
||||
failed_tests = []
|
||||
else:
|
||||
with fobj as fin:
|
||||
failed_tests = json.load(fin)
|
||||
return all_tests, failed_tests
|
||||
@@ -0,0 +1,26 @@
|
||||
from unittest import loader, case
|
||||
from os.path import isdir, isfile, join, dirname, basename
|
||||
|
||||
|
||||
class TestLoader(loader.TestLoader):
|
||||
|
||||
def __init__(self, topleveldir=None):
|
||||
super(TestLoader, self).__init__()
|
||||
self._top_level_dir = topleveldir or dirname(dirname(dirname(__file__)))
|
||||
|
||||
def _find_tests(self, start_dir, pattern, namespace=False):
|
||||
# Upstream doesn't look for 'load_tests' in start_dir.
|
||||
|
||||
if isdir(start_dir) and not namespace and isfile(join(start_dir, '__init__.py')):
|
||||
name = self._get_name_from_path(start_dir)
|
||||
package = self._get_module_from_name(name)
|
||||
load_tests = getattr(package, 'load_tests', None)
|
||||
tests = self.loadTestsFromModule(package)
|
||||
if load_tests is not None:
|
||||
try:
|
||||
yield load_tests(self, tests, pattern)
|
||||
except Exception as e:
|
||||
yield loader._make_failed_load_tests(package.__name__, e, self.suiteClass)
|
||||
else:
|
||||
for t in super(TestLoader, self)._find_tests(start_dir, pattern):
|
||||
yield t
|
||||
@@ -0,0 +1,871 @@
|
||||
import collections
|
||||
import contextlib
|
||||
import cProfile
|
||||
import inspect
|
||||
import gc
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
import warnings
|
||||
import zlib
|
||||
import traceback
|
||||
|
||||
from functools import lru_cache
|
||||
from io import StringIO
|
||||
from unittest import result, runner, signals, suite, loader, case
|
||||
|
||||
from .loader import TestLoader
|
||||
from numba.core import config
|
||||
from numba.misc import memoryutils
|
||||
|
||||
try:
|
||||
from multiprocessing import TimeoutError
|
||||
except ImportError:
|
||||
from Queue import Empty as TimeoutError
|
||||
|
||||
|
||||
def make_tag_decorator(known_tags):
|
||||
"""
|
||||
Create a decorator allowing tests to be tagged with the *known_tags*.
|
||||
"""
|
||||
|
||||
def tag(*tags):
|
||||
"""
|
||||
Tag a test method with the given tags.
|
||||
Can be used in conjunction with the --tags command-line argument
|
||||
for runtests.py.
|
||||
"""
|
||||
for t in tags:
|
||||
if t not in known_tags:
|
||||
raise ValueError("unknown tag: %r" % (t,))
|
||||
|
||||
def decorate(func):
|
||||
if (not callable(func) or isinstance(func, type)
|
||||
or not func.__name__.startswith('test_')):
|
||||
raise TypeError("@tag(...) should be used on test methods")
|
||||
try:
|
||||
s = func.tags
|
||||
except AttributeError:
|
||||
s = func.tags = set()
|
||||
s.update(tags)
|
||||
return func
|
||||
return decorate
|
||||
|
||||
return tag
|
||||
|
||||
|
||||
# Chances are the next queried class is the same as the previous, locally 128
|
||||
# entries seems to be fastest.
|
||||
# Current number of test classes can be found with:
|
||||
# $ ./runtests.py -l|sed -e 's/\(.*\)\..*/\1/'|grep ^numba|sort|uniq|wc -l
|
||||
# as of writing it's 658.
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_mtime(cls):
|
||||
"""
|
||||
Gets the mtime of the file in which a test class is defined.
|
||||
"""
|
||||
return str(os.path.getmtime(inspect.getfile(cls)))
|
||||
|
||||
|
||||
def cuda_sensitive_mtime(x):
|
||||
"""
|
||||
Return a key for sorting tests bases on mtime and test name. For CUDA
|
||||
tests, interleaving tests from different classes is dangerous as the CUDA
|
||||
context might get reset unexpectedly between methods of a class, so for
|
||||
CUDA tests the key prioritises the test module and class ahead of the
|
||||
mtime.
|
||||
"""
|
||||
cls = x.__class__
|
||||
key = _get_mtime(cls) + str(x)
|
||||
|
||||
from numba.cuda.testing import CUDATestCase
|
||||
if CUDATestCase in cls.mro():
|
||||
key = "%s.%s %s" % (str(cls.__module__), str(cls.__name__), key)
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def parse_slice(useslice):
|
||||
"""Parses the argument string "useslice" as a shard index and number and
|
||||
returns a function that filters on those arguments. i.e. input
|
||||
useslice="1:3" leads to output something like `lambda x: zlib.crc32(x) % 3
|
||||
== 1`.
|
||||
"""
|
||||
if callable(useslice):
|
||||
return useslice
|
||||
if not useslice:
|
||||
return lambda x: True
|
||||
try:
|
||||
(index, count) = useslice.split(":")
|
||||
index = int(index)
|
||||
count = int(count)
|
||||
except Exception:
|
||||
msg = (
|
||||
"Expected arguments shard index and count to follow "
|
||||
"option `-j i:t`, where i is the shard number and t "
|
||||
"is the total number of shards, found '%s'" % useslice)
|
||||
raise ValueError(msg)
|
||||
if count == 0:
|
||||
return lambda x: True
|
||||
elif count < 0 or index < 0 or index >= count:
|
||||
raise ValueError("Sharding out of range")
|
||||
else:
|
||||
def decide(test):
|
||||
func = getattr(test, test._testMethodName)
|
||||
if "always_test" in getattr(func, 'tags', {}):
|
||||
return True
|
||||
return abs(zlib.crc32(test.id().encode('utf-8'))) % count == index
|
||||
return decide
|
||||
|
||||
|
||||
class TestLister(object):
|
||||
"""Simply list available tests rather than running them."""
|
||||
def __init__(self, useslice):
|
||||
self.useslice = parse_slice(useslice)
|
||||
|
||||
def run(self, test):
|
||||
result = runner.TextTestResult(sys.stderr, descriptions=True,
|
||||
verbosity=1)
|
||||
self._test_list = _flatten_suite(test)
|
||||
masked_list = list(filter(self.useslice, self._test_list))
|
||||
self._test_list.sort(key=cuda_sensitive_mtime)
|
||||
for t in masked_list:
|
||||
print(t.id())
|
||||
print('%d tests found. %s selected' % (len(self._test_list),
|
||||
len(masked_list)))
|
||||
return result
|
||||
|
||||
|
||||
class SerialSuite(unittest.TestSuite):
|
||||
"""A simple marker to make sure tests in this suite are run serially.
|
||||
|
||||
Note: As the suite is going through internals of unittest,
|
||||
it may get unpacked and stuffed into a plain TestSuite.
|
||||
We need to set an attribute on the TestCase objects to
|
||||
remember they should not be run in parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, tests=()):
|
||||
super(SerialSuite, self).__init__(tests)
|
||||
self.resource_infos = []
|
||||
|
||||
def addTest(self, test):
|
||||
if not isinstance(test, unittest.TestCase):
|
||||
# It's a sub-suite, recurse
|
||||
for t in test:
|
||||
self.addTest(t)
|
||||
else:
|
||||
# It's a test case, mark it serial
|
||||
test._numba_parallel_test_ = False
|
||||
super(SerialSuite, self).addTest(test)
|
||||
|
||||
def run(self, result):
|
||||
# Run each test with memory tracking
|
||||
for test in self:
|
||||
if result.shouldStop:
|
||||
break
|
||||
memtrack = memoryutils.MemoryTracker(test.id())
|
||||
with memtrack.monitor():
|
||||
test(result)
|
||||
self.resource_infos.append(memtrack.get_summary())
|
||||
return result
|
||||
|
||||
|
||||
class BasicTestRunner(runner.TextTestRunner):
|
||||
def __init__(self, useslice, **kwargs):
|
||||
runner.TextTestRunner.__init__(self, **kwargs)
|
||||
self.useslice = parse_slice(useslice)
|
||||
|
||||
def run(self, test):
|
||||
run = list(filter(self.useslice, _flatten_suite(test)))
|
||||
run.sort(key=cuda_sensitive_mtime)
|
||||
wrapped = unittest.TestSuite(run)
|
||||
return super(BasicTestRunner, self).run(wrapped)
|
||||
|
||||
|
||||
# "unittest.main" is really the TestProgram class!
|
||||
# (defined in a module named itself "unittest.main"...)
|
||||
|
||||
class NumbaTestProgram(unittest.main):
|
||||
"""
|
||||
A TestProgram subclass adding the following options:
|
||||
* a -R option to enable reference leak detection
|
||||
* a --profile option to enable profiling of the test run
|
||||
* a -m option for parallel execution
|
||||
* a -l option to (only) list tests
|
||||
|
||||
Currently the options are only added in 3.4+.
|
||||
"""
|
||||
|
||||
refleak = False
|
||||
profile = False
|
||||
multiprocess = False
|
||||
useslice = None
|
||||
list = False
|
||||
tags = None
|
||||
exclude_tags = None
|
||||
random_select = None
|
||||
random_seed = 42
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
topleveldir = kwargs.pop('topleveldir', None)
|
||||
kwargs['testLoader'] = TestLoader(topleveldir)
|
||||
|
||||
# HACK to force unittest not to change warning display options
|
||||
# (so that NumbaWarnings don't appear all over the place)
|
||||
sys.warnoptions.append(':x')
|
||||
self.nomultiproc = kwargs.pop('nomultiproc', False)
|
||||
super(NumbaTestProgram, self).__init__(*args, **kwargs)
|
||||
|
||||
def _getParentArgParser(self):
|
||||
# NOTE: this hook only exists on Python 3.4+. The options won't be
|
||||
# added in earlier versions (which use optparse - 3.3 - or getopt()
|
||||
# - 2.x).
|
||||
parser = super(NumbaTestProgram, self)._getParentArgParser()
|
||||
if self.testRunner is None:
|
||||
parser.add_argument('-R', '--refleak', dest='refleak',
|
||||
action='store_true',
|
||||
help='Detect reference / memory leaks')
|
||||
parser.add_argument('-m', '--multiprocess', dest='multiprocess',
|
||||
nargs='?',
|
||||
type=int,
|
||||
const=multiprocessing.cpu_count(),
|
||||
help='Parallelize tests')
|
||||
parser.add_argument('-l', '--list', dest='list',
|
||||
action='store_true',
|
||||
help='List tests without running them')
|
||||
parser.add_argument('--tags', dest='tags', type=str,
|
||||
help='Comma-separated list of tags to select '
|
||||
'a subset of the test suite')
|
||||
parser.add_argument('--exclude-tags', dest='exclude_tags', type=str,
|
||||
help='Comma-separated list of tags to de-select '
|
||||
'a subset of the test suite')
|
||||
parser.add_argument('--random', dest='random_select', type=float,
|
||||
help='Random proportion of tests to select')
|
||||
parser.add_argument('--profile', dest='profile',
|
||||
action='store_true',
|
||||
help='Profile the test run')
|
||||
parser.add_argument('-j', '--slice', dest='useslice', nargs='?',
|
||||
type=str, const="None",
|
||||
help='Shard the test sequence')
|
||||
|
||||
def git_diff_str(x):
|
||||
if x != 'ancestor':
|
||||
raise ValueError("invalid option for --gitdiff")
|
||||
return x
|
||||
|
||||
parser.add_argument('-g', '--gitdiff', dest='gitdiff', type=git_diff_str,
|
||||
default=False, nargs='?',
|
||||
help=('Run tests from changes made against '
|
||||
'origin/release0.65 as identified by `git diff`. '
|
||||
'If set to "ancestor", the diff compares '
|
||||
'against the common ancestor.'))
|
||||
return parser
|
||||
|
||||
def _handle_tags(self, argv, tagstr):
|
||||
found = None
|
||||
for x in argv:
|
||||
if tagstr in x:
|
||||
if found is None:
|
||||
found = x
|
||||
else:
|
||||
raise ValueError("argument %s supplied repeatedly" % tagstr)
|
||||
|
||||
if found is not None:
|
||||
posn = argv.index(found)
|
||||
try:
|
||||
if found == tagstr: # --tagstr <arg>
|
||||
tag_args = argv[posn + 1].strip()
|
||||
argv.remove(tag_args)
|
||||
else: # --tagstr=<arg>
|
||||
if '=' in found:
|
||||
tag_args = found.split('=')[1].strip()
|
||||
else:
|
||||
raise AssertionError('unreachable')
|
||||
except IndexError:
|
||||
# at end of arg list, raise
|
||||
msg = "%s requires at least one tag to be specified"
|
||||
raise ValueError(msg % tagstr)
|
||||
# see if next arg is "end options" or some other flag
|
||||
if tag_args.startswith('-'):
|
||||
raise ValueError("tag starts with '-', probably a syntax error")
|
||||
# see if tag is something like "=<tagname>" which is likely a syntax
|
||||
# error of form `--tags =<tagname>`, note the space prior to `=`.
|
||||
if '=' in tag_args:
|
||||
msg = "%s argument contains '=', probably a syntax error"
|
||||
raise ValueError(msg % tagstr)
|
||||
attr = tagstr[2:].replace('-', '_')
|
||||
setattr(self, attr, tag_args)
|
||||
argv.remove(found)
|
||||
|
||||
|
||||
def parseArgs(self, argv):
|
||||
if '-l' in argv:
|
||||
argv.remove('-l')
|
||||
self.list = True
|
||||
|
||||
super(NumbaTestProgram, self).parseArgs(argv)
|
||||
|
||||
# If at this point self.test doesn't exist, it is because
|
||||
# no test ID was given in argv. Use the default instead.
|
||||
if not hasattr(self, 'test') or not self.test.countTestCases():
|
||||
self.testNames = (self.defaultTest,)
|
||||
self.createTests()
|
||||
|
||||
if self.tags:
|
||||
tags = [s.strip() for s in self.tags.split(',')]
|
||||
self.test = _choose_tagged_tests(self.test, tags, mode='include')
|
||||
|
||||
if self.exclude_tags:
|
||||
tags = [s.strip() for s in self.exclude_tags.split(',')]
|
||||
self.test = _choose_tagged_tests(self.test, tags, mode='exclude')
|
||||
|
||||
if self.random_select:
|
||||
self.test = _choose_random_tests(self.test, self.random_select,
|
||||
self.random_seed)
|
||||
|
||||
if self.gitdiff is not False:
|
||||
self.test = _choose_gitdiff_tests(
|
||||
self.test,
|
||||
use_common_ancestor=(self.gitdiff == 'ancestor'),
|
||||
)
|
||||
|
||||
if self.verbosity <= 0:
|
||||
# We aren't interested in informational messages / warnings when
|
||||
# running with '-q'.
|
||||
self.buffer = True
|
||||
|
||||
def _do_discovery(self, argv, Loader=None):
|
||||
# Disable unittest's implicit test discovery when parsing
|
||||
# CLI arguments, as it can select other tests than Numba's
|
||||
# (e.g. some test_xxx module that may happen to be directly
|
||||
# reachable from sys.path)
|
||||
return
|
||||
|
||||
def runTests(self):
|
||||
if self.refleak:
|
||||
self.testRunner = RefleakTestRunner
|
||||
|
||||
if not hasattr(sys, "gettotalrefcount"):
|
||||
warnings.warn("detecting reference leaks requires a debug build "
|
||||
"of Python, only memory leaks will be detected")
|
||||
|
||||
elif self.list:
|
||||
self.testRunner = TestLister(self.useslice)
|
||||
|
||||
elif self.testRunner is None:
|
||||
self.testRunner = BasicTestRunner(self.useslice,
|
||||
verbosity=self.verbosity,
|
||||
failfast=self.failfast,
|
||||
buffer=self.buffer)
|
||||
|
||||
if self.multiprocess and not self.nomultiproc:
|
||||
if self.multiprocess < 1:
|
||||
msg = ("Value specified for the number of processes to use in "
|
||||
"running the suite must be > 0")
|
||||
raise ValueError(msg)
|
||||
self.testRunner = ParallelTestRunner(runner.TextTestRunner,
|
||||
self.multiprocess,
|
||||
self.useslice,
|
||||
verbosity=self.verbosity,
|
||||
failfast=self.failfast,
|
||||
buffer=self.buffer)
|
||||
|
||||
def run_tests_real():
|
||||
super(NumbaTestProgram, self).runTests()
|
||||
|
||||
if self.profile:
|
||||
filename = os.path.splitext(
|
||||
os.path.basename(sys.modules['__main__'].__file__)
|
||||
)[0] + '.prof'
|
||||
p = cProfile.Profile(timer=time.perf_counter) # 3.3+
|
||||
p.enable()
|
||||
try:
|
||||
p.runcall(run_tests_real)
|
||||
finally:
|
||||
p.disable()
|
||||
print("Writing test profile data into %r" % (filename,))
|
||||
p.dump_stats(filename)
|
||||
else:
|
||||
run_tests_real()
|
||||
|
||||
|
||||
# These are tests which are generated and injected into the test suite, what
|
||||
# gets injected depends on features of the test environment, e.g. TBB presence
|
||||
# it's important for doing the CI "slice tests" that these are run at the end
|
||||
# See notes in `_flatten_suite` for why. Simple substring matching is used to
|
||||
# determine a match.
|
||||
_GENERATED = (
|
||||
"numba.cuda.tests.cudapy.test_libdevice.TestLibdeviceCompilation",
|
||||
"numba.tests.test_num_threads",
|
||||
"numba.tests.test_parallel_backend",
|
||||
"numba.tests.test_svml",
|
||||
"numba.tests.test_ufuncs",
|
||||
)
|
||||
|
||||
|
||||
def _flatten_suite_inner(test):
|
||||
"""
|
||||
Workhorse for _flatten_suite
|
||||
"""
|
||||
tests = []
|
||||
if isinstance(test, (unittest.TestSuite, list, tuple)):
|
||||
for x in test:
|
||||
tests.extend(_flatten_suite_inner(x))
|
||||
else:
|
||||
tests.append(test)
|
||||
return tests
|
||||
|
||||
|
||||
def _flatten_suite(test):
|
||||
"""
|
||||
Expand nested suite into list of test cases.
|
||||
"""
|
||||
tests = _flatten_suite_inner(test)
|
||||
# Strip out generated tests and stick them at the end, this is to make sure
|
||||
# that tests appear in a consistent order regardless of features available.
|
||||
# This is so that a slice through the test suite e.g. (1::N) would likely be
|
||||
# consistent up to the point of the generated tests, which rely on specific
|
||||
# features.
|
||||
generated = set()
|
||||
for t in tests:
|
||||
for g in _GENERATED:
|
||||
if g in str(t):
|
||||
generated.add(t)
|
||||
normal = set(tests) - generated
|
||||
def key(x):
|
||||
return x.__module__, type(x).__name__, x._testMethodName
|
||||
tests = sorted(normal, key=key)
|
||||
tests.extend(sorted(list(generated), key=key))
|
||||
return tests
|
||||
|
||||
|
||||
def _choose_gitdiff_tests(tests, *, use_common_ancestor=False):
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError:
|
||||
raise ValueError("gitpython needed for git functionality")
|
||||
repo = Repo('.')
|
||||
path = os.path.join('numba', 'tests')
|
||||
if use_common_ancestor:
|
||||
print(f"Git diff by common ancestor")
|
||||
target = 'origin/release0.65...HEAD'
|
||||
else:
|
||||
target = 'origin/release0.65..HEAD'
|
||||
gdiff_paths = repo.git.diff(target, path, name_only=True).split()
|
||||
# normalise the paths as they are unix style from repo.git.diff
|
||||
gdiff_paths = [os.path.normpath(x) for x in gdiff_paths]
|
||||
selected = []
|
||||
gdiff_paths = [os.path.join(repo.working_dir, x) for x in gdiff_paths]
|
||||
for test in _flatten_suite(tests):
|
||||
assert isinstance(test, unittest.TestCase)
|
||||
fname = inspect.getsourcefile(test.__class__)
|
||||
if fname in gdiff_paths:
|
||||
selected.append(test)
|
||||
print("Git diff identified %s tests" % len(selected))
|
||||
return unittest.TestSuite(selected)
|
||||
|
||||
def _choose_tagged_tests(tests, tags, mode='include'):
|
||||
"""
|
||||
Select tests that are tagged/not tagged with at least one of the given tags.
|
||||
Set mode to 'include' to include the tests with tags, or 'exclude' to
|
||||
exclude the tests with the tags.
|
||||
"""
|
||||
selected = []
|
||||
tags = set(tags)
|
||||
for test in _flatten_suite(tests):
|
||||
assert isinstance(test, unittest.TestCase)
|
||||
func = getattr(test, test._testMethodName)
|
||||
try:
|
||||
# Look up the method's underlying function (Python 2)
|
||||
func = func.im_func
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
found_tags = getattr(func, 'tags', None)
|
||||
# only include the test if the tags *are* present
|
||||
if mode == 'include':
|
||||
if found_tags is not None and found_tags & tags:
|
||||
selected.append(test)
|
||||
elif mode == 'exclude':
|
||||
# only include the test if the tags *are not* present
|
||||
if found_tags is None or not (found_tags & tags):
|
||||
selected.append(test)
|
||||
else:
|
||||
raise ValueError("Invalid 'mode' supplied: %s." % mode)
|
||||
return unittest.TestSuite(selected)
|
||||
|
||||
|
||||
def _choose_random_tests(tests, ratio, seed):
|
||||
"""
|
||||
Choose a given proportion of tests at random.
|
||||
"""
|
||||
rnd = random.Random()
|
||||
rnd.seed(seed)
|
||||
if isinstance(tests, unittest.TestSuite):
|
||||
tests = _flatten_suite(tests)
|
||||
tests = rnd.sample(tests, int(len(tests) * ratio))
|
||||
tests = sorted(tests, key=lambda case: case.id())
|
||||
return unittest.TestSuite(tests)
|
||||
|
||||
|
||||
# The reference leak detection code is liberally taken and adapted from
|
||||
# Python's own Lib/test/regrtest.py.
|
||||
|
||||
def _refleak_cleanup():
|
||||
# Collect cyclic trash and read memory statistics immediately after.
|
||||
func1 = sys.getallocatedblocks
|
||||
try:
|
||||
func2 = sys.gettotalrefcount
|
||||
except AttributeError:
|
||||
func2 = lambda: 42
|
||||
|
||||
# Flush standard output, so that buffered data is sent to the OS and
|
||||
# associated Python objects are reclaimed.
|
||||
for stream in (sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__):
|
||||
if stream is not None:
|
||||
stream.flush()
|
||||
|
||||
sys._clear_type_cache()
|
||||
# This also clears the various internal CPython freelists.
|
||||
gc.collect()
|
||||
return func1(), func2()
|
||||
|
||||
|
||||
class ReferenceLeakError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class IntPool(collections.defaultdict):
|
||||
|
||||
def __missing__(self, key):
|
||||
return key
|
||||
|
||||
|
||||
class RefleakTestResult(runner.TextTestResult):
|
||||
|
||||
warmup = 3
|
||||
repetitions = 6
|
||||
|
||||
def _huntLeaks(self, test):
|
||||
self.stream.flush()
|
||||
|
||||
repcount = self.repetitions
|
||||
nwarmup = self.warmup
|
||||
rc_deltas = [0] * (repcount - nwarmup)
|
||||
alloc_deltas = [0] * (repcount - nwarmup)
|
||||
# Preallocate ints likely to be stored in rc_deltas and alloc_deltas,
|
||||
# to make sys.getallocatedblocks() less flaky.
|
||||
_int_pool = IntPool()
|
||||
for i in range(-200, 200):
|
||||
_int_pool[i]
|
||||
|
||||
for i in range(repcount):
|
||||
# Use a pristine, silent result object to avoid recursion
|
||||
res = result.TestResult()
|
||||
test.run(res)
|
||||
# Poorly-written tests may fail when run several times.
|
||||
# In this case, abort the refleak run and report the failure.
|
||||
if not res.wasSuccessful():
|
||||
self.failures.extend(res.failures)
|
||||
self.errors.extend(res.errors)
|
||||
raise AssertionError
|
||||
del res
|
||||
alloc_after, rc_after = _refleak_cleanup()
|
||||
if i >= nwarmup:
|
||||
rc_deltas[i - nwarmup] = _int_pool[rc_after - rc_before]
|
||||
alloc_deltas[i - nwarmup] = _int_pool[alloc_after - alloc_before]
|
||||
alloc_before, rc_before = alloc_after, rc_after
|
||||
return rc_deltas, alloc_deltas
|
||||
|
||||
def addSuccess(self, test):
|
||||
try:
|
||||
rc_deltas, alloc_deltas = self._huntLeaks(test)
|
||||
except AssertionError:
|
||||
# Test failed when repeated
|
||||
assert not self.wasSuccessful()
|
||||
return
|
||||
|
||||
# These checkers return False on success, True on failure
|
||||
def check_rc_deltas(deltas):
|
||||
return any(deltas)
|
||||
|
||||
def check_alloc_deltas(deltas):
|
||||
# At least 1/3rd of 0s
|
||||
if 3 * deltas.count(0) < len(deltas):
|
||||
return True
|
||||
# Nothing else than 1s, 0s and -1s
|
||||
if not set(deltas) <= set((1, 0, -1)):
|
||||
return True
|
||||
return False
|
||||
|
||||
failed = False
|
||||
|
||||
for deltas, item_name, checker in [
|
||||
(rc_deltas, 'references', check_rc_deltas),
|
||||
(alloc_deltas, 'memory blocks', check_alloc_deltas)]:
|
||||
if checker(deltas):
|
||||
msg = '%s leaked %s %s, sum=%s' % (
|
||||
test, deltas, item_name, sum(deltas))
|
||||
failed = True
|
||||
try:
|
||||
raise ReferenceLeakError(msg)
|
||||
except Exception:
|
||||
exc_info = sys.exc_info()
|
||||
if self.showAll:
|
||||
self.stream.write("%s = %r " % (item_name, deltas))
|
||||
self.addFailure(test, exc_info)
|
||||
|
||||
if not failed:
|
||||
super(RefleakTestResult, self).addSuccess(test)
|
||||
|
||||
|
||||
class RefleakTestRunner(runner.TextTestRunner):
|
||||
resultclass = RefleakTestResult
|
||||
|
||||
|
||||
class ParallelTestResult(runner.TextTestResult):
|
||||
"""
|
||||
A TestResult able to inject results from other results.
|
||||
"""
|
||||
|
||||
def add_results(self, result):
|
||||
"""
|
||||
Add the results from the other *result* to this result.
|
||||
"""
|
||||
self.stream.write(result.stream.getvalue())
|
||||
self.stream.flush()
|
||||
self.testsRun += result.testsRun
|
||||
self.failures.extend(result.failures)
|
||||
self.errors.extend(result.errors)
|
||||
self.skipped.extend(result.skipped)
|
||||
self.expectedFailures.extend(result.expectedFailures)
|
||||
self.unexpectedSuccesses.extend(result.unexpectedSuccesses)
|
||||
|
||||
|
||||
class _MinimalResult(object):
|
||||
"""
|
||||
A minimal, picklable TestResult-alike object.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'failures', 'errors', 'skipped', 'expectedFailures',
|
||||
'unexpectedSuccesses', 'stream', 'shouldStop', 'testsRun',
|
||||
'test_id', 'resource_info')
|
||||
|
||||
def fixup_case(self, case):
|
||||
"""
|
||||
Remove any unpicklable attributes from TestCase instance *case*.
|
||||
"""
|
||||
# Python 3.3 doesn't reset this one.
|
||||
case._outcomeForDoCleanups = None
|
||||
|
||||
def __init__(self, original_result, test_id=None, resource_info=None):
|
||||
for attr in self.__slots__:
|
||||
setattr(self, attr, getattr(original_result, attr, None))
|
||||
for case, _ in self.expectedFailures:
|
||||
self.fixup_case(case)
|
||||
for case, _ in self.errors:
|
||||
self.fixup_case(case)
|
||||
for case, _ in self.failures:
|
||||
self.fixup_case(case)
|
||||
self.test_id = test_id
|
||||
self.resource_info = resource_info
|
||||
|
||||
|
||||
class _FakeStringIO(object):
|
||||
"""
|
||||
A trivial picklable StringIO-alike for Python 2.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def getvalue(self):
|
||||
return self._value
|
||||
|
||||
|
||||
class _MinimalRunner(object):
|
||||
"""
|
||||
A minimal picklable object able to instantiate a runner in a
|
||||
child process and run a test case with it.
|
||||
"""
|
||||
|
||||
def __init__(self, runner_cls, runner_args):
|
||||
self.runner_cls = runner_cls
|
||||
self.runner_args = runner_args
|
||||
|
||||
# Python 2 doesn't know how to pickle instance methods, so we use __call__
|
||||
# instead.
|
||||
|
||||
def __call__(self, test):
|
||||
# Executed in child process
|
||||
kwargs = self.runner_args
|
||||
# Force recording of output in a buffer (it will be printed out
|
||||
# by the parent).
|
||||
kwargs['stream'] = StringIO()
|
||||
runner = self.runner_cls(**kwargs)
|
||||
result = runner._makeResult()
|
||||
# Avoid child tracebacks when Ctrl-C is pressed.
|
||||
signals.installHandler()
|
||||
signals.registerResult(result)
|
||||
result.failfast = runner.failfast
|
||||
result.buffer = runner.buffer
|
||||
# Create a per-process memory tracker to avoid global state issues
|
||||
memtrack = memoryutils.MemoryTracker(test.id())
|
||||
with memtrack.monitor():
|
||||
with self.cleanup_object(test):
|
||||
test(result)
|
||||
# HACK as cStringIO.StringIO isn't picklable in 2.x
|
||||
result.stream = _FakeStringIO(result.stream.getvalue())
|
||||
return _MinimalResult(result, test.id(),
|
||||
resource_info=memtrack.get_summary())
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cleanup_object(self, test):
|
||||
"""
|
||||
A context manager which cleans up unwanted attributes on a test case
|
||||
(or any other object).
|
||||
"""
|
||||
vanilla_attrs = set(test.__dict__)
|
||||
try:
|
||||
yield test
|
||||
finally:
|
||||
spurious_attrs = set(test.__dict__) - vanilla_attrs
|
||||
for name in spurious_attrs:
|
||||
del test.__dict__[name]
|
||||
|
||||
|
||||
def _split_nonparallel_tests(test, sliced):
|
||||
"""
|
||||
Split test suite into parallel and serial tests.
|
||||
"""
|
||||
ptests = []
|
||||
stests = []
|
||||
|
||||
flat = [*filter(sliced, _flatten_suite(test))]
|
||||
|
||||
def is_parallelizable_test_case(test):
|
||||
# Guard for the fake test case created by unittest when test
|
||||
# discovery fails, as it isn't picklable (e.g. "LoadTestsFailure")
|
||||
method_name = test._testMethodName
|
||||
method = getattr(test, method_name)
|
||||
if method.__name__ != method_name and method.__name__ == "testFailure":
|
||||
return False
|
||||
# Was parallel execution explicitly disabled?
|
||||
return getattr(test, "_numba_parallel_test_", True)
|
||||
|
||||
for t in flat:
|
||||
if is_parallelizable_test_case(t):
|
||||
ptests.append(t)
|
||||
else:
|
||||
stests.append(t)
|
||||
|
||||
return ptests, stests
|
||||
|
||||
# A test can't run longer than 10 minutes
|
||||
_TIMEOUT = 1200
|
||||
|
||||
class ParallelTestRunner(runner.TextTestRunner):
|
||||
"""
|
||||
A test runner which delegates the actual running to a pool of child
|
||||
processes.
|
||||
"""
|
||||
|
||||
resultclass = ParallelTestResult
|
||||
timeout = _TIMEOUT
|
||||
|
||||
def __init__(self, runner_cls, nprocs, useslice, **kwargs):
|
||||
runner.TextTestRunner.__init__(self, **kwargs)
|
||||
self.runner_cls = runner_cls
|
||||
self.nprocs = nprocs
|
||||
self.useslice = parse_slice(useslice)
|
||||
self.runner_args = kwargs
|
||||
self.resource_infos = []
|
||||
|
||||
def _run_inner(self, result):
|
||||
# We hijack TextTestRunner.run()'s inner logic by passing this
|
||||
# method as if it were a test case.
|
||||
child_runner = _MinimalRunner(self.runner_cls, self.runner_args)
|
||||
|
||||
# Split the tests and recycle the worker process to tame memory usage.
|
||||
chunk_size = 100
|
||||
splitted_tests = [self._ptests[i:i + chunk_size]
|
||||
for i in range(0, len(self._ptests), chunk_size)]
|
||||
|
||||
spawnctx = multiprocessing.get_context("spawn")
|
||||
try:
|
||||
for tests in splitted_tests:
|
||||
pool = spawnctx.Pool(self.nprocs)
|
||||
try:
|
||||
self._run_parallel_tests(result, pool, child_runner, tests)
|
||||
except Exception:
|
||||
# On exception, kill still active workers immediately
|
||||
pool.terminate()
|
||||
# Make sure exception is reported and not ignored
|
||||
raise
|
||||
else:
|
||||
# Close the pool cleanly unless asked to early out
|
||||
if result.shouldStop:
|
||||
pool.terminate()
|
||||
break
|
||||
else:
|
||||
pool.close()
|
||||
finally:
|
||||
# Always join the pool (this is necessary for coverage.py)
|
||||
pool.join()
|
||||
if not result.shouldStop:
|
||||
# Run serial tests with memory tracking
|
||||
stests = SerialSuite(self._stests)
|
||||
stests.run(result)
|
||||
# Add serial test resource infos to the main collection
|
||||
self.resource_infos.extend(stests.resource_infos)
|
||||
return result
|
||||
finally:
|
||||
# Always display the resource infos
|
||||
if memoryutils.IS_SUPPORTED:
|
||||
try:
|
||||
print("=== Resource Infos ===")
|
||||
for ri in self.resource_infos:
|
||||
print(ri)
|
||||
except Exception:
|
||||
print("ERROR: Ignored exception in priting resource infos")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
print("=== End Resource Infos ===")
|
||||
|
||||
def _run_parallel_tests(self, result, pool, child_runner, tests):
|
||||
remaining_ids = set(t.id() for t in tests)
|
||||
tests.sort(key=cuda_sensitive_mtime)
|
||||
it = pool.imap_unordered(child_runner, tests)
|
||||
while True:
|
||||
try:
|
||||
child_result = it.__next__(self.timeout)
|
||||
except StopIteration:
|
||||
return
|
||||
except TimeoutError as e:
|
||||
# Diagnose the names of unfinished tests
|
||||
msg = ("Tests didn't finish before timeout (or crashed):\n%s"
|
||||
% "".join("- %r\n" % tid for tid in sorted(remaining_ids))
|
||||
)
|
||||
e.args = (msg,) + e.args[1:]
|
||||
raise e
|
||||
else:
|
||||
result.add_results(child_result)
|
||||
self.resource_infos.append(child_result.resource_info)
|
||||
remaining_ids.discard(child_result.test_id)
|
||||
if child_result.shouldStop:
|
||||
result.shouldStop = True
|
||||
return
|
||||
|
||||
def run(self, test):
|
||||
self._ptests, self._stests = _split_nonparallel_tests(test,
|
||||
self.useslice)
|
||||
print("Parallel: %s. Serial: %s" % (len(self._ptests),
|
||||
len(self._stests)))
|
||||
# This will call self._run_inner() on the created result object,
|
||||
# and print out the detailed test results at the end.
|
||||
return super(ParallelTestRunner, self).run(self._run_inner)
|
||||
@@ -0,0 +1,171 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from ipykernel.tests import utils
|
||||
from nbformat.converter import convert
|
||||
from nbformat.reader import reads
|
||||
|
||||
import re
|
||||
import json
|
||||
from copy import copy
|
||||
import unittest
|
||||
|
||||
try:
|
||||
# py3
|
||||
from queue import Empty
|
||||
|
||||
def isstr(s):
|
||||
return isinstance(s, str)
|
||||
except ImportError:
|
||||
# py2
|
||||
from Queue import Empty
|
||||
|
||||
def isstr(s):
|
||||
return isinstance(s, basestring) # noqa
|
||||
|
||||
class NotebookTest(TestCase):
|
||||
"""Validate a notebook. All code cells are executed in order. The output is either checked
|
||||
for errors (if no reference output is present), or is compared against expected output.
|
||||
|
||||
|
||||
Useful references:
|
||||
http://nbformat.readthedocs.org/en/latest/format_description.html
|
||||
http://jupyter-client.readthedocs.org/en/latest/messaging.html
|
||||
"""
|
||||
|
||||
|
||||
IGNORE_TYPES = ["execute_request", "execute_input", "status", "pyin"]
|
||||
STRIP_KEYS = ["execution_count", "traceback", "prompt_number", "source"]
|
||||
NBFORMAT_VERSION = 4
|
||||
|
||||
def _test_notebook(self, notebook, test):
|
||||
|
||||
with open(notebook) as f:
|
||||
nb = convert(reads(f.read()), self.NBFORMAT_VERSION)
|
||||
_, kernel = utils.start_new_kernel()
|
||||
for i, c in enumerate([c for c in nb.cells if c.cell_type == 'code']):
|
||||
self._test_notebook_cell(self.sanitize_cell(c), i, kernel, test)
|
||||
|
||||
def _test_notebook_cell(self, cell, i, kernel, test):
|
||||
|
||||
if hasattr(cell, 'source'): # nbformat 4.0 and later
|
||||
code = cell.source
|
||||
else:
|
||||
code = cell.input
|
||||
iopub = kernel.iopub_channel
|
||||
kernel.execute(code)
|
||||
outputs = []
|
||||
msg = None
|
||||
no_error = True
|
||||
first_error = -1
|
||||
error_msg = ''
|
||||
while self.should_continue(msg):
|
||||
try:
|
||||
msg = iopub.get_msg(block=True, timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
if msg['msg_type'] not in self.IGNORE_TYPES:
|
||||
if msg['msg_type'] == 'error':
|
||||
error_msg = ' ' + msg['content']['ename'] + '\n ' + msg['content']['evalue']
|
||||
no_error = False
|
||||
if first_error == -1:
|
||||
first_error = i
|
||||
i = len(outputs)
|
||||
expected = i < len(cell.outputs) and cell.outputs[i] or []
|
||||
o = self.transform_message(msg, expected)
|
||||
outputs.append(o)
|
||||
|
||||
if (test == 'check_error'):
|
||||
self.assertTrue(no_error, 'Executing cell %d resulted in an error:\n%s'%(first_error, error_msg))
|
||||
else:
|
||||
# Compare computed output against stored output.
|
||||
# TODO: This doesn't work right now as the generated output is too diverse to
|
||||
# be verifiable.
|
||||
scrub = lambda x: self.dump_canonical(list(self.scrub_outputs(x)))
|
||||
scrubbed = scrub(outputs)
|
||||
expected = scrub(cell.outputs)
|
||||
#print('output=%s'%outputs)
|
||||
#print('expected=%s'%expected)
|
||||
#self.assertEqual(scrubbed, expected, "\n{}\n\n{}".format(scrubbed, expected))
|
||||
|
||||
def dump_canonical(self, obj):
|
||||
return json.dumps(obj, indent=2, sort_keys=True)
|
||||
|
||||
def scrub_outputs(self, outputs):
|
||||
"""
|
||||
remove all scrubs from output data and text
|
||||
"""
|
||||
for output in outputs:
|
||||
out = copy(output)
|
||||
|
||||
for scrub, sub in []:#self.scrubs.items():
|
||||
def _scrubLines(lines):
|
||||
if isstr(lines):
|
||||
return re.sub(scrub, sub, lines)
|
||||
else:
|
||||
return [re.sub(scrub, sub, line) for line in lines]
|
||||
|
||||
if "text" in out:
|
||||
out["text"] = _scrubLines(out["text"])
|
||||
|
||||
if "data" in out:
|
||||
if isinstance(out["data"], dict):
|
||||
for mime, data in out["data"].items():
|
||||
out["data"][mime] = _scrubLines(data)
|
||||
else:
|
||||
out["data"] = _scrubLines(out["data"])
|
||||
yield out
|
||||
|
||||
def strip_keys(self, d):
|
||||
"""
|
||||
remove keys from STRIP_KEYS to ensure comparability
|
||||
"""
|
||||
for key in self.STRIP_KEYS:
|
||||
d.pop(key, None)
|
||||
return d
|
||||
|
||||
def sanitize_cell(self, cell):
|
||||
"""
|
||||
remove non-reproducible things
|
||||
"""
|
||||
for output in cell.outputs:
|
||||
self.strip_keys(output)
|
||||
return cell
|
||||
|
||||
def transform_message(self, msg, expected):
|
||||
"""
|
||||
transform a message into something like the notebook
|
||||
"""
|
||||
SWAP_KEYS = {
|
||||
"output_type": {
|
||||
"pyout": "execute_result",
|
||||
"pyerr": "error"
|
||||
}
|
||||
}
|
||||
|
||||
output = {
|
||||
u"output_type": msg["msg_type"]
|
||||
}
|
||||
output.update(msg["content"])
|
||||
|
||||
output = self.strip_keys(output)
|
||||
for key, swaps in SWAP_KEYS.items():
|
||||
if key in output and output[key] in swaps:
|
||||
output[key] = swaps[output[key]]
|
||||
|
||||
if "data" in output and "data" not in expected:
|
||||
output["text"] = output["data"]
|
||||
del output["data"]
|
||||
|
||||
return output
|
||||
|
||||
def should_continue(self, msg):
|
||||
"""
|
||||
determine whether the current message is the last for this cell
|
||||
"""
|
||||
if msg is None:
|
||||
return True
|
||||
|
||||
return not (msg["msg_type"] == "status" and
|
||||
msg["content"]["execution_state"] == "idle")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user