From 6f0feae446816e32d7d9faad7f2da8edf2dc3b0f Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Sun, 25 Aug 2024 21:27:05 -0400 Subject: [PATCH] Global cache for documents + top level jsonnet objects (#153) * Global cache for documents + top level jsonnet objects Closes https://github.com/grafana/jsonnet-language-server/issues/133 There are two caches currently: - One for protocol documents. This one is instantiated by the server and maintained up-to-date as documents are opened, changed, and closed. - One for jsonnet objects. This one is a global var and is only added to. Modified objects are never removed/modified from the cache. By merging the two caches, we can expand the first cache's behavior to also invalidate modified objects from the global cache when a document is changed. * Simplify processing args (#154) Instead of carrying a `cache` and `vm` around on each function, create a `Processor` struct to contain those * Fix linting --- .golangci.toml | 3 +- pkg/ast/processing/find_field.go | 38 +++++------ pkg/ast/processing/processor.go | 18 +++++ pkg/ast/processing/top_level_objects.go | 27 ++++---- pkg/{server => cache}/cache.go | 90 +++++++++++++++---------- pkg/server/completion.go | 13 ++-- pkg/server/configuration_test.go | 4 +- pkg/server/definition.go | 15 +++-- pkg/server/diagnostics.go | 47 ++++++------- pkg/server/diagnostics_test.go | 4 +- pkg/server/execute.go | 4 +- pkg/server/formatting.go | 6 +- pkg/server/hover.go | 12 ++-- pkg/server/server.go | 37 ++++++---- pkg/server/symbols.go | 6 +- 15 files changed, 185 insertions(+), 139 deletions(-) create mode 100644 pkg/ast/processing/processor.go rename pkg/{server => cache}/cache.go (54%) diff --git a/.golangci.toml b/.golangci.toml index 1178a69..9311a87 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -1,14 +1,13 @@ [linters] enable = [ + "copyloopvar", "dogsled", - "exportloopref", "forcetypeassert", "goconst", "gocritic", "gocyclo", "goimports", "goprintffuncname", - "gosec", "gosimple", "govet", "ineffassign", diff --git a/pkg/ast/processing/find_field.go b/pkg/ast/processing/find_field.go index f2497fd..0fe3627 100644 --- a/pkg/ast/processing/find_field.go +++ b/pkg/ast/processing/find_field.go @@ -5,13 +5,12 @@ import ( "reflect" "strings" - "github.com/google/go-jsonnet" "github.com/google/go-jsonnet/ast" "github.com/grafana/jsonnet-language-server/pkg/nodestack" log "github.com/sirupsen/logrus" ) -func FindRangesFromIndexList(stack *nodestack.NodeStack, indexList []string, vm *jsonnet.VM, partialMatchFields bool) ([]ObjectRange, error) { +func (p *Processor) FindRangesFromIndexList(stack *nodestack.NodeStack, indexList []string, partialMatchFields bool) ([]ObjectRange, error) { var foundDesugaredObjects []*ast.DesugaredObject // First element will be super, self, or var name start, indexList := indexList[0], indexList[1:] @@ -31,13 +30,13 @@ func FindRangesFromIndexList(stack *nodestack.NodeStack, indexList []string, vm if _, ok := tmpStack.Peek().(*ast.Binary); ok { tmpStack.Pop() } - foundDesugaredObjects = filterSelfScope(FindTopLevelObjects(tmpStack, vm)) + foundDesugaredObjects = filterSelfScope(p.FindTopLevelObjects(tmpStack)) case start == "std": return nil, fmt.Errorf("cannot get definition of std lib") case start == "$": - foundDesugaredObjects = FindTopLevelObjects(nodestack.NewNodeStack(stack.From), vm) + foundDesugaredObjects = p.FindTopLevelObjects(nodestack.NewNodeStack(stack.From)) case strings.Contains(start, "."): - foundDesugaredObjects = FindTopLevelObjectsInFile(vm, start, "") + foundDesugaredObjects = p.FindTopLevelObjectsInFile(start, "") default: if strings.Count(start, "(") == 1 && strings.Count(start, ")") == 1 { @@ -65,15 +64,15 @@ func FindRangesFromIndexList(stack *nodestack.NodeStack, indexList []string, vm foundDesugaredObjects = append(foundDesugaredObjects, bodyNode) case *ast.Self: tmpStack := nodestack.NewNodeStack(stack.From) - foundDesugaredObjects = FindTopLevelObjects(tmpStack, vm) + foundDesugaredObjects = p.FindTopLevelObjects(tmpStack) case *ast.Import: filename := bodyNode.File.Value - foundDesugaredObjects = FindTopLevelObjectsInFile(vm, filename, "") + foundDesugaredObjects = p.FindTopLevelObjectsInFile(filename, "") case *ast.Index, *ast.Apply: tempStack := nodestack.NewNodeStack(bodyNode) indexList = append(tempStack.BuildIndexList(), indexList...) - return FindRangesFromIndexList(stack, indexList, vm, partialMatchFields) + return p.FindRangesFromIndexList(stack, indexList, partialMatchFields) case *ast.Function: // If the function's body is an object, it means we can look for indexes within the function if funcBody := findChildDesugaredObject(bodyNode.Body); funcBody != nil { @@ -84,10 +83,10 @@ func FindRangesFromIndexList(stack *nodestack.NodeStack, indexList []string, vm } } - return extractObjectRangesFromDesugaredObjs(vm, foundDesugaredObjects, indexList, partialMatchFields) + return p.extractObjectRangesFromDesugaredObjs(foundDesugaredObjects, indexList, partialMatchFields) } -func extractObjectRangesFromDesugaredObjs(vm *jsonnet.VM, desugaredObjs []*ast.DesugaredObject, indexList []string, partialMatchFields bool) ([]ObjectRange, error) { +func (p *Processor) extractObjectRangesFromDesugaredObjs(desugaredObjs []*ast.DesugaredObject, indexList []string, partialMatchFields bool) ([]ObjectRange, error) { var ranges []ObjectRange for len(indexList) > 0 { index := indexList[0] @@ -111,7 +110,7 @@ func extractObjectRangesFromDesugaredObjs(vm *jsonnet.VM, desugaredObjs []*ast.D return ranges, nil } - fieldNodes, err := unpackFieldNodes(vm, foundFields) + fieldNodes, err := p.unpackFieldNodes(foundFields) if err != nil { return nil, err } @@ -125,7 +124,7 @@ func extractObjectRangesFromDesugaredObjs(vm *jsonnet.VM, desugaredObjs []*ast.D // The target is a function and will be found by FindVarReference on the next loop fieldNodes = append(fieldNodes, fieldNode.Target) case *ast.Var: - varReference, err := FindVarReference(fieldNode, vm) + varReference, err := p.FindVarReference(fieldNode) if err != nil { return nil, err } @@ -142,11 +141,11 @@ func extractObjectRangesFromDesugaredObjs(vm *jsonnet.VM, desugaredObjs []*ast.D // if we're trying to find the a definition which is an index, // we need to find it from itself, meaning that we need to create a stack // from the index's target and search from there - rootNode, _, _ := vm.ImportAST("", fieldNode.LocRange.FileName) + rootNode, _, _ := p.vm.ImportAST("", fieldNode.LocRange.FileName) stack, _ := FindNodeByPosition(rootNode, fieldNode.Target.Loc().Begin) if stack != nil { additionalIndexList := append(nodestack.NewNodeStack(fieldNode).BuildIndexList(), indexList...) - result, _ := FindRangesFromIndexList(stack, additionalIndexList, vm, partialMatchFields) + result, _ := p.FindRangesFromIndexList(stack, additionalIndexList, partialMatchFields) if len(result) > 0 { return result, err } @@ -157,7 +156,7 @@ func extractObjectRangesFromDesugaredObjs(vm *jsonnet.VM, desugaredObjs []*ast.D desugaredObjs = append(desugaredObjs, findChildDesugaredObject(fieldNode.Body)) case *ast.Import: filename := fieldNode.File.Value - newObjs := FindTopLevelObjectsInFile(vm, filename, string(fieldNode.Loc().File.DiagnosticFileName)) + newObjs := p.FindTopLevelObjectsInFile(filename, string(fieldNode.Loc().File.DiagnosticFileName)) desugaredObjs = append(desugaredObjs, newObjs...) } i++ @@ -177,13 +176,13 @@ func flattenBinary(node ast.Node) []ast.Node { // unpackFieldNodes extracts nodes from fields // - Binary nodes. A field could be either in the left or right side of the binary // - Self nodes. We want the object self refers to, not the self node itself -func unpackFieldNodes(vm *jsonnet.VM, fields []*ast.DesugaredObjectField) ([]ast.Node, error) { +func (p *Processor) unpackFieldNodes(fields []*ast.DesugaredObjectField) ([]ast.Node, error) { var fieldNodes []ast.Node for _, foundField := range fields { switch fieldNode := foundField.Body.(type) { case *ast.Self: filename := fieldNode.LocRange.FileName - rootNode, _, _ := vm.ImportAST("", filename) + rootNode, _, _ := p.vm.ImportAST("", filename) tmpStack, err := FindNodeByPosition(rootNode, fieldNode.LocRange.Begin) if err != nil { return nil, err @@ -220,7 +219,6 @@ func findObjectFieldsInObject(objectNode *ast.DesugaredObject, index string, par var matchingFields []*ast.DesugaredObjectField for _, field := range objectNode.Fields { - field := field literalString, isString := field.Name.(*ast.LiteralString) if !isString { continue @@ -253,8 +251,8 @@ func findChildDesugaredObject(node ast.Node) *ast.DesugaredObject { // FindVarReference finds the object that the variable is referencing // To do so, we get the stack where the var is used and search that stack for the var's definition -func FindVarReference(varNode *ast.Var, vm *jsonnet.VM) (ast.Node, error) { - varFileNode, _, _ := vm.ImportAST("", varNode.LocRange.FileName) +func (p *Processor) FindVarReference(varNode *ast.Var) (ast.Node, error) { + varFileNode, _, _ := p.vm.ImportAST("", varNode.LocRange.FileName) varStack, err := FindNodeByPosition(varFileNode, varNode.Loc().Begin) if err != nil { return nil, fmt.Errorf("got the following error when finding the bind for %s: %w", varNode.Id, err) diff --git a/pkg/ast/processing/processor.go b/pkg/ast/processing/processor.go new file mode 100644 index 0000000..32d8313 --- /dev/null +++ b/pkg/ast/processing/processor.go @@ -0,0 +1,18 @@ +package processing + +import ( + "github.com/google/go-jsonnet" + "github.com/grafana/jsonnet-language-server/pkg/cache" +) + +type Processor struct { + cache *cache.Cache + vm *jsonnet.VM +} + +func NewProcessor(cache *cache.Cache, vm *jsonnet.VM) *Processor { + return &Processor{ + cache: cache, + vm: vm, + } +} diff --git a/pkg/ast/processing/top_level_objects.go b/pkg/ast/processing/top_level_objects.go index 13014dd..4bef0fc 100644 --- a/pkg/ast/processing/top_level_objects.go +++ b/pkg/ast/processing/top_level_objects.go @@ -1,26 +1,23 @@ package processing import ( - "github.com/google/go-jsonnet" "github.com/google/go-jsonnet/ast" "github.com/grafana/jsonnet-language-server/pkg/nodestack" log "github.com/sirupsen/logrus" ) -var fileTopLevelObjectsCache = make(map[string][]*ast.DesugaredObject) - -func FindTopLevelObjectsInFile(vm *jsonnet.VM, filename, importedFrom string) []*ast.DesugaredObject { - cacheKey := importedFrom + ":" + filename - if _, ok := fileTopLevelObjectsCache[cacheKey]; !ok { - rootNode, _, _ := vm.ImportAST(importedFrom, filename) - fileTopLevelObjectsCache[cacheKey] = FindTopLevelObjects(nodestack.NewNodeStack(rootNode), vm) +func (p *Processor) FindTopLevelObjectsInFile(filename, importedFrom string) []*ast.DesugaredObject { + v, ok := p.cache.GetTopLevelObject(filename, importedFrom) + if !ok { + rootNode, _, _ := p.vm.ImportAST(importedFrom, filename) + v = p.FindTopLevelObjects(nodestack.NewNodeStack(rootNode)) + p.cache.PutTopLevelObject(filename, importedFrom, v) } - - return fileTopLevelObjectsCache[cacheKey] + return v } // Find all ast.DesugaredObject's from NodeStack -func FindTopLevelObjects(stack *nodestack.NodeStack, vm *jsonnet.VM) []*ast.DesugaredObject { +func (p *Processor) FindTopLevelObjects(stack *nodestack.NodeStack) []*ast.DesugaredObject { var objects []*ast.DesugaredObject for !stack.IsEmpty() { curr := stack.Pop() @@ -34,7 +31,7 @@ func FindTopLevelObjects(stack *nodestack.NodeStack, vm *jsonnet.VM) []*ast.Desu stack.Push(curr.Body) case *ast.Import: filename := curr.File.Value - rootNode, _, _ := vm.ImportAST(string(curr.Loc().File.DiagnosticFileName), filename) + rootNode, _, _ := p.vm.ImportAST(string(curr.Loc().File.DiagnosticFileName), filename) stack.Push(rootNode) case *ast.Index: indexValue, indexIsString := curr.Index.(*ast.LiteralString) @@ -45,7 +42,7 @@ func FindTopLevelObjects(stack *nodestack.NodeStack, vm *jsonnet.VM) []*ast.Desu var container ast.Node // If our target is a var, the container for the index is the var ref if varTarget, targetIsVar := curr.Target.(*ast.Var); targetIsVar { - ref, err := FindVarReference(varTarget, vm) + ref, err := p.FindVarReference(varTarget) if err != nil { log.WithError(err).Errorf("Error finding var reference, ignoring this node") continue @@ -62,7 +59,7 @@ func FindTopLevelObjects(stack *nodestack.NodeStack, vm *jsonnet.VM) []*ast.Desu if containerObj, containerIsObj := container.(*ast.DesugaredObject); containerIsObj { possibleObjects = []*ast.DesugaredObject{containerObj} } else if containerImport, containerIsImport := container.(*ast.Import); containerIsImport { - possibleObjects = FindTopLevelObjectsInFile(vm, containerImport.File.Value, string(containerImport.Loc().File.DiagnosticFileName)) + possibleObjects = p.FindTopLevelObjectsInFile(containerImport.File.Value, string(containerImport.Loc().File.DiagnosticFileName)) } for _, obj := range possibleObjects { @@ -71,7 +68,7 @@ func FindTopLevelObjects(stack *nodestack.NodeStack, vm *jsonnet.VM) []*ast.Desu } } case *ast.Var: - varReference, err := FindVarReference(curr, vm) + varReference, err := p.FindVarReference(curr) if err != nil { log.WithError(err).Errorf("Error finding var reference, ignoring this node") continue diff --git a/pkg/server/cache.go b/pkg/cache/cache.go similarity index 54% rename from pkg/server/cache.go rename to pkg/cache/cache.go index c350249..f402aea 100644 --- a/pkg/server/cache.go +++ b/pkg/cache/cache.go @@ -1,9 +1,10 @@ -package server +package cache import ( "errors" "fmt" "os" + "path/filepath" "strings" "sync" @@ -11,59 +12,63 @@ import ( "github.com/jdbaldry/go-language-server-protocol/lsp/protocol" ) -type document struct { +type Document struct { // From DidOpen and DidChange - item protocol.TextDocumentItem + Item protocol.TextDocumentItem // Contains the last successfully parsed AST. If doc.err is not nil, it's out of date. - ast ast.Node - linesChangedSinceAST map[int]bool + AST ast.Node + LinesChangedSinceAST map[int]bool // From diagnostics - val string - err error - diagnostics []protocol.Diagnostic + Val string + Err error + Diagnostics []protocol.Diagnostic } -// newCache returns a document cache. -func newCache() *cache { - return &cache{ - mu: sync.RWMutex{}, - docs: make(map[protocol.DocumentURI]*document), - diagQueue: make(map[protocol.DocumentURI]struct{}), - } +// Cache caches documents. +type Cache struct { + mu sync.RWMutex + docs map[protocol.DocumentURI]*Document + topLevelObjects map[string][]*ast.DesugaredObject } -// cache caches documents. -type cache struct { - mu sync.RWMutex - docs map[protocol.DocumentURI]*document - - diagMutex sync.RWMutex - diagQueue map[protocol.DocumentURI]struct{} - diagRunning sync.Map +// New returns a document cache. +func New() *Cache { + return &Cache{ + mu: sync.RWMutex{}, + docs: make(map[protocol.DocumentURI]*Document), + topLevelObjects: make(map[string][]*ast.DesugaredObject), + } } -// put adds or replaces a document in the cache. -func (c *cache) put(new *document) error { +// Put adds or replaces a document in the cache. +func (c *Cache) Put(new *Document) error { c.mu.Lock() defer c.mu.Unlock() - uri := new.item.URI + uri := new.Item.URI if old, ok := c.docs[uri]; ok { - if old.item.Version > new.item.Version { + if old.Item.Version > new.Item.Version { return errors.New("newer version of the document is already in the cache") } } c.docs[uri] = new + // Invalidate the TopLevelObject cache + for k := range c.topLevelObjects { + if strings.HasSuffix(k, filepath.Base(uri.SpanURI().Filename())) { + delete(c.topLevelObjects, k) + } + } + return nil } -// get retrieves a document from the cache. -func (c *cache) get(uri protocol.DocumentURI) (*document, error) { - c.mu.Lock() - defer c.mu.Unlock() +// Get retrieves a document from the cache. +func (c *Cache) Get(uri protocol.DocumentURI) (*Document, error) { + c.mu.RLock() + defer c.mu.RUnlock() doc, ok := c.docs[uri] if !ok { @@ -73,11 +78,11 @@ func (c *cache) get(uri protocol.DocumentURI) (*document, error) { return doc, nil } -func (c *cache) getContents(uri protocol.DocumentURI, position protocol.Range) (string, error) { +func (c *Cache) GetContents(uri protocol.DocumentURI, position protocol.Range) (string, error) { text := "" - doc, err := c.get(uri) + doc, err := c.Get(uri) if err == nil { - text = doc.item.Text + text = doc.Item.Text } else { // Read the file from disk (TODO: cache this) bytes, err := os.ReadFile(uri.SpanURI().Filename()) @@ -118,3 +123,20 @@ func (c *cache) getContents(uri protocol.DocumentURI, position protocol.Range) ( return contentBuilder.String(), nil } + +func (c *Cache) GetTopLevelObject(filename, importedFrom string) ([]*ast.DesugaredObject, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + cacheKey := importedFrom + ":" + filename + v, ok := c.topLevelObjects[cacheKey] + return v, ok +} + +func (c *Cache) PutTopLevelObject(filename, importedFrom string, objects []*ast.DesugaredObject) { + c.mu.Lock() + defer c.mu.Unlock() + + cacheKey := importedFrom + ":" + filename + c.topLevelObjects[cacheKey] = objects +} diff --git a/pkg/server/completion.go b/pkg/server/completion.go index e1a1138..153420e 100644 --- a/pkg/server/completion.go +++ b/pkg/server/completion.go @@ -18,12 +18,12 @@ import ( ) func (s *Server) Completion(_ context.Context, params *protocol.CompletionParams) (*protocol.CompletionList, error) { - doc, err := s.cache.get(params.TextDocument.URI) + doc, err := s.cache.Get(params.TextDocument.URI) if err != nil { return nil, utils.LogErrorf("Completion: %s: %w", errorRetrievingDocument, err) } - line := getCompletionLine(doc.item.Text, params.Position) + line := getCompletionLine(doc.Item.Text, params.Position) // Short-circuit if it's a stdlib completion if items := s.completionStdLib(line); len(items) > 0 { @@ -31,18 +31,18 @@ func (s *Server) Completion(_ context.Context, params *protocol.CompletionParams } // Otherwise, parse the AST and search for completions - if doc.ast == nil { + if doc.AST == nil { log.Errorf("Completion: document was never successfully parsed, can't autocomplete") return nil, nil } - searchStack, err := processing.FindNodeByPosition(doc.ast, position.ProtocolToAST(params.Position)) + searchStack, err := processing.FindNodeByPosition(doc.AST, position.ProtocolToAST(params.Position)) if err != nil { log.Errorf("Completion: error computing node: %v", err) return nil, nil } - vm := s.getVM(doc.item.URI.SpanURI().Filename()) + vm := s.getVM(doc.Item.URI.SpanURI().Filename()) items := s.completionFromStack(line, searchStack, vm, params.Position) return &protocol.CompletionList{IsIncomplete: false, Items: items}, nil @@ -84,7 +84,8 @@ func (s *Server) completionFromStack(line string, stack *nodestack.NodeStack, vm return items } - ranges, err := processing.FindRangesFromIndexList(stack, indexes, vm, true) + processor := processing.NewProcessor(s.cache, vm) + ranges, err := processor.FindRangesFromIndexList(stack, indexes, true) if err != nil { log.Errorf("Completion: error finding ranges: %v", err) return []protocol.CompletionItem{} diff --git a/pkg/server/configuration_test.go b/pkg/server/configuration_test.go index 020b3ae..ce39f8e 100644 --- a/pkg/server/configuration_test.go +++ b/pkg/server/configuration_test.go @@ -141,10 +141,10 @@ func TestConfiguration(t *testing.T) { vm := s.getVM("any") - doc, err := s.cache.get(fileURI) + doc, err := s.cache.Get(fileURI) assert.NoError(t, err) - json, err := vm.Evaluate(doc.ast) + json, err := vm.Evaluate(doc.AST) assert.NoError(t, err) assert.JSONEq(t, tc.expectedFileOutput, json) }) diff --git a/pkg/server/definition.go b/pkg/server/definition.go index 071d585..e64a47a 100644 --- a/pkg/server/definition.go +++ b/pkg/server/definition.go @@ -39,21 +39,21 @@ func (s *Server) Definition(_ context.Context, params *protocol.DefinitionParams } func (s *Server) definitionLink(params *protocol.DefinitionParams) ([]protocol.DefinitionLink, error) { - doc, err := s.cache.get(params.TextDocument.URI) + doc, err := s.cache.Get(params.TextDocument.URI) if err != nil { return nil, utils.LogErrorf("Definition: %s: %w", errorRetrievingDocument, err) } // Only find definitions, if the the line we're trying to find a definition for hasn't changed since last successful AST parse - if doc.ast == nil { + if doc.AST == nil { return nil, utils.LogErrorf("Definition: document was never successfully parsed, can't find definitions") } - if doc.linesChangedSinceAST[int(params.Position.Line)] { + if doc.LinesChangedSinceAST[int(params.Position.Line)] { return nil, utils.LogErrorf("Definition: document line %d was changed since last successful parse, can't find definitions", params.Position.Line) } - vm := s.getVM(doc.item.URI.SpanURI().Filename()) - responseDefLinks, err := findDefinition(doc.ast, params, vm) + vm := s.getVM(doc.Item.URI.SpanURI().Filename()) + responseDefLinks, err := s.findDefinition(doc.AST, params, vm) if err != nil { return nil, err } @@ -61,8 +61,9 @@ func (s *Server) definitionLink(params *protocol.DefinitionParams) ([]protocol.D return responseDefLinks, nil } -func findDefinition(root ast.Node, params *protocol.DefinitionParams, vm *jsonnet.VM) ([]protocol.DefinitionLink, error) { +func (s *Server) findDefinition(root ast.Node, params *protocol.DefinitionParams, vm *jsonnet.VM) ([]protocol.DefinitionLink, error) { var response []protocol.DefinitionLink + processor := processing.NewProcessor(s.cache, vm) searchStack, _ := processing.FindNodeByPosition(root, position.ProtocolToAST(params.Position)) deepestNode := searchStack.Pop() @@ -93,7 +94,7 @@ func findDefinition(root ast.Node, params *protocol.DefinitionParams, vm *jsonne indexSearchStack := nodestack.NewNodeStack(deepestNode) indexList := indexSearchStack.BuildIndexList() tempSearchStack := *searchStack - objectRanges, err := processing.FindRangesFromIndexList(&tempSearchStack, indexList, vm, false) + objectRanges, err := processor.FindRangesFromIndexList(&tempSearchStack, indexList, false) if err != nil { return nil, err } diff --git a/pkg/server/diagnostics.go b/pkg/server/diagnostics.go index a13e7b9..283ff34 100644 --- a/pkg/server/diagnostics.go +++ b/pkg/server/diagnostics.go @@ -10,6 +10,7 @@ import ( "time" "github.com/google/go-jsonnet/linter" + "github.com/grafana/jsonnet-language-server/pkg/cache" position "github.com/grafana/jsonnet-language-server/pkg/position_conversion" "github.com/jdbaldry/go-language-server-protocol/lsp/protocol" log "github.com/sirupsen/logrus" @@ -74,25 +75,25 @@ func parseErrRegexpMatch(match []string) (string, protocol.Range) { } func (s *Server) queueDiagnostics(uri protocol.DocumentURI) { - s.cache.diagMutex.Lock() - defer s.cache.diagMutex.Unlock() - s.cache.diagQueue[uri] = struct{}{} + s.diagMutex.Lock() + defer s.diagMutex.Unlock() + s.diagQueue[uri] = struct{}{} } func (s *Server) diagnosticsLoop() { go func() { for { - s.cache.diagMutex.Lock() - for uri := range s.cache.diagQueue { - if _, ok := s.cache.diagRunning.Load(uri); ok { + s.diagMutex.Lock() + for uri := range s.diagQueue { + if _, ok := s.diagRunning.Load(uri); ok { continue } go func() { - s.cache.diagRunning.Store(uri, true) + s.diagRunning.Store(uri, true) log.Debug("Publishing diagnostics for ", uri) - doc, err := s.cache.get(uri) + doc, err := s.cache.Get(uri) if err != nil { log.Errorf("publishDiagnostics: %s: %v\n", errorRetrievingDocument, err) return @@ -133,30 +134,30 @@ func (s *Server) diagnosticsLoop() { log.Errorf("publishDiagnostics: unable to publish diagnostics: %v\n", err) } - doc.diagnostics = diags + doc.Diagnostics = diags log.Debug("Done publishing diagnostics for ", uri) - s.cache.diagRunning.Delete(uri) + s.diagRunning.Delete(uri) }() - delete(s.cache.diagQueue, uri) + delete(s.diagQueue, uri) } - s.cache.diagMutex.Unlock() + s.diagMutex.Unlock() time.Sleep(1 * time.Second) } }() } -func (s *Server) getEvalDiags(doc *document) (diags []protocol.Diagnostic) { - if doc.err == nil && s.configuration.EnableEvalDiagnostics { - vm := s.getVM(doc.item.URI.SpanURI().Filename()) - doc.val, doc.err = vm.EvaluateAnonymousSnippet(doc.item.URI.SpanURI().Filename(), doc.item.Text) +func (s *Server) getEvalDiags(doc *cache.Document) (diags []protocol.Diagnostic) { + if doc.Err == nil && s.configuration.EnableEvalDiagnostics { + vm := s.getVM(doc.Item.URI.SpanURI().Filename()) + doc.Val, doc.Err = vm.EvaluateAnonymousSnippet(doc.Item.URI.SpanURI().Filename(), doc.Item.Text) } - if doc.err != nil { + if doc.Err != nil { diag := protocol.Diagnostic{Source: "jsonnet evaluation"} - lines := strings.Split(doc.err.Error(), "\n") + lines := strings.Split(doc.Err.Error(), "\n") if len(lines) == 0 { log.Errorf("publishDiagnostics: expected at least two lines of Jsonnet evaluation error output, got: %v\n", lines) return diags @@ -173,7 +174,7 @@ func (s *Server) getEvalDiags(doc *document) (diags []protocol.Diagnostic) { message, rang := parseErrRegexpMatch(match) if runtimeErr { - diag.Message = doc.err.Error() + diag.Message = doc.Err.Error() diag.Severity = protocol.SeverityWarning } else { diag.Message = message @@ -187,7 +188,7 @@ func (s *Server) getEvalDiags(doc *document) (diags []protocol.Diagnostic) { return diags } -func (s *Server) getLintDiags(doc *document) (diags []protocol.Diagnostic) { +func (s *Server) getLintDiags(doc *cache.Document) (diags []protocol.Diagnostic) { result, err := s.lintWithRecover(doc) if err != nil { log.Errorf("getLintDiags: %s: %v\n", errorRetrievingDocument, err) @@ -202,18 +203,18 @@ func (s *Server) getLintDiags(doc *document) (diags []protocol.Diagnostic) { return diags } -func (s *Server) lintWithRecover(doc *document) (result string, err error) { +func (s *Server) lintWithRecover(doc *cache.Document) (result string, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("error linting: %v", r) } }() - vm := s.getVM(doc.item.URI.SpanURI().Filename()) + vm := s.getVM(doc.Item.URI.SpanURI().Filename()) buf := &bytes.Buffer{} linter.LintSnippet(vm, buf, []linter.Snippet{ - {FileName: doc.item.URI.SpanURI().Filename(), Code: doc.item.Text}, + {FileName: doc.Item.URI.SpanURI().Filename(), Code: doc.Item.Text}, }) result = buf.String() diff --git a/pkg/server/diagnostics_test.go b/pkg/server/diagnostics_test.go index 6299018..ee4eabb 100644 --- a/pkg/server/diagnostics_test.go +++ b/pkg/server/diagnostics_test.go @@ -57,7 +57,7 @@ local unused = 'test'; for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s, fileURI := testServerWithFile(t, nil, tc.fileContent) - doc, err := s.cache.get(fileURI) + doc, err := s.cache.Get(fileURI) if err != nil { t.Fatalf("%s: %v", errorRetrievingDocument, err) } @@ -145,7 +145,7 @@ func TestGetEvalDiags(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s, fileURI := testServerWithFile(t, nil, tc.fileContent) - doc, err := s.cache.get(fileURI) + doc, err := s.cache.Get(fileURI) if err != nil { t.Fatalf("%s: %v", errorRetrievingDocument, err) } diff --git a/pkg/server/execute.go b/pkg/server/execute.go index e5956af..ce76df0 100644 --- a/pkg/server/execute.go +++ b/pkg/server/execute.go @@ -43,12 +43,12 @@ func (s *Server) evalItem(params *protocol.ExecuteCommandParams) (interface{}, e return nil, fmt.Errorf("failed to unmarshal position: %v", err) } - doc, err := s.cache.get(protocol.URIFromPath(fileName)) + doc, err := s.cache.Get(protocol.URIFromPath(fileName)) if err != nil { return nil, utils.LogErrorf("evalItem: %s: %w", errorRetrievingDocument, err) } - stack, err := processing.FindNodeByPosition(doc.ast, position.ProtocolToAST(p)) + stack, err := processing.FindNodeByPosition(doc.AST, position.ProtocolToAST(p)) if err != nil { return nil, err } diff --git a/pkg/server/formatting.go b/pkg/server/formatting.go index fa3cd21..019ced9 100644 --- a/pkg/server/formatting.go +++ b/pkg/server/formatting.go @@ -12,18 +12,18 @@ import ( ) func (s *Server) Formatting(_ context.Context, params *protocol.DocumentFormattingParams) ([]protocol.TextEdit, error) { - doc, err := s.cache.get(params.TextDocument.URI) + doc, err := s.cache.Get(params.TextDocument.URI) if err != nil { return nil, utils.LogErrorf("Formatting: %s: %w", errorRetrievingDocument, err) } - formatted, err := formatter.Format(params.TextDocument.URI.SpanURI().Filename(), doc.item.Text, s.configuration.FormattingOptions) + formatted, err := formatter.Format(params.TextDocument.URI.SpanURI().Filename(), doc.Item.Text, s.configuration.FormattingOptions) if err != nil { log.Errorf("error formatting document: %v", err) return nil, nil } - return getTextEdits(doc.item.Text, formatted), nil + return getTextEdits(doc.Item.Text, formatted), nil } func getTextEdits(before, after string) []protocol.TextEdit { diff --git a/pkg/server/hover.go b/pkg/server/hover.go index f3e95cc..f3b4a4f 100644 --- a/pkg/server/hover.go +++ b/pkg/server/hover.go @@ -14,18 +14,18 @@ import ( ) func (s *Server) Hover(_ context.Context, params *protocol.HoverParams) (*protocol.Hover, error) { - doc, err := s.cache.get(params.TextDocument.URI) + doc, err := s.cache.Get(params.TextDocument.URI) if err != nil { return nil, utils.LogErrorf("Hover: %s: %w", errorRetrievingDocument, err) } - if doc.err != nil { + if doc.Err != nil { // Hover triggers often. Throwing an error on each request is noisy log.Errorf("Hover: %s", errorParsingDocument) return nil, nil } - stack, err := processing.FindNodeByPosition(doc.ast, position.ProtocolToAST(params.Position)) + stack, err := processing.FindNodeByPosition(doc.AST, position.ProtocolToAST(params.Position)) if err != nil { return nil, err } @@ -41,7 +41,7 @@ func (s *Server) Hover(_ context.Context, params *protocol.HoverParams) (*protoc _, isVar := node.(*ast.Var) lineIndex := uint32(node.Loc().Begin.Line) - 1 startIndex := uint32(node.Loc().Begin.Column) - 1 - line := strings.Split(doc.item.Text, "\n")[lineIndex] + line := strings.Split(doc.Item.Text, "\n")[lineIndex] if (isIndex || isVar) && strings.HasPrefix(line[startIndex:], "std") { functionNameIndex := startIndex + 4 if functionNameIndex < uint32(len(line)) { @@ -67,7 +67,7 @@ func (s *Server) Hover(_ context.Context, params *protocol.HoverParams) (*protoc definitionParams := &protocol.DefinitionParams{ TextDocumentPositionParams: params.TextDocumentPositionParams, } - definitions, err := findDefinition(doc.ast, definitionParams, s.getVM(doc.item.URI.SpanURI().Filename())) + definitions, err := s.findDefinition(doc.AST, definitionParams, s.getVM(doc.Item.URI.SpanURI().Filename())) if err != nil { log.Debugf("Hover: error finding definition: %s", err) return nil, nil @@ -89,7 +89,7 @@ func (s *Server) Hover(_ context.Context, params *protocol.HoverParams) (*protoc contentBuilder.WriteString(fmt.Sprintf("## `%s`\n", header)) } - targetContent, err := s.cache.getContents(def.TargetURI, def.TargetRange) + targetContent, err := s.cache.GetContents(def.TargetURI, def.TargetRange) if err != nil { log.Debugf("Hover: error reading target content: %s", err) return nil, nil diff --git a/pkg/server/server.go b/pkg/server/server.go index abaf1cb..a4336a8 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -4,9 +4,11 @@ import ( "context" "path/filepath" "strings" + "sync" "github.com/google/go-jsonnet" "github.com/google/go-jsonnet/ast" + "github.com/grafana/jsonnet-language-server/pkg/cache" "github.com/grafana/jsonnet-language-server/pkg/stdlib" "github.com/grafana/jsonnet-language-server/pkg/utils" tankaJsonnet "github.com/grafana/tanka/pkg/jsonnet/implementations/goimpl" @@ -25,9 +27,11 @@ func NewServer(name, version string, client protocol.ClientCloser, configuration server := &Server{ name: name, version: version, - cache: newCache(), + cache: cache.New(), client: client, configuration: configuration, + + diagQueue: make(map[protocol.DocumentURI]struct{}), } return server @@ -38,10 +42,15 @@ type Server struct { name, version string stdlib []stdlib.Function - cache *cache + cache *cache.Cache client protocol.ClientCloser configuration Configuration + + // Diagnostics + diagMutex sync.RWMutex + diagQueue map[protocol.DocumentURI]struct{} + diagRunning sync.Map } func (s *Server) getVM(path string) *jsonnet.VM { @@ -69,29 +78,29 @@ func (s *Server) getVM(path string) *jsonnet.VM { func (s *Server) DidChange(_ context.Context, params *protocol.DidChangeTextDocumentParams) error { defer s.queueDiagnostics(params.TextDocument.URI) - doc, err := s.cache.get(params.TextDocument.URI) + doc, err := s.cache.Get(params.TextDocument.URI) if err != nil { return utils.LogErrorf("DidChange: %s: %w", errorRetrievingDocument, err) } - if params.TextDocument.Version > doc.item.Version && len(params.ContentChanges) != 0 { - oldText := doc.item.Text - doc.item.Text = params.ContentChanges[len(params.ContentChanges)-1].Text + if params.TextDocument.Version > doc.Item.Version && len(params.ContentChanges) != 0 { + oldText := doc.Item.Text + doc.Item.Text = params.ContentChanges[len(params.ContentChanges)-1].Text var ast ast.Node - ast, doc.err = jsonnet.SnippetToAST(doc.item.URI.SpanURI().Filename(), doc.item.Text) + ast, doc.Err = jsonnet.SnippetToAST(doc.Item.URI.SpanURI().Filename(), doc.Item.Text) // If the AST parsed correctly, set it on the document // Otherwise, keep the old AST, and find all the lines that have changed since last AST if ast != nil { - doc.ast = ast - doc.linesChangedSinceAST = map[int]bool{} + doc.AST = ast + doc.LinesChangedSinceAST = map[int]bool{} } else { splitOldText := strings.Split(oldText, "\n") - splitNewText := strings.Split(doc.item.Text, "\n") + splitNewText := strings.Split(doc.Item.Text, "\n") for index, oldLine := range splitOldText { if index >= len(splitNewText) || oldLine != splitNewText[index] { - doc.linesChangedSinceAST[index] = true + doc.LinesChangedSinceAST[index] = true } } } @@ -102,11 +111,11 @@ func (s *Server) DidChange(_ context.Context, params *protocol.DidChangeTextDocu func (s *Server) DidOpen(_ context.Context, params *protocol.DidOpenTextDocumentParams) (err error) { defer s.queueDiagnostics(params.TextDocument.URI) - doc := &document{item: params.TextDocument, linesChangedSinceAST: map[int]bool{}} + doc := &cache.Document{Item: params.TextDocument, LinesChangedSinceAST: map[int]bool{}} if params.TextDocument.Text != "" { - doc.ast, doc.err = jsonnet.SnippetToAST(params.TextDocument.URI.SpanURI().Filename(), params.TextDocument.Text) + doc.AST, doc.Err = jsonnet.SnippetToAST(params.TextDocument.URI.SpanURI().Filename(), params.TextDocument.Text) } - return s.cache.put(doc) + return s.cache.Put(doc) } func (s *Server) Initialize(_ context.Context, _ *protocol.ParamInitialize) (*protocol.InitializeResult, error) { diff --git a/pkg/server/symbols.go b/pkg/server/symbols.go index 4338c92..096f5bf 100644 --- a/pkg/server/symbols.go +++ b/pkg/server/symbols.go @@ -15,19 +15,19 @@ import ( ) func (s *Server) DocumentSymbol(_ context.Context, params *protocol.DocumentSymbolParams) ([]interface{}, error) { - doc, err := s.cache.get(params.TextDocument.URI) + doc, err := s.cache.Get(params.TextDocument.URI) if err != nil { return nil, utils.LogErrorf("DocumentSymbol: %s: %w", errorRetrievingDocument, err) } - if doc.err != nil { + if doc.Err != nil { // Returning an error too often can lead to the client killing the language server // Logging the errors is sufficient log.Errorf("DocumentSymbol: %s", errorParsingDocument) return nil, nil } - symbols := buildDocumentSymbols(doc.ast) + symbols := buildDocumentSymbols(doc.AST) result := make([]interface{}, len(symbols)) for i, symbol := range symbols {