mirror of https://github.com/zeldaret/botw.git
tools: Deduplicate AI tool code
This commit is contained in:
parent
18b90cea29
commit
d9f6166ed0
|
|
@ -1,18 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
_base_classes = {
|
||||
0x71024d8d68,
|
||||
0x71025129f0,
|
||||
0x7102513278,
|
||||
0x71024d8ef0,
|
||||
0x710243c9b8,
|
||||
}
|
||||
from util import ai_common
|
||||
from util.ai_common import BaseClasses
|
||||
from util.graph import Graph
|
||||
|
||||
_known_vtables = {
|
||||
0x71024d8d68: "ActionBase",
|
||||
|
|
@ -34,43 +29,6 @@ def get_name_for_vtable(vtable: Union[str, int]):
|
|||
return f"[V] {vtable:#x}"
|
||||
|
||||
|
||||
def dfs(nodes: dict, start, visited: set):
|
||||
result = []
|
||||
to_visit = [start]
|
||||
while to_visit:
|
||||
x = to_visit.pop()
|
||||
result.append(x)
|
||||
visited.add(x)
|
||||
|
||||
for y in nodes[x]:
|
||||
if y not in visited:
|
||||
to_visit.append(y)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self):
|
||||
self.nodes = defaultdict(set)
|
||||
|
||||
def add_edge(self, a, b):
|
||||
self.nodes[a].add(b)
|
||||
|
||||
def find_connected_components(self):
|
||||
nodes = defaultdict(list)
|
||||
for u in self.nodes:
|
||||
for v in self.nodes[u]:
|
||||
nodes[u].append(v)
|
||||
nodes[v].append(u)
|
||||
cc = []
|
||||
visited = set()
|
||||
for u in nodes.keys():
|
||||
if u in visited:
|
||||
continue
|
||||
cc.append(dfs(nodes, u, visited))
|
||||
return cc
|
||||
|
||||
|
||||
def guess_vtable_names(reverse_graph: Graph):
|
||||
for u in reverse_graph.nodes:
|
||||
targets = list(reverse_graph.nodes[u])
|
||||
|
|
@ -88,7 +46,7 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap
|
|||
from_ = classes[i]
|
||||
to_ = classes[i + 1]
|
||||
# Skip base classes to reduce noise.
|
||||
if to_ in _base_classes:
|
||||
if to_ in BaseClasses:
|
||||
break
|
||||
reverse_graph.add_edge(to_, from_)
|
||||
|
||||
|
|
@ -97,7 +55,7 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap
|
|||
for name, vtables in all_vtables[type_].items():
|
||||
classes = [name] + list(reversed(vtables))
|
||||
for i in range(len(classes) - 1):
|
||||
if classes[i + 1] in _base_classes:
|
||||
if classes[i + 1] in BaseClasses:
|
||||
break
|
||||
from_ = get_name_for_vtable(classes[i])
|
||||
to_ = get_name_for_vtable(classes[i + 1])
|
||||
|
|
@ -106,14 +64,12 @@ def build_graph(all_vtables: dict, type_: str, graph: Graph, reverse_graph: Grap
|
|||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Shows AI classes with non-trivial class hierarchies.")
|
||||
parser.add_argument("aidef_vtables", help="Path to aidef_vtables.yml")
|
||||
parser.add_argument("--type", help="AI class type to visualise", choices=["Action", "AI", "Behavior", "Query"],
|
||||
required=True)
|
||||
parser.add_argument("--out-names", help="Path to which a vtable -> name map will be written", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
with Path(args.aidef_vtables).open() as f:
|
||||
all_vtables: dict = yaml.load(f, Loader=yaml.CSafeLoader)
|
||||
all_vtables = ai_common.get_ai_vtables()
|
||||
|
||||
graph = Graph()
|
||||
reverse_graph = Graph()
|
||||
|
|
|
|||
|
|
@ -1,46 +1,10 @@
|
|||
import struct
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from ai_show_nontrivial_hierarchies import _base_classes
|
||||
from util import utils
|
||||
import yaml
|
||||
from util import utils, ai_common
|
||||
import idaapi
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self):
|
||||
self.nodes = defaultdict(set)
|
||||
|
||||
def add_edge(self, a, b):
|
||||
self.nodes[a].add(b)
|
||||
|
||||
def topological_sort(self) -> list:
|
||||
result = []
|
||||
visited = set()
|
||||
|
||||
def dfs(node):
|
||||
if node in visited:
|
||||
return
|
||||
# Our graph is guaranteed to be acyclic since it's a graph of vtables...
|
||||
visited.add(node)
|
||||
for y in self.nodes.get(node, set()):
|
||||
dfs(y)
|
||||
result.insert(0, node)
|
||||
|
||||
for x in self.nodes:
|
||||
dfs(x)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_graph(all_vtables: dict, graph: Graph):
|
||||
for name, vtables in all_vtables["Action"].items():
|
||||
classes = list(dict.fromkeys(reversed(vtables)))
|
||||
for i in range(len(classes) - 1):
|
||||
graph.add_edge(classes[i + 1], classes[i])
|
||||
|
||||
from util.ai_common import BaseClasses
|
||||
|
||||
_vtable_fn_names = [
|
||||
"_ZNK5uking6action{}27checkDerivedRuntimeTypeInfoEPKN4sead15RuntimeTypeInfo9InterfaceE",
|
||||
|
|
@ -96,21 +60,15 @@ _ida_base = 0x7100000000
|
|||
|
||||
|
||||
def main() -> None:
|
||||
data_dir = utils.get_repo_root() / "data"
|
||||
with Path(data_dir / "aidef_vtables.yml").open() as f:
|
||||
all_vtables: dict = yaml.load(f, Loader=yaml.CSafeLoader)
|
||||
with Path(data_dir / "aidef_action_vtables.yml").open() as f:
|
||||
names: Dict[int, str] = yaml.load(f, Loader=yaml.CSafeLoader)
|
||||
all_vtables = ai_common.get_ai_vtables()
|
||||
names = ai_common.get_action_vtable_names()
|
||||
not_decompiled = {func.addr for func in utils.get_functions() if func.status == utils.FunctionStatus.NotDecompiled}
|
||||
|
||||
new_names: Dict[int, str] = dict()
|
||||
|
||||
not_decompiled = {func.addr for func in utils.get_functions() if func.status == utils.FunctionStatus.NotDecompiled}
|
||||
|
||||
graph = Graph()
|
||||
build_graph(all_vtables, graph)
|
||||
order = graph.topological_sort()
|
||||
order = ai_common.topologically_sort_vtables(all_vtables, "Action")
|
||||
for vtable_addr in order:
|
||||
if vtable_addr in _base_classes:
|
||||
if vtable_addr in BaseClasses:
|
||||
continue
|
||||
|
||||
class_name = names.get(vtable_addr)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Dict, List
|
||||
import yaml
|
||||
|
||||
from util import utils
|
||||
from util.graph import Graph
|
||||
|
||||
BaseClasses = {
|
||||
0x71024d8d68,
|
||||
0x71025129f0,
|
||||
0x7102513278,
|
||||
0x71024d8ef0,
|
||||
0x710243c9b8,
|
||||
}
|
||||
|
||||
|
||||
def check_vtable_name_dict(names: Dict[int, str]):
|
||||
seen = set()
|
||||
for k, v in names.items():
|
||||
if v in seen:
|
||||
raise ValueError(f"invalid vtable names: {v} appears twice")
|
||||
seen.add(k)
|
||||
seen.add(v)
|
||||
|
||||
|
||||
def get_ai_vtables() -> Dict[str, Dict[str, List[int]]]:
|
||||
with (utils.get_repo_root() / "data" / "aidef_vtables.yml").open(encoding="utf-8") as f:
|
||||
return yaml.load(f, Loader=yaml.CSafeLoader)
|
||||
|
||||
|
||||
def get_action_vtable_names() -> Dict[int, str]:
|
||||
with (utils.get_repo_root() / "data" / "aidef_action_vtables.yml").open(encoding="utf-8") as f:
|
||||
names = yaml.load(f, Loader=yaml.CSafeLoader)
|
||||
|
||||
check_vtable_name_dict(names)
|
||||
return names
|
||||
|
||||
|
||||
def topologically_sort_vtables(all_vtables: dict, type_: str) -> List[int]:
|
||||
graph = Graph()
|
||||
for name, vtables in all_vtables[type_].items():
|
||||
classes = list(dict.fromkeys(reversed(vtables)))
|
||||
for i in range(len(classes) - 1):
|
||||
graph.add_edge(classes[i + 1], classes[i])
|
||||
return graph.topological_sort()
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
from collections import defaultdict
|
||||
|
||||
_Visiting = 0
|
||||
_Visited = 1
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self):
|
||||
self.nodes = defaultdict(set)
|
||||
|
||||
def add_edge(self, a, b):
|
||||
self.nodes[a].add(b)
|
||||
|
||||
def find_connected_components(self):
|
||||
nodes = defaultdict(list)
|
||||
for u in self.nodes:
|
||||
for v in self.nodes[u]:
|
||||
nodes[u].append(v)
|
||||
nodes[v].append(u)
|
||||
cc = []
|
||||
visited = set()
|
||||
|
||||
def dfs(start):
|
||||
result = []
|
||||
to_visit = [start]
|
||||
while to_visit:
|
||||
x = to_visit.pop()
|
||||
result.append(x)
|
||||
visited.add(x)
|
||||
for y in nodes[x]:
|
||||
if y not in visited:
|
||||
to_visit.append(y)
|
||||
return result
|
||||
|
||||
for u in nodes.keys():
|
||||
if u in visited:
|
||||
continue
|
||||
cc.append(dfs(u))
|
||||
return cc
|
||||
|
||||
def topological_sort(self) -> list:
|
||||
result = []
|
||||
statuses = dict()
|
||||
|
||||
def dfs(node):
|
||||
if statuses.get(node) == _Visiting:
|
||||
raise RuntimeError("Graph is not acyclic")
|
||||
if statuses.get(node) == _Visited:
|
||||
return
|
||||
|
||||
statuses[node] = _Visiting
|
||||
for y in self.nodes.get(node, set()):
|
||||
dfs(y)
|
||||
|
||||
statuses[node] = _Visited
|
||||
result.insert(0, node)
|
||||
|
||||
for x in self.nodes:
|
||||
dfs(x)
|
||||
|
||||
return result
|
||||
Loading…
Reference in New Issue