tools: Deduplicate code

This commit is contained in:
Léo Lam 2020-11-09 19:13:35 +01:00
parent d5bdc23ef5
commit 5d09d99b1b
No known key found for this signature in database
GPG Key ID: 0DF30F9081000741
9 changed files with 185 additions and 228 deletions

View File

@ -1,159 +1,34 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import capstone as cs
from elftools.elf.elffile import ELFFile
import diff_settings
from pathlib import Path
import sys import sys
from typing import Any, Dict, Set
import utils
config: Dict[str, Any] = {} import util.elf
diff_settings.apply(config, {}) from util import utils
root = Path(__file__).parent.parent
base_elf = ELFFile((root / config["baseimg"]).open("rb"))
my_elf = ELFFile((root / config["myimg"]).open("rb"))
my_symtab = my_elf.get_section_by_name(".symtab")
if not my_symtab:
utils.fail(f'{config["myimg"]} has no symbol table')
md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM)
md.detail = True
def get_file_offset(elf, addr: int) -> int:
for seg in elf.iter_segments():
if seg.header["p_type"] != "PT_LOAD":
continue
if seg["p_vaddr"] <= addr < seg["p_vaddr"] + seg["p_filesz"]:
return addr - seg["p_vaddr"] + seg["p_offset"]
assert False
def get_symbol_file_offset(elf, table, name: str) -> int:
syms = table.get_symbol_by_name(name)
if not syms or len(syms) != 1:
raise KeyError(name)
return get_file_offset(elf, syms[0]["st_value"])
def get_fn_from_base_elf(addr: int, size: int) -> bytes:
offset = get_file_offset(base_elf, addr)
base_elf.stream.seek(offset)
return base_elf.stream.read(size)
def get_fn_from_my_elf(name: str, size: int) -> bytes:
offset = get_symbol_file_offset(my_elf, my_symtab, name)
my_elf.stream.seek(offset)
return my_elf.stream.read(size)
def check_function_ex(addr: int, size: int, base_fn: bytes, my_fn: bytes) -> bool:
adrp_pair_registers: Set[int] = set()
for i1, i2 in zip(md.disasm(base_fn, addr), md.disasm(my_fn, addr)):
if i1.bytes == i2.bytes:
if i1.mnemonic == 'adrp':
adrp_pair_registers.add(i1.operands[0].reg)
elif i1.mnemonic == 'ldr':
reg = i1.operands[1].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'ldp':
reg = i1.operands[2].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'add':
reg = i1.operands[1].reg
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic != i2.mnemonic:
return False
# Ignore some address differences until a fully matching executable can be generated.
if i1.mnemonic == 'bl':
continue
if i1.mnemonic == 'b':
# Needed for tail calls.
branch_target = int(i1.op_str[1:], 16)
if not (addr <= branch_target < addr + size):
continue
if i1.mnemonic == 'adrp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
adrp_pair_registers.add(i1.operands[0].reg)
continue
if i1.mnemonic == 'ldr' or i1.mnemonic == 'str':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].value.mem.base != i2.operands[1].value.mem.base:
return False
reg = i1.operands[1].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'ldp' or i1.mnemonic == 'stp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
if i1.operands[2].value.mem.base != i2.operands[2].value.mem.base:
return False
reg = i1.operands[2].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'add':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
reg = i1.operands[1].reg
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
return False
return True
def check_function(addr: int, size: int, name: str, base_fn=None) -> bool: def check_function(addr: int, size: int, name: str, base_fn=None) -> bool:
if base_fn is None: if base_fn is None:
try: try:
base_fn = get_fn_from_base_elf(addr, size) base_fn = util.elf.get_fn_from_base_elf(addr, size)
except KeyError: except KeyError:
utils.print_error(f"couldn't find base function 0x{addr:016x} for {utils.format_symbol_name_for_msg(name)}") utils.print_error(f"couldn't find base function 0x{addr:016x} for {utils.format_symbol_name_for_msg(name)}")
return False return False
my_fn = get_fn_from_my_elf(name, size) my_fn = util.elf.get_fn_from_my_elf(name)
return check_function_ex(addr, size, base_fn, my_fn) return util.elf.check_function_ex(addr, size, base_fn, my_fn)
def main() -> None: def main() -> None:
failed = False failed = False
nonmatching_fns_with_dump = {p.stem: p.read_bytes() for p in (root / "expected").glob("*.bin")} nonmatching_fns_with_dump = {p.stem: p.read_bytes() for p in (utils.get_repo_root() / "expected").glob("*.bin")}
for func in utils.get_functions(): for func in utils.get_functions():
if not func.decomp_name: if not func.decomp_name:
continue continue
try: try:
get_fn_from_my_elf(func.decomp_name, 0) util.elf.get_fn_from_my_elf(func.decomp_name)
except KeyError: except KeyError:
utils.warn(f"couldn't find {utils.format_symbol_name_for_msg(func.decomp_name)}") utils.warn(f"couldn't find {utils.format_symbol_name_for_msg(func.decomp_name)}")
continue continue

View File

@ -1,10 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
from colorama import Fore, Style
import cxxfilt
import subprocess import subprocess
import sys
import utils import cxxfilt
from colorama import Fore, Style
from util import utils
parser = argparse.ArgumentParser(description="Diff assembly") parser = argparse.ArgumentParser(description="Diff assembly")
parser.add_argument( parser.add_argument(
@ -31,7 +32,8 @@ if info is not None:
if not info.decomp_name: if not info.decomp_name:
utils.fail(f"{args.function} has not been decompiled") utils.fail(f"{args.function} has not been decompiled")
print(f"diffing: {Style.BRIGHT}{Fore.BLUE}{cxxfilt.demangle(info.decomp_name)}{Style.RESET_ALL} {Style.DIM}({info.decomp_name}){Style.RESET_ALL}") print(
f"diffing: {Style.BRIGHT}{Fore.BLUE}{cxxfilt.demangle(info.decomp_name)}{Style.RESET_ALL} {Style.DIM}({info.decomp_name}){Style.RESET_ALL}")
addr_end = info.addr + info.size addr_end = info.addr + info.size
subprocess.call(["tools/asm-differ/diff.py", "-I", "-e", info.decomp_name, "0x%016x" % subprocess.call(["tools/asm-differ/diff.py", "-I", "-e", info.decomp_name, "0x%016x" %
info.addr, "0x%016x" % addr_end] + unknown) info.addr, "0x%016x" % addr_end] + unknown)

View File

@ -1,48 +1,15 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
from elftools.elf.elffile import ELFFile
import diff_settings
from pathlib import Path
from typing import Any, Dict
import utils
config: Dict[str, Any] = {} import util.elf
diff_settings.apply(config, {}) from util import utils
root = Path(__file__).parent.parent
my_elf = ELFFile((root / config["myimg"]).open("rb"))
my_symtab = my_elf.get_section_by_name(".symtab")
if not my_symtab:
utils.fail(f'{config["myimg"]} has no symbol table')
def get_file_offset(elf, addr: int) -> int:
for seg in elf.iter_segments():
if seg.header["p_type"] != "PT_LOAD":
continue
if seg["p_vaddr"] <= addr < seg["p_vaddr"] + seg["p_filesz"]:
return addr - seg["p_vaddr"] + seg["p_offset"]
assert False
def get_symbol_file_offset_and_size(elf, table, name: str) -> (int, int):
syms = table.get_symbol_by_name(name)
if not syms or len(syms) != 1:
raise KeyError(name)
return get_file_offset(elf, syms[0]["st_value"]), syms[0]["st_size"]
def get_fn_from_my_elf(name: str) -> bytes:
offset, size = get_symbol_file_offset_and_size(my_elf, my_symtab, name)
my_elf.stream.seek(offset)
return my_elf.stream.read(size)
def dump_fn(name: str) -> None: def dump_fn(name: str) -> None:
expected_dir = root / "expected" expected_dir = utils.get_repo_root() / "expected"
try: try:
fn = get_fn_from_my_elf(name) fn = util.elf.get_fn_from_my_elf(name)
path = expected_dir / f"{name}.bin" path = expected_dir / f"{name}.bin"
path.parent.mkdir(exist_ok=True) path.parent.mkdir(exist_ok=True)
path.write_bytes(fn) path.write_bytes(fn)

View File

@ -3,7 +3,7 @@ import argparse
from colorama import Fore, Style from colorama import Fore, Style
import diff_settings import diff_settings
import subprocess import subprocess
import utils from util import utils
parser = argparse.ArgumentParser(description="Prints build/uking.elf symbols") parser = argparse.ArgumentParser(description="Prints build/uking.elf symbols")
parser.add_argument("--print-undefined", "-u", parser.add_argument("--print-undefined", "-u",
@ -15,7 +15,6 @@ parser.add_argument("--hide-unknown", "-H",
parser.add_argument("--all", "-a", action="store_true") parser.add_argument("--all", "-a", action="store_true")
args = parser.parse_args() args = parser.parse_args()
listed_decomp_symbols = {info.decomp_name for info in utils.get_functions()} listed_decomp_symbols = {info.decomp_name for info in utils.get_functions()}
original_symbols = {info.name for info in utils.get_functions()} original_symbols = {info.name for info in utils.get_functions()}
@ -31,7 +30,8 @@ for entry in entries:
symbol_type: str = entry[1] symbol_type: str = entry[1]
name = entry[2] name = entry[2]
if (symbol_type == "t" or symbol_type == "T" or symbol_type == "W") and (args.all or name not in listed_decomp_symbols): if (symbol_type == "t" or symbol_type == "T" or symbol_type == "W") and (
args.all or name not in listed_decomp_symbols):
c1_name = name.replace("C2", "C1") c1_name = name.replace("C2", "C1")
is_c2_ctor = "C2" in name and c1_name in listed_decomp_symbols and utils.are_demangled_names_equal( is_c2_ctor = "C2" in name and c1_name in listed_decomp_symbols and utils.are_demangled_names_equal(
c1_name, name) c1_name, name)

View File

@ -4,8 +4,8 @@ from collections import defaultdict
from colorama import Back, Fore, Style from colorama import Back, Fore, Style
import enum import enum
from pathlib import Path from pathlib import Path
import utils from util import utils
from utils import FunctionStatus from util.utils import FunctionStatus
import typing as tp import typing as tp
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -1,57 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
from colorama import Fore, Style
import cxxfilt
from elftools.elf.elffile import ELFFile
import diff_settings
from pathlib import Path
import struct import struct
from typing import Any, Dict, Optional from typing import Optional
import utils
config: Dict[str, Any] = {} import cxxfilt
diff_settings.apply(config, {}) from colorama import Fore, Style
root = Path(__file__).parent.parent import util.elf
my_elf = ELFFile((root / config["myimg"]).open("rb")) from util import utils
my_symtab = my_elf.get_section_by_name(".symtab")
if not my_symtab:
utils.fail(f'{config["myimg"]} has no symbol table')
def get_file_offset(elf, addr: int) -> int:
for seg in elf.iter_segments():
if seg.header["p_type"] != "PT_LOAD":
continue
if seg["p_vaddr"] <= addr < seg["p_vaddr"] + seg["p_filesz"]:
return addr - seg["p_vaddr"] + seg["p_offset"]
assert False
def get_symbol_file_offset_and_size(elf, table, name: str) -> (int, int):
syms = table.get_symbol_by_name(name)
if not syms or len(syms) != 1:
raise KeyError(name)
return get_file_offset(elf, syms[0]["st_value"]), syms[0]["st_size"]
def build_symbol_table(symtab) -> Dict[int, str]:
table = dict()
for sym in symtab.iter_symbols():
addr = sym["st_value"]
existing_value = table.get(addr, None)
if existing_value is None or not existing_value.startswith("_Z"):
table[addr] = sym.name
return table
def find_vtable(symtab, class_name: str) -> Optional[str]: def find_vtable(symtab, class_name: str) -> Optional[str]:
OFFSET = len("vtable for ") name_offset = len("vtable for ")
for sym in symtab.iter_symbols(): for sym in symtab.iter_symbols():
if not sym.name.startswith("_ZTV"): if not sym.name.startswith("_ZTV"):
continue continue
if cxxfilt.demangle(sym.name)[OFFSET:] == class_name: if cxxfilt.demangle(sym.name)[name_offset:] == class_name:
return sym.name return sym.name
return None return None
@ -62,12 +27,12 @@ def bold(s) -> str:
def dump_table(name: str) -> None: def dump_table(name: str) -> None:
try: try:
symbols = build_symbol_table(my_symtab) symbols = util.elf.build_symbol_table(util.elf.my_symtab)
decomp_symbols = {fn.decomp_name for fn in utils.get_functions() if fn.decomp_name} decomp_symbols = {fn.decomp_name for fn in utils.get_functions() if fn.decomp_name}
offset, size = get_symbol_file_offset_and_size(my_elf, my_symtab, name) offset, size = util.elf.get_symbol_file_offset_and_size(util.elf.my_elf, util.elf.my_symtab, name)
my_elf.stream.seek(offset) util.elf.my_elf.stream.seek(offset)
vtable_bytes = my_elf.stream.read(size) vtable_bytes = util.elf.my_elf.stream.read(size)
if not vtable_bytes: if not vtable_bytes:
utils.fail( utils.fail(
@ -106,7 +71,7 @@ def main() -> None:
symbol_name: str = args.symbol_name symbol_name: str = args.symbol_name
if not symbol_name.startswith("_ZTV"): if not symbol_name.startswith("_ZTV"):
symbol_name = find_vtable(my_symtab, args.symbol_name) symbol_name = find_vtable(util.elf.my_symtab, args.symbol_name)
dump_table(symbol_name) dump_table(symbol_name)

0
tools/util/__init__.py Normal file
View File

142
tools/util/elf.py Normal file
View File

@ -0,0 +1,142 @@
#!/usr/bin/env python3
from typing import Any, Dict, Set
import capstone as cs
from elftools.elf.elffile import ELFFile
import diff_settings
from util import utils
_config: Dict[str, Any] = {}
diff_settings.apply(_config, {})
_root = utils.get_repo_root()
base_elf = ELFFile((_root / _config["baseimg"]).open("rb"))
my_elf = ELFFile((_root / _config["myimg"]).open("rb"))
my_symtab = my_elf.get_section_by_name(".symtab")
if not my_symtab:
utils.fail(f'{_config["myimg"]} has no symbol table')
md = cs.Cs(cs.CS_ARCH_ARM64, cs.CS_MODE_ARM)
md.detail = True
def get_file_offset(elf, addr: int) -> int:
for seg in elf.iter_segments():
if seg.header["p_type"] != "PT_LOAD":
continue
if seg["p_vaddr"] <= addr < seg["p_vaddr"] + seg["p_filesz"]:
return addr - seg["p_vaddr"] + seg["p_offset"]
assert False
def get_symbol_file_offset_and_size(elf, table, name: str) -> (int, int):
syms = table.get_symbol_by_name(name)
if not syms or len(syms) != 1:
raise KeyError(name)
return get_file_offset(elf, syms[0]["st_value"]), syms[0]["st_size"]
def build_symbol_table(symtab) -> Dict[int, str]:
table = dict()
for sym in symtab.iter_symbols():
addr = sym["st_value"]
existing_value = table.get(addr, None)
if existing_value is None or not existing_value.startswith("_Z"):
table[addr] = sym.name
return table
def get_fn_from_base_elf(addr: int, size: int) -> bytes:
offset = get_file_offset(base_elf, addr)
base_elf.stream.seek(offset)
return base_elf.stream.read(size)
def get_fn_from_my_elf(name: str) -> bytes:
offset, size = get_symbol_file_offset_and_size(my_elf, my_symtab, name)
my_elf.stream.seek(offset)
return my_elf.stream.read(size)
def check_function_ex(addr: int, size: int, base_fn: bytes, my_fn: bytes) -> bool:
adrp_pair_registers: Set[int] = set()
for i1, i2 in zip(md.disasm(base_fn, addr), md.disasm(my_fn, addr)):
if i1.bytes == i2.bytes:
if i1.mnemonic == 'adrp':
adrp_pair_registers.add(i1.operands[0].reg)
elif i1.mnemonic == 'ldr':
reg = i1.operands[1].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'ldp':
reg = i1.operands[2].value.mem.base
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
elif i1.mnemonic == 'add':
reg = i1.operands[1].reg
if reg in adrp_pair_registers:
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic != i2.mnemonic:
return False
# Ignore some address differences until a fully matching executable can be generated.
if i1.mnemonic == 'bl':
continue
if i1.mnemonic == 'b':
# Needed for tail calls.
branch_target = int(i1.op_str[1:], 16)
if not (addr <= branch_target < addr + size):
continue
if i1.mnemonic == 'adrp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
adrp_pair_registers.add(i1.operands[0].reg)
continue
if i1.mnemonic == 'ldr' or i1.mnemonic == 'str':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].value.mem.base != i2.operands[1].value.mem.base:
return False
reg = i1.operands[1].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'ldp' or i1.mnemonic == 'stp':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
if i1.operands[2].value.mem.base != i2.operands[2].value.mem.base:
return False
reg = i1.operands[2].value.mem.base
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
if i1.mnemonic == 'add':
if i1.operands[0].reg != i2.operands[0].reg:
return False
if i1.operands[1].reg != i2.operands[1].reg:
return False
reg = i1.operands[1].reg
if reg not in adrp_pair_registers:
return False
adrp_pair_registers.remove(reg)
continue
return False
return True

View File

@ -47,8 +47,10 @@ def parse_function_csv_entry(row) -> FunctionInfo:
return FunctionInfo(addr, name, int(size, 0), decomp_name, status) return FunctionInfo(addr, name, int(size, 0), decomp_name, status)
def get_functions() -> tp.Iterable[FunctionInfo]: def get_functions(path: tp.Optional[Path] = None) -> tp.Iterable[FunctionInfo]:
with (Path(__file__).parent.parent / "data" / "uking_functions.csv").open() as f: if path is None:
path = get_repo_root() / "data" / "uking_functions.csv"
with path.open() as f:
for row in csv.reader(f): for row in csv.reader(f):
yield parse_function_csv_entry(row) yield parse_function_csv_entry(row)
@ -86,3 +88,7 @@ def print_error(msg: str, prefix: str = ""):
def fail(msg: str, prefix: str = ""): def fail(msg: str, prefix: str = ""):
print_error(msg, prefix) print_error(msg, prefix)
sys.exit(1) sys.exit(1)
def get_repo_root() -> Path:
return Path(__file__).parent.parent.parent