tools: Add support for data symbol checking for GOT entries

This commit is contained in:
Léo Lam 2020-11-11 15:09:25 +01:00
parent b53135a885
commit 68c9ffeede
No known key found for this signature in database
GPG Key ID: 0DF30F9081000741
3 changed files with 67 additions and 16 deletions

View File

@ -1,3 +1,4 @@
0x00000071023556B0,_ZTVN4sead14SafeStringBaseIcEE
0x000000710246F9E0,_ZN4ksys3gdt6detail13sCommonFlags0E 0x000000710246F9E0,_ZN4ksys3gdt6detail13sCommonFlags0E
0x00000071024709E0,_ZN4ksys3gdt6detail13sCommonFlags1E 0x00000071024709E0,_ZN4ksys3gdt6detail13sCommonFlags1E
0x00000071024719E0,_ZN4ksys3gdt6detail13sCommonFlags2E 0x00000071024719E0,_ZN4ksys3gdt6detail13sCommonFlags2E

1 0x000000710246F9E0 0x00000071023556B0 _ZN4ksys3gdt6detail13sCommonFlags0E _ZTVN4sead14SafeStringBaseIcEE
1 0x00000071023556B0 _ZTVN4sead14SafeStringBaseIcEE
2 0x000000710246F9E0 0x000000710246F9E0 _ZN4ksys3gdt6detail13sCommonFlags0E _ZN4ksys3gdt6detail13sCommonFlags0E
3 0x00000071024709E0 0x00000071024709E0 _ZN4ksys3gdt6detail13sCommonFlags1E _ZN4ksys3gdt6detail13sCommonFlags1E
4 0x00000071024719E0 0x00000071024719E0 _ZN4ksys3gdt6detail13sCommonFlags2E _ZN4ksys3gdt6detail13sCommonFlags2E

View File

@ -1,5 +1,6 @@
import struct
from collections import defaultdict from collections import defaultdict
from typing import Set, DefaultDict, Dict, Optional from typing import Set, DefaultDict, Dict, Optional, Tuple
import capstone as cs import capstone as cs
@ -18,6 +19,9 @@ class FunctionChecker:
self._mismatch_addr1 = -1 self._mismatch_addr1 = -1
self._mismatch_addr2 = -1 self._mismatch_addr2 = -1
self._mismatch_cause = "" self._mismatch_cause = ""
self._base_got_section = elf.base_elf.get_section_by_name(".got")
self._decomp_glob_data_table = elf.build_glob_data_table(elf.my_elf)
self._got_data_symbol_check_cache: Dict[Tuple[int, int], bool] = dict()
self.load_data_for_project() self.load_data_for_project()
@ -113,7 +117,7 @@ class FunctionChecker:
gprs1[reg] += i1.operands[1].value.mem.disp gprs1[reg] += i1.operands[1].value.mem.disp
gprs2[reg] += i2.operands[1].value.mem.disp gprs2[reg] += i2.operands[1].value.mem.disp
if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]): if not self._check_data_symbol_load(i1, i2, gprs1[reg], gprs2[reg]):
return False return False
adrp_pair_registers.remove(reg) adrp_pair_registers.remove(reg)
@ -130,9 +134,9 @@ class FunctionChecker:
if reg not in adrp_pair_registers: if reg not in adrp_pair_registers:
return False return False
gprs1[reg] += i1.operands[1].value.mem.disp gprs1[reg] += i1.operands[2].value.mem.disp
gprs2[reg] += i2.operands[1].value.mem.disp gprs2[reg] += i2.operands[2].value.mem.disp
if not self._check_data_symbol(i1, i2, gprs1[reg], gprs2[reg]): if not self._check_data_symbol_load(i1, i2, gprs1[reg], gprs2[reg]):
return False return False
adrp_pair_registers.remove(reg) adrp_pair_registers.remove(reg)
@ -174,10 +178,30 @@ class FunctionChecker:
return True return True
if self._log_mismatch_cause: if self._log_mismatch_cause:
self._set_mismatch_cause(i1, i2, f"data symbol mismatch: {symbol.name} (original address: {orig_addr:#x})") self._set_mismatch_cause(i1, i2, f"data symbol mismatch: {symbol.name} (original address: {orig_addr:#x}, "
f"expected: {decomp_symbol.addr:#x}, "
f"actual: {decomp_addr:#x})")
return False return False
def _check_data_symbol_load(self, i1, i2, orig_addr: int, decomp_addr: int) -> bool:
cached_result = self._got_data_symbol_check_cache.get((orig_addr, decomp_addr), None)
if cached_result is not None:
return cached_result
if not elf.is_in_section(self._base_got_section, orig_addr, 8):
return True
ptr1, = struct.unpack("<Q", elf.read_from_elf(elf.base_elf, orig_addr, 8))
if self.dsymtab.get_symbol(ptr1) is None:
return True
ptr2 = self._decomp_glob_data_table[decomp_addr]
result = self._check_data_symbol(i1, i2, ptr1, ptr2)
self._got_data_symbol_check_cache[(orig_addr, decomp_addr)] = result
return result
def _check_function_call(self, i1, i2, orig_addr: int, decomp_addr: int) -> bool: def _check_function_call(self, i1, i2, orig_addr: int, decomp_addr: int) -> bool:
name = self.decompiled_fns.get(orig_addr, None) name = self.decompiled_fns.get(orig_addr, None)
if name is None: if name is None:

View File

@ -1,9 +1,9 @@
#!/usr/bin/env python3
from typing import Any, Dict, NamedTuple
import io import io
from typing import Any, Dict, NamedTuple
from elftools.elf.elffile import ELFFile from elftools.elf.elffile import ELFFile
from elftools.elf.relocation import RelocationSection
from elftools.elf.sections import Section
import diff_settings import diff_settings
from util import utils from util import utils
@ -40,7 +40,13 @@ def get_file_offset(elf, addr: int) -> int:
continue continue
if seg["p_vaddr"] <= addr < seg["p_vaddr"] + seg["p_filesz"]: if seg["p_vaddr"] <= addr < seg["p_vaddr"] + seg["p_filesz"]:
return addr - seg["p_vaddr"] + seg["p_offset"] return addr - seg["p_vaddr"] + seg["p_offset"]
assert False raise KeyError(f"No segment found for {addr:#x}")
def is_in_section(section: Section, addr: int, size: int) -> bool:
begin = section["sh_addr"]
end = begin + section["sh_size"]
return begin <= addr < end and begin <= addr + size < end
def get_symbol(table, name: str) -> Symbol: def get_symbol(table, name: str) -> Symbol:
@ -69,14 +75,34 @@ def build_name_to_symbol_table(symtab) -> Dict[str, Symbol]:
return {sym.name: Symbol(sym["st_value"], sym.name, sym["st_size"]) for sym in symtab.iter_symbols()} return {sym.name: Symbol(sym["st_value"], sym.name, sym["st_size"]) for sym in symtab.iter_symbols()}
def read_from_elf(elf: ELFFile, addr: int, size: int) -> bytes:
offset: int = get_file_offset(elf, addr)
elf.stream.seek(offset)
return elf.stream.read(size)
def get_fn_from_base_elf(addr: int, size: int) -> Function: def get_fn_from_base_elf(addr: int, size: int) -> Function:
offset = get_file_offset(base_elf, addr) return Function(read_from_elf(base_elf, addr, size), addr)
base_elf.stream.seek(offset)
return Function(base_elf.stream.read(size), addr)
def get_fn_from_my_elf(name: str) -> Function: def get_fn_from_my_elf(name: str) -> Function:
sym = get_symbol(my_symtab, name) sym = get_symbol(my_symtab, name)
offset = get_file_offset(my_elf, sym.addr) return Function(read_from_elf(my_elf, sym.addr, sym.size), sym.addr)
my_elf.stream.seek(offset)
return Function(my_elf.stream.read(sym.size), sym.addr)
R_AARCH64_GLOB_DAT = 1025
def build_glob_data_table(elf: ELFFile) -> Dict[int, int]:
table: Dict[int, int] = dict()
section = elf.get_section_by_name(".rela.dyn")
assert isinstance(section, RelocationSection)
symtab = elf.get_section(section["sh_link"])
for reloc in section.iter_relocations():
sym_value = symtab.get_symbol(reloc["r_info_sym"])["st_value"]
if reloc["r_info_type"] == R_AARCH64_GLOB_DAT:
table[reloc["r_offset"]] = sym_value + reloc["r_addend"]
return table