tools: Do not use a global capstone disassembler instance

Instead, use one instance per function checker.
This commit is contained in:
Léo Lam 2020-11-09 20:02:19 +01:00
parent 5d09d99b1b
commit 844db4220c
No known key found for this signature in database
GPG Key ID: 0DF30F9081000741
3 changed files with 97 additions and 89 deletions

View File

@ -3,10 +3,11 @@
import sys
import util.elf
import util.checker
from util import utils
def check_function(addr: int, size: int, name: str, base_fn=None) -> bool:
def check_function(checker: util.checker.FunctionChecker, addr: int, size: int, name: str, base_fn=None) -> bool:
if base_fn is None:
try:
base_fn = util.elf.get_fn_from_base_elf(addr, size)
@ -15,7 +16,7 @@ def check_function(addr: int, size: int, name: str, base_fn=None) -> bool:
return False
my_fn = util.elf.get_fn_from_my_elf(name)
return util.elf.check_function_ex(addr, size, base_fn, my_fn)
return checker.check(addr, size, base_fn, my_fn)
def main() -> None:
@ -23,6 +24,8 @@ def main() -> None:
nonmatching_fns_with_dump = {p.stem: p.read_bytes() for p in (utils.get_repo_root() / "expected").glob("*.bin")}
checker = util.checker.FunctionChecker()
for func in utils.get_functions():
if not func.decomp_name:
continue
@ -34,17 +37,17 @@ def main() -> None:
continue
if func.status == utils.FunctionStatus.Matching:
if not check_function(func.addr, func.size, func.decomp_name):
if not check_function(checker, func.addr, func.size, func.decomp_name):
utils.print_error(
f"function {utils.format_symbol_name_for_msg(func.decomp_name)} is marked as matching but does not match")
failed = True
elif func.status == utils.FunctionStatus.Equivalent or func.status == utils.FunctionStatus.NonMatching:
if check_function(func.addr, func.size, func.decomp_name):
if check_function(checker, func.addr, func.size, func.decomp_name):
utils.print_note(
f"function {utils.format_symbol_name_for_msg(func.decomp_name)} is marked as non-matching but matches")
fn_dump = nonmatching_fns_with_dump.get(func.decomp_name, None)
if fn_dump is not None and not check_function(func.addr, len(fn_dump), func.decomp_name, fn_dump):
if fn_dump is not None and not check_function(checker, func.addr, len(fn_dump), func.decomp_name, fn_dump):
utils.print_error(
f"function {utils.format_symbol_name_for_msg(func.decomp_name)} does not match expected output")
failed = True

89
tools/util/checker.py Normal file
View File

@ -0,0 +1,89 @@
from typing import Set
import capstone as cs
class FunctionChecker:
def __init__(self):
self.md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM)
self.md.detail = True
def check(self, addr: int, size: int, base_fn: bytes, my_fn: bytes) -> bool:
adrp_pair_registers: Set[int] = set()
for i1, i2 in zip(self.md.disasm(base_fn, addr), self.md.disasm(my_fn, addr)):
if i1.bytes == i2.bytes:
if i1.mnemonic == 'adrp':
adrp_pair_registers.add(i1.operands[0].reg)
elif i1.mnemonic == 'ldr':
reg = i1.operands[1].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'ldp':
reg = i1.operands[2].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'add':
reg = i1.operands[1].reg
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic != i2.mnemonic:
return False
# Ignore some address differences until a fully matching executable can be generated.
if i1.mnemonic == 'bl':
continue
if i1.mnemonic == 'b':
# Needed for tail calls.
branch_target = int(i1.op_str[1:], 16)
if not (addr <= branch_target < addr + size):
continue
if i1.mnemonic == 'adrp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
adrp_pair_registers.add(i1.operands[0].reg)
continue
if i1.mnemonic == 'ldr' or i1.mnemonic == 'str':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].value.mem.base != i2.operands[1].value.mem.base:
return False
reg = i1.operands[1].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'ldp' or i1.mnemonic == 'stp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
if i1.operands[2].value.mem.base != i2.operands[2].value.mem.base:
return False
reg = i1.operands[2].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'add':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
reg = i1.operands[1].reg
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
return False
return True

View File

@ -19,9 +19,6 @@ my_symtab = my_elf.get_section_by_name(".symtab")
if not my_symtab:
utils.fail(f'{_config["myimg"]} has no symbol table')
md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM)
md.detail = True
def get_file_offset(elf, addr: int) -> int:
for seg in elf.iter_segments():
@ -59,84 +56,3 @@ def get_fn_from_my_elf(name: str) -> bytes:
offset, size = get_symbol_file_offset_and_size(my_elf, my_symtab, name)
my_elf.stream.seek(offset)
return my_elf.stream.read(size)
def check_function_ex(addr: int, size: int, base_fn: bytes, my_fn: bytes) -> bool:
adrp_pair_registers: Set[int] = set()
for i1, i2 in zip(md.disasm(base_fn, addr), md.disasm(my_fn, addr)):
if i1.bytes == i2.bytes:
if i1.mnemonic == 'adrp':
adrp_pair_registers.add(i1.operands[0].reg)
elif i1.mnemonic == 'ldr':
reg = i1.operands[1].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'ldp':
reg = i1.operands[2].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'add':
reg = i1.operands[1].reg
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic != i2.mnemonic:
return False
# Ignore some address differences until a fully matching executable can be generated.
if i1.mnemonic == 'bl':
continue
if i1.mnemonic == 'b':
# Needed for tail calls.
branch_target = int(i1.op_str[1:], 16)
if not (addr <= branch_target < addr + size):
continue
if i1.mnemonic == 'adrp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
adrp_pair_registers.add(i1.operands[0].reg)
continue
if i1.mnemonic == 'ldr' or i1.mnemonic == 'str':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].value.mem.base != i2.operands[1].value.mem.base:
return False
reg = i1.operands[1].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'ldp' or i1.mnemonic == 'stp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
if i1.operands[2].value.mem.base != i2.operands[2].value.mem.base:
return False
reg = i1.operands[2].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'add':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
reg = i1.operands[1].reg
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
return False
return True