Videre
This commit is contained in:
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Contains tests and a prototype implementation for the fanout algorithm in
|
||||
the LLVM refprune pass.
|
||||
"""
|
||||
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
pass
|
||||
from collections import defaultdict
|
||||
|
||||
# The entry block. It's always the same.
|
||||
ENTRY = "A"
|
||||
|
||||
|
||||
# The following caseNN() functions returns a 3-tuple of
|
||||
# (nodes, edges, expected).
|
||||
# `nodes` maps BB nodes to incref/decref inside the block.
|
||||
# `edges` maps BB nodes to their successor BB.
|
||||
# `expected` maps BB-node with incref to a set of BB-nodes with the decrefs, or
|
||||
# the value can be None, indicating invalid prune.
|
||||
|
||||
def case1():
|
||||
edges = {
|
||||
"A": ["B"],
|
||||
"B": ["C", "D"],
|
||||
"C": [],
|
||||
"D": ["E", "F"],
|
||||
"E": ["G"],
|
||||
"F": [],
|
||||
"G": ["H", "I"],
|
||||
"I": ["G", "F"],
|
||||
"H": ["J", "K"],
|
||||
"J": ["L", "M"],
|
||||
"K": [],
|
||||
"L": ["Z"],
|
||||
"M": ["Z", "O", "P"],
|
||||
"O": ["Z"],
|
||||
"P": ["Z"],
|
||||
"Z": [],
|
||||
}
|
||||
nodes = defaultdict(list)
|
||||
nodes["D"] = ["incref"]
|
||||
nodes["H"] = ["decref"]
|
||||
nodes["F"] = ["decref", "decref"]
|
||||
expected = {"D": {"H", "F"}}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case2():
|
||||
edges = {
|
||||
"A": ["B", "C"],
|
||||
"B": ["C"],
|
||||
"C": [],
|
||||
}
|
||||
nodes = defaultdict(list)
|
||||
nodes["A"] = ["incref"]
|
||||
nodes["B"] = ["decref"]
|
||||
nodes["C"] = ["decref"]
|
||||
expected = {"A": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case3():
|
||||
nodes, edges, _ = case1()
|
||||
# adds an invalid edge
|
||||
edges["H"].append("F")
|
||||
expected = {"D": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case4():
|
||||
nodes, edges, _ = case1()
|
||||
# adds an invalid edge
|
||||
edges["H"].append("E")
|
||||
expected = {"D": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case5():
|
||||
nodes, edges, _ = case1()
|
||||
# adds backedge to go before incref
|
||||
edges["B"].append("I")
|
||||
expected = {"D": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case6():
|
||||
nodes, edges, _ = case1()
|
||||
# adds backedge to go before incref
|
||||
edges["I"].append("B")
|
||||
expected = {"D": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case7():
|
||||
nodes, edges, _ = case1()
|
||||
# adds forward jump outside
|
||||
edges["I"].append("M")
|
||||
expected = {"D": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case8():
|
||||
edges = {
|
||||
"entry:": ["A"],
|
||||
"A": ["B", "C"],
|
||||
"B": ["C"],
|
||||
"C": [],
|
||||
}
|
||||
nodes = defaultdict(list)
|
||||
nodes["A"] = ["incref"]
|
||||
nodes["C"] = ["decref"]
|
||||
expected = {"A": {"C"}}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case9():
|
||||
nodes, edges, _ = case8()
|
||||
# adds back edge
|
||||
edges["C"].append("B")
|
||||
expected = {"A": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case10():
|
||||
nodes, edges, _ = case8()
|
||||
# adds back edge to A
|
||||
edges["C"].append("A")
|
||||
expected = {"A": {"C"}}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case11():
|
||||
nodes, edges, _ = case8()
|
||||
edges["C"].append("D")
|
||||
edges["D"] = []
|
||||
expected = {"A": {"C"}}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case12():
|
||||
nodes, edges, _ = case8()
|
||||
edges["C"].append("D")
|
||||
edges["D"] = ["A"]
|
||||
expected = {"A": {"C"}}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def case13():
|
||||
nodes, edges, _ = case8()
|
||||
edges["C"].append("D")
|
||||
edges["D"] = ["B"]
|
||||
expected = {"A": None}
|
||||
return nodes, edges, expected
|
||||
|
||||
|
||||
def make_predecessor_map(edges):
|
||||
d = defaultdict(set)
|
||||
for src, outgoings in edges.items():
|
||||
for dst in outgoings:
|
||||
d[dst].add(src)
|
||||
return d
|
||||
|
||||
|
||||
class FanoutAlgorithm:
|
||||
def __init__(self, nodes, edges, verbose=False):
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.rev_edges = make_predecessor_map(edges)
|
||||
self.print = print if verbose else self._null_print
|
||||
|
||||
def run(self):
|
||||
return self.find_fanout_in_function()
|
||||
|
||||
def _null_print(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def find_fanout_in_function(self):
|
||||
got = {}
|
||||
for cur_node in self.edges:
|
||||
for incref in (x for x in self.nodes[cur_node] if x == "incref"):
|
||||
decref_blocks = self.find_fanout(cur_node)
|
||||
self.print(">>", cur_node, "===", decref_blocks)
|
||||
got[cur_node] = decref_blocks
|
||||
return got
|
||||
|
||||
def find_fanout(self, head_node):
|
||||
decref_blocks = self.find_decref_candidates(head_node)
|
||||
self.print("candidates", decref_blocks)
|
||||
if not decref_blocks:
|
||||
return None
|
||||
if not self.verify_non_overlapping(
|
||||
head_node, decref_blocks, entry=ENTRY
|
||||
):
|
||||
return None
|
||||
return set(decref_blocks)
|
||||
|
||||
def verify_non_overlapping(self, head_node, decref_blocks, entry):
|
||||
self.print("verify_non_overlapping".center(80, "-"))
|
||||
# reverse walk for each decref_blocks
|
||||
# they should end at head_node
|
||||
todo = list(decref_blocks)
|
||||
while todo:
|
||||
cur_node = todo.pop()
|
||||
visited = set()
|
||||
|
||||
workstack = [cur_node]
|
||||
del cur_node
|
||||
while workstack:
|
||||
cur_node = workstack.pop()
|
||||
self.print("cur_node", cur_node, "|", workstack)
|
||||
if cur_node in visited:
|
||||
continue # skip
|
||||
if cur_node == entry:
|
||||
# Entry node
|
||||
self.print(
|
||||
"!! failed because we arrived at entry", cur_node
|
||||
)
|
||||
return False
|
||||
visited.add(cur_node)
|
||||
# check all predecessors
|
||||
self.print(
|
||||
f" {cur_node} preds {self.get_predecessors(cur_node)}"
|
||||
)
|
||||
for pred in self.get_predecessors(cur_node):
|
||||
if pred in decref_blocks:
|
||||
# reject because there's a predecessor in decref_blocks
|
||||
self.print(
|
||||
"!! reject because predecessor in decref_blocks"
|
||||
)
|
||||
return False
|
||||
if pred != head_node:
|
||||
|
||||
workstack.append(pred)
|
||||
|
||||
return True
|
||||
|
||||
def get_successors(self, node):
|
||||
return tuple(self.edges[node])
|
||||
|
||||
def get_predecessors(self, node):
|
||||
return tuple(self.rev_edges[node])
|
||||
|
||||
def has_decref(self, node):
|
||||
return "decref" in self.nodes[node]
|
||||
|
||||
def walk_child_for_decref(
|
||||
self, cur_node, path_stack, decref_blocks, depth=10
|
||||
):
|
||||
indent = " " * len(path_stack)
|
||||
self.print(indent, "walk", path_stack, cur_node)
|
||||
if depth <= 0:
|
||||
return False # missing
|
||||
if cur_node in path_stack:
|
||||
if cur_node == path_stack[0]:
|
||||
return False # reject interior node backedge
|
||||
return True # skip
|
||||
if self.has_decref(cur_node):
|
||||
decref_blocks.add(cur_node)
|
||||
self.print(indent, "found decref")
|
||||
return True
|
||||
|
||||
depth -= 1
|
||||
path_stack += (cur_node,)
|
||||
found = False
|
||||
for child in self.get_successors(cur_node):
|
||||
if not self.walk_child_for_decref(
|
||||
child, path_stack, decref_blocks
|
||||
):
|
||||
found = False
|
||||
break
|
||||
else:
|
||||
found = True
|
||||
|
||||
self.print(indent, f"ret {found}")
|
||||
return found
|
||||
|
||||
def find_decref_candidates(self, cur_node):
|
||||
# Forward pass
|
||||
self.print("find_decref_candidates".center(80, "-"))
|
||||
path_stack = (cur_node,)
|
||||
found = False
|
||||
decref_blocks = set()
|
||||
for child in self.get_successors(cur_node):
|
||||
if not self.walk_child_for_decref(
|
||||
child, path_stack, decref_blocks
|
||||
):
|
||||
found = False
|
||||
break
|
||||
else:
|
||||
found = True
|
||||
if not found:
|
||||
return set()
|
||||
else:
|
||||
return decref_blocks
|
||||
|
||||
|
||||
def check_once():
|
||||
nodes, edges, expected = case13()
|
||||
|
||||
# Render graph
|
||||
G = Digraph()
|
||||
for node in edges:
|
||||
G.node(node, shape="rect", label=f"{node}\n" + r"\l".join(nodes[node]))
|
||||
for node, children in edges.items():
|
||||
for child in children:
|
||||
G.edge(node, child)
|
||||
|
||||
G.view()
|
||||
|
||||
algo = FanoutAlgorithm(nodes, edges, verbose=True)
|
||||
got = algo.run()
|
||||
assert expected == got
|
||||
|
||||
|
||||
def check_all():
|
||||
for k, fn in list(globals().items()):
|
||||
if k.startswith("case"):
|
||||
print(f"{fn}".center(80, "-"))
|
||||
nodes, edges, expected = fn()
|
||||
algo = FanoutAlgorithm(nodes, edges)
|
||||
got = algo.run()
|
||||
assert expected == got
|
||||
print("ALL PASSED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# check_once()
|
||||
check_all()
|
||||
Reference in New Issue
Block a user