tools: Add function call checking

This commit is contained in:
Léo Lam 2020-11-10 20:03:43 +01:00
parent a3abb115b9
commit dc9b346d7e
No known key found for this signature in database
GPG Key ID: 0DF30F9081000741
2 changed files with 25 additions and 5 deletions

View File

@ -40,7 +40,6 @@ def main() -> None:
new_matches: Dict[int, str] = dict() new_matches: Dict[int, str] = dict()
checker = util.checker.FunctionChecker() checker = util.checker.FunctionChecker()
checker.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv")
# Given a list L of functions to identify and a small list of candidates C, this tool will attempt to # Given a list L of functions to identify and a small list of candidates C, this tool will attempt to
# automatically identify matches by checking each function in L against each function in C. # automatically identify matches by checking each function in L against each function in C.

View File

@ -1,9 +1,9 @@
from collections import defaultdict from collections import defaultdict
from typing import Set, DefaultDict from typing import Set, DefaultDict, Dict
import capstone as cs import capstone as cs
from util import dsym, elf from util import dsym, elf, utils
class FunctionChecker: class FunctionChecker:
@ -12,10 +12,16 @@ class FunctionChecker:
self.md.detail = True self.md.detail = True
self.my_symtab = elf.build_name_to_symbol_table(elf.my_symtab) self.my_symtab = elf.build_name_to_symbol_table(elf.my_symtab)
self.dsymtab = dsym.DataSymbolContainer() self.dsymtab = dsym.DataSymbolContainer()
self.decompiled_fns: Dict[int, str] = dict()
self.load_data_for_project()
def get_data_symtab(self) -> dsym.DataSymbolContainer: def get_data_symtab(self) -> dsym.DataSymbolContainer:
return self.dsymtab return self.dsymtab
def load_data_for_project(self) -> None:
self.decompiled_fns = {func.addr: func.decomp_name for func in utils.get_functions() if func.decomp_name}
self.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv")
def check(self, base_fn: elf.Function, my_fn: elf.Function) -> bool: def check(self, base_fn: elf.Function, my_fn: elf.Function) -> bool:
gprs1: DefaultDict[int, int] = defaultdict(int) gprs1: DefaultDict[int, int] = defaultdict(int)
gprs2: DefaultDict[int, int] = defaultdict(int) gprs2: DefaultDict[int, int] = defaultdict(int)
@ -51,13 +57,20 @@ class FunctionChecker:
# Ignore some address differences until a fully matching executable can be generated. # Ignore some address differences until a fully matching executable can be generated.
if i1.mnemonic == 'bl': if i1.mnemonic == 'bl':
if not self._check_function_call(i1.operands[0].imm, i2.operands[0].imm):
return False
continue continue
if i1.mnemonic == 'b': if i1.mnemonic == 'b':
# Needed for tail calls. branch_target = i1.operands[0].imm
branch_target = int(i1.op_str[1:], 16) # If we are branching outside the function, this is likely a tail call.
# Treat this as a function call.
if not (base_fn.addr <= branch_target < base_fn.addr + size): if not (base_fn.addr <= branch_target < base_fn.addr + size):
if not self._check_function_call(branch_target, i2.operands[0].imm):
return False
continue continue
# Otherwise, it's a mismatch.
return False
if i1.mnemonic == 'adrp': if i1.mnemonic == 'adrp':
if i1.operands[0].reg != i2.operands[0].reg: if i1.operands[0].reg != i2.operands[0].reg:
@ -134,3 +147,11 @@ class FunctionChecker:
decomp_symbol = self.my_symtab[symbol.name] decomp_symbol = self.my_symtab[symbol.name]
return decomp_symbol.addr == decomp_addr return decomp_symbol.addr == decomp_addr
def _check_function_call(self, orig_addr: int, decomp_addr: int) -> bool:
name = self.decompiled_fns.get(orig_addr, None)
if name is None:
return True
decomp_symbol = self.my_symtab[name]
return decomp_symbol.addr == decomp_addr