Skip to content

Commit

Permalink
Simplify processing args (#154)
Browse files Browse the repository at this point in the history
Instead of carrying a `cache` and `vm` around on each function, create a `Processor` struct to contain those
  • Loading branch information
julienduchesne authored Aug 26, 2024
1 parent dfed920 commit 8a00f52
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 34 deletions.
38 changes: 18 additions & 20 deletions pkg/ast/processing/find_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"reflect"
"strings"

"github.com/google/go-jsonnet"
"github.com/google/go-jsonnet/ast"
"github.com/grafana/jsonnet-language-server/pkg/cache"
"github.com/grafana/jsonnet-language-server/pkg/nodestack"
log "github.com/sirupsen/logrus"
)

func FindRangesFromIndexList(cache *cache.Cache, stack *nodestack.NodeStack, indexList []string, vm *jsonnet.VM, partialMatchFields bool) ([]ObjectRange, error) {
func (p *Processor) FindRangesFromIndexList(stack *nodestack.NodeStack, indexList []string, partialMatchFields bool) ([]ObjectRange, error) {
var foundDesugaredObjects []*ast.DesugaredObject
// First element will be super, self, or var name
start, indexList := indexList[0], indexList[1:]
Expand All @@ -32,13 +30,13 @@ func FindRangesFromIndexList(cache *cache.Cache, stack *nodestack.NodeStack, ind
if _, ok := tmpStack.Peek().(*ast.Binary); ok {
tmpStack.Pop()
}
foundDesugaredObjects = filterSelfScope(FindTopLevelObjects(cache, tmpStack, vm))
foundDesugaredObjects = filterSelfScope(p.FindTopLevelObjects(tmpStack))
case start == "std":
return nil, fmt.Errorf("cannot get definition of std lib")
case start == "$":
foundDesugaredObjects = FindTopLevelObjects(cache, nodestack.NewNodeStack(stack.From), vm)
foundDesugaredObjects = p.FindTopLevelObjects(nodestack.NewNodeStack(stack.From))
case strings.Contains(start, "."):
foundDesugaredObjects = FindTopLevelObjectsInFile(cache, vm, start, "")
foundDesugaredObjects = p.FindTopLevelObjectsInFile(start, "")

default:
if strings.Count(start, "(") == 1 && strings.Count(start, ")") == 1 {
Expand Down Expand Up @@ -66,15 +64,15 @@ func FindRangesFromIndexList(cache *cache.Cache, stack *nodestack.NodeStack, ind
foundDesugaredObjects = append(foundDesugaredObjects, bodyNode)
case *ast.Self:
tmpStack := nodestack.NewNodeStack(stack.From)
foundDesugaredObjects = FindTopLevelObjects(cache, tmpStack, vm)
foundDesugaredObjects = p.FindTopLevelObjects(tmpStack)
case *ast.Import:
filename := bodyNode.File.Value
foundDesugaredObjects = FindTopLevelObjectsInFile(cache, vm, filename, "")
foundDesugaredObjects = p.FindTopLevelObjectsInFile(filename, "")

case *ast.Index, *ast.Apply:
tempStack := nodestack.NewNodeStack(bodyNode)
indexList = append(tempStack.BuildIndexList(), indexList...)
return FindRangesFromIndexList(cache, stack, indexList, vm, partialMatchFields)
return p.FindRangesFromIndexList(stack, indexList, partialMatchFields)
case *ast.Function:
// If the function's body is an object, it means we can look for indexes within the function
if funcBody := findChildDesugaredObject(bodyNode.Body); funcBody != nil {
Expand All @@ -85,10 +83,10 @@ func FindRangesFromIndexList(cache *cache.Cache, stack *nodestack.NodeStack, ind
}
}

return extractObjectRangesFromDesugaredObjs(cache, vm, foundDesugaredObjects, indexList, partialMatchFields)
return p.extractObjectRangesFromDesugaredObjs(foundDesugaredObjects, indexList, partialMatchFields)
}

func extractObjectRangesFromDesugaredObjs(cache *cache.Cache, vm *jsonnet.VM, desugaredObjs []*ast.DesugaredObject, indexList []string, partialMatchFields bool) ([]ObjectRange, error) {
func (p *Processor) extractObjectRangesFromDesugaredObjs(desugaredObjs []*ast.DesugaredObject, indexList []string, partialMatchFields bool) ([]ObjectRange, error) {
var ranges []ObjectRange
for len(indexList) > 0 {
index := indexList[0]
Expand All @@ -112,7 +110,7 @@ func extractObjectRangesFromDesugaredObjs(cache *cache.Cache, vm *jsonnet.VM, de
return ranges, nil
}

fieldNodes, err := unpackFieldNodes(vm, foundFields)
fieldNodes, err := p.unpackFieldNodes(foundFields)
if err != nil {
return nil, err
}
Expand All @@ -126,7 +124,7 @@ func extractObjectRangesFromDesugaredObjs(cache *cache.Cache, vm *jsonnet.VM, de
// The target is a function and will be found by FindVarReference on the next loop
fieldNodes = append(fieldNodes, fieldNode.Target)
case *ast.Var:
varReference, err := FindVarReference(fieldNode, vm)
varReference, err := p.FindVarReference(fieldNode)
if err != nil {
return nil, err
}
Expand All @@ -143,11 +141,11 @@ func extractObjectRangesFromDesugaredObjs(cache *cache.Cache, vm *jsonnet.VM, de
// if we're trying to find the a definition which is an index,
// we need to find it from itself, meaning that we need to create a stack
// from the index's target and search from there
rootNode, _, _ := vm.ImportAST("", fieldNode.LocRange.FileName)
rootNode, _, _ := p.vm.ImportAST("", fieldNode.LocRange.FileName)
stack, _ := FindNodeByPosition(rootNode, fieldNode.Target.Loc().Begin)
if stack != nil {
additionalIndexList := append(nodestack.NewNodeStack(fieldNode).BuildIndexList(), indexList...)
result, _ := FindRangesFromIndexList(cache, stack, additionalIndexList, vm, partialMatchFields)
result, _ := p.FindRangesFromIndexList(stack, additionalIndexList, partialMatchFields)
if len(result) > 0 {
return result, err
}
Expand All @@ -158,7 +156,7 @@ func extractObjectRangesFromDesugaredObjs(cache *cache.Cache, vm *jsonnet.VM, de
desugaredObjs = append(desugaredObjs, findChildDesugaredObject(fieldNode.Body))
case *ast.Import:
filename := fieldNode.File.Value
newObjs := FindTopLevelObjectsInFile(cache, vm, filename, string(fieldNode.Loc().File.DiagnosticFileName))
newObjs := p.FindTopLevelObjectsInFile(filename, string(fieldNode.Loc().File.DiagnosticFileName))
desugaredObjs = append(desugaredObjs, newObjs...)
}
i++
Expand All @@ -178,13 +176,13 @@ func flattenBinary(node ast.Node) []ast.Node {
// unpackFieldNodes extracts nodes from fields
// - Binary nodes. A field could be either in the left or right side of the binary
// - Self nodes. We want the object self refers to, not the self node itself
func unpackFieldNodes(vm *jsonnet.VM, fields []*ast.DesugaredObjectField) ([]ast.Node, error) {
func (p *Processor) unpackFieldNodes(fields []*ast.DesugaredObjectField) ([]ast.Node, error) {
var fieldNodes []ast.Node
for _, foundField := range fields {
switch fieldNode := foundField.Body.(type) {
case *ast.Self:
filename := fieldNode.LocRange.FileName
rootNode, _, _ := vm.ImportAST("", filename)
rootNode, _, _ := p.vm.ImportAST("", filename)
tmpStack, err := FindNodeByPosition(rootNode, fieldNode.LocRange.Begin)
if err != nil {
return nil, err
Expand Down Expand Up @@ -254,8 +252,8 @@ func findChildDesugaredObject(node ast.Node) *ast.DesugaredObject {

// FindVarReference finds the object that the variable is referencing
// To do so, we get the stack where the var is used and search that stack for the var's definition
func FindVarReference(varNode *ast.Var, vm *jsonnet.VM) (ast.Node, error) {
varFileNode, _, _ := vm.ImportAST("", varNode.LocRange.FileName)
func (p *Processor) FindVarReference(varNode *ast.Var) (ast.Node, error) {
varFileNode, _, _ := p.vm.ImportAST("", varNode.LocRange.FileName)
varStack, err := FindNodeByPosition(varFileNode, varNode.Loc().Begin)
if err != nil {
return nil, fmt.Errorf("got the following error when finding the bind for %s: %w", varNode.Id, err)
Expand Down
18 changes: 18 additions & 0 deletions pkg/ast/processing/processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package processing

import (
"github.com/google/go-jsonnet"
"github.com/grafana/jsonnet-language-server/pkg/cache"
)

type Processor struct {
cache *cache.Cache
vm *jsonnet.VM
}

func NewProcessor(cache *cache.Cache, vm *jsonnet.VM) *Processor {
return &Processor{
cache: cache,
vm: vm,
}
}
22 changes: 10 additions & 12 deletions pkg/ast/processing/top_level_objects.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
package processing

import (
"github.com/google/go-jsonnet"
"github.com/google/go-jsonnet/ast"
"github.com/grafana/jsonnet-language-server/pkg/cache"
"github.com/grafana/jsonnet-language-server/pkg/nodestack"
log "github.com/sirupsen/logrus"
)

func FindTopLevelObjectsInFile(cache *cache.Cache, vm *jsonnet.VM, filename, importedFrom string) []*ast.DesugaredObject {
v, ok := cache.GetTopLevelObject(filename, importedFrom)
func (p *Processor) FindTopLevelObjectsInFile(filename, importedFrom string) []*ast.DesugaredObject {
v, ok := p.cache.GetTopLevelObject(filename, importedFrom)
if !ok {
rootNode, _, _ := vm.ImportAST(importedFrom, filename)
v = FindTopLevelObjects(cache, nodestack.NewNodeStack(rootNode), vm)
cache.PutTopLevelObject(filename, importedFrom, v)
rootNode, _, _ := p.vm.ImportAST(importedFrom, filename)
v = p.FindTopLevelObjects(nodestack.NewNodeStack(rootNode))
p.cache.PutTopLevelObject(filename, importedFrom, v)
}
return v
}

// Find all ast.DesugaredObject's from NodeStack
func FindTopLevelObjects(cache *cache.Cache, stack *nodestack.NodeStack, vm *jsonnet.VM) []*ast.DesugaredObject {
func (p *Processor) FindTopLevelObjects(stack *nodestack.NodeStack) []*ast.DesugaredObject {
var objects []*ast.DesugaredObject
for !stack.IsEmpty() {
curr := stack.Pop()
Expand All @@ -33,7 +31,7 @@ func FindTopLevelObjects(cache *cache.Cache, stack *nodestack.NodeStack, vm *jso
stack.Push(curr.Body)
case *ast.Import:
filename := curr.File.Value
rootNode, _, _ := vm.ImportAST(string(curr.Loc().File.DiagnosticFileName), filename)
rootNode, _, _ := p.vm.ImportAST(string(curr.Loc().File.DiagnosticFileName), filename)
stack.Push(rootNode)
case *ast.Index:
indexValue, indexIsString := curr.Index.(*ast.LiteralString)
Expand All @@ -44,7 +42,7 @@ func FindTopLevelObjects(cache *cache.Cache, stack *nodestack.NodeStack, vm *jso
var container ast.Node
// If our target is a var, the container for the index is the var ref
if varTarget, targetIsVar := curr.Target.(*ast.Var); targetIsVar {
ref, err := FindVarReference(varTarget, vm)
ref, err := p.FindVarReference(varTarget)
if err != nil {
log.WithError(err).Errorf("Error finding var reference, ignoring this node")
continue
Expand All @@ -61,7 +59,7 @@ func FindTopLevelObjects(cache *cache.Cache, stack *nodestack.NodeStack, vm *jso
if containerObj, containerIsObj := container.(*ast.DesugaredObject); containerIsObj {
possibleObjects = []*ast.DesugaredObject{containerObj}
} else if containerImport, containerIsImport := container.(*ast.Import); containerIsImport {
possibleObjects = FindTopLevelObjectsInFile(cache, vm, containerImport.File.Value, string(containerImport.Loc().File.DiagnosticFileName))
possibleObjects = p.FindTopLevelObjectsInFile(containerImport.File.Value, string(containerImport.Loc().File.DiagnosticFileName))
}

for _, obj := range possibleObjects {
Expand All @@ -70,7 +68,7 @@ func FindTopLevelObjects(cache *cache.Cache, stack *nodestack.NodeStack, vm *jso
}
}
case *ast.Var:
varReference, err := FindVarReference(curr, vm)
varReference, err := p.FindVarReference(curr)
if err != nil {
log.WithError(err).Errorf("Error finding var reference, ignoring this node")
continue
Expand Down
3 changes: 2 additions & 1 deletion pkg/server/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ func (s *Server) completionFromStack(line string, stack *nodestack.NodeStack, vm
return items
}

ranges, err := processing.FindRangesFromIndexList(s.cache, stack, indexes, vm, true)
processor := processing.NewProcessor(s.cache, vm)
ranges, err := processor.FindRangesFromIndexList(stack, indexes, true)
if err != nil {
log.Errorf("Completion: error finding ranges: %v", err)
return []protocol.CompletionItem{}
Expand Down
3 changes: 2 additions & 1 deletion pkg/server/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func (s *Server) definitionLink(params *protocol.DefinitionParams) ([]protocol.D

func (s *Server) findDefinition(root ast.Node, params *protocol.DefinitionParams, vm *jsonnet.VM) ([]protocol.DefinitionLink, error) {
var response []protocol.DefinitionLink
processor := processing.NewProcessor(s.cache, vm)

searchStack, _ := processing.FindNodeByPosition(root, position.ProtocolToAST(params.Position))
deepestNode := searchStack.Pop()
Expand Down Expand Up @@ -93,7 +94,7 @@ func (s *Server) findDefinition(root ast.Node, params *protocol.DefinitionParams
indexSearchStack := nodestack.NewNodeStack(deepestNode)
indexList := indexSearchStack.BuildIndexList()
tempSearchStack := *searchStack
objectRanges, err := processing.FindRangesFromIndexList(s.cache, &tempSearchStack, indexList, vm, false)
objectRanges, err := processor.FindRangesFromIndexList(&tempSearchStack, indexList, false)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 8a00f52

Please sign in to comment.