Videre
This commit is contained in:
@@ -0,0 +1,252 @@
|
||||
import re
|
||||
from io import StringIO
|
||||
|
||||
import numba
|
||||
from numba.core import types
|
||||
from numba import jit, njit
|
||||
from numba.tests.support import override_config, TestCase
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError:
|
||||
jinja2 = None
|
||||
|
||||
try:
|
||||
import pygments
|
||||
except ImportError:
|
||||
pygments = None
|
||||
|
||||
|
||||
@unittest.skipIf(jinja2 is None, "please install the 'jinja2' package")
|
||||
class TestAnnotation(TestCase):
|
||||
|
||||
@TestCase.run_test_in_subprocess # annotations compound per module
|
||||
def test_exercise_code_path(self):
|
||||
"""
|
||||
Ensures template.html is available
|
||||
"""
|
||||
|
||||
def foo(n, a):
|
||||
s = a
|
||||
for i in range(n):
|
||||
s += i
|
||||
return s
|
||||
|
||||
cfunc = njit((types.int32, types.int32))(foo)
|
||||
cres = cfunc.overloads[cfunc.signatures[0]]
|
||||
ta = cres.type_annotation
|
||||
|
||||
buf = StringIO()
|
||||
ta.html_annotate(buf)
|
||||
output = buf.getvalue()
|
||||
buf.close()
|
||||
self.assertIn("foo", output)
|
||||
|
||||
@TestCase.run_test_in_subprocess # annotations compound per module
|
||||
def test_exercise_code_path_with_lifted_loop(self):
|
||||
"""
|
||||
Ensures that lifted loops are handled correctly in obj mode
|
||||
"""
|
||||
# the functions to jit
|
||||
def bar(x):
|
||||
return x
|
||||
|
||||
def foo(x):
|
||||
h = 0.
|
||||
for i in range(x): # py 38 needs two loops for one to lift?!
|
||||
h = h + i
|
||||
for k in range(x):
|
||||
h = h + k
|
||||
if x:
|
||||
h = h - bar(x)
|
||||
return h
|
||||
|
||||
# compile into an isolated context
|
||||
cfunc = jit((types.intp,), forceobj=True, looplift=True)(foo)
|
||||
cres = cfunc.overloads[cfunc.signatures[0]]
|
||||
|
||||
ta = cres.type_annotation
|
||||
|
||||
buf = StringIO()
|
||||
ta.html_annotate(buf)
|
||||
output = buf.getvalue()
|
||||
buf.close()
|
||||
self.assertIn("bar", output)
|
||||
self.assertIn("foo", output)
|
||||
self.assertIn("LiftedLoop", output)
|
||||
|
||||
@TestCase.run_test_in_subprocess # annotations compound per module
|
||||
def test_html_output_with_lifted_loop(self):
|
||||
"""
|
||||
Test some format and behavior of the html annotation with lifted loop
|
||||
"""
|
||||
@numba.jit(forceobj=True)
|
||||
def udt(x):
|
||||
object() # to force object mode
|
||||
z = 0
|
||||
for i in range(x): # this line is tagged
|
||||
z += i
|
||||
return z
|
||||
|
||||
# Regex pattern to check for the "lifted_tag" in the line of the loop
|
||||
re_lifted_tag = re.compile(
|
||||
r'<td class="lifted_tag">\s*'
|
||||
r'\s*<details>'
|
||||
r'\s*<summary>'
|
||||
r'\s*<code>'
|
||||
r'\s*[0-9]+:'
|
||||
r'\s*[ ]+for i in range\(x\): # this line is tagged\s*',
|
||||
re.MULTILINE)
|
||||
|
||||
# Compile int64 version
|
||||
sig_i64 = (types.int64,)
|
||||
udt.compile(sig_i64) # compile with lifted loop
|
||||
cres = udt.overloads[sig_i64]
|
||||
|
||||
# Make html output
|
||||
buf = StringIO()
|
||||
cres.type_annotation.html_annotate(buf)
|
||||
output = buf.getvalue()
|
||||
buf.close()
|
||||
|
||||
# There should be only one function output.
|
||||
self.assertEqual(output.count("Function name: udt"), 1)
|
||||
|
||||
sigfmt = "with signature: {} -> pyobject"
|
||||
self.assertEqual(output.count(sigfmt.format(sig_i64)), 1)
|
||||
# Ensure the loop is tagged
|
||||
self.assertEqual(len(re.findall(re_lifted_tag, output)), 1,
|
||||
msg='%s not found in %s' % (re_lifted_tag, output))
|
||||
|
||||
# Compile float64 version
|
||||
sig_f64 = (types.float64,)
|
||||
udt.compile(sig_f64)
|
||||
cres = udt.overloads[sig_f64]
|
||||
|
||||
# Make html output
|
||||
buf = StringIO()
|
||||
cres.type_annotation.html_annotate(buf)
|
||||
output = buf.getvalue()
|
||||
buf.close()
|
||||
|
||||
# There should be two function output
|
||||
self.assertEqual(output.count("Function name: udt"), 2)
|
||||
self.assertEqual(output.count(sigfmt.format(sig_i64)), 1)
|
||||
self.assertEqual(output.count(sigfmt.format(sig_f64)), 1)
|
||||
# Ensure the loop is tagged in both output
|
||||
self.assertEqual(len(re.findall(re_lifted_tag, output)), 2)
|
||||
|
||||
@unittest.skipIf(pygments is None, "please install the 'pygments' package")
|
||||
def test_pretty_print(self):
|
||||
|
||||
@numba.njit
|
||||
def foo(x, y):
|
||||
return x, y
|
||||
|
||||
foo(1, 2)
|
||||
# Exercise the method
|
||||
foo.inspect_types(pretty=True)
|
||||
|
||||
# Exercise but supply a not None file kwarg, this is invalid
|
||||
with self.assertRaises(ValueError) as raises:
|
||||
foo.inspect_types(pretty=True, file='should be None')
|
||||
self.assertIn('`file` must be None if `pretty=True`',
|
||||
str(raises.exception))
|
||||
|
||||
|
||||
class TestTypeAnnotation(unittest.TestCase):
|
||||
|
||||
def findpatloc(self, lines, pat):
|
||||
for i, ln in enumerate(lines):
|
||||
if pat in ln:
|
||||
return i
|
||||
raise ValueError("can't find {!r}".format(pat))
|
||||
|
||||
def getlines(self, func):
|
||||
strbuf = StringIO()
|
||||
func.inspect_types(strbuf)
|
||||
return strbuf.getvalue().splitlines()
|
||||
|
||||
def test_delete(self):
|
||||
@numba.njit
|
||||
def foo(appleorange, berrycherry):
|
||||
return appleorange + berrycherry
|
||||
|
||||
foo(1, 2)
|
||||
|
||||
lines = self.getlines(foo)
|
||||
|
||||
# Ensure deletion show up after their use
|
||||
sa = self.findpatloc(lines, 'appleorange = arg(0, name=appleorange)')
|
||||
sb = self.findpatloc(lines, 'berrycherry = arg(1, name=berrycherry)')
|
||||
|
||||
ea = self.findpatloc(lines, 'del appleorange')
|
||||
eb = self.findpatloc(lines, 'del berrycherry')
|
||||
|
||||
self.assertLess(sa, ea)
|
||||
self.assertLess(sb, eb)
|
||||
|
||||
def _lifetimes_impl(self, extend):
|
||||
with override_config('EXTEND_VARIABLE_LIFETIMES', extend):
|
||||
@njit
|
||||
def foo(a):
|
||||
b = a
|
||||
return b
|
||||
x = 10
|
||||
b = foo(x)
|
||||
self.assertEqual(b, x)
|
||||
|
||||
lines = self.getlines(foo)
|
||||
|
||||
sa = self.findpatloc(lines, 'a = arg(0, name=a)')
|
||||
sb = self.findpatloc(lines, 'b = a')
|
||||
|
||||
cast_ret = self.findpatloc(lines, 'cast(value=b)')
|
||||
|
||||
dela = self.findpatloc(lines, 'del a')
|
||||
delb = self.findpatloc(lines, 'del b')
|
||||
|
||||
return sa, sb, cast_ret, dela, delb
|
||||
|
||||
def test_delete_standard_lifetimes(self):
|
||||
# without extended lifetimes, dels occur as soon as dead
|
||||
#
|
||||
# label 0
|
||||
# a = arg(0, name=a) :: int64
|
||||
# b = a :: int64
|
||||
# del a
|
||||
# $8return_value.2 = cast(value=b) :: int64
|
||||
# del b
|
||||
# return $8return_value.2
|
||||
|
||||
sa, sb, cast_ret, dela, delb = self._lifetimes_impl(extend=0)
|
||||
|
||||
self.assertLess(sa, dela)
|
||||
self.assertLess(sb, delb)
|
||||
# del a is before cast and del b is after
|
||||
self.assertLess(dela, cast_ret)
|
||||
self.assertGreater(delb, cast_ret)
|
||||
|
||||
def test_delete_extended_lifetimes(self):
|
||||
# with extended lifetimes, dels are last in block:
|
||||
#
|
||||
# label 0
|
||||
# a = arg(0, name=a) :: int64
|
||||
# b = a :: int64
|
||||
# $8return_value.2 = cast(value=b) :: int64
|
||||
# del a
|
||||
# del b
|
||||
# return $8return_value.2
|
||||
|
||||
sa, sb, cast_ret, dela, delb = self._lifetimes_impl(extend=1)
|
||||
|
||||
self.assertLess(sa, dela)
|
||||
self.assertLess(sb, delb)
|
||||
# dels are after the cast
|
||||
self.assertGreater(dela, cast_ret)
|
||||
self.assertGreater(delb, cast_ret)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user