#!/usr/bin/env python3 # PYTHON_ARGCOMPLETE_OK """ This script will extract literals and strings data from sections located in the baserom.dol. Useful when trying to match .rodata and .sdata2. Usage: ./tools/section2cpp.py --section .rodata --string --object JKRSolidHeap.o """ import argparse import sys import os import struct from decimal import getcontext, Decimal from pathlib import Path, PurePath, PureWindowsPath from typing import ( Any, Dict, List, Match, NamedTuple, NoReturn, Optional, Set, Tuple, Union, Callable, Pattern, ) try: import numpy except: print("error: missing numpy") sys.exit(1) try: import argcomplete # type: ignore except ModuleNotFoundError: argcomplete = None parser = argparse.ArgumentParser(description="Extract section data and generate C++ code.") parser.add_argument( "--section", dest="section", type=str, metavar="SECTION", help="SECTION to extract data from.", required=True ) parser.add_argument( "--file-offset", dest="file_offset", type=lambda x: int(x,0), metavar="OFFSET", help="OFFSET in the baserom for the SECTION." ) parser.add_argument( "--object", dest="object_name", type=str, metavar="OBJECT", help="OBJECT filename to extract data from. (e.g. JKRSolidHeap.o)" ) parser.add_argument( "--map", dest="map_path", type=str, metavar="MAP", help="frameworkF.map path", default="frameworkF.map" ) parser.add_argument( "--baserom", dest="baserom", type=str, metavar="DOL", help="baserom.dol path", default="baserom.dol" ) parser.add_argument( "--string", dest="as_string", action="store_true", help="Print arrays as strings" ) parser.add_argument( "--array", dest="as_array", action="store_true", help="Print everything as u8 arrays" ) parser.add_argument( "--shift-jis", dest="shift_jis", action="store_true", help="Convert shift-jis to utf-8" ) # # # def _itersplit(l, splitters): current = [] for item in l: if item in splitters: yield current current = [] else: current.append(item) yield current def magicsplit(l, *splitters): return [subl for subl in _itersplit(l, splitters) ] def str_encoding(data): if data[-1] != 0: return None try: data.decode("utf-8") return "utf-8" except: pass try: data.decode("shift_jisx0213") return "shift-jis" except: pass return None def encoding_char_list(encoding, data): if args.shift_jis and encoding == "shift-jis": try: return escape(data.decode("shift_jisx0213")) except: pass return [ str(bytes([x]))[2:-1].replace("\"", "\\\"") for x in data ] def raw_string(data): return "".join(data) def raw_array(data): return ",".join([hex(x) for x in list(data)]) def escape_char(v): if v == "\n": return "\\n" elif v == "\t": return "\\t" elif v == "\v": return "\\v" elif v == "\b": return "\\b" elif v == "\r": return "\\r" elif v == "\f": return "\\f" elif v == "\a": return "\\a" elif v == "\\": return "\\\\" elif v == "\"": return "\\\"" elif ord(v) < 32 and ord(v) > 127: return "\\x" + hex(v)[2:].upper().rjust(2, '0') else: return v def escape(v): return "".join([ escape_char(x) for x in list(v) ]) def bytes2float32(data): if len(data) < 4: return None result = numpy.frombuffer(data[0:4][::-1], dtype='float32') if result: return result[0] else: return None def bytes2float64(data): if len(data) < 8: return None result = numpy.frombuffer(data[0:8][::-1], dtype='float64') if result: return result[0] else: return None def is_nice_float32(f): try: if int(f*1000) == f*1000: return True if int(f*100) == f*100: return True if int(f*10) == f*10: return True if int(f) == f: return True except: return False return False def is_nice_float64(f): try: if int(f*1000) == f*1000: return True if int(f*100) == f*100: return True if int(f*10) == f*10: return True if int(f) == f: return True except: return False return False float32_exact: Dict[numpy.float32, Tuple[int,int]] = {} float64_exact: Dict[numpy.float64, Tuple[int,int]] = {} getcontext().prec = 64 for i in range(1,32): for j in range(1,32): if i%j == 0: continue d = Decimal(i)/Decimal(j) f = numpy.float32(d) if str(f) != str(d): if not f in float32_exact: float32_exact[f] = (i,j) for i in range(1,32): for j in range(1,32): if i%j == 0: continue d = Decimal(i)/Decimal(j) f = numpy.float64(d) if str(f) != str(d): if not f in float64_exact: float64_exact[f] = (i,j) class Symbol: def __init__(self, name, addr, size): self.name = name self.addr = addr self.size = size self.padding = 0 def __str__(self): return " %s %s %s+%s %s" % (self.name.ljust(40, ' '), hex(self.addr), hex(self.addr + self.size), hex(self.padding), hex(self.size)) class ObjectFile: def __init__(self, path): self.path = path self.symbols = [] self.start = 0 self.end = 0 self.mk = False def addSymbol(self, name, str_addr, str_size): addr = int(str_addr, base=16) size = int(str_size, base=16) symbol = Symbol(name, addr, size) if not self.symbols: self.start = symbol.addr else: last_symbol = self.symbols[-1] last_addr = last_symbol.addr + last_symbol.size if last_addr != addr: last_symbol.padding += addr - last_addr self.symbols += [ symbol ] def setEnd(self, end): self.end = end last_symbol = self.symbols[-1] last_symbol.padding = self.end - (last_symbol.addr + last_symbol.size) def find_symbols(): file = map_path.open('r') lines = file.readlines() in_section = False last_obj = None for line in lines: data = [ x.strip() for x in line.strip().split(" ") ] data = [ x for x in data if len(x) > 0 ] if len(data) == 3: in_section = False if data[0] == section: in_section = True continue if not in_section: continue if len(data) < 6 or len(data) > 7: continue # get object filename obj = data[5] if len(data) > 6: obj = data[6] # remove path from object filename obj = obj.split("\\")[-1] if last_obj != obj: assert obj not in object_map object_map[obj] = ObjectFile(obj) last_obj = obj # add symbol size = data[1] addr = data[2] name = data[4] object_map[obj].addSymbol(name, addr, size) keys = list(object_map.keys()) for i,_ in enumerate(keys[:-1]): obj = object_map[keys[i]] next_obj = object_map[keys[i + 1]] obj.setEnd(next_obj.start) # total size of rodata must be aligned to 0x20 obj = object_map[keys[-1]] last_symbol = obj.symbols[-1] last_addr = last_symbol.addr + last_symbol.size last_symbol.padding = ((last_addr + 31) & ~31) - last_addr file.close() def chunks(lst, n): for i in range(0, len(lst), n): yield lst[i:i + n] def data_as_string(data): return ", ".join([ "0x" + hex(x)[2:].rjust(2, '0') for x in data ]) class Literal: def __init__(self, name, type, value, comment=None): self.name = name self.type = type self.value = value self.comment = comment def format(self): return str(self.value) def lines(self): line = "static const %s %s = %s;" % (self.type, self.name, self.format()) if self.comment: line = line.ljust(90, ' ') + " // " + self.comment return [ line ] def __str__(self): return "\n".join(self.lines()) class Label(Literal): def __init__(self, name): super().__init__(name, "", None, None) def lines(self): return [ "", "", "// " + self.name ] class Float32Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "float", value, comment) def format(self): return "%sf" % self.value class Float64Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "double", value, comment) class FractionFloat32Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "float", value, comment) def format(self): return "%i.0f / %i.0f" % self.value class FractionFloat64Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "double", value, comment) def format(self): return "%i.0 / %i.0" % self.value class U32Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "u32", value, comment) class S32Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "s32", value, comment) class S64Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "s64", value, comment) class U64Literal(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "u64", value, comment) class ArrayLiteral(Literal): def __init__(self, name, value, comment=None): super().__init__(name, "u8", value, comment) def lines(self): one_line = "static const %s %s[%i] = { %s };" % (self.type, self.name, len(self.value), data_as_string(self.value)) lines = [] if len(one_line) < 90: lines += [ one_line ] else: lines += [ "static const %s %s[%i] = {" % (self.type, self.name, len(self.value)) ] data_chunks = chunks(list(self.value), 16) for chunk in data_chunks: lines += [ " " + data_as_string(chunk) ] lines += [ "};" ] if lines and self.comment: lines[0] = lines[0].ljust(90, ' ') + " // " + self.comment return lines class StringLiteral(Literal): def __init__(self, name, encoding, value, comment=None): assert value[-1] == 0 super().__init__(name, "char", value[:-1], comment) self.encoding = encoding def lines(self): char_list = encoding_char_list(self.encoding, self.value) one_line = "static const %s %s = \"%s\";" % (self.type, self.name, raw_string(char_list)) lines = [] if len(one_line) < 90: lines += [ one_line ] else: lines += [ "static const %s %s = " % (self.type, self.name) ] data_chunks = chunks(char_list, 16) for chunk in data_chunks: lines += [ " \"%s\"" % raw_string(chunk) ] lines[-1] += ";" if lines and self.comment: lines[0] = lines[0].ljust(90, ' ') + " // " + self.comment return lines def output_cpp(): object_names = [] if object_name: if not object_name in object_map: print("error: %s object file not found!" % object_name) sys.exit(1) object_names += [ object_name ] else: object_names = [*object_map.keys()] br = baserom.open("rb") br.seek(0, os.SEEK_END) br_size = br.tell() br.seek(0, os.SEEK_SET) literals = [] for obj_name in object_names: literals += [ Label(obj_name) ] obj = object_map[obj_name] for symbol in obj.symbols: label = "lbl_%s" % (hex(symbol.addr).upper()[2:]) symbol_file_offset = symbol.addr - file_offset symbol_file_size = symbol.size + symbol.padding if symbol_file_offset + symbol_file_size > br_size: print("error: reading outside baserom file. (%i, %i)" % (symbol_file_offset + symbol_file_size, br_size)) br.seek(symbol_file_offset, os.SEEK_SET) data = br.read(symbol.size) padding = br.read(symbol.padding) if args.as_string: offset = 0 str_segments = [ x for x in magicsplit(data, 0) ] for segment in str_segments[:-1]: str_data = bytes(segment + [0]) encoding = str_encoding(str_data) str_label = "lbl_%s" % (hex(symbol.addr + offset).upper()[2:]) if encoding == "shift-jis": literals += [ StringLiteral(str_label, "shift-jis", str_data, "TODO: shift-jis strings in Metrowerks") ] elif encoding == "utf-8": literals += [ StringLiteral(str_label, "utf-8", str_data) ] else: literals += [ ArrayLiteral(str_label, str_data, "undecodable string") ] offset += len(str_data) if padding: padding_label = "lbl_%s" % (hex(symbol.addr + symbol.size).upper()[2:]) literals += [ StringLiteral(padding_label, None, padding, "padding") ] padding = None elif args.as_array: literals += [ ArrayLiteral(label, data) ] else: lit = None if len(data) == 4: u32_data = struct.unpack('>I', data)[0] s32_data = struct.unpack('>i', data)[0] float_data = bytes2float32(data) if s32_data == 0 or (s32_data >= -4096 and s32_data <= 4096): lit = S32Literal(label, s32_data) elif u32_data == 0 or (u32_data < 4096): lit = U32Literal(label, u32_data) elif float_data in float32_exact: lit = FractionFloat32Literal(label, float32_exact[float_data], "%sf %s" % (float_data, hex(u32_data))) elif is_nice_float32(float_data): lit = Float32Literal(label, float_data, hex(u32_data)) elif len(data) == 8: u64_data = struct.unpack('>Q', data)[0] s64_data = struct.unpack('>q', data)[0] double_data = bytes2float64(data) if u64_data == 0x4330000000000000: lit = Float64Literal(label, double_data, "%s | u32 to float (compiler-generated)" % hex(u64_data)) elif u64_data == 0x4330000080000000: lit = Float64Literal(label, double_data, "%s | s32 to float (compiler-generated)" % hex(u64_data)) elif s64_data == 0 or (s64_data >= -4096 and s64_data <= 4096): lit = S64Literal(label, s64_data) elif u64_data == 0 or (u64_data < 4096): lit = U64Literal(label, u64_data) elif double_data in float64_exact: lit = FractionFloat64Literal(label, float64_exact[double_data], "%s %s" % (double_data, hex(u64_data))) elif is_nice_float64(double_data): lit = Float64Literal(label, double_data, hex(u64_data)) if not lit: lit = ArrayLiteral(label, data) literals += [ lit ] if padding: padding_label = "lbl_%s" % (hex(symbol.addr + symbol.size).upper()[2:]) literals += [ ArrayLiteral(padding_label, padding, "padding") ] for lit in literals: print(lit) br.close() # # # try: args = parser.parse_args() except: parser.print_help() sys.exit(0) section = args.section object_name = args.object_name file_offset: Optional[int] = args.file_offset baserom = Path(args.baserom) map_path = Path(args.map_path) file_offsets = { ".rodata": 0x80003000, ".sdata": 0x800802A0, ".sdata2": 0x800811A0, } if not file_offset: if not section in file_offsets: print("error: missing --file-offset") sys.exit(1) else: file_offset = file_offsets[section] if not baserom.exists(): print("error: baserom '%s' not found!" % args.baserom) sys.exit(1) if not map_path.exists(): print("error: frameworkF.map '%s' not found!" % args.map_path) sys.exit(1) object_map: Dict[str,ObjectFile] = {} find_symbols() output_cpp()