From d9f6166ed0162cc916ee8d6c3ae439a96c623c6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Lam?= Date: Sat, 26 Dec 2020 23:19:01 +0100 Subject: [PATCH] tools: Deduplicate AI tool code --- tools/ai_show_nontrivial_hierarchies.py | 56 +++-------------------- tools/ida_ai_rename_action_vfns.py | 56 +++-------------------- tools/util/ai_common.py | 44 ++++++++++++++++++ tools/util/graph.py | 61 +++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 99 deletions(-) create mode 100644 tools/util/ai_common.py create mode 100644 tools/util/graph.py diff --git a/tools/ai_show_nontrivial_hierarchies.py b/tools/ai_show_nontrivial_hierarchies.py index 726331e2..32ce620c 100755 --- a/tools/ai_show_nontrivial_hierarchies.py +++ b/tools/ai_show_nontrivial_hierarchies.py @@ -1,18 +1,13 @@ #!/usr/bin/env python3 import argparse -from collections import defaultdict from pathlib import Path from typing import Union import yaml -_base_classes = { - 0x71024d8d68, - 0x71025129f0, - 0x7102513278, - 0x71024d8ef0, - 0x710243c9b8, -} +from util import ai_common +from util.ai_common import BaseClasses +from util.graph import Graph _known_vtables = { 0x71024d8d68: "ActionBase", @@ -34,43 +29,6 @@ def get_name_for_vtable(vtable: Union[str, int]): return f"[V] {vtable:#x}" -def dfs(nodes: dict, start, visited: set): - result = [] - to_visit = [start] - while to_visit: - x = to_visit.pop() - result.append(x) - visited.add(x) - - for y in nodes[x]: - if y not in visited: - to_visit.append(y) - - return result - - -class Graph: - def __init__(self): - self.nodes = defaultdict(set) - - def add_edge(self, a, b): - self.nodes[a].add(b) - - def find_connected_components(self): - nodes = defaultdict(list) - for u in self.nodes: - for v in self.nodes[u]: - nodes[u].append(v) - nodes[v].append(u) - cc = [] - visited = set() - for u in nodes.keys(): - if u in visited: - continue - cc.append(dfs(nodes, u, visited)) - return cc - - def guess_vtable_names(reverse_graph: Graph): for u in reverse_graph.nodes: targets = list(reverse_graph.nodes[u]) @@ -88,7 +46,7 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap from_ = classes[i] to_ = classes[i + 1] # Skip base classes to reduce noise. - if to_ in _base_classes: + if to_ in BaseClasses: break reverse_graph.add_edge(to_, from_) @@ -97,7 +55,7 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap for name, vtables in all_vtables[type_].items(): classes = [name] + list(reversed(vtables)) for i in range(len(classes) - 1): - if classes[i + 1] in _base_classes: + if classes[i + 1] in BaseClasses: break from_ = get_name_for_vtable(classes[i]) to_ = get_name_for_vtable(classes[i + 1]) @@ -106,14 +64,12 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap def main() -> None: parser = argparse.ArgumentParser(description="Shows AI classes with non-trivial class hierarchies.") - parser.add_argument("aidef_vtables", help="Path to aidef_vtables.yml") parser.add_argument("--type", help="AI class type to visualise", choices=["Action", "AI", "Behavior", "Query"], required=True) parser.add_argument("--out-names", help="Path to which a vtable -> name map will be written", required=True) args = parser.parse_args() - with Path(args.aidef_vtables).open() as f: - all_vtables: dict = yaml.load(f, Loader=yaml.CSafeLoader) + all_vtables = ai_common.get_ai_vtables() graph = Graph() reverse_graph = Graph() diff --git a/tools/ida_ai_rename_action_vfns.py b/tools/ida_ai_rename_action_vfns.py index 07804cb6..4ad0bcdd 100644 --- a/tools/ida_ai_rename_action_vfns.py +++ b/tools/ida_ai_rename_action_vfns.py @@ -1,46 +1,10 @@ import struct -from collections import defaultdict -from pathlib import Path from typing import Dict -from ai_show_nontrivial_hierarchies import _base_classes -from util import utils -import yaml +from util import utils, ai_common import idaapi - -class Graph: - def __init__(self): - self.nodes = defaultdict(set) - - def add_edge(self, a, b): - self.nodes[a].add(b) - - def topological_sort(self) -> list: - result = [] - visited = set() - - def dfs(node): - if node in visited: - return - # Our graph is guaranteed to be acyclic since it's a graph of vtables... - visited.add(node) - for y in self.nodes.get(node, set()): - dfs(y) - result.insert(0, node) - - for x in self.nodes: - dfs(x) - - return result - - -def build_graph(all_vtables: dict, graph: Graph): - for name, vtables in all_vtables["Action"].items(): - classes = list(dict.fromkeys(reversed(vtables))) - for i in range(len(classes) - 1): - graph.add_edge(classes[i + 1], classes[i]) - +from util.ai_common import BaseClasses _vtable_fn_names = [ "_ZNK5uking6action{}27checkDerivedRuntimeTypeInfoEPKN4sead15RuntimeTypeInfo9InterfaceE", @@ -96,21 +60,15 @@ _ida_base = 0x7100000000 def main() -> None: - data_dir = utils.get_repo_root() / "data" - with Path(data_dir / "aidef_vtables.yml").open() as f: - all_vtables: dict = yaml.load(f, Loader=yaml.CSafeLoader) - with Path(data_dir / "aidef_action_vtables.yml").open() as f: - names: Dict[int, str] = yaml.load(f, Loader=yaml.CSafeLoader) + all_vtables = ai_common.get_ai_vtables() + names = ai_common.get_action_vtable_names() + not_decompiled = {func.addr for func in utils.get_functions() if func.status == utils.FunctionStatus.NotDecompiled} new_names: Dict[int, str] = dict() - not_decompiled = {func.addr for func in utils.get_functions() if func.status == utils.FunctionStatus.NotDecompiled} - - graph = Graph() - build_graph(all_vtables, graph) - order = graph.topological_sort() + order = ai_common.topologically_sort_vtables(all_vtables, "Action") for vtable_addr in order: - if vtable_addr in _base_classes: + if vtable_addr in BaseClasses: continue class_name = names.get(vtable_addr) diff --git a/tools/util/ai_common.py b/tools/util/ai_common.py new file mode 100644 index 00000000..6de539ea --- /dev/null +++ b/tools/util/ai_common.py @@ -0,0 +1,44 @@ +from typing import Dict, List +import yaml + +from util import utils +from util.graph import Graph + +BaseClasses = { + 0x71024d8d68, + 0x71025129f0, + 0x7102513278, + 0x71024d8ef0, + 0x710243c9b8, +} + + +def check_vtable_name_dict(names: Dict[int, str]): + seen = set() + for k, v in names.items(): + if v in seen: + raise ValueError(f"invalid vtable names: {v} appears twice") + seen.add(k) + seen.add(v) + + +def get_ai_vtables() -> Dict[str, Dict[str, List[int]]]: + with (utils.get_repo_root() / "data" / "aidef_vtables.yml").open(encoding="utf-8") as f: + return yaml.load(f, Loader=yaml.CSafeLoader) + + +def get_action_vtable_names() -> Dict[int, str]: + with (utils.get_repo_root() / "data" / "aidef_action_vtables.yml").open(encoding="utf-8") as f: + names = yaml.load(f, Loader=yaml.CSafeLoader) + + check_vtable_name_dict(names) + return names + + +def topologically_sort_vtables(all_vtables: dict, type_: str) -> List[int]: + graph = Graph() + for name, vtables in all_vtables[type_].items(): + classes = list(dict.fromkeys(reversed(vtables))) + for i in range(len(classes) - 1): + graph.add_edge(classes[i + 1], classes[i]) + return graph.topological_sort() diff --git a/tools/util/graph.py b/tools/util/graph.py new file mode 100644 index 00000000..9657103f --- /dev/null +++ b/tools/util/graph.py @@ -0,0 +1,61 @@ +from collections import defaultdict + +_Visiting = 0 +_Visited = 1 + + +class Graph: + def __init__(self): + self.nodes = defaultdict(set) + + def add_edge(self, a, b): + self.nodes[a].add(b) + + def find_connected_components(self): + nodes = defaultdict(list) + for u in self.nodes: + for v in self.nodes[u]: + nodes[u].append(v) + nodes[v].append(u) + cc = [] + visited = set() + + def dfs(start): + result = [] + to_visit = [start] + while to_visit: + x = to_visit.pop() + result.append(x) + visited.add(x) + for y in nodes[x]: + if y not in visited: + to_visit.append(y) + return result + + for u in nodes.keys(): + if u in visited: + continue + cc.append(dfs(u)) + return cc + + def topological_sort(self) -> list: + result = [] + statuses = dict() + + def dfs(node): + if statuses.get(node) == _Visiting: + raise RuntimeError("Graph is not acyclic") + if statuses.get(node) == _Visited: + return + + statuses[node] = _Visiting + for y in self.nodes.get(node, set()): + dfs(y) + + statuses[node] = _Visited + result.insert(0, node) + + for x in self.nodes: + dfs(x) + + return result