tp.py: Add type annotations (#190)

This commit is contained in:
Jcw87 2022-05-07 11:38:20 -07:00 committed by GitHub
parent d91b1294bb
commit 88be83ca43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 26 deletions

View File

@ -14,7 +14,7 @@ import multiprocessing as mp
import shutil
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, List, Set, Tuple
from pathlib import Path
try:
@ -93,7 +93,7 @@ def tp():
default=DEFAULT_EXPECTED_PATH,
required=True,
)
def expected_copy(debug, build_path, expected_path):
def expected_copy(debug: bool, build_path: Path, expected_path: Path):
"""Copy the current build folder to the expected folder"""
if debug:
@ -127,7 +127,7 @@ def expected_copy(debug, build_path, expected_path):
default=DEFAULT_TOOLS_PATH,
required=True,
)
def setup(debug, game_path, tools_path):
def setup(debug: bool, game_path: Path, tools_path: Path):
"""Setup project"""
if debug:
@ -224,7 +224,7 @@ def setup(debug, game_path, tools_path):
c27_mwcceppc_orignal = c27.joinpath("mwcceppc.exe")
c27_mwcceppc_patched = c27.joinpath("mwcceppc_patched.exe")
def patch_compiler(src, dst, apply):
def patch_compiler(src: Path, dst: Path, apply: bool):
with src.open("rb") as src_file:
with dst.open("wb") as dst_file:
data = bytearray(src_file.read())
@ -308,7 +308,7 @@ def setup(debug, game_path, tools_path):
default=DEFAULT_BUILD_PATH,
required=True,
)
def progress(debug, matching, format, print_rels, build_path):
def progress(debug: bool, matching: bool, format: str, print_rels: bool, build_path: Path):
"""Calculate decompilation progress"""
if debug:
@ -340,7 +340,7 @@ def progress(debug, matching, format, print_rels, build_path):
default=DEFAULT_BUILD_PATH,
required=True,
)
def check(debug, rels, game_path, build_path):
def check(debug: bool, rels: bool, game_path: Path, build_path: Path):
"""Compare SHA1 Checksums"""
if debug:
@ -387,7 +387,7 @@ class ProgressGroup:
return 100 * (self.decompiled / self.size)
def calculate_rel_progress(build_path, matching, format, asm_files, ranges):
def calculate_rel_progress(build_path: Path, matching: bool, format: str, asm_files: Set[Path], ranges: List[Tuple[int, int]]):
results = []
start = time.time()
rel_paths = get_files_with_ext(build_path.joinpath("rel"), ".rel")
@ -422,7 +422,7 @@ def calculate_rel_progress(build_path, matching, format, asm_files, ranges):
return results
def calculate_dol_progress(build_path, matching, format, asm_files, ranges):
def calculate_dol_progress(build_path: Path, matching: bool, format: str, asm_files: Set[Path], ranges: List[Tuple[int, int]]):
# read .dol file
dol_path = build_path.joinpath("main.dol")
if not dol_path.exists():
@ -479,7 +479,7 @@ def calculate_dol_progress(build_path, matching, format, asm_files, ranges):
return ProgressGroup("main.dol", total_size, total_decompiled_size, sections)
def calculate_progress(build_path, matching, format, print_rels):
def calculate_progress(build_path: Path, matching: bool, format: str, print_rels: bool):
if not matching:
LOG.error("non-matching progress is not support yet.")
sys.exit(1)
@ -695,7 +695,7 @@ def calculate_progress(build_path, matching, format, print_rels):
def find_function_range(asm):
def find_function_range(asm: Path) -> Tuple[int, int]:
with asm.open("r", encoding="utf-8") as file:
lines = file.readlines()
for line in lines:
@ -711,7 +711,7 @@ def find_function_range(asm):
return (fast_first, fast_last)
def find_function_ranges(asm_files):
def find_function_ranges(asm_files: Set[Path]):
if len(asm_files) < 128:
return [find_function_range(x) for x in asm_files]
@ -734,7 +734,7 @@ def find_function_ranges(asm_files):
help="Remove all of the asm that is decompiled and not used anymore",
)
@click.option("--check", default=False, is_flag=True)
def remove_unused_asm_cmd(check):
def remove_unused_asm_cmd(check: bool):
result = remove_unused_asm(check)
if check:
if result == 0:
@ -744,7 +744,7 @@ def remove_unused_asm_cmd(check):
sys.exit(1)
def remove_unused_asm(check):
def remove_unused_asm(check: bool):
unused_files, error_files = find_unused_asm_files(False, use_progress_bar=not check)
if not check:
@ -789,7 +789,7 @@ def remove_unused_asm(check):
default=DEFAULT_BUILD_PATH,
required=True,
)
def pull_request(debug, rels, thread_count, game_path, build_path):
def pull_request(debug: bool, rels: bool, thread_count: int, game_path: Path, build_path: Path):
"""Verify that everything is OK before pull-request"""
if debug:
@ -846,13 +846,13 @@ def pull_request(debug, rels, thread_count, game_path, build_path):
calculate_progress(build_path, True, "FANCY", rels)
def find_all_asm_files():
def find_all_asm_files() -> Tuple[Set[Path], Set[Path]]:
"""Recursivly find all files in the 'asm/' folder"""
files = set()
errors = set()
def recursive(parent):
def recursive(parent: Path):
paths = sorted(
parent.iterdir(),
key=lambda path: (path.is_file(), path.name.lower()),
@ -879,7 +879,7 @@ def find_all_asm_files():
return files, errors
def find_unused_asm_files(non_matching, use_progress_bar=True):
def find_unused_asm_files(non_matching: bool, use_progress_bar: bool = True):
"""Search for unused asm function files."""
asm_files, error_files = find_all_asm_files()
@ -893,12 +893,12 @@ def find_unused_asm_files(non_matching, use_progress_bar=True):
return unused_asm_files, error_files
def find_all_header_files():
def find_all_header_files() -> Set[Path]:
"""Recursivly find all files in the 'include/' folder"""
files = set()
def recursive(parent):
def recursive(parent: Path):
paths = sorted(
parent.iterdir(),
key=lambda path: (path.is_file(), path.name.lower()),
@ -921,12 +921,12 @@ def find_all_header_files():
return files
def find_all_files():
def find_all_files() -> Set[Path]:
"""Recursively find all c/cpp files in '/src/', '/libs/', and '/rel/' """
files = set()
def recursive(parent):
def recursive(parent: Path):
paths = sorted(
parent.iterdir(),
key=lambda path: (path.is_file(), path.name.lower()),
@ -956,7 +956,7 @@ def find_all_files():
return files
def find_includes(lines, non_matching, ext=".s"):
def find_includes(lines: List[str], non_matching: bool, ext: str = ".s") -> Set[Path]:
includes = set()
for line in lines:
key = '#include "'
@ -976,7 +976,7 @@ def find_includes(lines, non_matching, ext=".s"):
return includes
def find_used_asm_files(non_matching, use_progress_bar=True):
def find_used_asm_files(non_matching: bool, use_progress_bar: bool = True) -> Set[Path]:
cpp_files = find_all_files()
includes = set()
@ -1003,7 +1003,7 @@ def find_used_asm_files(non_matching, use_progress_bar=True):
return includes
def rebuild(thread_count, include_rels):
def rebuild(thread_count: int, include_rels: bool):
LOG.debug("make clean")
with Progress(console=CONSOLE, transient=True, refresh_per_second=5) as progress:
task = progress.add_task(f"make clean", total=1000, start=False)
@ -1068,7 +1068,7 @@ def sha1_from_data(data):
return sha1.hexdigest().upper()
def get_files_with_ext(path, ext):
def get_files_with_ext(path: Path, ext: str):
return [x for x in path.glob(f"**/*{ext}") if x.is_file()]
@ -1076,7 +1076,7 @@ class CheckException(Exception):
...
def check_sha1(game_path, build_path, include_rels):
def check_sha1(game_path: Path, build_path: Path, include_rels: bool):
if include_rels:
rel_path = game_path.joinpath("rel/Final/Release")
if not rel_path.exists():