mirror of https://github.com/zeldaret/tp.git
490 lines
14 KiB
Python
490 lines
14 KiB
Python
"""
|
|
|
|
tp.py - Various tools used for the zeldaret/tp project
|
|
|
|
progress: Calculates decompilation progress. By assuming that the code was generated by 'dol2asm'
|
|
and that all non-code sections are fully decompiled. The script calculate the amount of bytes
|
|
that are left to decompile (all code in the .s files).
|
|
|
|
pull-request: Helps people make sure that everything is OK before making pull-requests.
|
|
The script does three things: remove unused asm files, rebuild the full project, and
|
|
clang-format every file.
|
|
|
|
"""
|
|
|
|
import click
|
|
import sys
|
|
import os
|
|
import rich
|
|
import logging
|
|
import subprocess
|
|
import time
|
|
import hashlib
|
|
import json
|
|
import git
|
|
import libdol
|
|
|
|
from pathlib import Path
|
|
from rich.logging import RichHandler
|
|
from rich.console import Console
|
|
from rich.progress import Progress
|
|
from rich.text import Text
|
|
from rich.table import Table
|
|
|
|
import multiprocessing as mp
|
|
|
|
VERSION = "1.0"
|
|
CONSOLE = Console()
|
|
|
|
logging.basicConfig(
|
|
level="NOTSET",
|
|
format="%(message)s",
|
|
datefmt="[%X]",
|
|
handlers=[RichHandler(console=CONSOLE, rich_tracebacks=True)]
|
|
)
|
|
|
|
LOG = logging.getLogger("rich")
|
|
LOG.setLevel(logging.INFO)
|
|
|
|
@click.group()
|
|
@click.version_option(VERSION)
|
|
def tp():
|
|
""" Tools to help the decompilation of "The Legend of Zelda: Twilight Princess" """
|
|
pass
|
|
|
|
|
|
@tp.command(name="progress")
|
|
@click.option('--debug/--no-debug')
|
|
@click.option('--matchning/--no-matching', default=True, is_flag=True)
|
|
@click.option('--format', '-f', default="FANCY", type=click.Choice(['FANCY', 'CSV', 'JSON-SHIELD'], case_sensitive=False))
|
|
def progress(debug, matchning, format):
|
|
""" Calculate decompilation progress """
|
|
|
|
if debug:
|
|
LOG.setLevel(logging.DEBUG)
|
|
|
|
calculate_progress(matchning, format)
|
|
|
|
|
|
def calculate_progress(matchning, format):
|
|
if not matchning:
|
|
LOG.error("non-matching progress is not support yet.")
|
|
sys.exit(1)
|
|
|
|
# read .dol file
|
|
dol_path = Path("build/dolzel2/main.dol")
|
|
if not dol_path.exists():
|
|
LOG.error(f"Unable to read '{dol_path}'")
|
|
sys.exit(1)
|
|
|
|
with dol_path.open("rb") as file:
|
|
data = file.read()
|
|
dol = libdol.read(data)
|
|
|
|
# get section sizes
|
|
total_size = len(data)
|
|
format_size = 0x100
|
|
|
|
init = dol.get_named_section(".init")
|
|
assert init
|
|
init_decompiled_size = init.size
|
|
|
|
text = dol.get_named_section(".text")
|
|
assert text
|
|
text_decompiled_size = text.size
|
|
|
|
data_sections = [
|
|
section
|
|
for section in dol.sections
|
|
if section.data and not section.addr in {init.addr, text.addr}
|
|
]
|
|
|
|
data_size = sum([section.size for section in data_sections])
|
|
|
|
# find all _used_ asm files
|
|
asm_files = find_used_asm_files(not matchning)
|
|
|
|
# calculate the range each asm file occupies
|
|
ranges = find_function_ranges(asm_files)
|
|
|
|
LOG.debug(f"init {init.addr:08X}-{init.addr + init.size:08X}")
|
|
LOG.debug(f"text {text.addr:08X}-{text.addr + text.size:08X}")
|
|
|
|
# substract the size of each asm function
|
|
for function_range in ranges:
|
|
if function_range[0] >= init.addr and function_range[1] < init.addr + init.size:
|
|
init_decompiled_size -= (function_range[1] - function_range[0])
|
|
elif function_range[0] >= text.addr and function_range[1] < text.addr + text.size:
|
|
text_decompiled_size -= (function_range[1] - function_range[0])
|
|
|
|
# calculate the progress
|
|
init_result = init_decompiled_size / init.size
|
|
text_result = text_decompiled_size / text.size
|
|
total_decompiled_size = (init_decompiled_size +
|
|
text_decompiled_size + data_size + format_size)
|
|
total_result = total_decompiled_size / total_size
|
|
|
|
init_pct = 100 * init_result
|
|
text_pct = 100 * text_result
|
|
total_pct = 100 * total_result
|
|
|
|
if format == "FANCY":
|
|
table = Table(title="main.dol")
|
|
table.add_column("Section", justify="right",
|
|
style="cyan", no_wrap=True)
|
|
table.add_column("Percentage", style="green")
|
|
table.add_column("Decompiled (bytes)",
|
|
justify="right", style="bright_yellow")
|
|
table.add_column("Total (bytes)", justify="right",
|
|
style="bright_magenta")
|
|
|
|
table.add_row(".init", f"{init_pct:10.6f}%",
|
|
f"{init_decompiled_size}", f"{init.size}")
|
|
table.add_row(".text", f"{text_pct:10.6f}%",
|
|
f"{text_decompiled_size}", f"{text.size}")
|
|
table.add_row("total", f"{total_pct:10.6f}%",
|
|
f"{total_decompiled_size}", f"{total_size}")
|
|
CONSOLE.print(table)
|
|
elif format == "CSV":
|
|
version = 1
|
|
git_object = git.Repo().head.object
|
|
timestamp = str(git_object.committed_date)
|
|
git_hash = git_object.hexsha
|
|
data = [
|
|
str(version), timestamp, git_hash,
|
|
str(init_decompiled_size), str(init.size),
|
|
str(text_decompiled_size), str(text.size),
|
|
str(total_decompiled_size), str(total_size),
|
|
|
|
]
|
|
print(",".join(data))
|
|
elif format == "JSON-SHIELD":
|
|
# https://shields.io/endpoint
|
|
print(json.dumps({
|
|
"schemaVersion": 1,
|
|
"label": "progress",
|
|
"message": f"{total_pct:.3g}%",
|
|
"color": 'yellow',
|
|
}))
|
|
else:
|
|
print(init_pct, text_pct, total_pct)
|
|
LOG.error("unknown format: '{format}'")
|
|
|
|
|
|
def find_function_ranges(asm_files):
|
|
function_ranges = []
|
|
for asm in asm_files:
|
|
with asm.open('r') as file:
|
|
first = None
|
|
last = None
|
|
for line in file.readlines():
|
|
line_start = line.find("/* ")
|
|
line_end = line.find(" */", 3)
|
|
|
|
if line_start < 0 or line_end < 0:
|
|
continue
|
|
|
|
line_values = line[line_start+3:line_end].split(" ")
|
|
assert len(line_values) == 6
|
|
addr = int(line_values[0], 16)
|
|
if not first:
|
|
first = addr
|
|
last = addr + 4
|
|
|
|
function_ranges.append((first, last))
|
|
|
|
return function_ranges
|
|
|
|
|
|
@tp.command(name="pull-request")
|
|
@click.option('--debug/--no-debug')
|
|
@click.option('--thread-count', '-j', 'thread_count', help="Thread that should be used. This option is passed forward to any 'make' command.", default=4)
|
|
def pull_request(debug, thread_count):
|
|
""" Verify that everything is OK before pull-request """
|
|
|
|
if debug:
|
|
LOG.setLevel(logging.DEBUG)
|
|
|
|
text = Text("Pull-Request Checklist:")
|
|
text.stylize("bold")
|
|
CONSOLE.print(text)
|
|
|
|
#
|
|
text = Text("--- Removing Unused '.s' Files")
|
|
text.stylize("bold magenta")
|
|
CONSOLE.print(text)
|
|
|
|
unused_files, error_files = find_unused_asm_files(False)
|
|
|
|
for unused_file in unused_files:
|
|
unused_file.unlink()
|
|
CONSOLE.print(f"removed '{unused_file}'")
|
|
|
|
text = Text(" OK")
|
|
text.stylize("bold green")
|
|
CONSOLE.print(text)
|
|
|
|
#
|
|
text = Text("--- Full Rebuild")
|
|
text.stylize("bold magenta")
|
|
CONSOLE.print(text)
|
|
|
|
if rebuild(thread_count):
|
|
text = Text(" OK")
|
|
text.stylize("bold green")
|
|
CONSOLE.print(text)
|
|
else:
|
|
text = Text(" ERR")
|
|
text.stylize("bold red")
|
|
CONSOLE.print(text)
|
|
sys.exit(1)
|
|
|
|
#
|
|
text = Text("--- Clang-Format")
|
|
text.stylize("bold magenta")
|
|
CONSOLE.print(text)
|
|
|
|
if clang_format(thread_count):
|
|
text = Text(" OK")
|
|
text.stylize("bold green")
|
|
CONSOLE.print(text)
|
|
else:
|
|
text = Text(" ERR")
|
|
text.stylize("bold red")
|
|
CONSOLE.print(text)
|
|
sys.exit(1)
|
|
|
|
#
|
|
text = Text("--- Calculate Progress")
|
|
text.stylize("bold magenta")
|
|
CONSOLE.print(text)
|
|
|
|
calculate_progress(True, "FANCY")
|
|
|
|
|
|
def find_all_asm_files():
|
|
""" Recursivly find all files in the 'asm/' folder """
|
|
|
|
files = set()
|
|
errors = set()
|
|
|
|
def recursive(parent):
|
|
paths = sorted(
|
|
parent.iterdir(),
|
|
key=lambda path: (path.is_file(), path.name.lower()),
|
|
)
|
|
for path in paths:
|
|
if path.name.startswith("."):
|
|
continue
|
|
if path.is_dir():
|
|
recursive(path)
|
|
else:
|
|
if path.suffix == '.s':
|
|
files.add(path)
|
|
else:
|
|
errors.add(path)
|
|
|
|
root = Path("./asm/")
|
|
assert root.exists()
|
|
|
|
recursive(root)
|
|
|
|
LOG.debug(
|
|
f"find_all_asm_files: found {len(files)} .s files and {len(errors)} bad files")
|
|
return files, errors
|
|
|
|
|
|
def find_unused_asm_files(non_matching):
|
|
""" Search for unused asm function files. """
|
|
|
|
asm_files, error_files = find_all_asm_files()
|
|
included_asm_files = find_used_asm_files(non_matching)
|
|
|
|
unused_asm_files = asm_files - included_asm_files
|
|
LOG.debug(
|
|
f"find_unused_asm_files: found {len(unused_asm_files)} unused .s files")
|
|
|
|
return unused_asm_files, error_files
|
|
|
|
|
|
def find_all_header_files():
|
|
""" Recursivly find all files in the 'include/' folder """
|
|
|
|
files = set()
|
|
|
|
def recursive(parent):
|
|
paths = sorted(
|
|
parent.iterdir(),
|
|
key=lambda path: (path.is_file(), path.name.lower()),
|
|
)
|
|
for path in paths:
|
|
# Remove hidden files
|
|
if path.name.startswith("."):
|
|
continue
|
|
if path.is_dir():
|
|
recursive(path)
|
|
else:
|
|
if path.suffix == '.h':
|
|
files.add(path)
|
|
|
|
root = Path("./include/")
|
|
assert root.exists()
|
|
recursive(root)
|
|
|
|
LOG.debug(f"find_all_header_files: found {len(files)} .h files")
|
|
return files
|
|
|
|
|
|
def find_all_cpp_files():
|
|
""" Recursivly find all files in the 'cpp/' folder """
|
|
|
|
files = set()
|
|
|
|
def recursive(parent):
|
|
paths = sorted(
|
|
parent.iterdir(),
|
|
key=lambda path: (path.is_file(), path.name.lower()),
|
|
)
|
|
for path in paths:
|
|
# Remove hidden files
|
|
if path.name.startswith("."):
|
|
continue
|
|
if path.is_dir():
|
|
recursive(path)
|
|
else:
|
|
if path.suffix == '.cpp':
|
|
files.add(path)
|
|
|
|
src_root = Path("./src/")
|
|
libs_root = Path("./libs/")
|
|
rel_root = Path("./rel/")
|
|
assert src_root.exists()
|
|
assert libs_root.exists()
|
|
assert rel_root.exists()
|
|
|
|
recursive(src_root)
|
|
recursive(libs_root)
|
|
recursive(rel_root)
|
|
|
|
LOG.debug(f"find_all_cpp_files: found {len(files)} .cpp files")
|
|
return files
|
|
|
|
|
|
def find_includes(lines, non_matching, ext=".s"):
|
|
includes = set()
|
|
for line in lines:
|
|
key = '#include "'
|
|
start = line.find(key)
|
|
if start < 0:
|
|
continue
|
|
|
|
start += len(key)
|
|
end = line.find('"', start)
|
|
if end < 0:
|
|
continue
|
|
|
|
include_path = line[start:end]
|
|
if include_path.endswith(ext):
|
|
includes.add(Path(include_path))
|
|
|
|
return includes
|
|
|
|
|
|
def find_used_asm_files(non_matching):
|
|
|
|
cpp_files = find_all_cpp_files()
|
|
includes = set()
|
|
|
|
with Progress(console=CONSOLE, transient=True, refresh_per_second=1) as progress:
|
|
task = progress.add_task(f"preprocessing...", total=len(cpp_files))
|
|
|
|
for cpp_file in cpp_files:
|
|
with cpp_file.open("r") as file:
|
|
includes.update(find_includes(file.readlines(), non_matching))
|
|
|
|
progress.update(task, advance=1)
|
|
|
|
# TODO: NON_MATCHING
|
|
|
|
LOG.debug(f"find_used_asm_files: found {len(includes)} included .s files")
|
|
|
|
return includes
|
|
|
|
|
|
def clang_format_impl(file):
|
|
cmd = ["clang-format", "-i", str(file)]
|
|
cf = subprocess.run(args=cmd, stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
|
|
|
|
def clang_format(thread_count):
|
|
|
|
cpp_files = find_all_cpp_files()
|
|
h_files = find_all_header_files()
|
|
files = cpp_files | h_files
|
|
|
|
with mp.Pool(processes=2 * thread_count) as pool:
|
|
result = pool.map_async(clang_format_impl, files)
|
|
jobs_left = len(files)
|
|
with Progress(console=CONSOLE, transient=True, refresh_per_second=5) as progress:
|
|
task = progress.add_task(f"clang-formating...", total=len(files))
|
|
|
|
while result._number_left > 0:
|
|
left = result._number_left * result._chunksize
|
|
change = jobs_left - left
|
|
jobs_left = left
|
|
progress.update(
|
|
task, description=f"clang-formating... ({left} left)", advance=change)
|
|
time.sleep(1/5)
|
|
|
|
progress.update(task, advance=jobs_left)
|
|
|
|
return True
|
|
|
|
|
|
def rebuild(thread_count):
|
|
LOG.debug("make clean")
|
|
with Progress(console=CONSOLE, transient=True, refresh_per_second=5) as progress:
|
|
task = progress.add_task(f"make clean", total=1000, start=False)
|
|
|
|
cmd = ["make", f"-j{thread_count}", "clean"]
|
|
subprocess.run(args=cmd, stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
LOG.debug("make clean complete")
|
|
|
|
LOG.debug("make main.dol")
|
|
with Progress(console=CONSOLE, transient=True, refresh_per_second=5) as progress:
|
|
task = progress.add_task(f"make", total=1000, start=False)
|
|
|
|
cmd = ["make", f"-j{thread_count}", "build/dolzel2/main.dol"]
|
|
subprocess.run(args=cmd, stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
LOG.debug("make main.dol complete")
|
|
|
|
dol = Path("build/dolzel2/main.dol")
|
|
if not dol.exists():
|
|
return False
|
|
|
|
with dol.open("rb") as file:
|
|
data = file.read()
|
|
|
|
# TODO: move?
|
|
expected = "4997D93B9692620C40E90374A0F1DBF0E4889395"
|
|
|
|
sha1 = hashlib.sha1()
|
|
sha1.update(data)
|
|
current = sha1.hexdigest().upper()
|
|
|
|
LOG.debug(f"expected: '{expected}'")
|
|
LOG.debug(f"current: '{current}'")
|
|
|
|
if expected != current:
|
|
LOG.error("main.dol is not OK!")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tp()
|