mirror of https://github.com/zeldaret/mm.git
506 lines
16 KiB
Python
506 lines
16 KiB
Python
from base64 import b64decode
|
|
from collections import defaultdict
|
|
import copy
|
|
from dataclasses import dataclass
|
|
from random import Random
|
|
import re
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
|
|
|
from pycparser import CParser, c_ast as ca, c_generator
|
|
from pycparser.plyparser import ParseError
|
|
|
|
from .error import CandidateConstructionFailure
|
|
from .ast_types import SimpleType, set_decl_name
|
|
|
|
|
|
@dataclass
|
|
class Indices:
|
|
starts: Dict[ca.Node, int]
|
|
ends: Dict[ca.Node, int]
|
|
|
|
|
|
Block = Union[ca.Compound, ca.Case, ca.Default]
|
|
if TYPE_CHECKING:
|
|
# ca.Expression and ca.Statement don't actually exist, they live only in
|
|
# the stubs file.
|
|
Expression = ca.Expression
|
|
Statement = ca.Statement
|
|
else:
|
|
Expression = Statement = None
|
|
|
|
|
|
def to_c_raw(node: ca.Node) -> str:
|
|
source: str = c_generator.CGenerator().visit(node)
|
|
return source
|
|
|
|
|
|
def to_c(node: ca.Node, *, from_import: bool = False) -> str:
|
|
source = to_c_raw(node) if from_import else PatchedCGenerator().visit(node)
|
|
return process_pragmas(source)
|
|
|
|
|
|
def process_pragmas(source: str) -> str:
|
|
if "#pragma" not in source:
|
|
return source
|
|
lines = source.split("\n")
|
|
out: List[str] = []
|
|
same_line = 0
|
|
ignore = 0
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
if stripped.startswith("#pragma _permuter "):
|
|
# Expand permuter pragmas to nothing, by default. Still, keep one
|
|
# output line per input line to preserve line numbers for import.py
|
|
# error messages.
|
|
line = ""
|
|
|
|
stripped = stripped[len("#pragma _permuter ") :]
|
|
if stripped == "sameline start":
|
|
same_line += 1
|
|
elif stripped == "sameline end":
|
|
same_line -= 1
|
|
elif stripped == "latedefine start":
|
|
ignore += 1
|
|
elif stripped == "latedefine end":
|
|
assert ignore > 0, "mismatched ignore pragmas"
|
|
ignore -= 1
|
|
elif stripped.startswith("define "):
|
|
assert ignore > 0, "define pragma must be within latedefine block"
|
|
line = "#" + stripped
|
|
elif stripped.startswith("b64literal "):
|
|
line = b64decode(stripped.split(" ", 1)[1]).decode("utf-8")
|
|
elif ignore > 0:
|
|
# Ignore non-pragma lines within latedefine section
|
|
line = ""
|
|
|
|
if not same_line:
|
|
line += "\n"
|
|
elif line and out and not out[-1].endswith("\n"):
|
|
line = " " + line.lstrip()
|
|
out.append(line)
|
|
assert same_line == 0
|
|
assert ignore == 0, "unbalanced ignore pragmas"
|
|
return "".join(out).rstrip() + "\n"
|
|
|
|
|
|
class PatchedCGenerator(c_generator.CGenerator):
|
|
"""Like a CGenerator, except it keeps else if's prettier despite
|
|
the terrible things we've done to them in normalize_ast."""
|
|
|
|
def visit_If(self, n: ca.If) -> str:
|
|
n2 = n
|
|
if (
|
|
n.iffalse
|
|
and isinstance(n.iffalse, ca.Compound)
|
|
and n.iffalse.block_items
|
|
and len(n.iffalse.block_items) == 1
|
|
and isinstance(n.iffalse.block_items[0], ca.If)
|
|
):
|
|
n2 = ca.If(cond=n.cond, iftrue=n.iftrue, iffalse=n.iffalse.block_items[0])
|
|
return super().visit_If(n2) # type: ignore
|
|
|
|
|
|
def extract_fn(ast: ca.FileAST, fn_name: str) -> Tuple[ca.FuncDef, int]:
|
|
ret = []
|
|
for i, node in enumerate(ast.ext):
|
|
if isinstance(node, ca.FuncDef):
|
|
if node.decl.name == fn_name:
|
|
ret.append((node, i))
|
|
else:
|
|
node = node.decl
|
|
ast.ext[i] = node
|
|
if isinstance(node, ca.Decl) and isinstance(node.type, ca.FuncDecl):
|
|
node.funcspec = [spec for spec in node.funcspec if spec != "static"]
|
|
if len(ret) == 0:
|
|
raise CandidateConstructionFailure(f"Function {fn_name} not found in base.c.")
|
|
if len(ret) > 1:
|
|
raise CandidateConstructionFailure(
|
|
f"Found multiple copies of function {fn_name} in base.c."
|
|
)
|
|
return ret[0]
|
|
|
|
|
|
def parse_c(source: str, *, from_import: bool = False) -> ca.FileAST:
|
|
try:
|
|
parser = CParser()
|
|
return parser.parse(source, "<source>")
|
|
except ParseError as e:
|
|
msg = str(e)
|
|
position, msg = msg.split(": ", 1)
|
|
parts = position.split(":")
|
|
if len(parts) >= 2:
|
|
lineno = int(parts[1])
|
|
posstr = f" at approximately line {lineno}"
|
|
if len(parts) >= 3:
|
|
posstr += f", column {parts[2]}"
|
|
if not from_import:
|
|
posstr += " (after PERM expansion)"
|
|
try:
|
|
line = source.split("\n")[lineno - 1].rstrip()
|
|
posstr += "\n\n" + line
|
|
except IndexError:
|
|
posstr += "(out of bounds?)"
|
|
else:
|
|
posstr = ""
|
|
raise CandidateConstructionFailure(
|
|
f"Syntax error in base.c.\n{msg}{posstr}"
|
|
) from None
|
|
|
|
|
|
def compute_node_indices(top_node: ca.Node) -> Indices:
|
|
starts: Dict[ca.Node, int] = {}
|
|
ends: Dict[ca.Node, int] = {}
|
|
cur_index = 1
|
|
|
|
class Visitor(ca.NodeVisitor):
|
|
def generic_visit(self, node: ca.Node) -> None:
|
|
nonlocal cur_index
|
|
assert node not in starts, "nodes should only appear once in AST"
|
|
starts[node] = cur_index
|
|
cur_index += 2
|
|
super().generic_visit(node)
|
|
ends[node] = cur_index
|
|
cur_index += 2
|
|
|
|
Visitor().visit(top_node)
|
|
return Indices(starts, ends)
|
|
|
|
|
|
def equal_ast(a: ca.Node, b: ca.Node) -> bool:
|
|
def equal(a: Any, b: Any) -> bool:
|
|
if type(a) != type(b):
|
|
return False
|
|
if a is None:
|
|
return b is None
|
|
if isinstance(a, list):
|
|
assert isinstance(b, list)
|
|
if len(a) != len(b):
|
|
return False
|
|
for i in range(len(a)):
|
|
if not equal(a[i], b[i]):
|
|
return False
|
|
return True
|
|
if isinstance(a, (int, str)):
|
|
return bool(a == b)
|
|
assert isinstance(a, ca.Node)
|
|
for name in a.__slots__[:-2]: # type: ignore
|
|
if not equal(getattr(a, name), getattr(b, name)):
|
|
return False
|
|
return True
|
|
|
|
return equal(a, b)
|
|
|
|
|
|
def is_lvalue(expr: Expression) -> bool:
|
|
if isinstance(expr, (ca.ID, ca.StructRef, ca.ArrayRef)):
|
|
return True
|
|
if isinstance(expr, ca.UnaryOp):
|
|
return expr.op == "*"
|
|
return False
|
|
|
|
|
|
def is_effectful(expr: Expression) -> bool:
|
|
found = False
|
|
|
|
class Visitor(ca.NodeVisitor):
|
|
def visit_UnaryOp(self, node: ca.UnaryOp) -> None:
|
|
nonlocal found
|
|
if node.op in ["p++", "p--", "++", "--"]:
|
|
found = True
|
|
else:
|
|
self.generic_visit(node.expr)
|
|
|
|
def visit_FuncCall(self, _: ca.Node) -> None:
|
|
nonlocal found
|
|
found = True
|
|
|
|
def visit_Assignment(self, _: ca.Node) -> None:
|
|
nonlocal found
|
|
found = True
|
|
|
|
Visitor().visit(expr)
|
|
return found
|
|
|
|
|
|
def get_block_stmts(block: Block, force: bool) -> List[Statement]:
|
|
if isinstance(block, ca.Compound):
|
|
ret = block.block_items or []
|
|
if force and not block.block_items:
|
|
block.block_items = ret
|
|
else:
|
|
ret = block.stmts or []
|
|
if force and not block.stmts:
|
|
block.stmts = ret
|
|
return ret
|
|
|
|
|
|
def insert_decl(
|
|
fn: ca.FuncDef, var: str, type: SimpleType, random: Optional[Random] = None
|
|
) -> None:
|
|
type = copy.deepcopy(type)
|
|
decl = ca.Decl(
|
|
name=var, quals=[], storage=[], funcspec=[], type=type, init=None, bitsize=None
|
|
)
|
|
set_decl_name(decl)
|
|
assert fn.body.block_items, "Non-empty function"
|
|
for index, stmt in enumerate(fn.body.block_items):
|
|
if not isinstance(stmt, ca.Decl):
|
|
break
|
|
else:
|
|
index = len(fn.body.block_items)
|
|
|
|
if random:
|
|
index = random.randint(0, index)
|
|
fn.body.block_items[index:index] = [decl]
|
|
|
|
|
|
def insert_statement(block: Block, index: int, stmt: Statement) -> None:
|
|
stmts = get_block_stmts(block, True)
|
|
stmts[index:index] = [stmt]
|
|
|
|
|
|
def brace_nested_blocks(stmt: Statement) -> None:
|
|
def brace(stmt: Statement) -> Block:
|
|
if isinstance(stmt, (ca.Compound, ca.Case, ca.Default)):
|
|
return stmt
|
|
return ca.Compound([stmt])
|
|
|
|
if isinstance(stmt, (ca.For, ca.While, ca.DoWhile)):
|
|
stmt.stmt = brace(stmt.stmt)
|
|
elif isinstance(stmt, ca.If):
|
|
stmt.iftrue = brace(stmt.iftrue)
|
|
if stmt.iffalse:
|
|
stmt.iffalse = brace(stmt.iffalse)
|
|
elif isinstance(stmt, ca.Switch):
|
|
stmt.stmt = brace(stmt.stmt)
|
|
elif isinstance(stmt, ca.Label):
|
|
brace_nested_blocks(stmt.stmt)
|
|
|
|
|
|
def has_nested_block(node: ca.Node) -> bool:
|
|
return isinstance(
|
|
node,
|
|
(
|
|
ca.Compound,
|
|
ca.For,
|
|
ca.While,
|
|
ca.DoWhile,
|
|
ca.If,
|
|
ca.Switch,
|
|
ca.Case,
|
|
ca.Default,
|
|
),
|
|
)
|
|
|
|
|
|
def for_nested_blocks(stmt: Statement, callback: Callable[[Block], None]) -> None:
|
|
def invoke(stmt: Statement) -> None:
|
|
assert isinstance(
|
|
stmt, (ca.Compound, ca.Case, ca.Default)
|
|
), "brace_nested_blocks should have turned nested statements into blocks"
|
|
callback(stmt)
|
|
|
|
if isinstance(stmt, ca.Compound):
|
|
invoke(stmt)
|
|
elif isinstance(stmt, (ca.For, ca.While, ca.DoWhile)):
|
|
invoke(stmt.stmt)
|
|
elif isinstance(stmt, ca.If):
|
|
if stmt.iftrue:
|
|
invoke(stmt.iftrue)
|
|
if stmt.iffalse:
|
|
invoke(stmt.iffalse)
|
|
elif isinstance(stmt, ca.Switch):
|
|
invoke(stmt.stmt)
|
|
elif isinstance(stmt, (ca.Case, ca.Default)):
|
|
invoke(stmt)
|
|
elif isinstance(stmt, ca.Label):
|
|
for_nested_blocks(stmt.stmt, callback)
|
|
|
|
|
|
def normalize_ast(fn: ca.FuncDef, ast: ca.FileAST) -> None:
|
|
"""Add braces to all ifs/fors/etc., to make it easier to insert statements."""
|
|
|
|
def rec(block: Block) -> None:
|
|
stmts = get_block_stmts(block, False)
|
|
for stmt in stmts:
|
|
brace_nested_blocks(stmt)
|
|
for_nested_blocks(stmt, rec)
|
|
|
|
rec(fn.body)
|
|
|
|
|
|
def prune_ast(fn: ca.FuncDef, ast: ca.FileAST) -> int:
|
|
"""Prune away unnecessary parts of the AST, to reduce overhead from serialization
|
|
and from the compiler's C parser."""
|
|
|
|
# Create a GC graph that maps names of declarations and enumerators to indices
|
|
# in ast.ext, as well an initial list of GC roots, consisting of everything
|
|
# that isn't a Decl and or Typedef.
|
|
edges: Dict[str, List[int]] = defaultdict(list)
|
|
gc_roots: List[int] = []
|
|
can_fwd_declare_typedef: Set[str] = set()
|
|
can_fwd_declare_tagged: Set[str] = set()
|
|
|
|
def add_type_edges(
|
|
tp: Union["ca.Type", ca.Struct, ca.Union, ca.Enum], i: int
|
|
) -> None:
|
|
while isinstance(tp, (ca.PtrDecl, ca.ArrayDecl)):
|
|
tp = tp.type
|
|
if isinstance(tp, ca.FuncDecl):
|
|
return
|
|
inner_type = tp.type if isinstance(tp, ca.TypeDecl) else tp
|
|
if isinstance(inner_type, ca.IdentifierType):
|
|
return
|
|
if inner_type.name:
|
|
edges[inner_type.name].append(i)
|
|
if isinstance(inner_type, ca.Enum) and inner_type.values:
|
|
for value in inner_type.values.enumerators:
|
|
edges[value.name].append(i)
|
|
if isinstance(inner_type, (ca.Struct, ca.Union)) and inner_type.decls:
|
|
for decl in inner_type.decls:
|
|
if isinstance(decl, ca.Decl):
|
|
add_type_edges(decl.type, i)
|
|
|
|
for i in range(len(ast.ext)):
|
|
item = ast.ext[i]
|
|
if isinstance(item, ca.Decl) and not item.init:
|
|
# (Exclude declarations with initializers, since taking function
|
|
# pointers can affect regalloc on IDO.)
|
|
if item.name:
|
|
edges[item.name].append(i)
|
|
if isinstance(item.type, (ca.Struct, ca.Union, ca.Enum)) and item.type.name:
|
|
can_fwd_declare_tagged.add(item.type.name)
|
|
add_type_edges(item.type, i)
|
|
elif isinstance(item, ca.Typedef):
|
|
edges[item.name].append(i)
|
|
if isinstance(item.type, ca.TypeDecl) and isinstance(
|
|
item.type.type, (ca.Struct, ca.Union, ca.Enum)
|
|
):
|
|
can_fwd_declare_typedef.add(item.name)
|
|
add_type_edges(item.type, i)
|
|
elif isinstance(item, ca.Pragma) and "GLOBAL_ASM" in item.string:
|
|
pass
|
|
else:
|
|
gc_roots.append(i)
|
|
|
|
mentioned_ids: Set[str] = set()
|
|
|
|
class IdVisitor(ca.NodeVisitor):
|
|
def visit_Pragma(self, node: ca.Pragma) -> None:
|
|
for token in re.findall(r"[a-zA-Z0-9_$]+", node.string):
|
|
mentioned_ids.add(token)
|
|
|
|
def visit_ID(self, node: ca.ID) -> None:
|
|
mentioned_ids.add(node.name)
|
|
|
|
IdVisitor().visit(ast)
|
|
|
|
# Do the GC as a DFS traversal of the graph. Visiting a node searches its
|
|
# AST for all kinds of mentioned IDs, and adds more nodes to the stack
|
|
# using the edges we found before.
|
|
gc_todo: List[int] = gc_roots
|
|
need_fwd_decl_typedef: List[str] = []
|
|
need_fwd_decl_tagged: List[str] = []
|
|
|
|
def add_name(name: str) -> None:
|
|
if name in edges:
|
|
gc_todo.extend(edges[name])
|
|
del edges[name]
|
|
|
|
class Visitor(ca.NodeVisitor):
|
|
def visit_Pragma(self, node: ca.Pragma) -> None:
|
|
for token in re.findall(r"[a-zA-Z0-9_$]+", node.string):
|
|
add_name(token)
|
|
|
|
def visit_ID(self, node: ca.ID) -> None:
|
|
add_name(node.name)
|
|
|
|
def visit_IdentifierType(self, node: ca.IdentifierType) -> None:
|
|
for name in node.names:
|
|
add_name(name)
|
|
|
|
def visit_Enum(self, node: ca.Enum) -> None:
|
|
if node.name and not node.values:
|
|
add_name(node.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Struct(self, node: ca.Struct) -> None:
|
|
if node.name and not node.decls:
|
|
add_name(node.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_Union(self, node: ca.Union) -> None:
|
|
if node.name and not node.decls:
|
|
add_name(node.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_PtrDecl(self, node: ca.PtrDecl) -> None:
|
|
# For pointer declarations which haven't been accessed, forward
|
|
# declarations suffice.
|
|
if (
|
|
isinstance(node.type, ca.TypeDecl)
|
|
and node.type.declname
|
|
and node.type.declname not in mentioned_ids
|
|
):
|
|
tp = node.type.type
|
|
if isinstance(tp, ca.IdentifierType):
|
|
if all(name in can_fwd_declare_typedef for name in tp.names):
|
|
need_fwd_decl_typedef.extend(tp.names)
|
|
return
|
|
elif tp.name and tp.name in can_fwd_declare_tagged:
|
|
if not (tp.values if isinstance(tp, ca.Enum) else tp.decls):
|
|
need_fwd_decl_tagged.append(tp.name)
|
|
return
|
|
self.generic_visit(node)
|
|
|
|
def visit_TypeDecl(self, node: ca.TypeDecl) -> None:
|
|
if node.declname:
|
|
add_name(node.declname)
|
|
self.generic_visit(node)
|
|
|
|
keep_exts: Set[int] = set()
|
|
while gc_todo:
|
|
i = gc_todo.pop()
|
|
if i not in keep_exts:
|
|
keep_exts.add(i)
|
|
Visitor().visit(ast.ext[i])
|
|
|
|
temp_id = 0
|
|
|
|
def fwd_declare(tp: Union[ca.Struct, ca.Union, ca.Enum]) -> None:
|
|
nonlocal temp_id
|
|
if not tp.name:
|
|
temp_id += 1
|
|
tp.name = f"_PermuterTemp{temp_id}"
|
|
if isinstance(tp, (ca.Struct, ca.Union)):
|
|
tp.decls = None
|
|
elif isinstance(tp, ca.Enum):
|
|
tp.values = None
|
|
else:
|
|
assert False
|
|
|
|
new_ext = []
|
|
|
|
for i, item in enumerate(ast.ext):
|
|
if i in keep_exts:
|
|
pass
|
|
elif isinstance(item, ca.Typedef) and item.name in need_fwd_decl_typedef:
|
|
assert item.name in can_fwd_declare_typedef
|
|
assert isinstance(item.type, ca.TypeDecl)
|
|
assert isinstance(item.type.type, (ca.Struct, ca.Union, ca.Enum))
|
|
fwd_declare(item.type.type)
|
|
elif (
|
|
isinstance(item, ca.Decl)
|
|
and isinstance(item.type, (ca.Struct, ca.Union, ca.Enum))
|
|
and item.type.name
|
|
and item.type.name in need_fwd_decl_tagged
|
|
):
|
|
assert item.type.name in can_fwd_declare_tagged
|
|
fwd_declare(item.type)
|
|
else:
|
|
continue
|
|
new_ext.append(item)
|
|
|
|
ast.ext = new_ext
|
|
return ast.ext.index(fn)
|