add match dsl syntax

This commit is contained in:
Alex Bates 2020-11-08 12:16:10 +00:00
parent 5890d885ee
commit 6ecf63e992
4 changed files with 150 additions and 65 deletions

View File

@ -28,6 +28,8 @@ astyle ${FILES} \
--indent-labels \
--pad-oper --pad-comma --pad-header --unpad-paren \
--attach-return-type \
--keep-one-line-blocks \
--keep-one-line-statements
# add newline at eof
find ${FILES} -exec sed -i -e '$a\' {} \;

View File

@ -13,11 +13,9 @@ MapConfig M(config) = {
.tattle = MessageID_TATTLE_KMR_12,
};
Script M(PlayMusic) = {
SI_CALL(SetMusicTrack, 0, Song_PLEASANT_PATH, 0, 8),
SI_RETURN(),
SI_END(),
};
Script M(PlayMusic) = SCRIPT({
SetMusicTrack(0, Song_PLEASANT_PATH, 0, 8)
});
ApiStatus GetGoomba(ScriptInstance* script, s32 isInitialCall) {
script->varTable[0] = get_enemy_safe(NpcId_GOOMBA);

View File

@ -5,6 +5,8 @@ from lark import Lark, exceptions, Tree, Transformer, Visitor, v_args, Token
from lark.visitors import Discard
import traceback
DEBUG_OUTPUT = None
def eprint(*args, **kwargs):
print(*args, file=stderr, **kwargs)
@ -14,6 +16,9 @@ def write(s):
#write_buf += s
print(s, end="")
if DEBUG_OUTPUT:
print(s, file=DEBUG_OUTPUT, end="")
ANSI_RED = "\033[1;31;40m"
ANSI_RESET = "\u001b[0m"
@ -26,12 +31,16 @@ def pairs(seq):
script_parser = Lark(r"""
block: "{" NEWLINE* (stmt STMT_SEP)* NEWLINE* "}"
block: "{" NEWLINE* (_block STMT_SEP*)? "}"
_block: stmt STMT_SEP _block
| stmt
?stmt: call
| label ":" [stmt] -> label_decl
| "goto" label -> label_goto
| if_stmt
| match_stmt
| "return" -> return_stmt
| "break" -> break_stmt
| "sleep" expr -> sleep_stmt
@ -40,7 +49,6 @@ script_parser = Lark(r"""
| "await" expr -> await_stmt
| lhs "=" "spawn" expr -> spawn_set_stmt
| lhs set_op expr -> set_stmt
| "const" lhs set_op expr -> set_const_stmt
| bind_stmt
| bind_set_stmt
| "unbind" -> unbind_stmt
@ -64,6 +72,18 @@ script_parser = Lark(r"""
| ">=" -> if_op_ge
| "<=" -> if_op_le
match_stmt: "match" expr "{" NEWLINE* (match_block STMT_SEP*)? "}"
match_block: match_case STMT_SEP match_block
| match_case
match_case: "else" block -> case_else
| "=="? expr block -> case_eq
| "!=" expr block -> case_ne
| ">" expr block -> case_gt
| "<" expr block -> case_lt
| ">=" expr block -> case_gt
| "<=" expr block -> case_lt
| "?" expr block -> case_flag
suspend_stmt: "suspend" control_type expr ("," control_type expr)* [","]
resume_stmt: "resume" control_type expr ("," control_type expr)* [","]
kill_stmt: "kill" control_type expr ("," control_type expr)* [","]
@ -94,9 +114,6 @@ script_parser = Lark(r"""
| "%=" -> set_op_mod
| "&=" -> set_op_and
| "|=" -> set_op_or
| ":=" -> set_op_eq_const
| ":&=" -> set_op_and_const
| ":|=" -> set_op_or_const
c_const_expr: c_const_expr_internal
c_const_expr_internal: "(" (c_const_expr_internal | NOT_PARENS)+ ")"
@ -176,7 +193,7 @@ class RootCtx(CmdCtx):
class IfCtx(CmdCtx):
pass
class SwitchCtx(CmdCtx):
class MatchCtx(CmdCtx):
def break_opcode(self, meta):
return 0x22
@ -262,7 +279,7 @@ class Compile(Transformer):
# flatten children list
flat = []
for node in tree.children:
if type(node) == list:
if type(node) is list:
flat += node
elif isinstance(node, BaseCmd):
flat.append(node)
@ -271,6 +288,11 @@ class Compile(Transformer):
else:
raise Exception(f"block statment {type(node)} is not a BaseCmd: {node}")
return flat
def _block(self, tree):
if len(tree.children) == 1:
return [tree.children[0]]
else:
return [tree.children[0], *tree.children[2]]
def call(self, tree):
# TODO: type checking etc
@ -289,6 +311,43 @@ class Compile(Transformer):
def if_op_le(self, tree): return 0x0E
def if_op_ge(self, tree): return 0x0F
def match_stmt(self, tree):
expr = tree.children[0]
cases = []
for node in tree.children[1:]:
if type(node) is list:
for el in node:
if type(el) is list:
cases += el
else:
cases.append(el)
for cmd in cases:
if isinstance(cmd, BaseCmd):
cmd.add_context(MatchCtx())
else:
raise Exception(f"uncompiled match case: {cmd}")
return [
Cmd(0x14, expr, meta=tree.meta),
*cases,
Cmd(0x24),
]
def match_block(self, tree):
if len(tree.children) == 1:
return [tree.children[0]]
else:
return [tree.children[0], *tree.children[2]]
def case_eq(self, tree): return [Cmd(0x16, tree.children[0]), *tree.children[1]]
def case_ne(self, tree): return [Cmd(0x17, tree.children[0]), *tree.children[1]]
def case_lt(self, tree): return [Cmd(0x18, tree.children[0]), *tree.children[1]]
def case_gt(self, tree): return [Cmd(0x19, tree.children[0]), *tree.children[1]]
def case_le(self, tree): return [Cmd(0x1A, tree.children[0]), *tree.children[1]]
def case_ge(self, tree): return [Cmd(0x1B, tree.children[0]), *tree.children[1]]
def case_else(self, tree): return [Cmd(0x1C), *tree.children[0]]
def case_flag(self, tree): return [Cmd(0x1F, tree.children[0]), *tree.children[1]]
def loop_stmt(self, tree):
expr = tree.children.pop(0) if len(tree.children) > 1 else 0
block = tree.children[0]
@ -400,12 +459,12 @@ class Compile(Transformer):
if not opcode:
raise CompileError(f"operation `{opcodes['__op__']}' not supported for ints", tree.meta)
return Cmd(opcode, lhs, rhs)
def set_const_stmt(self, tree):
lhs, opcodes, rhs = tree.children
opcode = opcodes.get("const", None)
if not opcode:
raise CompileError(f"operation `{opcodes['__op__']}' not supported for consts", tree.meta)
return Cmd(opcode, lhs, rhs)
# def set_const_stmt(self, tree):
# lhs, opcodes, rhs = tree.children
# opcode = opcodes.get("const", None)
# if not opcode:
# raise CompileError(f"operation `{opcodes['__op__']}' not supported for consts", tree.meta)
# return Cmd(opcode, lhs, rhs)
def set_op_eq(self, tree):
return {
"__op__": "=",
@ -581,6 +640,9 @@ def gen_line_map(source, source_line_no = 1):
# Expects output from C preprocessor on argv
if __name__ == "__main__":
if DEBUG_OUTPUT is not None:
DEBUG_OUTPUT = open(DEBUG_OUTPUT, "w")
line_no = 1
char_no = 1
file_info = []

View File

@ -326,6 +326,11 @@ class UnsupportedScript(Exception):
pass
class ScriptDSLDisassembler(ScriptDisassembler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.in_case = False
def var(self, arg):
if arg in self.symbol_map:
return self.symbol_map[arg]
@ -414,63 +419,81 @@ class ScriptDSLDisassembler(ScriptDisassembler):
elif opcode == 0x13:
self.indent -= 1
self.write_line("}")
# elif opcode == 0x14:
# self.write_line(f"SI_SWITCH({self.var(argv[0])}),")
# self.indent += 2
elif opcode == 0x14:
self.write_line(f"match {self.var(argv[0])} {{")
self.indent += 2
self.in_case = False
# elif opcode == 0x15:
# self.write_line(f"SI_SWITCH_CONST(0x{argv[0]:X}),")
# self.indent += 2
# elif opcode == 0x16:
# self.indent -= 1
# self.write_line(f"SI_CASE_EQ({self.var(argv[0])}),")
# self.indent += 1
# elif opcode == 0x17:
# self.indent -= 1
# self.write_line(f"SI_CASE_NE({self.var(argv[0])}),")
# self.indent += 1
# elif opcode == 0x18:
# self.indent -= 1
# self.write_line(f"SI_CASE_LT({self.var(argv[0])}),")
# self.indent += 1
# elif opcode == 0x19:
# self.indent -= 1
# self.write_line(f"SI_CASE_GT({self.var(argv[0])}),")
# self.indent += 1
# elif opcode == 0x1A:
# self.indent -= 1
# self.write_line(f"SI_CASE_LE({self.var(argv[0])}),")
# self.indent += 1
# elif opcode == 0x1B:
# self.indent -= 1
# self.write_line(f"SI_CASE_GE({self.var(argv[0])}),")
# self.indent += 1
# elif opcode == 0x1C:
# self.indent -= 1
# self.write_line(f"SI_CASE_DEFAULT(),")
# self.indent += 1
elif opcode == 0x16:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = True
self.write_line(f"{self.var(argv[0])} {{")
self.indent += 1
elif opcode == 0x17:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = True
self.write_line(f"!= {self.var(argv[0])} {{")
self.indent += 1
elif opcode == 0x18:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = True
self.write_line(f"< {self.var(argv[0])} {{")
self.indent += 1
elif opcode == 0x19:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = True
self.write_line(f"> {self.var(argv[0])} {{")
self.indent += 1
elif opcode == 0x1A:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = True
self.write_line(f"<= {self.var(argv[0])} {{")
self.indent += 1
elif opcode == 0x1B:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = True
self.write_line(f">= {self.var(argv[0])} {{")
self.indent += 1
elif opcode == 0x1C:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = True
self.write_line(f"else {{")
self.indent += 1
# elif opcode == 0x1D:
# self.indent -= 1
# self.write_line(f"SI_CASE_OR_EQ({self.var(argv[0])}),")
# self.indent += 1
# # opcode 0x1E?
# elif opcode == 0x1F:
# self.indent -= 1
# self.write_line(f"SI_CASE_BITS_ON({self.var(argv[0])}),")
# self.indent += 1
# opcode 0x1E?
elif opcode == 0x1F:
self.indent -= 1
self.write_line(f"& {self.var(argv[0])}")
self.indent += 1
# elif opcode == 0x20:
# self.indent -= 1
# self.write_line(f"SI_END_MULTI_CASE(),")
# self.indent += 1
# elif opcode == 0x21:
# self.indent -= 1
# self.write_line(f"case {self.var(argv[0])}..{self.var(argv[1])}:")
# self.indent += 1
# elif opcode == 0x22: self.write_line("break")
# elif opcode == 0x23:
# self.indent -= 2
# self.write_line("}")
elif opcode == 0x21:
self.indent -= 1
self.write_line(f"{self.var(argv[0])}..{self.var(argv[1])} {{")
self.indent += 1
elif opcode == 0x22: self.write_line("break")
elif opcode == 0x23:
self.indent -= 1
if self.in_case: self.write_line("}")
self.in_case = False
self.indent -= 1
self.write_line("}")
elif opcode == 0x24: self.write_line(f"{self.var(argv[0])} = {self.var(argv[1])}")
elif opcode == 0x25: self.write_line(f"const {self.var(argv[0])} = 0x{argv[1]:X}")
#elif opcode == 0x25: self.write_line(f"{self.var(argv[0])} #= 0x{argv[1]:X}")
elif opcode == 0x26: self.write_line(f"{self.var(argv[0])} = {self.verify_float(self.var(argv[1]))}")
elif opcode == 0x27: self.write_line(f"{self.var(argv[0])} += {self.var(argv[1])}")
elif opcode == 0x28: self.write_line(f"{self.var(argv[0])} -= {self.var(argv[1])}")
@ -483,8 +506,8 @@ class ScriptDSLDisassembler(ScriptDisassembler):
elif opcode == 0x2F: self.write_line(f"{self.var(argv[0])} /= {self.verify_float(self.var(argv[1]))}")
elif opcode == 0x3F: self.write_line(f"{self.var(argv[0])} &= {self.var(argv[1])}")
elif opcode == 0x40: self.write_line(f"{self.var(argv[0])} |= {self.var(argv[1])}")
elif opcode == 0x41: self.write_line(f"const {self.var(argv[0])} &= {argv[1]:X})")
elif opcode == 0x42: self.write_line(f"const {self.var(argv[0])} |= {argv[1]:X})")
#elif opcode == 0x41: self.write_line(f"{self.var(argv[0])} #&= {argv[1]:X})")
#elif opcode == 0x42: self.write_line(f"{self.var(argv[0])} #|= {argv[1]:X})")
elif opcode == 0x43:
argv_str = ", ".join(self.var(arg) for arg in argv[1:])
self.write_line(f"{self.addr_ref(argv[0])}({argv_str})")