diff --git a/data/data_symbols.csv b/data/data_symbols.csv new file mode 100644 index 00000000..a05715ca --- /dev/null +++ b/data/data_symbols.csv @@ -0,0 +1,4 @@ +0x000000710246F9E0,_ZN4ksys3gdt6detail13sCommonFlags0E +0x00000071024709E0,_ZN4ksys3gdt6detail13sCommonFlags1E +0x00000071024719E0,_ZN4ksys3gdt6detail13sCommonFlags2E +0x00000071024729E0,_ZN4ksys3gdt6detail13sCommonFlags3E diff --git a/tools/check.py b/tools/check.py index d9d95bd0..648301e0 100755 --- a/tools/check.py +++ b/tools/check.py @@ -1,13 +1,15 @@ #!/usr/bin/env python3 import sys +from typing import Optional import util.elf import util.checker from util import utils -def check_function(checker: util.checker.FunctionChecker, addr: int, size: int, name: str, base_fn=None) -> bool: +def check_function(checker: util.checker.FunctionChecker, addr: int, size: int, name: str, + base_fn: Optional[util.elf.Function] = None) -> bool: if base_fn is None: try: base_fn = util.elf.get_fn_from_base_elf(addr, size) @@ -16,13 +18,14 @@ def check_function(checker: util.checker.FunctionChecker, addr: int, size: int, return False my_fn = util.elf.get_fn_from_my_elf(name) - return checker.check(addr, size, base_fn, my_fn) + return checker.check(base_fn, my_fn) def main() -> None: failed = False - nonmatching_fns_with_dump = {p.stem: p.read_bytes() for p in (utils.get_repo_root() / "expected").glob("*.bin")} + nonmatching_fns_with_dump = {p.stem: util.elf.Function(p.read_bytes(), 0) for p in + (utils.get_repo_root() / "expected").glob("*.bin")} checker = util.checker.FunctionChecker() diff --git a/tools/dump_function.py b/tools/dump_function.py index 638cc501..f95f3b7a 100755 --- a/tools/dump_function.py +++ b/tools/dump_function.py @@ -12,7 +12,7 @@ def dump_fn(name: str) -> None: fn = util.elf.get_fn_from_my_elf(name) path = expected_dir / f"{name}.bin" path.parent.mkdir(exist_ok=True) - path.write_bytes(fn) + path.write_bytes(fn.data) except KeyError: utils.fail("could not find function") diff --git a/tools/identify_matching_functions.py b/tools/identify_matching_functions.py new file mode 100755 index 00000000..842ee3d9 --- /dev/null +++ b/tools/identify_matching_functions.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +import argparse +from colorama import Fore +import csv +import sys +from pathlib import Path +from typing import Dict + +import util.checker +import util.elf +from util import utils + + +def read_candidates(path: Path) -> Dict[str, util.elf.Function]: + candidates: Dict[str, util.elf.Function] = dict() + + for candidate in path.read_text().splitlines(): + columns = candidate.split() + if len(columns) == 3: + candidate = columns[2] + + candidates[candidate] = util.elf.get_fn_from_my_elf(candidate) + + return candidates + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("csv_path", + help="Path to a list of functions to identify (in the same format as the main function CSV)") + parser.add_argument("candidates_path", + help="Path to a list of candidates (names only)") + args = parser.parse_args() + + csv_path = Path(args.csv_path) + candidates_path = Path(args.candidates_path) + + candidates = read_candidates(candidates_path) + + new_matches: Dict[int, str] = dict() + checker = util.checker.FunctionChecker() + checker.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv") + + # Given a list L of functions to identify and a small list of candidates C, this tool will attempt to + # automatically identify matches by checking each function in L against each function in C. + # + # This matching algorithm is quite naive (quadratic time complexity if both lists have about the same size) + # but this should work well enough for short lists of candidates... + for func in utils.get_functions(csv_path): + if func.status != utils.FunctionStatus.NotDecompiled: + continue + + match_name = "" + + for candidate_name, candidate in candidates.items(): + if len(candidate.data) != func.size: + continue + if checker.check(util.elf.get_fn_from_base_elf(func.addr, func.size), candidate): + match_name = candidate_name + break + + if match_name: + new_matches[func.addr] = match_name + utils.print_note( + f"found new match: {Fore.BLUE}{match_name}{Fore.RESET} ({func.addr | 0x71_00000000:#018x})") + # This is no longer a candidate. + del candidates[match_name] + else: + utils.warn(f"no match found for {Fore.BLUE}{func.name}{Fore.RESET} ({func.addr | 0x71_00000000:#018x})") + + # Output the modified function CSV. + writer = csv.writer(sys.stdout, lineterminator="\n") + for func in utils.get_functions(): + if func.status == utils.FunctionStatus.NotDecompiled and func.addr in new_matches: + func.raw_row[3] = new_matches[func.addr] + writer.writerow(func.raw_row) + + +if __name__ == "__main__": + main() diff --git a/tools/show_vtable.py b/tools/show_vtable.py index 649f4e13..b659cacc 100755 --- a/tools/show_vtable.py +++ b/tools/show_vtable.py @@ -27,7 +27,7 @@ def bold(s) -> str: def dump_table(name: str) -> None: try: - symbols = util.elf.build_symbol_table(util.elf.my_symtab) + symbols = util.elf.build_addr_to_symbol_table(util.elf.my_symtab) decomp_symbols = {fn.decomp_name for fn in utils.get_functions() if fn.decomp_name} offset, size = util.elf.get_symbol_file_offset_and_size(util.elf.my_elf, util.elf.my_symtab, name) diff --git a/tools/util/checker.py b/tools/util/checker.py index cdf3397f..e12f5c35 100644 --- a/tools/util/checker.py +++ b/tools/util/checker.py @@ -1,19 +1,35 @@ -from typing import Set +from collections import defaultdict +from typing import Set, DefaultDict import capstone as cs +from util import dsym, elf + class FunctionChecker: def __init__(self): self.md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM) self.md.detail = True + self.my_symtab = elf.build_name_to_symbol_table(elf.my_symtab) + self.dsymtab = dsym.DataSymbolContainer() - def check(self, addr: int, size: int, base_fn: bytes, my_fn: bytes) -> bool: + def get_data_symtab(self) -> dsym.DataSymbolContainer: + return self.dsymtab + + def check(self, base_fn: elf.Function, my_fn: elf.Function) -> bool: + gprs1: DefaultDict[int, int] = defaultdict(int) + gprs2: DefaultDict[int, int] = defaultdict(int) adrp_pair_registers: Set[int] = set() - for i1, i2 in zip(self.md.disasm(base_fn, addr), self.md.disasm(my_fn, addr)): + size = len(base_fn) + if len(base_fn) != len(my_fn): + return False + + for i1, i2 in zip(self.md.disasm(base_fn.data, base_fn.addr), self.md.disasm(my_fn.data, my_fn.addr)): if i1.bytes == i2.bytes: if i1.mnemonic == 'adrp': + gprs1[i1.operands[0].reg] = i1.operands[1].imm + gprs2[i2.operands[0].reg] = i2.operands[1].imm adrp_pair_registers.add(i1.operands[0].reg) elif i1.mnemonic == 'ldr': reg = i1.operands[1].value.mem.base @@ -40,13 +56,18 @@ class FunctionChecker: if i1.mnemonic == 'b': # Needed for tail calls. branch_target = int(i1.op_str[1:], 16) - if not (addr <= branch_target < addr + size): + if not (base_fn.addr <= branch_target < base_fn.addr + size): continue if i1.mnemonic == 'adrp': if i1.operands[0].reg != i2.operands[0].reg: return False - adrp_pair_registers.add(i1.operands[0].reg) + reg = i1.operands[0].reg + + gprs1[reg] = i1.operands[1].imm + gprs2[reg] = i2.operands[1].imm + + adrp_pair_registers.add(reg) continue if i1.mnemonic == 'ldr' or i1.mnemonic == 'str': @@ -57,6 +78,12 @@ class FunctionChecker: reg = i1.operands[1].value.mem.base if reg not in adrp_pair_registers: return False + + gprs1[reg] += i1.operands[1].value.mem.disp + gprs2[reg] += i2.operands[1].value.mem.disp + if not self._check_data_symbol(gprs1[reg], gprs2[reg]): + return False + adrp_pair_registers.remove(reg) continue @@ -70,6 +97,12 @@ class FunctionChecker: reg = i1.operands[2].value.mem.base if reg not in adrp_pair_registers: return False + + gprs1[reg] += i1.operands[1].value.mem.disp + gprs2[reg] += i2.operands[1].value.mem.disp + if not self._check_data_symbol(gprs1[reg], gprs2[reg]): + return False + adrp_pair_registers.remove(reg) continue @@ -81,9 +114,23 @@ class FunctionChecker: reg = i1.operands[1].reg if reg not in adrp_pair_registers: return False + + gprs1[reg] += i1.operands[2].imm + gprs2[reg] += i2.operands[2].imm + if not self._check_data_symbol(gprs1[reg], gprs2[reg]): + return False + adrp_pair_registers.remove(reg) continue return False return True + + def _check_data_symbol(self, orig_addr: int, decomp_addr: int) -> bool: + symbol = self.dsymtab.get_symbol(orig_addr) + if symbol is None: + return True + + decomp_symbol = self.my_symtab[symbol.name] + return decomp_symbol.addr == decomp_addr diff --git a/tools/util/dsym.py b/tools/util/dsym.py new file mode 100644 index 00000000..3c2976af --- /dev/null +++ b/tools/util/dsym.py @@ -0,0 +1,60 @@ +import csv +from pathlib import Path +import typing as tp + +import util.elf + + +class DataSymbol(tp.NamedTuple): + addr: int # without the 0x7100000000 base + name: str + size: int + + +_IDA_BASE = 0x7100000000 + + +class DataSymbolContainer: + def __init__(self) -> None: + self.symbols: tp.List[DataSymbol] = [] + + def load_from_csv(self, path: Path): + symtab = util.elf.build_name_to_symbol_table(util.elf.my_symtab) + + with path.open("r") as f: + for i, line in enumerate(csv.reader(f)): + if len(line) != 2: + raise RuntimeError(f"Invalid line format at line {i}") + + addr = int(line[0], 16) - _IDA_BASE + name = line[1] + size = symtab[name].size + + self.symbols.append(DataSymbol(addr, name, size)) + + # Sort the list, just in case the entries were not sorted in the CSV. + self.symbols.sort(key=lambda sym: sym.addr) + + def get_symbol(self, addr: int) -> tp.Optional[DataSymbol]: + """If addr is part of a known data symbol, this function returns the corresponding symbol.""" + + # Perform a binary search on self.symbols. + a = 0 + b = len(self.symbols) - 1 + while a <= b: + m = (a + b) // 2 + + symbol: DataSymbol = self.symbols[m] + addr_begin = symbol.addr + addr_end = addr_begin + symbol.size + + if addr_begin <= addr < addr_end: + return symbol + if addr <= addr_begin: + b = m - 1 + elif addr >= addr_end: + a = m + 1 + else: + return None + + return None diff --git a/tools/util/elf.py b/tools/util/elf.py index 56fcc535..76d90e38 100644 --- a/tools/util/elf.py +++ b/tools/util/elf.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from typing import Any, Dict, Set +from typing import Any, Dict, NamedTuple -import capstone as cs from elftools.elf.elffile import ELFFile import diff_settings @@ -20,6 +19,17 @@ if not my_symtab: utils.fail(f'{_config["myimg"]} has no symbol table') +class Symbol(NamedTuple): + addr: int + name: str + size: int + + +class Function(NamedTuple): + data: bytes + addr: int + + def get_file_offset(elf, addr: int) -> int: for seg in elf.iter_segments(): if seg.header["p_type"] != "PT_LOAD": @@ -29,14 +39,19 @@ def get_file_offset(elf, addr: int) -> int: assert False -def get_symbol_file_offset_and_size(elf, table, name: str) -> (int, int): +def get_symbol(table, name: str) -> Symbol: syms = table.get_symbol_by_name(name) if not syms or len(syms) != 1: raise KeyError(name) - return get_file_offset(elf, syms[0]["st_value"]), syms[0]["st_size"] + return Symbol(syms[0]["st_value"], name, syms[0]["st_size"]) -def build_symbol_table(symtab) -> Dict[int, str]: +def get_symbol_file_offset_and_size(elf, table, name: str) -> (int, int): + sym = get_symbol(table, name) + return get_file_offset(elf, sym.addr), sym.size + + +def build_addr_to_symbol_table(symtab) -> Dict[int, str]: table = dict() for sym in symtab.iter_symbols(): addr = sym["st_value"] @@ -46,13 +61,18 @@ def build_symbol_table(symtab) -> Dict[int, str]: return table -def get_fn_from_base_elf(addr: int, size: int) -> bytes: +def build_name_to_symbol_table(symtab) -> Dict[str, Symbol]: + return {sym.name: Symbol(sym["st_value"], sym.name, sym["st_size"]) for sym in symtab.iter_symbols()} + + +def get_fn_from_base_elf(addr: int, size: int) -> Function: offset = get_file_offset(base_elf, addr) base_elf.stream.seek(offset) - return base_elf.stream.read(size) + return Function(base_elf.stream.read(size), addr) -def get_fn_from_my_elf(name: str) -> bytes: - offset, size = get_symbol_file_offset_and_size(my_elf, my_symtab, name) +def get_fn_from_my_elf(name: str) -> Function: + sym = get_symbol(my_symtab, name) + offset = get_file_offset(my_elf, sym.addr) my_elf.stream.seek(offset) - return my_elf.stream.read(size) + return Function(my_elf.stream.read(sym.size), sym.addr) diff --git a/tools/util/utils.py b/tools/util/utils.py index f126e341..711652c8 100644 --- a/tools/util/utils.py +++ b/tools/util/utils.py @@ -21,6 +21,7 @@ class FunctionInfo(tp.NamedTuple): size: int decomp_name: str status: FunctionStatus + raw_row: tp.List[str] _markers = { @@ -44,7 +45,7 @@ def parse_function_csv_entry(row) -> FunctionInfo: status = FunctionStatus.NotDecompiled addr = int(ea, 16) - 0x7100000000 - return FunctionInfo(addr, name, int(size, 0), decomp_name, status) + return FunctionInfo(addr, name, int(size, 0), decomp_name, status, row) def get_functions(path: tp.Optional[Path] = None) -> tp.Iterable[FunctionInfo]: