tools: Deduplicate AI tool code

This commit is contained in:
Léo Lam 2020-12-26 23:19:01 +01:00
parent 18b90cea29
commit d9f6166ed0
No known key found for this signature in database
GPG Key ID: 0DF30F9081000741
4 changed files with 118 additions and 99 deletions

View File

@ -1,18 +1,13 @@
#!/usr/bin/env python3
import argparse
from collections import defaultdict
from pathlib import Path
from typing import Union
import yaml
_base_classes = {
0x71024d8d68,
0x71025129f0,
0x7102513278,
0x71024d8ef0,
0x710243c9b8,
}
from util import ai_common
from util.ai_common import BaseClasses
from util.graph import Graph
_known_vtables = {
0x71024d8d68: "ActionBase",
@ -34,43 +29,6 @@ def get_name_for_vtable(vtable: Union[str, int]):
return f"[V] {vtable:#x}"
def dfs(nodes: dict, start, visited: set):
result = []
to_visit = [start]
while to_visit:
x = to_visit.pop()
result.append(x)
visited.add(x)
for y in nodes[x]:
if y not in visited:
to_visit.append(y)
return result
class Graph:
def __init__(self):
self.nodes = defaultdict(set)
def add_edge(self, a, b):
self.nodes[a].add(b)
def find_connected_components(self):
nodes = defaultdict(list)
for u in self.nodes:
for v in self.nodes[u]:
nodes[u].append(v)
nodes[v].append(u)
cc = []
visited = set()
for u in nodes.keys():
if u in visited:
continue
cc.append(dfs(nodes, u, visited))
return cc
def guess_vtable_names(reverse_graph: Graph):
for u in reverse_graph.nodes:
targets = list(reverse_graph.nodes[u])
@ -88,7 +46,7 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap
from_ = classes[i]
to_ = classes[i + 1]
# Skip base classes to reduce noise.
if to_ in _base_classes:
if to_ in BaseClasses:
break
reverse_graph.add_edge(to_, from_)
@ -97,7 +55,7 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap
for name, vtables in all_vtables[type_].items():
classes = [name] + list(reversed(vtables))
for i in range(len(classes) - 1):
if classes[i + 1] in _base_classes:
if classes[i + 1] in BaseClasses:
break
from_ = get_name_for_vtable(classes[i])
to_ = get_name_for_vtable(classes[i + 1])
@ -106,14 +64,12 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap
def main() -> None:
parser = argparse.ArgumentParser(description="Shows AI classes with non-trivial class hierarchies.")
parser.add_argument("aidef_vtables", help="Path to aidef_vtables.yml")
parser.add_argument("--type", help="AI class type to visualise", choices=["Action", "AI", "Behavior", "Query"],
required=True)
parser.add_argument("--out-names", help="Path to which a vtable -> name map will be written", required=True)
args = parser.parse_args()
with Path(args.aidef_vtables).open() as f:
all_vtables: dict = yaml.load(f, Loader=yaml.CSafeLoader)
all_vtables = ai_common.get_ai_vtables()
graph = Graph()
reverse_graph = Graph()

View File

@ -1,46 +1,10 @@
import struct
from collections import defaultdict
from pathlib import Path
from typing import Dict
from ai_show_nontrivial_hierarchies import _base_classes
from util import utils
import yaml
from util import utils, ai_common
import idaapi
class Graph:
def __init__(self):
self.nodes = defaultdict(set)
def add_edge(self, a, b):
self.nodes[a].add(b)
def topological_sort(self) -> list:
result = []
visited = set()
def dfs(node):
if node in visited:
return
# Our graph is guaranteed to be acyclic since it's a graph of vtables...
visited.add(node)
for y in self.nodes.get(node, set()):
dfs(y)
result.insert(0, node)
for x in self.nodes:
dfs(x)
return result
def build_graph(all_vtables: dict, graph: Graph):
for name, vtables in all_vtables["Action"].items():
classes = list(dict.fromkeys(reversed(vtables)))
for i in range(len(classes) - 1):
graph.add_edge(classes[i + 1], classes[i])
from util.ai_common import BaseClasses
_vtable_fn_names = [
"_ZNK5uking6action{}27checkDerivedRuntimeTypeInfoEPKN4sead15RuntimeTypeInfo9InterfaceE",
@ -96,21 +60,15 @@ _ida_base = 0x7100000000
def main() -> None:
data_dir = utils.get_repo_root() / "data"
with Path(data_dir / "aidef_vtables.yml").open() as f:
all_vtables: dict = yaml.load(f, Loader=yaml.CSafeLoader)
with Path(data_dir / "aidef_action_vtables.yml").open() as f:
names: Dict[int, str] = yaml.load(f, Loader=yaml.CSafeLoader)
all_vtables = ai_common.get_ai_vtables()
names = ai_common.get_action_vtable_names()
not_decompiled = {func.addr for func in utils.get_functions() if func.status == utils.FunctionStatus.NotDecompiled}
new_names: Dict[int, str] = dict()
not_decompiled = {func.addr for func in utils.get_functions() if func.status == utils.FunctionStatus.NotDecompiled}
graph = Graph()
build_graph(all_vtables, graph)
order = graph.topological_sort()
order = ai_common.topologically_sort_vtables(all_vtables, "Action")
for vtable_addr in order:
if vtable_addr in _base_classes:
if vtable_addr in BaseClasses:
continue
class_name = names.get(vtable_addr)

44
tools/util/ai_common.py Normal file
View File

@ -0,0 +1,44 @@
from typing import Dict, List
import yaml
from util import utils
from util.graph import Graph
BaseClasses = {
0x71024d8d68,
0x71025129f0,
0x7102513278,
0x71024d8ef0,
0x710243c9b8,
}
def check_vtable_name_dict(names: Dict[int, str]):
seen = set()
for k, v in names.items():
if v in seen:
raise ValueError(f"invalid vtable names: {v} appears twice")
seen.add(k)
seen.add(v)
def get_ai_vtables() -> Dict[str, Dict[str, List[int]]]:
with (utils.get_repo_root() / "data" / "aidef_vtables.yml").open(encoding="utf-8") as f:
return yaml.load(f, Loader=yaml.CSafeLoader)
def get_action_vtable_names() -> Dict[int, str]:
with (utils.get_repo_root() / "data" / "aidef_action_vtables.yml").open(encoding="utf-8") as f:
names = yaml.load(f, Loader=yaml.CSafeLoader)
check_vtable_name_dict(names)
return names
def topologically_sort_vtables(all_vtables: dict, type_: str) -> List[int]:
graph = Graph()
for name, vtables in all_vtables[type_].items():
classes = list(dict.fromkeys(reversed(vtables)))
for i in range(len(classes) - 1):
graph.add_edge(classes[i + 1], classes[i])
return graph.topological_sort()

61
tools/util/graph.py Normal file
View File

@ -0,0 +1,61 @@
from collections import defaultdict
_Visiting = 0
_Visited = 1
class Graph:
def __init__(self):
self.nodes = defaultdict(set)
def add_edge(self, a, b):
self.nodes[a].add(b)
def find_connected_components(self):
nodes = defaultdict(list)
for u in self.nodes:
for v in self.nodes[u]:
nodes[u].append(v)
nodes[v].append(u)
cc = []
visited = set()
def dfs(start):
result = []
to_visit = [start]
while to_visit:
x = to_visit.pop()
result.append(x)
visited.add(x)
for y in nodes[x]:
if y not in visited:
to_visit.append(y)
return result
for u in nodes.keys():
if u in visited:
continue
cc.append(dfs(u))
return cc
def topological_sort(self) -> list:
result = []
statuses = dict()
def dfs(node):
if statuses.get(node) == _Visiting:
raise RuntimeError("Graph is not acyclic")
if statuses.get(node) == _Visited:
return
statuses[node] = _Visiting
for y in self.nodes.get(node, set()):
dfs(y)
statuses[node] = _Visited
result.insert(0, node)
for x in self.nodes:
dfs(x)
return result