diff --git a/tools/identify_matching_functions.py b/tools/identify_matching_functions.py index 842ee3d9..afb149d2 100755 --- a/tools/identify_matching_functions.py +++ b/tools/identify_matching_functions.py @@ -40,7 +40,6 @@ def main() -> None: new_matches: Dict[int, str] = dict() checker = util.checker.FunctionChecker() - checker.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv") # Given a list L of functions to identify and a small list of candidates C, this tool will attempt to # automatically identify matches by checking each function in L against each function in C. diff --git a/tools/util/checker.py b/tools/util/checker.py index e12f5c35..e7bdf4c4 100644 --- a/tools/util/checker.py +++ b/tools/util/checker.py @@ -1,9 +1,9 @@ from collections import defaultdict -from typing import Set, DefaultDict +from typing import Set, DefaultDict, Dict import capstone as cs -from util import dsym, elf +from util import dsym, elf, utils class FunctionChecker: @@ -12,10 +12,16 @@ class FunctionChecker: self.md.detail = True self.my_symtab = elf.build_name_to_symbol_table(elf.my_symtab) self.dsymtab = dsym.DataSymbolContainer() + self.decompiled_fns: Dict[int, str] = dict() + self.load_data_for_project() def get_data_symtab(self) -> dsym.DataSymbolContainer: return self.dsymtab + def load_data_for_project(self) -> None: + self.decompiled_fns = {func.addr: func.decomp_name for func in utils.get_functions() if func.decomp_name} + self.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv") + def check(self, base_fn: elf.Function, my_fn: elf.Function) -> bool: gprs1: DefaultDict[int, int] = defaultdict(int) gprs2: DefaultDict[int, int] = defaultdict(int) @@ -51,13 +57,20 @@ class FunctionChecker: # Ignore some address differences until a fully matching executable can be generated. if i1.mnemonic == 'bl': + if not self._check_function_call(i1.operands[0].imm, i2.operands[0].imm): + return False continue if i1.mnemonic == 'b': - # Needed for tail calls. - branch_target = int(i1.op_str[1:], 16) + branch_target = i1.operands[0].imm + # If we are branching outside the function, this is likely a tail call. + # Treat this as a function call. if not (base_fn.addr <= branch_target < base_fn.addr + size): + if not self._check_function_call(branch_target, i2.operands[0].imm): + return False continue + # Otherwise, it's a mismatch. + return False if i1.mnemonic == 'adrp': if i1.operands[0].reg != i2.operands[0].reg: @@ -134,3 +147,11 @@ class FunctionChecker: decomp_symbol = self.my_symtab[symbol.name] return decomp_symbol.addr == decomp_addr + + def _check_function_call(self, orig_addr: int, decomp_addr: int) -> bool: + name = self.decompiled_fns.get(orig_addr, None) + if name is None: + return True + + decomp_symbol = self.my_symtab[name] + return decomp_symbol.addr == decomp_addr