Videre
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
# A port of https://github.com/python/cpython/blob/e42b7051/Lib/heapq.py
|
||||
|
||||
|
||||
import heapq as hq
|
||||
|
||||
from numba.core import types
|
||||
from numba.core.errors import TypingError
|
||||
from numba.core.extending import overload, register_jitable
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftdown(heap, startpos, pos):
|
||||
newitem = heap[pos]
|
||||
|
||||
while pos > startpos:
|
||||
parentpos = (pos - 1) >> 1
|
||||
parent = heap[parentpos]
|
||||
if newitem < parent:
|
||||
heap[pos] = parent
|
||||
pos = parentpos
|
||||
continue
|
||||
break
|
||||
|
||||
heap[pos] = newitem
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftup(heap, pos):
|
||||
endpos = len(heap)
|
||||
startpos = pos
|
||||
newitem = heap[pos]
|
||||
|
||||
childpos = 2 * pos + 1
|
||||
while childpos < endpos:
|
||||
|
||||
rightpos = childpos + 1
|
||||
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
|
||||
childpos = rightpos
|
||||
|
||||
heap[pos] = heap[childpos]
|
||||
pos = childpos
|
||||
childpos = 2 * pos + 1
|
||||
|
||||
heap[pos] = newitem
|
||||
_siftdown(heap, startpos, pos)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftdown_max(heap, startpos, pos):
|
||||
newitem = heap[pos]
|
||||
|
||||
while pos > startpos:
|
||||
parentpos = (pos - 1) >> 1
|
||||
parent = heap[parentpos]
|
||||
if parent < newitem:
|
||||
heap[pos] = parent
|
||||
pos = parentpos
|
||||
continue
|
||||
break
|
||||
heap[pos] = newitem
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _siftup_max(heap, pos):
|
||||
endpos = len(heap)
|
||||
startpos = pos
|
||||
newitem = heap[pos]
|
||||
|
||||
childpos = 2 * pos + 1
|
||||
while childpos < endpos:
|
||||
|
||||
rightpos = childpos + 1
|
||||
if rightpos < endpos and not heap[rightpos] < heap[childpos]:
|
||||
childpos = rightpos
|
||||
|
||||
heap[pos] = heap[childpos]
|
||||
pos = childpos
|
||||
childpos = 2 * pos + 1
|
||||
|
||||
heap[pos] = newitem
|
||||
_siftdown_max(heap, startpos, pos)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def reversed_range(x):
|
||||
# analogous to reversed(range(x))
|
||||
return range(x - 1, -1, -1)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _heapify_max(x):
|
||||
n = len(x)
|
||||
|
||||
for i in reversed_range(n // 2):
|
||||
_siftup_max(x, i)
|
||||
|
||||
|
||||
@register_jitable
|
||||
def _heapreplace_max(heap, item):
|
||||
returnitem = heap[0]
|
||||
heap[0] = item
|
||||
_siftup_max(heap, 0)
|
||||
return returnitem
|
||||
|
||||
|
||||
def assert_heap_type(heap):
|
||||
if not isinstance(heap, (types.List, types.ListType)):
|
||||
raise TypingError('heap argument must be a list')
|
||||
|
||||
dt = heap.dtype
|
||||
if isinstance(dt, types.Complex):
|
||||
msg = ("'<' not supported between instances "
|
||||
"of 'complex' and 'complex'")
|
||||
raise TypingError(msg)
|
||||
|
||||
|
||||
def assert_item_type_consistent_with_heap_type(heap, item):
|
||||
if not heap.dtype == item:
|
||||
raise TypingError('heap type must be the same as item type')
|
||||
|
||||
|
||||
@overload(hq.heapify)
|
||||
def hq_heapify(x):
|
||||
assert_heap_type(x)
|
||||
|
||||
def hq_heapify_impl(x):
|
||||
n = len(x)
|
||||
for i in reversed_range(n // 2):
|
||||
_siftup(x, i)
|
||||
|
||||
return hq_heapify_impl
|
||||
|
||||
|
||||
@overload(hq.heappop)
|
||||
def hq_heappop(heap):
|
||||
assert_heap_type(heap)
|
||||
|
||||
def hq_heappop_impl(heap):
|
||||
lastelt = heap.pop()
|
||||
if heap:
|
||||
returnitem = heap[0]
|
||||
heap[0] = lastelt
|
||||
_siftup(heap, 0)
|
||||
return returnitem
|
||||
return lastelt
|
||||
|
||||
return hq_heappop_impl
|
||||
|
||||
|
||||
@overload(hq.heappush)
|
||||
def heappush(heap, item):
|
||||
assert_heap_type(heap)
|
||||
assert_item_type_consistent_with_heap_type(heap, item)
|
||||
|
||||
def hq_heappush_impl(heap, item):
|
||||
heap.append(item)
|
||||
_siftdown(heap, 0, len(heap) - 1)
|
||||
|
||||
return hq_heappush_impl
|
||||
|
||||
|
||||
@overload(hq.heapreplace)
|
||||
def heapreplace(heap, item):
|
||||
assert_heap_type(heap)
|
||||
assert_item_type_consistent_with_heap_type(heap, item)
|
||||
|
||||
def hq_heapreplace(heap, item):
|
||||
returnitem = heap[0]
|
||||
heap[0] = item
|
||||
_siftup(heap, 0)
|
||||
return returnitem
|
||||
|
||||
return hq_heapreplace
|
||||
|
||||
|
||||
@overload(hq.heappushpop)
|
||||
def heappushpop(heap, item):
|
||||
assert_heap_type(heap)
|
||||
assert_item_type_consistent_with_heap_type(heap, item)
|
||||
|
||||
def hq_heappushpop_impl(heap, item):
|
||||
if heap and heap[0] < item:
|
||||
item, heap[0] = heap[0], item
|
||||
_siftup(heap, 0)
|
||||
return item
|
||||
|
||||
return hq_heappushpop_impl
|
||||
|
||||
|
||||
def check_input_types(n, iterable):
|
||||
|
||||
if not isinstance(n, (types.Integer, types.Boolean)):
|
||||
raise TypingError("First argument 'n' must be an integer")
|
||||
# heapq also accepts 1.0 (but not 0.0, 2.0, 3.0...) but
|
||||
# this isn't replicated
|
||||
|
||||
if not isinstance(iterable, (types.Sequence, types.Array, types.ListType)):
|
||||
raise TypingError("Second argument 'iterable' must be iterable")
|
||||
|
||||
|
||||
@overload(hq.nsmallest)
|
||||
def nsmallest(n, iterable):
|
||||
check_input_types(n, iterable)
|
||||
|
||||
def hq_nsmallest_impl(n, iterable):
|
||||
|
||||
if n == 0:
|
||||
return [iterable[0] for _ in range(0)]
|
||||
elif n == 1:
|
||||
out = min(iterable)
|
||||
return [out]
|
||||
|
||||
size = len(iterable)
|
||||
if n >= size:
|
||||
return sorted(iterable)[:n]
|
||||
|
||||
it = iter(iterable)
|
||||
result = [(elem, i) for i, elem in zip(range(n), it)]
|
||||
|
||||
_heapify_max(result)
|
||||
top = result[0][0]
|
||||
order = n
|
||||
|
||||
for elem in it:
|
||||
if elem < top:
|
||||
_heapreplace_max(result, (elem, order))
|
||||
top, _order = result[0]
|
||||
order += 1
|
||||
result.sort()
|
||||
return [elem for (elem, order) in result]
|
||||
|
||||
return hq_nsmallest_impl
|
||||
|
||||
|
||||
@overload(hq.nlargest)
|
||||
def nlargest(n, iterable):
|
||||
check_input_types(n, iterable)
|
||||
|
||||
def hq_nlargest_impl(n, iterable):
|
||||
|
||||
if n == 0:
|
||||
return [iterable[0] for _ in range(0)]
|
||||
elif n == 1:
|
||||
out = max(iterable)
|
||||
return [out]
|
||||
|
||||
size = len(iterable)
|
||||
if n >= size:
|
||||
return sorted(iterable)[::-1][:n]
|
||||
|
||||
it = iter(iterable)
|
||||
result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
|
||||
|
||||
hq.heapify(result)
|
||||
top = result[0][0]
|
||||
order = -n
|
||||
|
||||
for elem in it:
|
||||
if top < elem:
|
||||
hq.heapreplace(result, (elem, order))
|
||||
top, _order = result[0]
|
||||
order -= 1
|
||||
result.sort(reverse=True)
|
||||
return [elem for (elem, order) in result]
|
||||
|
||||
return hq_nlargest_impl
|
||||
Reference in New Issue
Block a user