tp/tools/splitter/split.py

339 lines
11 KiB
Python

"""
split.py - 202x erin moon for zeldaret
"""
from typing import Iterable, List
from dataclasses import dataclass
from pathlib import Path, PosixPath
from textwrap import dedent
from loguru import logger
from datetime import datetime
import re
import click
from asm_parser import asm, Emittable, Global, Label, Line, BlockComment, Instruction
from demangle import parse_framework_map, demangle
from util import PathPath, pairwise
from pprint import pprint
import pickle
import IPython
SDA_BASE = 0x80458580
SDA2_BASE = 0x80459A00
__version__ = 'v0.4'
def function_global_search(lines: List[Line]) -> Iterable[Line]:
i = 0
while i < len(lines):
if isinstance(lines[i].body, Global):
sym = lines[i].body.symbol
if isinstance(lines[i + 1].body, Label) and lines[i + 1].body.symbol == sym:
yield lines[i]
i += 2
else:
i += 1
def emit_lines(lines: List[Line]) -> str:
return '\n'.join([line.emit() for line in lines])
def comment_out(line: Line) -> Line:
return Line(line.index, [BlockComment(line.emit())], None)
def fix_sda_base_add(line: Line) -> Line:
if 'SDA' in line.body.operands[2]:
ops = line.body.operands[2].split('-')
lbl_name = ops[0]
if ops[1] == '_SDA_BASE_':
sda_reg = 'r13'
elif ops[1] == '_SDA2_BASE_':
sda_reg = 'r2'
else:
logger.error('Unknown SDABASE!')
return line
line.body.opcode = 'la'
line.body.operands = [line.body.operands[0], f'{lbl_name}({sda_reg})']
return line
QUANT_REG_RE = re.compile(r'qr(\d+)')
def patch_gqrs(line: Line) -> Line:
line.body.operands = [QUANT_REG_RE.sub(r'\1', o) for o in line.body.operands]
return line
@dataclass
class Function:
name: str
addr: int
lines: List[Line]
@property
def line_count(self):
return len(self.lines)
@property
def filename(self) -> str:
return f'func_{self.addr:X}.s'
def include_path(self, base: Path) -> str:
return str(PosixPath(base) / PosixPath(self.filename))
def find_functions(lines: List[Line], framework_map) -> Iterable[Function]:
for func_global_line, next_func_global_line in pairwise(
function_global_search(lines)
):
# some blocks weren't properly split, use the map to find missing functions
fr = func_global_line.index + 2
# if no next global, to = None => end of file
to = next_func_global_line and next_func_global_line.index
func_lines = lines[fr:to]
func_idx = []
for idx, line in enumerate(func_lines):
if isinstance(line.body, Instruction):
addr = int(line.content[0].text.strip().split()[0], 16)
if f'{addr:x}' in framework_map:
func_idx.append(idx)
for start_idx, end_idx in pairwise(func_idx):
sub_func_lines = func_lines[
start_idx : (len(func_lines) if end_idx == None else end_idx)
]
addr = int(sub_func_lines[0].content[0].text.strip().split()[0], 16)
yield Function(
name=func_global_line.body.symbol
if start_idx == 0
else f'func_{addr:X}',
addr=addr,
lines=sub_func_lines,
)
def emit_cxx_asmfn(inc_base: Path, func: Function) -> str:
return dedent(
'''\
asm void {name}(void) {{
nofralloc
#include "{inc}"
}}'''.format(
name=func.name, inc=func.include_path(inc_base)
)
)
def emit_cxx_extern_fns(tu_file: str, labels: Iterable[str]) -> str:
def decl(label):
return f'void {label}(void);'
defs = '\n '.join(decl(label) for label in sorted(labels))
return (
f'// additional symbols needed for {tu_file}\n'
f'// autogenerated by split.py {__version__} at {datetime.utcnow()}\n'
'extern "C" {\n'
' ' + defs + '\n}'
)
def emit_cxx_extern_vars(tu_file: str, labels: Iterable[str]) -> str:
def decl(label):
return f'extern u8 {label};'
return (
f'// additional symbols needed for {tu_file}\n'
f'// autogenerated by split.py {__version__} at {datetime.utcnow()}\n'
+ '\n'.join(decl(label) for label in sorted(labels))
+ '\n'
)
@click.command()
@click.argument('src', type=PathPath(file_okay=True, dir_okay=False, exists=True))
@click.argument('cxx_out', type=PathPath(file_okay=True, dir_okay=False))
@click.option(
'--funcs-out',
type=PathPath(file_okay=False, dir_okay=True),
default='include/funcs',
)
@click.option('--s-include-base', type=str, default='funcs')
@click.option(
'--framework-map-file',
type=PathPath(file_okay=True, dir_okay=False),
default='frameworkF.map',
)
@click.option(
'--ldscript-file',
type=PathPath(file_okay=True, dir_okay=False),
default='ldscript.lcf',
)
@click.option('--from-line', type=int)
@click.option('--to-line', type=int)
@click.option('--preparsed', is_flag=True)
@click.option('--forceactive',
type=click.Choice(['all', 'none', 'missingfunc']), default='missingfunc')
def split(
src,
cxx_out,
funcs_out,
s_include_base,
framework_map_file,
ldscript_file,
from_line,
to_line,
preparsed,
forceactive
):
funcs_out.mkdir(exist_ok=True, parents=True)
if preparsed:
logger.info('loading preparsed assembly')
with src.open('rb') as f:
lines = pickle.load(f)
else:
logger.info('parsing assembly')
lines = asm.parse(src.read_text())
lines = lines[
(from_line - 1 if from_line else 0) : (to_line - 1 if to_line else -1)
]
logger.info('parsing map file')
framework_map = parse_framework_map(framework_map_file)
logger.debug(f'loaded {len(framework_map)} symbols from map')
logger.info('reading ldscript')
ldscript_file_content = ldscript_file.read_text()
new_ldfuncs = []
# -- get all defined labels and jump targets
jumped_labels = set()
defined_labels = set()
logger.info('scanning for branch targets')
for line in lines:
if isinstance(line.body, Label):
defined_labels.add(line.body.symbol)
if isinstance(line.body, Instruction):
if line.body.opcode[0] == 'b' and line.body.operands != []: # branch
jumped_labels.add(line.body.operands[0]) # jump target
# -- find everything of the form lbl_[hex] that's in an operand on the RHS of a l* instruction
# this is a relatively okay assumption given that any var that's *not* of the form lbl_ has
# probably already been renamed and thus is exported in variables.h
LBL_RE = re.compile(r'lbl_[0-9A-F]+')
def find_labels_in_operands(operands):
for operand in operands:
if match := LBL_RE.search(operand):
yield match.group()
logger.info('scanning for load/store target labels')
loaded_labels = set()
for line in lines:
if isinstance(line.body, Instruction):
if line.body.opcode[0] in {'l', 's'} or line.body.opcode == 'addi': # load and store instructions, ish. Also addi for loading SDA addresses
loaded_labels |= set(find_labels_in_operands(line.body.operands))
# -- find all defined functions and split them
functions = list(find_functions(lines, framework_map))
logger.info('splitting functions')
for func in functions:
logger.debug(
f'working on function {func.name} @ {func.addr:X} with {func.line_count} lines'
)
# comment out .globals
func.lines = [
comment_out(line) if isinstance(line.body, Global) else line
for line in func.lines
]
# fix SDA_BASE addi
func.lines = [
fix_sda_base_add(line)
if isinstance(line.body, Instruction)
and line.body.opcode == 'addi'
else line
for line in func.lines
]
# remove GQR mnemonics
func.lines = [
patch_gqrs(line)
if isinstance(line.body, Instruction)
and line.body.opcode.startswith('psq_')
else line
for line in func.lines
]
# check if needs to be defined in ldscript
if forceactive != 'none':
if not func.name in ldscript_file_content and (forceactive=='all' or func.name.startswith('func_')):
new_ldfuncs.append(func.name)
with open(out_path := funcs_out / func.filename, 'w') as f:
logger.debug(f'emitting {out_path}')
f.write(emit_lines(func.lines))
# -- dump new labels to functions.h
logger.info('dumping labels to extern functions header')
func_labels = jumped_labels - defined_labels
# add in everything we def to make sure asm can get backrefs
for func in functions:
func_labels.add(func.name)
# -- write asm stubs to cxx_out (could've done this as part of previous loop but imo this is cleaner)
logger.info(f'emitting c++ asm stubs to {cxx_out}')
with open(cxx_out, 'w') as f:
f.write(
f'/* {cxx_out.name} autogenerated by split.py {__version__} at {datetime.utcnow()} */\n\n'
)
f.write('#include "global.h"\n\n')
# extern functions
f.write(emit_cxx_extern_fns(cxx_out.name, func_labels))
f.write('\n\n')
# extern variables
f.write(emit_cxx_extern_vars(cxx_out.name, loaded_labels))
f.write('\n\n')
f.write('extern "C" {\n')
for func in functions:
logger.debug(f'emitting asm stub for {func.name}')
mangled_func_name = framework_map[f'{func.addr:x}']
f.write(f'// {mangled_func_name}\n')
try:
demangled_func_name = demangle(mangled_func_name)
f.write(f'// {demangled_func_name}\n')
except Exception as e:
logger.warning(f"could not demangle symbol '{mangled_func_name}': {e}")
f.write(emit_cxx_asmfn(s_include_base, func))
f.write('\n\n')
f.write('};\n') # extern C end
# -- make defined functions FORCEACTIVE in ldscript.lcf
logger.info(f'writing to FORCEACTIVE in linker script')
forceactive_start = ldscript_file_content.find('FORCEACTIVE')
forceactive_end = ldscript_file_content.find('}', forceactive_start)
for func in new_ldfuncs:
ldscript_file_content = (
ldscript_file_content[:forceactive_end]
+ func
+ '\n'
+ ldscript_file_content[forceactive_end:]
)
ldscript_file.write_text(ldscript_file_content)
if __name__ == '__main__':
split()