a new tool: splitter (#39)

* splitter: v0.1

* basic demangle stuff

* splitter: v0.2

- add from, to options to select line range in .s to process
- infer referred labels from both loads and stores
- dump our own function labels in addition to externs into functions.h,
  to provide forward declarations for labels that other functions might
  use.
- fix off-by-one which was eating the last instruction of some functions

* splitter: v0.3

merged a bunch of work lepelog did, including:
- demangling support
- better function identification
- automatic FORCEACTIVE

and did a little bit of cleanup

* splitter: improve sda hack and format

* splitter: fix comment_out(), patch GQR references

* splitter: some speed optimizations

* remove debug print

* splitter: forceactive options

* refactor demangler, add support for more operators and more mangling symbols

* array and member (still one non working case)

* fix some operands in demangler

* make parents for funcs_out

* splitter: fix off-by-one in last line of last function in some .s files

Co-authored-by: lepelog <lepelog@users.noreply.github.com>
Co-authored-by: Pheenoh <pheenoh@gmail.com>
This commit is contained in:
Erin Moon 2021-01-02 01:22:57 -06:00 committed by GitHub
parent c5b3cbaa08
commit 052119c7a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 849 additions and 0 deletions

300
tools/splitter/demangle.py Normal file
View File

@ -0,0 +1,300 @@
from typing import List, Optional
from pathlib import Path
from dataclasses import dataclass, field
import re
operator_func_re = re.compile(r'^__([a-z]+)')
types = {
'i': 'int',
'l': 'long',
's': 'short',
'c': 'char',
'f': 'float',
'd': 'double',
'v': 'void',
'x': 'long long',
'b': 'bool',
'e': 'varargs...',
}
# {'defctor', 'ops',}
special_funcs = {
'eq': 'operator==',
'as': 'operator=',
'ne': 'operator!=',
'dv': 'operator/',
'pl': 'operator+',
'mi': 'operator-',
'ml': 'operator*',
'adv': 'operator/=',
'apl': 'operator+=',
'ami': 'operator-=',
'amu': 'operator*=',
'lt': 'operator<',
'gt': 'operator>',
'cl': 'operator()',
'dla': 'operator delete[]',
'nwa': 'operator new[]',
'dl': 'operator delete',
'nw': 'operator new',
}
@dataclass
class Param:
name: str = ''
pointer_lvl: int = 0
is_const: bool = False
is_ref: bool = False
is_unsigned: bool = False
is_signed: bool = False
def to_str(self) -> str:
ret = ''
if self.is_const:
ret += 'const '
if self.is_unsigned:
ret += 'unsigned '
ret += self.name
for _ in range(self.pointer_lvl):
ret += '*'
if self.is_ref:
ret += '&'
return ret
@dataclass
class FuncParam:
ret_type: Optional[str] = None
params: List[str] = field(default_factory=list)
def to_str(self) -> str:
ret = ''
if self.ret_type is None:
ret += 'void'
else:
ret += self.ret_type
ret += ' (*)('
ret += ', '.join(self.params)
ret += ')'
return ret
class ParseError(Exception):
...
class ParseCtx:
def __init__(self, mangled: str):
self.mangled = mangled
self.index = 0
self.demangled = []
self.cur_type = None
self.class_name = None
self.is_const = False
self.func_name = None
def demangle(self):
# this split is still not accurate, but good enough for most cases
last_f = self.mangled.rfind('F')
if last_f == -1:
return
split_pos = self.mangled.rfind('__', 0, last_f)
if split_pos == -1 or split_pos == 0:
return
self.func_name = self.mangled[:split_pos]
self.mangled = self.mangled[split_pos+2:]
if self.func_name.startswith('__'):
match = operator_func_re.match(self.func_name)
if match:
special_func_name = match.group(1)
if special_func_name in special_funcs:
self.func_name = special_funcs[special_func_name]
else:
if special_func_name == 'ct':
self.func_name = '.ctor'
elif special_func_name == 'dt':
self.func_name = '.dtor'
self.demangle_first_class()
while self.index < len(self.mangled):
self.demangled.append(self.demangle_next_type())
if self.func_name == '.ctor':
self.func_name = self.class_name
if self.func_name == '.dtor':
self.func_name = '~' + self.class_name
def demangle_first_class(self):
if self.peek_next_char().isdecimal():
self.class_name = self.demangle_class()
if self.peek_next_char() == 'C':
self.is_const = True
self.index += 1
assert self.consume_next_char() == 'F', 'next char should be F!'
elif self.peek_next_char() == 'Q':
self.index += 1
self.class_name = self.demangle_qualified_name()
if self.peek_next_char() == 'C':
self.is_const = True
self.index += 1
assert self.consume_next_char() == 'F', 'next char should be F!'
else:
assert self.consume_next_char() == 'F', 'next char should be F!'
def demangle_next_type(self) -> str:
cur_type = Param()
while True:
cur_char = self.peek_next_char()
if cur_char.isdecimal():
class_name = self.demangle_class()
cur_type.name = class_name
return cur_type.to_str()
elif cur_char in types:
type_name = self.demangle_prim_type()
cur_type.name = type_name
return cur_type.to_str()
elif cur_char == 'U':
cur_type.is_unsigned = True
self.index += 1
elif cur_char == 'S':
cur_type.is_signed = True
self.index += 1
elif cur_char == 'C':
cur_type.is_const = True
self.index += 1
elif cur_char == 'P':
cur_type.pointer_lvl += 1
self.index += 1
elif cur_char == 'R':
cur_type.is_ref = True
self.index += 1
elif cur_char == 'F':
self.index += 1
func = self.demangle_function()
return func.to_str()
elif cur_char == 'Q':
self.index += 1
return self.demangle_qualified_name()
elif cur_char == 'A':
if cur_type.pointer_lvl < 1 and not cur_type.is_ref:
raise ParseError("pointer level for array is wrong!")
# decrease pointer level by one, cause one is already handled in the array demangle
if not cur_type.is_ref:
cur_type.pointer_lvl -= 1
cur_type.name = self.demangle_array()
return cur_type.to_str()
else:
raise ParseError(f'unexpected character {cur_char}')
def demangle_array(self) -> str:
sizes = []
while self.peek_next_char() == 'A':
self.index += 1
sizes.append(self.read_next_int())
if self.consume_next_char() != '_':
raise ParseError("Need to have '_' after Array size!")
array_type = self.demangle_next_type()
return f'{array_type} []' + ''.join(f'[{i}]' for i in sizes)
def demangle_function(self) -> FuncParam:
func_param = FuncParam()
while True:
cur_char = self.peek_next_char()
if cur_char == '_':
self.index += 1
func_param.ret_type = self.demangle_next_type()
return func_param
func_param.params.append(self.demangle_next_type())
def demangle_qualified_name(self) -> str:
part_count = int(self.consume_next_char())
parts = []
for _ in range(part_count):
parts.append(self.demangle_class())
return '::'.join(parts)
def read_next_int(self) -> int:
class_len_str = ''
cur_char = self.peek_next_char()
while cur_char.isdecimal():
class_len_str += cur_char
self.index += 1
cur_char = self.peek_next_char()
return int(class_len_str)
def demangle_class(self) -> str:
if not self.peek_next_char().isdecimal():
raise ParseError(f'class mangling must start with number')
class_len = self.read_next_int()
class_name = self.mangled[self.index : self.index + class_len]
self.index += class_len
if self.peek_next_char() == 'M':
self.index += 1
class_name += '::' + self.demangle_class()
return class_name
def demangle_prim_type(self) -> str:
ret = types[self.consume_next_char()]
return ret
def consume_next_char(self) -> str:
next_char = self.mangled[self.index]
self.index += 1
return next_char
def peek_next_char(self) -> str:
if self.index >= len(self.mangled):
return None
return self.mangled[self.index]
def to_str(self) -> str:
if self.func_name is None:
return ''
elif self.class_name is None:
return self.func_name + '(' + ', '.join(self.demangled) + ')'
else:
return self.class_name + '::' + self.func_name + '(' + ', '.join(self.demangled) + ')' + (' const' if self.is_const else '')
def demangle(s):
p = ParseCtx(s)
p.demangle()
return p.to_str()
def parse_framework_map(path: Path):
address_funcname = {}
with path.open() as f:
for line in f.readlines():
if line.startswith('.ctors'):
return address_funcname
if not line.startswith(' '):
continue
funcname = line[30:].split(' ', 1)[0]
address = line[18:26]
address_funcname[address] = funcname
return address_funcname
# def try_demangle_all():
# with open('frameworkF.map') as f:
# for line in f.readlines():
# if line.startswith('.ctors'):
# return
# if not line.startswith(' '):
# continue
# line = line[30:]
# line_spl = line.split(' ',1)[0]
# try:
# d = demangle(line_spl)
# if d:
# print(d)
# # except NotImplementedError:
# # pass
# except Exception as e:
# # print(f'could not demangle {line_spl}: {repr(e)}')
# # raise e
# pass
# try_demangle_all()

159
tools/splitter/parser.py Normal file
View File

@ -0,0 +1,159 @@
from dataclasses import dataclass
from parsy import string, regex, seq, generate, line_info
from typing import Optional, List, Union, Protocol
class Emittable(Protocol):
def emit(self) -> str:
...
@dataclass
class BlockComment:
text: str
def emit(self) -> str:
return f'/*{self.text}*/'
@dataclass
class TrailingComment:
text: str
def emit(self) -> str:
return f'# {self.text}'
@dataclass
class Include:
file: str
def emit(self) -> str:
return f'.include "{self.file}"'
@dataclass
class Section:
name: str
flags: Optional[str]
def emit(self) -> str:
directive = f'.section .{self.name}'
if self.flags is not None:
directive += f', "{self.flags}"'
return directive
@dataclass
class Global:
symbol: str
def emit(self) -> str:
return f'.global {self.symbol}'
@dataclass
class Label:
symbol: str
def emit(self) -> str:
return f'{self.symbol}:'
@dataclass
class Instruction:
opcode: str
operands: List[str]
def emit(self) -> str:
instr = self.opcode
if len(self.operands) > 0:
instr += ' ' + ', '.join(self.operands)
return instr
@dataclass
class Line:
index: int
content: List[
Union[
BlockComment, TrailingComment, Instruction, Global, Section, Include, Label
]
]
body: Optional[Union[Global, Section, Include, Label, Instruction]]
def emit(self) -> str:
return ' '.join([x.emit() for x in self.content])
space = regex(r'[ \t]+')
line_ending = regex('(\n)|(\r\n)').desc('newline')
pad = regex(r'[ \t]*')
block_comment = (
string('/*') >> regex(r'[\w\s]*').map(BlockComment) << string('*/')
).desc('block comment')
trailing_comment = (
string('#') >> pad >> regex(r'[^\n\r]*').map(TrailingComment)
).desc('trailing comment')
symbolname = regex(r'[a-zA-Z._$][a-zA-Z0-9._$?]*')
label = (symbolname.map(Label) << string(':')).desc('label')
delimited_string = (string('"') >> regex(r'[^"]*') << string('"')).desc(
'double-quote delimited string'
)
directive_include = string('include') >> space >> delimited_string.map(Include)
directive_section = seq(
name=string('section')
>> space
>> string('.')
>> regex(r'[a-z]+'),
flags=(pad >> string(',') >> space >> delimited_string).optional(),
).combine_dict(Section)
directive_global = string('global') >> space >> symbolname.map(Global)
directive = (
string('.')
>> (
directive_include
| directive_section
| directive_global
| string('text').result(Section('text', flags=None))
| string('data').result(Section('data', flags=None))
)
).desc('directive')
opcode = regex(r'[a-z_0-9]+\.?').concat().desc('opcode')
operand = regex(r'[^,#\s]+')
operands = operand.sep_by(string(',') << pad)
@generate
def instruction():
op = yield opcode
sp = yield space.optional()
if sp:
oprs = yield operands
else:
oprs = []
return Instruction(op, oprs)
@generate
def line():
line, _ = yield line_info
content = yield (pad >> block_comment << pad).many()
body = yield (directive | label | instruction).optional() << pad
if body:
content.append(body)
content += yield (pad >> block_comment).many()
trailing = yield (pad >> trailing_comment).optional()
if trailing:
content.append(trailing)
return Line(line, content, body)
asm = line.sep_by(line_ending)

373
tools/splitter/split.py Normal file
View File

@ -0,0 +1,373 @@
"""
split.py - 202x erin moon for zeldaret
"""
from typing import Iterable, List
from dataclasses import dataclass
from pathlib import Path, PosixPath
from textwrap import dedent
from loguru import logger
from datetime import datetime
import re
import click
from parser import asm, Emittable, Global, Label, Line, BlockComment, Instruction
from demangle import parse_framework_map, demangle
from util import PathPath, pairwise
from pprint import pprint
import pickle
import IPython
SDA_BASE = 0x80458580
SDA2_BASE = 0x80459A00
__version__ = 'v0.3'
def function_global_search(lines: List[Line]) -> Iterable[Line]:
i = 0
while i < len(lines):
if isinstance(lines[i].body, Global):
sym = lines[i].body.symbol
if isinstance(lines[i + 1].body, Label) and lines[i + 1].body.symbol == sym:
yield lines[i]
i += 2
else:
i += 1
def emit_lines(lines: List[Line]) -> str:
return '\n'.join([line.emit() for line in lines])
def comment_out(line: Line) -> Line:
return Line(line.index, [BlockComment(line.emit())], None)
def fix_sda_base_add(line: Line) -> Line:
if 'SDA' in line.body.operands[2]:
ops = line.body.operands[2].split('-')
lbl_addr = int(ops[0][4:], 16)
if ops[1] == '_SDA_BASE_':
sda_addr = SDA_BASE
elif ops[1] == '_SDA2_BASE_':
sda_addr = SDA2_BASE
else:
logger.error('Unknown SDABASE!')
return line
line.content.append(
BlockComment(f'SDA HACK; original: {line.body.operands[2]}')
)
line.body.operands[2] = f'0x{lbl_addr:X} - 0x{sda_addr:X}'
return line
QUANT_REG_RE = re.compile(r'qr(\d+)')
def patch_gqrs(line: Line) -> Line:
line.body.operands = [QUANT_REG_RE.sub(r'\1', o) for o in line.body.operands]
return line
@dataclass
class Function:
name: str
addr: int
lines: List[Line]
@property
def line_count(self):
return len(self.lines)
@property
def filename(self) -> str:
return f'func_{self.addr:X}.s'
def include_path(self, base: Path) -> str:
return str(PosixPath(base) / PosixPath(self.filename))
def find_functions(lines: List[Line], framework_map) -> Iterable[Function]:
for func_global_line, next_func_global_line in pairwise(
function_global_search(lines)
):
# some blocks weren't properly split, use the map to find missing functions
fr = func_global_line.index + 2
# if no next global, to = None => end of file
to = next_func_global_line and next_func_global_line.index
func_lines = lines[fr:to]
func_idx = []
for idx, line in enumerate(func_lines):
if isinstance(line.body, Instruction):
addr = int(line.content[0].text.strip().split()[0], 16)
if f'{addr:x}' in framework_map:
func_idx.append(idx)
for start_idx, end_idx in pairwise(func_idx):
sub_func_lines = func_lines[
start_idx : (len(func_lines) if end_idx == None else end_idx)
]
addr = int(sub_func_lines[0].content[0].text.strip().split()[0], 16)
yield Function(
name=func_global_line.body.symbol
if start_idx == 0
else f'func_{addr:X}',
addr=addr,
lines=sub_func_lines,
)
def emit_cxx_asmfn(inc_base: Path, func: Function) -> str:
return dedent(
'''\
asm void {name}(void) {{
nofralloc
#include "{inc}"
}}'''.format(
name=func.name, inc=func.include_path(inc_base)
)
)
def emit_cxx_extern_fns(tu_file: str, labels: Iterable[str]) -> str:
def decl(label):
return f'void {label}(void);'
defs = '\n '.join(decl(label) for label in labels)
return (
f'// additional symbols needed for {tu_file}\n'
f'// autogenerated by split.py {__version__} at {datetime.utcnow()}\n'
'extern "C" {\n'
' ' + defs + '\n}'
)
def emit_cxx_extern_vars(tu_file: str, labels: Iterable[str]) -> str:
def decl(label):
return f'extern u8 {label};'
return (
f'// additional symbols needed for {tu_file}\n'
f'// autogenerated by split.py {__version__} at {datetime.utcnow()}\n'
+ '\n'.join(decl(label) for label in labels)
+ '\n'
)
@click.command()
@click.argument('src', type=PathPath(file_okay=True, dir_okay=False, exists=True))
@click.argument('cxx_out', type=PathPath(file_okay=True, dir_okay=False))
@click.option(
'--funcs-out',
type=PathPath(file_okay=False, dir_okay=True),
default='include/funcs',
)
@click.option('--s-include-base', type=str, default='funcs')
@click.option(
'--extern-functions-file',
type=PathPath(file_okay=True, dir_okay=False),
default='include/functions.h',
)
@click.option(
'--extern-variables-file',
type=PathPath(file_okay=True, dir_okay=False),
default='include/variables.h',
)
@click.option(
'--framework-map-file',
type=PathPath(file_okay=True, dir_okay=False),
default='frameworkF.map',
)
@click.option(
'--ldscript-file',
type=PathPath(file_okay=True, dir_okay=False),
default='ldscript.lcf',
)
@click.option('--from-line', type=int)
@click.option('--to-line', type=int)
@click.option('--preparsed', is_flag=True)
@click.option('--forceactive',
type=click.Choice(['all', 'none', 'missingfunc']), default='missingfunc')
def split(
src,
cxx_out,
funcs_out,
s_include_base,
extern_functions_file,
extern_variables_file,
framework_map_file,
ldscript_file,
from_line,
to_line,
preparsed,
forceactive
):
funcs_out.mkdir(exist_ok=True, parents=True)
if preparsed:
logger.info('loading preparsed assembly')
with src.open('rb') as f:
lines = pickle.load(f)
else:
logger.info('parsing assembly')
lines = asm.parse(src.read_text())
lines = lines[
(from_line - 1 if from_line else 0) : (to_line - 1 if to_line else -1)
]
logger.info('reading extern func/vars files')
extern_funcs_src = extern_functions_file.read_text()
extern_vars_src = extern_variables_file.read_text()
logger.info('parsing map file')
framework_map = parse_framework_map(framework_map_file)
logger.debug(f'loaded {len(framework_map)} symbols from map')
logger.info('reading ldscript')
ldscript_file_content = ldscript_file.read_text()
new_ldfuncs = []
# -- get all defined labels and jump targets
jumped_labels = set()
defined_labels = set()
logger.info('scanning for branch targets')
for line in lines:
if isinstance(line.body, Label):
defined_labels.add(line.body.symbol)
if isinstance(line.body, Instruction):
if line.body.opcode[0] == 'b' and line.body.operands != []: # branch
jumped_labels.add(line.body.operands[0]) # jump target
# -- find everything of the form lbl_[hex] that's in an operand on the RHS of a l* instruction
# this is a relatively okay assumption given that any var that's *not* of the form lbl_ has
# probably already been renamed and thus is exported in variables.h
LBL_RE = re.compile(r'lbl_[0-9A-F]+')
def find_labels_in_operands(operands):
for operand in operands:
if match := LBL_RE.search(operand):
yield match.group()
logger.info('scanning for load/store target labels')
loaded_labels = set()
for line in lines:
if isinstance(line.body, Instruction):
if line.body.opcode[0] in {'l', 's'}: # load and store instructions, ish
loaded_labels |= set(find_labels_in_operands(line.body.operands))
# -- dump new variable labels to variables.h
logger.info('dumping variable labels to extern vars header')
vars_new = set()
for label in loaded_labels:
if label not in extern_vars_src:
logger.debug(f'adding extern var {label} to {extern_variables_file}')
vars_new.add(label)
if len(vars_new) > 0:
with open(extern_variables_file, 'a') as f:
f.write('\n\n')
f.write(emit_cxx_extern_vars(cxx_out.name, vars_new))
# -- find all defined functions and split them
functions = list(find_functions(lines, framework_map))
logger.info('splitting functions')
for func in functions:
logger.debug(
f'working on function {func.name} @ {func.addr:X} with {func.line_count} lines'
)
# comment out .globals
func.lines = [
comment_out(line) if isinstance(line.body, Global) else line
for line in func.lines
]
# fix SDA_BASE addi
func.lines = [
fix_sda_base_add(line)
if isinstance(line.body, Instruction)
and line.body.opcode == 'addi'
else line
for line in func.lines
]
# remove GQR mnemonics
func.lines = [
patch_gqrs(line)
if isinstance(line.body, Instruction)
and line.body.opcode.startswith('psq_')
else line
for line in func.lines
]
# check if needs to be defined in ldscript
if forceactive != 'none':
if not func.name in ldscript_file_content and (forceactive=='all' or func.name.startswith('func_')):
new_ldfuncs.append(func.name)
with open(out_path := funcs_out / func.filename, 'w') as f:
logger.debug(f'emitting {out_path}')
f.write(emit_lines(func.lines))
# -- dump new labels to functions.h
logger.info('dumping labels to extern functions header')
func_labels = jumped_labels - defined_labels
# add in everything we def to make sure asm can get backrefs
for func in functions:
func_labels.add(func.name)
# get rid of stuff already in functions.h. extremely hacky
funcs_new_labels = set()
for label in func_labels:
if label not in extern_funcs_src:
logger.info(f'adding extern func {label} to {extern_functions_file}')
funcs_new_labels.add(label)
if len(funcs_new_labels) > 0:
with open(extern_functions_file, 'a') as f:
f.write('\n\n')
f.write(emit_cxx_extern_fns(cxx_out.name, funcs_new_labels))
# -- write asm stubs to cxx_out (could've done this as part of previous loop but imo this is cleaner)
logger.info(f'emitting c++ asm stubs to {cxx_out}')
with open(cxx_out, 'w') as f:
f.write(
f'/* {cxx_out.name} autogenerated by split.py {__version__} at {datetime.utcnow()} */\n\n'
)
f.write('#include "global.h"\n\n')
f.write('extern "C" {\n')
for func in functions:
logger.debug(f'emitting asm stub for {func.name}')
mangled_func_name = framework_map[f'{func.addr:x}']
f.write(f'// {mangled_func_name}\n')
try:
demangled_func_name = demangle(mangled_func_name)
f.write(f'// {demangled_func_name}\n')
except Exception as e:
logger.warning(f"could not demangle symbol '{mangled_func_name}': {e}")
f.write(emit_cxx_asmfn(s_include_base, func))
f.write('\n\n')
f.write('};\n') # extern C end
# -- make defined functions FORCEACTIVE in ldscript.lcf
logger.info(f'writing to FORCEACTIVE in linker script')
forceactive_start = ldscript_file_content.find('FORCEACTIVE')
forceactive_end = ldscript_file_content.find('}', forceactive_start)
for func in new_ldfuncs:
ldscript_file_content = (
ldscript_file_content[:forceactive_end]
+ func
+ '\n'
+ ldscript_file_content[forceactive_end:]
)
ldscript_file.write_text(ldscript_file_content)
if __name__ == '__main__':
split()

17
tools/splitter/util.py Normal file
View File

@ -0,0 +1,17 @@
import itertools
import click
from pathlib import Path
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
return itertools.zip_longest(a, b)
class PathPath(click.Path):
"""A Click path argument that returns a pathlib Path, not a string"""
def convert(self, value, param, ctx):
return Path(super().convert(value, param, ctx))