mirror of https://github.com/zeldaret/botw.git
tools: Print mismatch cause
Makes it easier to identify what's wrong.
This commit is contained in:
parent
dc9b346d7e
commit
7e7a76f926
|
|
@ -27,7 +27,7 @@ def main() -> None:
|
|||
nonmatching_fns_with_dump = {p.stem: util.elf.Function(p.read_bytes(), 0) for p in
|
||||
(utils.get_repo_root() / "expected").glob("*.bin")}
|
||||
|
||||
checker = util.checker.FunctionChecker()
|
||||
checker = util.checker.FunctionChecker(log_mismatch_cause=True)
|
||||
|
||||
for func in utils.get_functions():
|
||||
if not func.decomp_name:
|
||||
|
|
@ -43,6 +43,8 @@ def main() -> None:
|
|||
if not check_function(checker, func.addr, func.size, func.decomp_name):
|
||||
utils.print_error(
|
||||
f"function {utils.format_symbol_name_for_msg(func.decomp_name)} is marked as matching but does not match")
|
||||
a1, a2, reason = checker.get_mismatch()
|
||||
sys.stderr.write(f" at {a1|0x7100000000:#x} : {reason}\n")
|
||||
failed = True
|
||||
elif func.status == utils.FunctionStatus.Equivalent or func.status == utils.FunctionStatus.NonMatching:
|
||||
if check_function(checker, func.addr, func.size, func.decomp_name):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections import defaultdict
|
||||
from typing import Set, DefaultDict, Dict
|
||||
from typing import Set, DefaultDict, Dict, Optional
|
||||
|
||||
import capstone as cs
|
||||
|
||||
|
|
@ -7,28 +7,45 @@ from util import dsym, elf, utils
|
|||
|
||||
|
||||
class FunctionChecker:
|
||||
def __init__(self):
|
||||
def __init__(self, log_mismatch_cause: bool = False):
|
||||
self.md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM)
|
||||
self.md.detail = True
|
||||
self.my_symtab = elf.build_name_to_symbol_table(elf.my_symtab)
|
||||
self.dsymtab = dsym.DataSymbolContainer()
|
||||
self.decompiled_fns: Dict[int, str] = dict()
|
||||
|
||||
self._log_mismatch_cause = log_mismatch_cause
|
||||
self._mismatch_addr1 = -1
|
||||
self._mismatch_addr2 = -1
|
||||
self._mismatch_cause = ""
|
||||
|
||||
self.load_data_for_project()
|
||||
|
||||
def _reset_mismatch(self) -> None:
|
||||
self._mismatch_addr1 = -1
|
||||
self._mismatch_addr2 = -1
|
||||
self._mismatch_cause = ""
|
||||
|
||||
def get_data_symtab(self) -> dsym.DataSymbolContainer:
|
||||
return self.dsymtab
|
||||
|
||||
def get_mismatch(self) -> (int, int, str):
|
||||
return self._mismatch_addr1, self._mismatch_addr2, self._mismatch_cause
|
||||
|
||||
def load_data_for_project(self) -> None:
|
||||
self.decompiled_fns = {func.addr: func.decomp_name for func in utils.get_functions() if func.decomp_name}
|
||||
self.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv")
|
||||
|
||||
def check(self, base_fn: elf.Function, my_fn: elf.Function) -> bool:
|
||||
self._reset_mismatch()
|
||||
gprs1: DefaultDict[int, int] = defaultdict(int)
|
||||
gprs2: DefaultDict[int, int] = defaultdict(int)
|
||||
adrp_pair_registers: Set[int] = set()
|
||||
|
||||
size = len(base_fn)
|
||||
if len(base_fn) != len(my_fn):
|
||||
if self._log_mismatch_cause:
|
||||
self._set_mismatch_cause(None, None, "different function length")
|
||||
return False
|
||||
|
||||
for i1, i2 in zip(self.md.disasm(base_fn.data, base_fn.addr), self.md.disasm(my_fn.data, my_fn.addr)):
|
||||
|
|
@ -52,12 +69,14 @@ class FunctionChecker:
|
|||
continue
|
||||
|
||||
if i1.mnemonic != i2.mnemonic:
|
||||
if self._log_mismatch_cause:
|
||||
self._set_mismatch_cause(i1, i2, "mnemonics are different")
|
||||
return False
|
||||
|
||||
# Ignore some address differences until a fully matching executable can be generated.
|
||||
|
||||
if i1.mnemonic == 'bl':
|
||||
if not self._check_function_call(i1.operands[0].imm, i2.operands[0].imm):
|
||||
if not self._check_function_call(i1, i2, i1.operands[0].imm, i2.operands[0].imm):
|
||||
return False
|
||||
continue
|
||||
|
||||
|
|
@ -66,7 +85,7 @@ class FunctionChecker:
|
|||
# If we are branching outside the function, this is likely a tail call.
|
||||
# Treat this as a function call.
|
||||
if not (base_fn.addr <= branch_target < base_fn.addr + size):
|
||||
if not self._check_function_call(branch_target, i2.operands[0].imm):
|
||||
if not self._check_function_call(i1, i2, branch_target, i2.operands[0].imm):
|
||||
return False
|
||||
continue
|
||||
# Otherwise, it's a mismatch.
|
||||
|
|
@ -94,7 +113,7 @@ class FunctionChecker:
|
|||
|
||||
gprs1[reg] += i1.operands[1].value.mem.disp
|
||||
gprs2[reg] += i2.operands[1].value.mem.disp
|
||||
if not self._check_data_symbol(gprs1[reg], gprs2[reg]):
|
||||
if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]):
|
||||
return False
|
||||
|
||||
adrp_pair_registers.remove(reg)
|
||||
|
|
@ -113,7 +132,7 @@ class FunctionChecker:
|
|||
|
||||
gprs1[reg] += i1.operands[1].value.mem.disp
|
||||
gprs2[reg] += i2.operands[1].value.mem.disp
|
||||
if not self._check_data_symbol(gprs1[reg], gprs2[reg]):
|
||||
if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]):
|
||||
return False
|
||||
|
||||
adrp_pair_registers.remove(reg)
|
||||
|
|
@ -130,7 +149,7 @@ class FunctionChecker:
|
|||
|
||||
gprs1[reg] += i1.operands[2].imm
|
||||
gprs2[reg] += i2.operands[2].imm
|
||||
if not self._check_data_symbol(gprs1[reg], gprs2[reg]):
|
||||
if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]):
|
||||
return False
|
||||
|
||||
adrp_pair_registers.remove(reg)
|
||||
|
|
@ -140,18 +159,35 @@ class FunctionChecker:
|
|||
|
||||
return True
|
||||
|
||||
def _check_data_symbol(self, orig_addr: int, decomp_addr: int) -> bool:
|
||||
def _set_mismatch_cause(self, i1: Optional[any], i2: Optional[any], description: str) -> None:
|
||||
self._mismatch_addr1 = i1.address if i1 else -1
|
||||
self._mismatch_addr2 = i2.address if i2 else -1
|
||||
self._mismatch_cause = description
|
||||
|
||||
def _check_data_symbol(self, i1, i2, orig_addr: int, decomp_addr: int) -> bool:
|
||||
symbol = self.dsymtab.get_symbol(orig_addr)
|
||||
if symbol is None:
|
||||
return True
|
||||
|
||||
decomp_symbol = self.my_symtab[symbol.name]
|
||||
return decomp_symbol.addr == decomp_addr
|
||||
if decomp_symbol.addr == decomp_addr:
|
||||
return True
|
||||
|
||||
def _check_function_call(self, orig_addr: int, decomp_addr: int) -> bool:
|
||||
if self._log_mismatch_cause:
|
||||
self._set_mismatch_cause(i1, i2, f"data symbol mismatch: {symbol.name} (original address: {orig_addr:#x})")
|
||||
|
||||
return False
|
||||
|
||||
def _check_function_call(self, i1, i2, orig_addr: int, decomp_addr: int) -> bool:
|
||||
name = self.decompiled_fns.get(orig_addr, None)
|
||||
if name is None:
|
||||
return True
|
||||
|
||||
decomp_symbol = self.my_symtab[name]
|
||||
return decomp_symbol.addr == decomp_addr
|
||||
if decomp_symbol.addr == decomp_addr:
|
||||
return True
|
||||
|
||||
if self._log_mismatch_cause:
|
||||
self._set_mismatch_cause(i1, i2, f"function call mismatch: {name}")
|
||||
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in New Issue