diff --git a/formatter/format_test.go b/formatter/format_test.go index c2353e6..ca8ecad 100644 --- a/formatter/format_test.go +++ b/formatter/format_test.go @@ -334,6 +334,50 @@ func TestCompilationUnit(t *testing.T) { } } +func TestEndOfFileComments(t *testing.T) { + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + + } + tests := + []struct { + input string + output string + }{ + { + `private class T1Exception extends Exception {} //test`, + `private class T1Exception extends Exception {} +//test`}, + { + `public class MyClass { public static void noop() {}} + //test comment + // oie`, + `public class MyClass { + public static void noop() {} +} +//test comment +// oie`}, + } + for _, tt := range tests { + input := antlr.NewInputStream(tt.input) + lexer := parser.NewApexLexer(input) + stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) + + p := parser.NewApexParser(stream) + p.RemoveErrorListeners() + p.AddErrorListener(&testErrorListener{t: t}) + + v := NewFormatVisitor(stream) + out, ok := v.visitRule(p.CompilationUnit()).(string) + if !ok { + t.Errorf("Unexpected result parsing apex") + } + if out != tt.output { + t.Errorf("unexpected format. expected:\n%s\ngot:\n%s\n", tt.output, out) + } + } +} + func TestSOQL(t *testing.T) { if testing.Verbose() { log.SetLevel(log.DebugLevel) diff --git a/formatter/formatter.go b/formatter/formatter.go index 652f935..43974a0 100644 --- a/formatter/formatter.go +++ b/formatter/formatter.go @@ -74,7 +74,7 @@ func (f *Formatter) Format() error { if f.source == nil { src, err := readFile(f.filename, f.reader) if err != nil { - return fmt.Errorf("Failed to read file %s: %w", f.SourceName(), err) + return fmt.Errorf("failed to read file %s: %w", f.SourceName(), err) } f.source = src } diff --git a/formatter/visitors.go b/formatter/visitors.go index 720c77e..55028b2 100644 --- a/formatter/visitors.go +++ b/formatter/visitors.go @@ -10,15 +10,18 @@ import ( ) func (v *FormatVisitor) VisitCompilationUnit(ctx *parser.CompilationUnitContext) interface{} { + + result := "" + if trigger := ctx.TriggerUnit(); trigger != nil { - return v.visitRule(trigger) + result = v.visitRule(trigger).(string) } t := ctx.TypeDeclaration() switch { case t.ClassDeclaration() != nil: - return fmt.Sprintf("%s%s", v.Modifiers(t.AllModifier()), v.visitRule(t.ClassDeclaration()).(string)) + result = fmt.Sprintf("%s%s", v.Modifiers(t.AllModifier()), v.visitRule(t.ClassDeclaration()).(string)) case t.InterfaceDeclaration() != nil: - return fmt.Sprintf("%s%s", v.Modifiers(t.AllModifier()), v.visitRule(t.InterfaceDeclaration()).(string)) + result = fmt.Sprintf("%s%s", v.Modifiers(t.AllModifier()), v.visitRule(t.InterfaceDeclaration()).(string)) case t.EnumDeclaration() != nil: enum := t.EnumDeclaration() constants := []string{} @@ -27,9 +30,30 @@ func (v *FormatVisitor) VisitCompilationUnit(ctx *parser.CompilationUnitContext) constants = append(constants, e.GetText()) } } - return fmt.Sprintf("enum %s {%s}", v.visitRule(enum.Id()), strings.Join(constants, ", ")) + result = fmt.Sprintf("enum %s {%s}", v.visitRule(enum.Id()), strings.Join(constants, ", ")) } - return "" + + stop := ctx.GetStop() + if stop != nil { + eofComments := v.tokens.GetHiddenTokensToRight(stop.GetTokenIndex(), COMMENTS_CHANNEL) + + if eofComments != nil { + comments := []string{} + + for _, c := range eofComments { + if _, seen := v.commentsOutput[c.GetTokenIndex()]; !seen { + comments = append(comments, cleanWhitespace(c.GetText())) + v.commentsOutput[c.GetTokenIndex()] = struct{}{} + } + } + + if len(comments) > 0 { + return (fmt.Sprintf("%s\n%s", result, strings.Join(comments, "\n"))) + } + } + } + + return result } func (v *FormatVisitor) VisitClassDeclaration(ctx *parser.ClassDeclarationContext) interface{} {