From 7e7a76f92695535c31618879ebb23753e72dc822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Lam?= Date: Tue, 10 Nov 2020 21:10:58 +0100 Subject: [PATCH] tools: Print mismatch cause Makes it easier to identify what's wrong. --- tools/check.py | 4 ++- tools/util/checker.py | 58 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/tools/check.py b/tools/check.py index 648301e0..84421225 100755 --- a/tools/check.py +++ b/tools/check.py @@ -27,7 +27,7 @@ def main() -> None: nonmatching_fns_with_dump = {p.stem: util.elf.Function(p.read_bytes(), 0) for p in (utils.get_repo_root() / "expected").glob("*.bin")} - checker = util.checker.FunctionChecker() + checker = util.checker.FunctionChecker(log_mismatch_cause=True) for func in utils.get_functions(): if not func.decomp_name: @@ -43,6 +43,8 @@ def main() -> None: if not check_function(checker, func.addr, func.size, func.decomp_name): utils.print_error( f"function {utils.format_symbol_name_for_msg(func.decomp_name)} is marked as matching but does not match") + a1, a2, reason = checker.get_mismatch() + sys.stderr.write(f" at {a1|0x7100000000:#x} : {reason}\n") failed = True elif func.status == utils.FunctionStatus.Equivalent or func.status == utils.FunctionStatus.NonMatching: if check_function(checker, func.addr, func.size, func.decomp_name): diff --git a/tools/util/checker.py b/tools/util/checker.py index e7bdf4c4..261d7c96 100644 --- a/tools/util/checker.py +++ b/tools/util/checker.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Set, DefaultDict, Dict +from typing import Set, DefaultDict, Dict, Optional import capstone as cs @@ -7,28 +7,45 @@ from util import dsym, elf, utils class FunctionChecker: - def __init__(self): + def __init__(self, log_mismatch_cause: bool = False): self.md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM) self.md.detail = True self.my_symtab = elf.build_name_to_symbol_table(elf.my_symtab) self.dsymtab = dsym.DataSymbolContainer() self.decompiled_fns: Dict[int, str] = dict() + + self._log_mismatch_cause = log_mismatch_cause + self._mismatch_addr1 = -1 + self._mismatch_addr2 = -1 + self._mismatch_cause = "" + self.load_data_for_project() + def _reset_mismatch(self) -> None: + self._mismatch_addr1 = -1 + self._mismatch_addr2 = -1 + self._mismatch_cause = "" + def get_data_symtab(self) -> dsym.DataSymbolContainer: return self.dsymtab + def get_mismatch(self) -> (int, int, str): + return self._mismatch_addr1, self._mismatch_addr2, self._mismatch_cause + def load_data_for_project(self) -> None: self.decompiled_fns = {func.addr: func.decomp_name for func in utils.get_functions() if func.decomp_name} self.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv") def check(self, base_fn: elf.Function, my_fn: elf.Function) -> bool: + self._reset_mismatch() gprs1: DefaultDict[int, int] = defaultdict(int) gprs2: DefaultDict[int, int] = defaultdict(int) adrp_pair_registers: Set[int] = set() size = len(base_fn) if len(base_fn) != len(my_fn): + if self._log_mismatch_cause: + self._set_mismatch_cause(None, None, "different function length") return False for i1, i2 in zip(self.md.disasm(base_fn.data, base_fn.addr), self.md.disasm(my_fn.data, my_fn.addr)): @@ -52,12 +69,14 @@ class FunctionChecker: continue if i1.mnemonic != i2.mnemonic: + if self._log_mismatch_cause: + self._set_mismatch_cause(i1, i2, "mnemonics are different") return False # Ignore some address differences until a fully matching executable can be generated. if i1.mnemonic == 'bl': - if not self._check_function_call(i1.operands[0].imm, i2.operands[0].imm): + if not self._check_function_call(i1, i2, i1.operands[0].imm, i2.operands[0].imm): return False continue @@ -66,7 +85,7 @@ class FunctionChecker: # If we are branching outside the function, this is likely a tail call. # Treat this as a function call. if not (base_fn.addr <= branch_target < base_fn.addr + size): - if not self._check_function_call(branch_target, i2.operands[0].imm): + if not self._check_function_call(i1, i2, branch_target, i2.operands[0].imm): return False continue # Otherwise, it's a mismatch. @@ -94,7 +113,7 @@ class FunctionChecker: gprs1[reg] += i1.operands[1].value.mem.disp gprs2[reg] += i2.operands[1].value.mem.disp - if not self._check_data_symbol(gprs1[reg], gprs2[reg]): + if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]): return False adrp_pair_registers.remove(reg) @@ -113,7 +132,7 @@ class FunctionChecker: gprs1[reg] += i1.operands[1].value.mem.disp gprs2[reg] += i2.operands[1].value.mem.disp - if not self._check_data_symbol(gprs1[reg], gprs2[reg]): + if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]): return False adrp_pair_registers.remove(reg) @@ -130,7 +149,7 @@ class FunctionChecker: gprs1[reg] += i1.operands[2].imm gprs2[reg] += i2.operands[2].imm - if not self._check_data_symbol(gprs1[reg], gprs2[reg]): + if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]): return False adrp_pair_registers.remove(reg) @@ -140,18 +159,35 @@ class FunctionChecker: return True - def _check_data_symbol(self, orig_addr: int, decomp_addr: int) -> bool: + def _set_mismatch_cause(self, i1: Optional[any], i2: Optional[any], description: str) -> None: + self._mismatch_addr1 = i1.address if i1 else -1 + self._mismatch_addr2 = i2.address if i2 else -1 + self._mismatch_cause = description + + def _check_data_symbol(self, i1, i2, orig_addr: int, decomp_addr: int) -> bool: symbol = self.dsymtab.get_symbol(orig_addr) if symbol is None: return True decomp_symbol = self.my_symtab[symbol.name] - return decomp_symbol.addr == decomp_addr + if decomp_symbol.addr == decomp_addr: + return True - def _check_function_call(self, orig_addr: int, decomp_addr: int) -> bool: + if self._log_mismatch_cause: + self._set_mismatch_cause(i1, i2, f"data symbol mismatch: {symbol.name} (original address: {orig_addr:#x})") + + return False + + def _check_function_call(self, i1, i2, orig_addr: int, decomp_addr: int) -> bool: name = self.decompiled_fns.get(orig_addr, None) if name is None: return True decomp_symbol = self.my_symtab[name] - return decomp_symbol.addr == decomp_addr + if decomp_symbol.addr == decomp_addr: + return True + + if self._log_mismatch_cause: + self._set_mismatch_cause(i1, i2, f"function call mismatch: {name}") + + return False