Skip to content

Commit

Permalink
Revamp checks against spurious symbols after expressions
Browse files Browse the repository at this point in the history
Rewrite and reorder the checks of expressions that this xpression has been parsed completely and that there aren't any spurious symbols following this expression:
– All functions ‘below’ `ParseExpression_Term` that parse expressions must check this and throw when the checks fail.
– Callers of such functions must trust that the called function has checked this when it returns and not attempt to repeat the checks.
  • Loading branch information
fernewelten committed Aug 31, 2024
1 parent 1b19d66 commit 9a399eb
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 135 deletions.
202 changes: 116 additions & 86 deletions Compiler/script2/cs_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1871,8 +1871,28 @@ void AGS::Parser::EvaluationResultToAx(EvaluationResult &eres)
}
}

void AGS::Parser::StripOutermostParens(SrcList &expression)
{
// Note: If 'expression' starts out with '(' and ends with ')' then we can't
// conclude that those parens belong together. Counter-example: '(a+b) * (a-b)'
while (true)
{
expression.StartRead();
if (kKW_OpenParenthesis != expression.GetNext())
return; // done
expression.SkipToCloser();
SkipNextSymbol(expression, kKW_CloseParenthesis);
if (!expression.ReachedEOF())
return; // done

expression.EatFirstSymbol();
expression.EatLastSymbol();
}
}

void AGS::Parser::ParseExpression_New(SrcList &expression, EvaluationResult &eres)
{
expression.StartRead();
if (expression.ReachedEOF())
UserError("Expected a type after 'new' but didn't find any");
Vartype const argument_vartype = expression.GetNext();
Expand All @@ -1893,27 +1913,27 @@ void AGS::Parser::ParseExpression_New(SrcList &expression, EvaluationResult &ere

// Check for '[' with a handcrafted error message so that the user isn't led to
// fix their code by defining a dynamic array when this would be the wrong thing to do
Symbol const open_bracket = _src.GetNext();
Symbol const open_bracket = expression.GetNext();
if (kKW_OpenBracket != open_bracket)
UserError("Unexpected '%s'", _sym.GetName(open_bracket).c_str());

EvaluationResult bracketed_eres;
ParseIntegerExpression(_src, bracketed_eres);
ParseIntegerExpression(expression, bracketed_eres);
EvaluationResultToAx(bracketed_eres);
Expect(kKW_CloseBracket, _src.GetNext());
Expect(kKW_CloseBracket, expression.GetNext());

element_vartype = is_managed ? _sym.VartypeWithDynpointer(argument_vartype) : argument_vartype;
eres.Vartype = _sym.VartypeWithDynarray(element_vartype);

while (kKW_OpenBracket == expression.PeekNext())
{
SkipNextSymbol(_src, kKW_OpenBracket);
SkipNextSymbol(expression, kKW_OpenBracket);
Expect(kKW_CloseBracket, expression.GetNext());
element_vartype = eres.Vartype;
eres.Vartype = _sym.VartypeWithDynarray(element_vartype);
}
}
else
else // no bracket expression
{
if (_sym[argument_vartype].VartypeD->Flags[VTF::kUndefined])
UserError(
Expand Down Expand Up @@ -1969,14 +1989,16 @@ void AGS::Parser::ParseExpression_New(SrcList &expression, EvaluationResult &ere

_reg_track.SetRegister(SREG_AX);

ParseExpression_CheckUsedUp(expression);

eres.Type = eres.kTY_RunTimeValue;
eres.Location = eres.kLOC_AX;
eres.Symbol = kKW_NoSymbol;
// Vartype has already been set
}

// We're parsing an expression that starts with '-' (unary minus)
void AGS::Parser::ParseExpression_PrefixMinus(SrcList &expression, EvaluationResult &eres)
void AGS::Parser::ParseExpression_PrefixMinus(EvaluationResult &eres)
{
if (eres.kTY_Literal == eres.Type)
{
Expand Down Expand Up @@ -2007,11 +2029,8 @@ void AGS::Parser::ParseExpression_PrefixMinus(SrcList &expression, EvaluationRes
}

// We're parsing an expression that starts with '+' (unary plus)
void AGS::Parser::ParseExpression_PrefixPlus(SrcList &expression, EvaluationResult &eres)
void AGS::Parser::ParseExpression_PrefixPlus(EvaluationResult &eres)
{
expression.StartRead();

ParseExpression_Term(expression, eres);

if (_sym.IsAnyIntegerVartype(eres.Vartype) || kKW_Float == eres.Vartype)
return;
Expand All @@ -2020,7 +2039,7 @@ void AGS::Parser::ParseExpression_PrefixPlus(SrcList &expression, EvaluationResu
}

// We're parsing an expression that starts with '!' (boolean NOT) or '~' (bitwise Negate)
void AGS::Parser::ParseExpression_PrefixNegate(Symbol op_sym, SrcList &expression, EvaluationResult &eres)
void AGS::Parser::ParseExpression_PrefixNegate(Symbol op_sym, EvaluationResult &eres)
{
bool const bitwise_negation = kKW_BitNeg == op_sym;

Expand Down Expand Up @@ -2062,24 +2081,38 @@ void AGS::Parser::ParseExpression_PrefixNegate(Symbol op_sym, SrcList &expressio

void AGS::Parser::ParseExpression_PrefixCrement(Symbol op_sym, AGS::SrcList &expression, EvaluationResult &eres)
{
bool const op_is_inc = (kKW_Increment == op_sym);
SrcList inner = SrcList(expression);

expression.StartRead();
// Strip any enclosing '()' around the term
// so that you can do '++(foo)' as well as '++foo'
if (kKW_OpenParenthesis == inner[0u])
{
StripOutermostParens(inner);
if (kKW_OpenParenthesis == inner[0u])
{
// There must be spurious trailing symbols after the closing paren.
inner.StartRead();
inner.SkipToCloser();
SkipNextSymbol(inner, kKW_CloseParenthesis);
ParseExpression_CheckUsedUp(inner);
}
}

inner.StartRead();

ParseAssignment_ReadLHSForModification(expression, eres);
ParseAssignment_ReadLHSForModification(inner, eres);

std::string msg = "Argument of '<op>'";
msg.replace(msg.find("<op>"), 4u, _sym.GetName(op_sym).c_str());
CheckVartypeMismatch(eres.Vartype, kKW_Int, true, msg);

WriteCmd((op_is_inc ? SCMD_ADD : SCMD_SUB), SREG_AX, 1);
WriteCmd((kKW_Increment == op_sym) ? SCMD_ADD : SCMD_SUB, SREG_AX, 1);
_reg_track.SetRegister(SREG_AX);

// Really do the assignment the long way so that all the checks and safeguards will run.
// If a shortcut is possible then undo this and generate the shortcut instead.
RestorePoint before_long_way_modification = RestorePoint(_scrip);

AccessData_AssignTo(expression, eres);
AccessData_AssignTo(inner, eres);

if (EvaluationResult::kLOC_MemoryAtMAR == eres.Location)
{
Expand All @@ -2089,14 +2122,7 @@ void AGS::Parser::ParseExpression_PrefixCrement(Symbol op_sym, AGS::SrcList &exp
_reg_track.SetRegister(SREG_AX);
}
eres.SideEffects = true;
}

void AGS::Parser::ParseExpression_LongMin(EvaluationResult &eres)
{
eres.Type = eres.kTY_Literal;
eres.Location = eres.kLOC_SymbolTable;
eres.Symbol = _sym.Find("-2147483648");
eres.Vartype = kKW_Int;
ParseExpression_CheckUsedUp(inner);
}


Expand All @@ -2111,67 +2137,72 @@ void AGS::Parser::ParseExpression_Prefix(SrcList &expression, EvaluationResult &
"Expected a term after '%s' but didn't find any",
_sym.GetName(op_sym).c_str());

expression.EatFirstSymbol();
SrcList inner = SrcList(expression);
inner.EatFirstSymbol();

if (kKW_New == op_sym)
return ParseExpression_New(expression, eres);
return ParseExpression_New(inner, eres);

if (kKW_Decrement == op_sym || kKW_Increment == op_sym)
{
StripOutermostParens(expression);
return ParseExpression_PrefixCrement(op_sym, expression, eres);
}

return ParseExpression_PrefixCrement(op_sym, inner, eres);

// Special case: Lowest integer literal, written in decimal notation.
// We treat this here in the parser because the scanner doesn't know
// whether a minus symbol stands for a unary minus.
if (op_sym == kKW_Minus &&
expression.Length() == 1u &&
expression[0] == kKW_OnePastLongMax)
return ParseExpression_LongMin(eres);
inner.Length() == 1u &&
inner[0] == kKW_OnePastLongMax)
{
eres.Type = eres.kTY_Literal;
eres.Location = eres.kLOC_SymbolTable;
eres.Symbol = _sym.Find("-2147483648");
eres.Vartype = kKW_Int;
ParseExpression_CheckUsedUp(inner);
return;
}

ParseExpression_Term(expression, eres);
ParseExpression_Term(inner, eres);

switch (op_sym)
{
case kKW_BitNeg:
case kKW_Not:
return ParseExpression_PrefixNegate(op_sym, expression, eres);
return ParseExpression_PrefixNegate(op_sym, eres);

case kKW_Minus:
return ParseExpression_PrefixMinus(expression, eres);
return ParseExpression_PrefixMinus( eres);

case kKW_Plus:
return ParseExpression_PrefixPlus(expression, eres);
return ParseExpression_PrefixPlus(eres);
}

InternalError("Illegal prefix op '%s'", _sym.GetName(op_sym).c_str());
}

void AGS::Parser::StripOutermostParens(SrcList &expression)
{
while (expression[0] == kKW_OpenParenthesis)
{
size_t const last = expression.Length() - 1u;
if (kKW_CloseParenthesis != expression[last])
return;
expression.SetCursor(1u);
expression.SkipToCloser();
if (expression.GetCursor() != last)
return;
expression.EatFirstSymbol();
expression.EatLastSymbol();
}
}

void AGS::Parser::ParseExpression_PostfixCrement(Symbol const op_sym, SrcList &expression, EvaluationResult &eres)
{
bool const op_is_inc = kKW_Increment == op_sym;

StripOutermostParens(expression);
expression.StartRead();
SrcList inner = SrcList(expression);

// Strip any enclosing '()' around the term
// so that you can do '(foo)++' as well as 'foo++'
if (kKW_OpenParenthesis == inner[0u])
{
StripOutermostParens(inner);
if (kKW_OpenParenthesis == inner[0u])
{
// There must be spurious trailing symbols after the closing paren.
inner.StartRead();
inner.SkipToCloser();
SkipNextSymbol(inner, kKW_CloseParenthesis);
ParseExpression_CheckUsedUp(inner);
}
}

ParseAssignment_ReadLHSForModification(expression, eres);
inner.StartRead();
ParseAssignment_ReadLHSForModification(inner, eres);
ParseExpression_CheckUsedUp(inner);

std::string msg = "Argument of '<op>'";
msg.replace(msg.find("<op>"), 4u, _sym.GetName(op_sym).c_str());
Expand All @@ -2183,7 +2214,7 @@ void AGS::Parser::ParseExpression_PostfixCrement(Symbol const op_sym, SrcList &e

PushReg(SREG_AX);
WriteCmd((op_is_inc ? SCMD_ADD : SCMD_SUB), SREG_AX, 1);
AccessData_AssignTo(expression, eres);
AccessData_AssignTo(inner, eres);
PopReg(SREG_AX);

if (EvaluationResult::kLOC_MemoryAtMAR == eres.Location)
Expand All @@ -2201,27 +2232,26 @@ void AGS::Parser::ParseExpression_PostfixCrement(Symbol const op_sym, SrcList &e

void AGS::Parser::ParseExpression_Postfix(SrcList &expression, EvaluationResult &eres, bool const result_used)
{
size_t const len = expression.Length();
size_t const expr_len = expression.Length();

if (0u == len)
if (0u == expr_len)
InternalError("Empty expression");

Symbol const op_sym = expression[len - 1u];
if (1u == len)
Symbol const op_sym = expression[expr_len - 1u];
if (1u == expr_len)
UserError("'%s' must either precede or follow some term to be modified", _sym.GetName(op_sym).c_str());

expression.EatLastSymbol();
SrcList inner = SrcList(expression);
inner.EatLastSymbol();

switch (op_sym)
{
case kKW_Decrement:
case kKW_Increment:
// If the result isn't used then take the more efficient version of increment / decrement
return result_used ?
ParseExpression_PostfixCrement(op_sym, expression, eres) : ParseExpression_PrefixCrement(op_sym, expression, eres);
}
if (op_sym != kKW_Decrement && op_sym != kKW_Increment)
UserError("Expected a term following the '%s', didn't find it", _sym.GetName(op_sym).c_str());

// If the result isn't used then take the more efficient version of increment / decrement
if (!result_used)
return ParseExpression_PrefixCrement(op_sym, inner, eres);

UserError("Expected a term following the '%s', didn't find it", _sym.GetName(op_sym).c_str());
ParseExpression_PostfixCrement(op_sym, inner, eres);
}

void AGS::Parser::ParseExpression_Ternary_Term2(EvaluationResult &eres_term1, bool term1_has_been_ripped_out, SrcList &term2, EvaluationResult &eres, bool const result_used)
Expand Down Expand Up @@ -2255,7 +2285,7 @@ void AGS::Parser::ParseExpression_Ternary_Term2(EvaluationResult &eres_term1, bo
ConvertAXStringToStringObject(_sym.GetStringStructPtrSym(), eres.Vartype);
}

void AGS::Parser::ParseExpression_Ternary(size_t tern_idx, SrcList &expression, EvaluationResult &eres, bool const result_used)
void AGS::Parser::ParseExpression_Ternary(size_t const tern_idx, SrcList &expression, bool const result_used, EvaluationResult &eres)
{
// First term ends before the '?'
SrcList term1 = SrcList(expression, 0, tern_idx);
Expand Down Expand Up @@ -2310,9 +2340,6 @@ void AGS::Parser::ParseExpression_Ternary(size_t tern_idx, SrcList &expression,
EvaluationResult eres_dummy = eres_term1;
EvaluationResultToAx(eres_dummy);

if (!term1.ReachedEOF())
InternalError("Unexpected '%s' after 1st term of ternary", _sym.GetName(term1.GetNext()).c_str());

bool term1_has_been_ripped_out = false;
if (term1_known)
{ // Don't need to do the test at runtime
Expand Down Expand Up @@ -2513,7 +2540,7 @@ void AGS::Parser::ParseExpression_Binary(size_t const op_idx, SrcList &expressio
eres = condition ? eres_rhs : eres_lhs;
else // kKW_Or
eres = condition ? eres_lhs : eres_rhs;

if (!_sym.IsAnyIntegerVartype(_sym[eres.Symbol].LiteralD->Vartype))
{ // Swap an int literal in (note: Don't change the vartype of the pre-existing literal)
bool const result = (0 != _sym[eres.Symbol].LiteralD->Value);
Expand Down Expand Up @@ -2553,9 +2580,10 @@ void AGS::Parser::ParseExpression_InParens(SrcList &expression, EvaluationResult
expression.SkipToCloser();
SkipNextSymbol(expression, kKW_CloseParenthesis);
ParseExpression_CheckUsedUp(expression);

StripOutermostParens(expression);
return ParseExpression_Term(expression, eres, result_used);

SrcList inner(expression);
StripOutermostParens(inner);
return ParseExpression_Term(inner, eres, result_used);
}

// We're in the parameter list of a function call, and we have less parameters than declared.
Expand Down Expand Up @@ -3000,12 +3028,13 @@ void AGS::Parser::ParseExpression_NoOps(SrcList &expression, EvaluationResult &e
return ParseExpression_InParens(expression, eres, result_used);

AccessData(VAC::kReading, expression, eres);
return ParseExpression_CheckUsedUp(expression);
ParseExpression_CheckUsedUp(expression);
}

void AGS::Parser::ParseExpression_Term(SrcList &expression, EvaluationResult &eres, bool const result_used)
{
if (expression.Length() == 0u)
size_t const exp_length = expression.Length();
if (0u == exp_length)
InternalError("Cannot parse empty subexpression");

int const least_binding_op_idx = IndexOfLeastBondingOperator(expression); // can be < 0
Expand All @@ -3017,10 +3046,11 @@ void AGS::Parser::ParseExpression_Term(SrcList &expression, EvaluationResult &er
else if (expression.Length() - 1u == least_binding_op_idx)
ParseExpression_Postfix(expression, eres, result_used);
else if (kKW_Tern == expression[least_binding_op_idx])
ParseExpression_Ternary(least_binding_op_idx, expression, eres, result_used);
ParseExpression_Ternary(least_binding_op_idx, expression, result_used, eres);
else
ParseExpression_Binary(least_binding_op_idx, expression, eres);


expression.SetCursor(exp_length);
return HandleStructOrArrayResult(eres);
}

Expand Down Expand Up @@ -3721,7 +3751,7 @@ void AGS::Parser::AccessData_AssignTo(SrcList &expression, EvaluationResult eres
[&]
{
AccessData(VAC::kWriting, expression, lhs_eres);
if (!expression.ReachedEOF() && lhs_eres.kTY_AttributeName != lhs_eres.Type)
if (!expression.ReachedEOF() && lhs_eres.kTY_AttributeName != lhs_eres.Type)
// Spurious characters follow the LHS in front of the assignment symbol, e.g., 'var 77 = 9;'
UserError("Unexpected '%s'", _sym.GetName(expression.PeekNext()).c_str());

Expand Down
Loading

0 comments on commit 9a399eb

Please sign in to comment.