Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
boolangery committed Oct 21, 2024
1 parent 101f177 commit e04ef32
Show file tree
Hide file tree
Showing 9 changed files with 2,491 additions and 1,040 deletions.
362 changes: 358 additions & 4 deletions luaparser/ast.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,371 @@
import ast

from antlr4 import InputStream, CommonTokenStream
from antlr4.tree.Tree import ParseTreeVisitor, TerminalNodeImpl, ErrorNodeImpl

from luaparser.parser.LuaLexer import LuaLexer
from luaparser.astnodes import *
from luaparser import printers
from luaparser.builder import Builder
from luaparser.parser.LuaParser import LuaParser
from luaparser.parser.LuaParserVisitor import LuaParserVisitor
from luaparser.utils.visitor import *
from antlr4.error.ErrorListener import ErrorListener
import json
from typing import Generator


def _listify(obj):
if not isinstance(obj, list):
return [obj]
else:
return obj


class MyVisitor(LuaParserVisitor):
# Visit a parse tree produced by LuaParser#start_.
def visitStart_(self, ctx: LuaParser.Start_Context):
return self.visitChildren(ctx)

def visitTerminal(self, node: TerminalNodeImpl):
match node.symbol.type:
case LuaParser.EOF:
return None
case LuaParser.NAME:
return Name(node.getText())
case _:
return node.getText()

def visitErrorNode(self, node: ErrorNodeImpl):
return "error:" + node.getText()

def defaultResult(self):
return None

def aggregateResult(self, aggregate, nextResult):
if aggregate is None:
return nextResult
if type(nextResult) == list:
nextResult.append(aggregate)
return nextResult
if type(aggregate) == list:
aggregate.append(nextResult)
return aggregate
if nextResult == None:
return aggregate
return [nextResult, aggregate]

# Visit a parse tree produced by LuaParser#chunk.
def visitChunk(self, ctx: LuaParser.ChunkContext):
return Chunk(
body=self.visitChildren(ctx)
)

# Visit a parse tree produced by LuaParser#block.
def visitBlock(self, ctx: LuaParser.BlockContext):
statements = [self.visit(stat) for stat in ctx.stat()]
if ctx.retstat():
statements.append(self.visit(ctx.retstat()))
return Block(
body=statements
)

# Visit a parse tree produced by LuaParser#stat.
def visitStat(self, ctx: LuaParser.StatContext):
if ctx.SEMI():
return SemiColon()
elif ctx.BREAK():
return Break()
else:
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#assign.
def visitAssign(self, ctx: LuaParser.AssignContext):
return Assign(
targets=_listify(self.visit(ctx.varlist())),
values=_listify(self.visit(ctx.explist())),
)

# Visit a parse tree produced by LuaParser#goto.
def visitGoto(self, ctx: LuaParser.GotoContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#do.
def visitDo(self, ctx: LuaParser.DoContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#while.
def visitWhile(self, ctx: LuaParser.WhileContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#repeat.
def visitRepeat(self, ctx: LuaParser.RepeatContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#if.
def visitIf(self, ctx: LuaParser.IfContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#for.
def visitFor(self, ctx: LuaParser.ForContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#forin.
def visitForin(self, ctx: LuaParser.ForinContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#functiondef.
def visitFunctiondef(self, ctx: LuaParser.FunctiondefContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#localfunction.
def visitLocalfunction(self, ctx: LuaParser.LocalfunctionContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#localassign.
def visitLocalassign(self, ctx: LuaParser.LocalassignContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#attnamelist.
def visitAttnamelist(self, ctx: LuaParser.AttnamelistContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#attrib.
def visitAttrib(self, ctx: LuaParser.AttribContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#retstat.
def visitRetstat(self, ctx: LuaParser.RetstatContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#label.
def visitLabel(self, ctx: LuaParser.LabelContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#funcname.
def visitFuncname(self, ctx: LuaParser.FuncnameContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#varlist.
def visitVarlist(self, ctx: LuaParser.VarlistContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#namelist.
def visitNamelist(self, ctx: LuaParser.NamelistContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#explist.
def visitExplist(self, ctx: LuaParser.ExplistContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#exp.
def visitExp(self, ctx: LuaParser.ExpContext):
if ctx.NIL():
return Nil()
elif ctx.FALSE():
return FalseExpr()
elif ctx.TRUE():
return TrueExpr()
elif ctx.DDD():
return Dots()
else:
expressions = ctx.exp()

if len(expressions) == 2 and ctx.CARET():
left = self.visit(expressions[0])
right = self.visit(expressions[1])
return ExpoOp(left, right)
elif len(expressions) == 1:
left = self.visit(expressions[0])

if ctx.NOT():
return ULNotOp(left)
elif ctx.POUND():
return ULengthOP(left)
elif ctx.MINUS():
return UMinusOp(left)
elif ctx.SQUIG():
return UBNotOp(left)
elif len(expressions) == 2:
left = self.visit(expressions[0])
right = self.visit(expressions[1])

if ctx.STAR():
return MultOp(left, right)
elif ctx.SLASH():
return FloatDivOp(left, right)
elif ctx.PER():
return ModOp(left, right)
elif ctx.SS():
return FloorDivOp(left, right)
elif ctx.PLUS():
return AddOp(left, right)
elif ctx.MINUS():
return SubOp(left, right)
elif ctx.DD():
return Concat(left, right)
elif ctx.LT():
return LessThanOp(left, right)
elif ctx.GT():
return GreaterThanOp(left, right)
elif ctx.LE():
return LessOrEqThanOp(left, right)
elif ctx.GE():
return GreaterOrEqThanOp(left, right)
elif ctx.SQEQ():
return NotEqToOp(left, right)
elif ctx.EE():
return EqToOp(left, right)
elif ctx.AND():
return AndLoOp(left, right)
elif ctx.OR():
return OrLoOp(left, right)
elif ctx.AMP():
return BAndOp(left, right)
elif ctx.PIPE():
return BOrOp(left, right)
elif ctx.SQUIG():
return BXorOp(left, right)
elif ctx.LL():
return BShiftLOp(left, right)
elif ctx.GG():
return BShiftROp(left, right)
else:
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#var.
def visitVar(self, ctx: LuaParser.VarContext):
if ctx.NAME():
return Name(ctx.NAME().getText())
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#prefixexp.
def visitPrefixexp(self, ctx: LuaParser.PrefixexpContext):
tail = self.visit(ctx.nestedtail())

if ctx.NAME(): # NAME nestedtail
if isinstance(tail, Index):
tail.value = self.visit(ctx.NAME())
return tail
else:
raise Exception("Invalid tail type")
elif ctx.functioncall(): # functioncall nestedtail
raise Exception("functioncall not implemented")
else: # '(' exp ')' nestedtail
exp = self.visit(ctx.exp())
exp.wrapped = True
# TODO: handle tail
return exp

# Visit a parse tree produced by LuaParser#functioncall.
def visitFunctioncall(self, ctx: LuaParser.FunctioncallContext):
args = self.visit(ctx.args())
names = ctx.NAME()

tails = None
if ctx.tail():
tails = [self.visit(t) for t in ctx.tail()]

if len(names) == 1: # NAME tail* args
if isinstance(args, Call):
args.func = self.visit(names[0])
return args

return Call(
func=ctx.NAME(),
args=self.visit(ctx.args()),
)

def visitAnonfunctiondef(self, ctx: LuaParser.AnonfunctiondefContext):
return self.visitChildren(ctx)

def visitNestedtail(self, ctx: LuaParser.NestedtailContext):
return self.visitChildren(ctx)

def visitTail(self, ctx: LuaParser.TailContext):
if ctx.OB() and ctx.CB():
return Index(
idx=self.visit(ctx.exp()),
value=Name(""), # value must be set in parent
notation=IndexNotation.SQUARE,
)
else:
return Index(
idx=Name(self.visit(ctx.NAME())),
value=Name(""), # value must be set in parent
notation=IndexNotation.DOT,
)

# Visit a parse tree produced by LuaParser#args.
def visitArgs(self, ctx: LuaParser.ArgsContext):
if ctx.OP() and ctx.CP():
exp_list = []
if ctx.explist():
exp_list = self.visit(ctx.explist())

return Call(None, exp_list)
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#functiondef.
def visitFunctiondef(self, ctx: LuaParser.FunctiondefContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#funcbody.
def visitFuncbody(self, ctx: LuaParser.FuncbodyContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#parlist.
def visitParlist(self, ctx: LuaParser.ParlistContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#tableconstructor.
def visitTableconstructor(self, ctx: LuaParser.TableconstructorContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#fieldlist.
def visitFieldlist(self, ctx: LuaParser.FieldlistContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#field.
def visitField(self, ctx: LuaParser.FieldContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#fieldsep.
def visitFieldsep(self, ctx: LuaParser.FieldsepContext):
return self.visitChildren(ctx)

# Visit a parse tree produced by LuaParser#number.
def visitNumber(self, ctx: LuaParser.NumberContext):
number_text = self.visitChildren(ctx)
try:
number = ast.literal_eval(number_text)
except:
# exception occurs with leading zero number: 002
number = float(number_text)
return Number(
number,
)


# Visit a parse tree produced by LuaParser#string.
def visitString(self, ctx: LuaParser.StringContext):
return self.visitChildren(ctx)


def parse(source: str) -> Chunk:
"""Parse Lua source to a Chunk."""
return Builder(source).process()
lexer = LuaLexer(InputStream(source))
stream = CommonTokenStream(lexer)
parser = LuaParser(stream)
tree = parser.start_()

if parser.getNumberOfSyntaxErrors() > 0:
raise SyntaxException("syntax errors")
else:
v = MyVisitor()
val = v.visit(tree)
print(val)
return val


def get_token_stream(source: str) -> CommonTokenStream:
Expand Down Expand Up @@ -354,16 +708,16 @@ def syntaxError(self, recognizer, offending_symbol, line, column, msg, e):
raise SyntaxException(str(line) + ":" + str(column) + ": " + str(msg))

def reportAmbiguity(
self, recognizer, dfa, start_index, stop_index, exact, ambig_alts, configs
self, recognizer, dfa, start_index, stop_index, exact, ambig_alts, configs
):
pass

def reportAttemptingFullContext(
self, recognizer, dfa, start_index, stop_index, conflicting_alts, configs
self, recognizer, dfa, start_index, stop_index, conflicting_alts, configs
):
pass

def reportContextSensitivity(
self, recognizer, dfa, start_index, stop_index, prediction, configs
self, recognizer, dfa, start_index, stop_index, prediction, configs
):
pass
Loading

0 comments on commit e04ef32

Please sign in to comment.