diff --git a/tools/check.py b/tools/check.py index 5fdd3bd4..d9d95bd0 100755 --- a/tools/check.py +++ b/tools/check.py @@ -3,10 +3,11 @@ import sys import util.elf +import util.checker from util import utils -def check_function(addr: int, size: int, name: str, base_fn=None) -> bool: +def check_function(checker: util.checker.FunctionChecker, addr: int, size: int, name: str, base_fn=None) -> bool: if base_fn is None: try: base_fn = util.elf.get_fn_from_base_elf(addr, size) @@ -15,7 +16,7 @@ def check_function(addr: int, size: int, name: str, base_fn=None) -> bool: return False my_fn = util.elf.get_fn_from_my_elf(name) - return util.elf.check_function_ex(addr, size, base_fn, my_fn) + return checker.check(addr, size, base_fn, my_fn) def main() -> None: @@ -23,6 +24,8 @@ def main() -> None: nonmatching_fns_with_dump = {p.stem: p.read_bytes() for p in (utils.get_repo_root() / "expected").glob("*.bin")} + checker = util.checker.FunctionChecker() + for func in utils.get_functions(): if not func.decomp_name: continue @@ -34,17 +37,17 @@ def main() -> None: continue if func.status == utils.FunctionStatus.Matching: - if not check_function(func.addr, func.size, func.decomp_name): + if not check_function(checker, func.addr, func.size, func.decomp_name): utils.print_error( f"function {utils.format_symbol_name_for_msg(func.decomp_name)} is marked as matching but does not match") failed = True elif func.status == utils.FunctionStatus.Equivalent or func.status == utils.FunctionStatus.NonMatching: - if check_function(func.addr, func.size, func.decomp_name): + if check_function(checker, func.addr, func.size, func.decomp_name): utils.print_note( f"function {utils.format_symbol_name_for_msg(func.decomp_name)} is marked as non-matching but matches") fn_dump = nonmatching_fns_with_dump.get(func.decomp_name, None) - if fn_dump is not None and not check_function(func.addr, len(fn_dump), func.decomp_name, fn_dump): + if fn_dump is not None and not check_function(checker, func.addr, len(fn_dump), func.decomp_name, fn_dump): utils.print_error( f"function {utils.format_symbol_name_for_msg(func.decomp_name)} does not match expected output") failed = True diff --git a/tools/util/checker.py b/tools/util/checker.py new file mode 100644 index 00000000..cdf3397f --- /dev/null +++ b/tools/util/checker.py @@ -0,0 +1,89 @@ +from typing import Set + +import capstone as cs + + +class FunctionChecker: + def __init__(self): + self.md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM) + self.md.detail = True + + def check(self, addr: int, size: int, base_fn: bytes, my_fn: bytes) -> bool: + adrp_pair_registers: Set[int] = set() + + for i1, i2 in zip(self.md.disasm(base_fn, addr), self.md.disasm(my_fn, addr)): + if i1.bytes == i2.bytes: + if i1.mnemonic == 'adrp': + adrp_pair_registers.add(i1.operands[0].reg) + elif i1.mnemonic == 'ldr': + reg = i1.operands[1].value.mem.base + if reg in adrp_pair_registers: + adrp_pair_registers.remove(reg) + elif i1.mnemonic == 'ldp': + reg = i1.operands[2].value.mem.base + if reg in adrp_pair_registers: + adrp_pair_registers.remove(reg) + elif i1.mnemonic == 'add': + reg = i1.operands[1].reg + if reg in adrp_pair_registers: + adrp_pair_registers.remove(reg) + continue + + if i1.mnemonic != i2.mnemonic: + return False + + # Ignore some address differences until a fully matching executable can be generated. + + if i1.mnemonic == 'bl': + continue + + if i1.mnemonic == 'b': + # Needed for tail calls. + branch_target = int(i1.op_str[1:], 16) + if not (addr <= branch_target < addr + size): + continue + + if i1.mnemonic == 'adrp': + if i1.operands[0].reg != i2.operands[0].reg: + return False + adrp_pair_registers.add(i1.operands[0].reg) + continue + + if i1.mnemonic == 'ldr' or i1.mnemonic == 'str': + if i1.operands[0].reg != i2.operands[0].reg: + return False + if i1.operands[1].value.mem.base != i2.operands[1].value.mem.base: + return False + reg = i1.operands[1].value.mem.base + if reg not in adrp_pair_registers: + return False + adrp_pair_registers.remove(reg) + continue + + if i1.mnemonic == 'ldp' or i1.mnemonic == 'stp': + if i1.operands[0].reg != i2.operands[0].reg: + return False + if i1.operands[1].reg != i2.operands[1].reg: + return False + if i1.operands[2].value.mem.base != i2.operands[2].value.mem.base: + return False + reg = i1.operands[2].value.mem.base + if reg not in adrp_pair_registers: + return False + adrp_pair_registers.remove(reg) + continue + + if i1.mnemonic == 'add': + if i1.operands[0].reg != i2.operands[0].reg: + return False + if i1.operands[1].reg != i2.operands[1].reg: + return False + reg = i1.operands[1].reg + if reg not in adrp_pair_registers: + return False + adrp_pair_registers.remove(reg) + continue + + return False + + return True diff --git a/tools/util/elf.py b/tools/util/elf.py index 0d23e4b2..56fcc535 100644 --- a/tools/util/elf.py +++ b/tools/util/elf.py @@ -19,9 +19,6 @@ my_symtab = my_elf.get_section_by_name(".symtab") if not my_symtab: utils.fail(f'{_config["myimg"]} has no symbol table') -md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM) -md.detail = True - def get_file_offset(elf, addr: int) -> int: for seg in elf.iter_segments(): @@ -59,84 +56,3 @@ def get_fn_from_my_elf(name: str) -> bytes: offset, size = get_symbol_file_offset_and_size(my_elf, my_symtab, name) my_elf.stream.seek(offset) return my_elf.stream.read(size) - - -def check_function_ex(addr: int, size: int, base_fn: bytes, my_fn: bytes) -> bool: - adrp_pair_registers: Set[int] = set() - - for i1, i2 in zip(md.disasm(base_fn, addr), md.disasm(my_fn, addr)): - if i1.bytes == i2.bytes: - if i1.mnemonic == 'adrp': - adrp_pair_registers.add(i1.operands[0].reg) - elif i1.mnemonic == 'ldr': - reg = i1.operands[1].value.mem.base - if reg in adrp_pair_registers: - adrp_pair_registers.remove(reg) - elif i1.mnemonic == 'ldp': - reg = i1.operands[2].value.mem.base - if reg in adrp_pair_registers: - adrp_pair_registers.remove(reg) - elif i1.mnemonic == 'add': - reg = i1.operands[1].reg - if reg in adrp_pair_registers: - adrp_pair_registers.remove(reg) - continue - - if i1.mnemonic != i2.mnemonic: - return False - - # Ignore some address differences until a fully matching executable can be generated. - - if i1.mnemonic == 'bl': - continue - - if i1.mnemonic == 'b': - # Needed for tail calls. - branch_target = int(i1.op_str[1:], 16) - if not (addr <= branch_target < addr + size): - continue - - if i1.mnemonic == 'adrp': - if i1.operands[0].reg != i2.operands[0].reg: - return False - adrp_pair_registers.add(i1.operands[0].reg) - continue - - if i1.mnemonic == 'ldr' or i1.mnemonic == 'str': - if i1.operands[0].reg != i2.operands[0].reg: - return False - if i1.operands[1].value.mem.base != i2.operands[1].value.mem.base: - return False - reg = i1.operands[1].value.mem.base - if reg not in adrp_pair_registers: - return False - adrp_pair_registers.remove(reg) - continue - - if i1.mnemonic == 'ldp' or i1.mnemonic == 'stp': - if i1.operands[0].reg != i2.operands[0].reg: - return False - if i1.operands[1].reg != i2.operands[1].reg: - return False - if i1.operands[2].value.mem.base != i2.operands[2].value.mem.base: - return False - reg = i1.operands[2].value.mem.base - if reg not in adrp_pair_registers: - return False - adrp_pair_registers.remove(reg) - continue - - if i1.mnemonic == 'add': - if i1.operands[0].reg != i2.operands[0].reg: - return False - if i1.operands[1].reg != i2.operands[1].reg: - return False - reg = i1.operands[1].reg - if reg not in adrp_pair_registers: - return False - adrp_pair_registers.remove(reg) - continue - - return False - - return True