mirror of https://github.com/zeldaret/botw.git
				
				
				
			tools: Add function call checking
This commit is contained in:
		
							parent
							
								
									a3abb115b9
								
							
						
					
					
						commit
						dc9b346d7e
					
				| 
						 | 
				
			
			@ -40,7 +40,6 @@ def main() -> None:
 | 
			
		|||
 | 
			
		||||
    new_matches: Dict[int, str] = dict()
 | 
			
		||||
    checker = util.checker.FunctionChecker()
 | 
			
		||||
    checker.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv")
 | 
			
		||||
 | 
			
		||||
    # Given a list L of functions to identify and a small list of candidates C, this tool will attempt to
 | 
			
		||||
    # automatically identify matches by checking each function in L against each function in C.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,9 @@
 | 
			
		|||
from collections import defaultdict
 | 
			
		||||
from typing import Set, DefaultDict
 | 
			
		||||
from typing import Set, DefaultDict, Dict
 | 
			
		||||
 | 
			
		||||
import capstone as cs
 | 
			
		||||
 | 
			
		||||
from util import dsym, elf
 | 
			
		||||
from util import dsym, elf, utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FunctionChecker:
 | 
			
		||||
| 
						 | 
				
			
			@ -12,10 +12,16 @@ class FunctionChecker:
 | 
			
		|||
        self.md.detail = True
 | 
			
		||||
        self.my_symtab = elf.build_name_to_symbol_table(elf.my_symtab)
 | 
			
		||||
        self.dsymtab = dsym.DataSymbolContainer()
 | 
			
		||||
        self.decompiled_fns: Dict[int, str] = dict()
 | 
			
		||||
        self.load_data_for_project()
 | 
			
		||||
 | 
			
		||||
    def get_data_symtab(self) -> dsym.DataSymbolContainer:
 | 
			
		||||
        return self.dsymtab
 | 
			
		||||
 | 
			
		||||
    def load_data_for_project(self) -> None:
 | 
			
		||||
        self.decompiled_fns = {func.addr: func.decomp_name for func in utils.get_functions() if func.decomp_name}
 | 
			
		||||
        self.get_data_symtab().load_from_csv(utils.get_repo_root() / "data" / "data_symbols.csv")
 | 
			
		||||
 | 
			
		||||
    def check(self, base_fn: elf.Function, my_fn: elf.Function) -> bool:
 | 
			
		||||
        gprs1: DefaultDict[int, int] = defaultdict(int)
 | 
			
		||||
        gprs2: DefaultDict[int, int] = defaultdict(int)
 | 
			
		||||
| 
						 | 
				
			
			@ -51,13 +57,20 @@ class FunctionChecker:
 | 
			
		|||
            # Ignore some address differences until a fully matching executable can be generated.
 | 
			
		||||
 | 
			
		||||
            if i1.mnemonic == 'bl':
 | 
			
		||||
                if not self._check_function_call(i1.operands[0].imm, i2.operands[0].imm):
 | 
			
		||||
                    return False
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if i1.mnemonic == 'b':
 | 
			
		||||
                # Needed for tail calls.
 | 
			
		||||
                branch_target = int(i1.op_str[1:], 16)
 | 
			
		||||
                branch_target = i1.operands[0].imm
 | 
			
		||||
                # If we are branching outside the function, this is likely a tail call.
 | 
			
		||||
                # Treat this as a function call.
 | 
			
		||||
                if not (base_fn.addr <= branch_target < base_fn.addr + size):
 | 
			
		||||
                    if not self._check_function_call(branch_target, i2.operands[0].imm):
 | 
			
		||||
                        return False
 | 
			
		||||
                    continue
 | 
			
		||||
                # Otherwise, it's a mismatch.
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
            if i1.mnemonic == 'adrp':
 | 
			
		||||
                if i1.operands[0].reg != i2.operands[0].reg:
 | 
			
		||||
| 
						 | 
				
			
			@ -134,3 +147,11 @@ class FunctionChecker:
 | 
			
		|||
 | 
			
		||||
        decomp_symbol = self.my_symtab[symbol.name]
 | 
			
		||||
        return decomp_symbol.addr == decomp_addr
 | 
			
		||||
 | 
			
		||||
    def _check_function_call(self, orig_addr: int, decomp_addr: int) -> bool:
 | 
			
		||||
        name = self.decompiled_fns.get(orig_addr, None)
 | 
			
		||||
        if name is None:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        decomp_symbol = self.my_symtab[name]
 | 
			
		||||
        return decomp_symbol.addr == decomp_addr
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue