diff --git a/pkg/ast/processing/find_field.go b/pkg/ast/processing/find_field.go index 0fe3627..a35a38b 100644 --- a/pkg/ast/processing/find_field.go +++ b/pkg/ast/processing/find_field.go @@ -17,7 +17,7 @@ func (p *Processor) FindRangesFromIndexList(stack *nodestack.NodeStack, indexLis switch { case start == "super": // Find the LHS desugared object of a binary node - lhsObject, err := findLHSDesugaredObject(stack) + lhsObject, err := p.findLHSDesugaredObject(stack) if err != nil { return nil, err } @@ -68,16 +68,13 @@ func (p *Processor) FindRangesFromIndexList(stack *nodestack.NodeStack, indexLis case *ast.Import: filename := bodyNode.File.Value foundDesugaredObjects = p.FindTopLevelObjectsInFile(filename, "") - case *ast.Index, *ast.Apply: tempStack := nodestack.NewNodeStack(bodyNode) indexList = append(tempStack.BuildIndexList(), indexList...) return p.FindRangesFromIndexList(stack, indexList, partialMatchFields) case *ast.Function: // If the function's body is an object, it means we can look for indexes within the function - if funcBody := findChildDesugaredObject(bodyNode.Body); funcBody != nil { - foundDesugaredObjects = append(foundDesugaredObjects, funcBody) - } + foundDesugaredObjects = append(foundDesugaredObjects, p.findChildDesugaredObjects(bodyNode.Body)...) default: return nil, fmt.Errorf("unexpected node type when finding bind for '%s': %s", start, reflect.TypeOf(bind.Body)) } @@ -119,6 +116,8 @@ func (p *Processor) extractObjectRangesFromDesugaredObjs(desugaredObjs []*ast.De for i < len(fieldNodes) { fieldNode := fieldNodes[i] switch fieldNode := fieldNode.(type) { + default: + desugaredObjs = append(desugaredObjs, p.findChildDesugaredObjects(fieldNode)...) case *ast.Apply: // Add the target of the Apply to the list of field nodes to look for // The target is a function and will be found by FindVarReference on the next loop @@ -130,13 +129,12 @@ func (p *Processor) extractObjectRangesFromDesugaredObjs(desugaredObjs []*ast.De } // If the reference is an object, add it directly to the list of objects to look in // Otherwise, add it back to the list for further processing - if varReferenceObj := findChildDesugaredObject(varReference); varReferenceObj != nil { - desugaredObjs = append(desugaredObjs, varReferenceObj) + if varReferenceObjs := p.findChildDesugaredObjects(varReference); len(varReferenceObjs) > 0 { + desugaredObjs = append(desugaredObjs, varReferenceObjs...) } else { fieldNodes = append(fieldNodes, varReference) } - case *ast.DesugaredObject: - desugaredObjs = append(desugaredObjs, fieldNode) + case *ast.Index: // if we're trying to find the a definition which is an index, // we need to find it from itself, meaning that we need to create a stack @@ -153,11 +151,27 @@ func (p *Processor) extractObjectRangesFromDesugaredObjs(desugaredObjs []*ast.De fieldNodes = append(fieldNodes, fieldNode.Target) case *ast.Function: - desugaredObjs = append(desugaredObjs, findChildDesugaredObject(fieldNode.Body)) + fieldNodes = append(fieldNodes, fieldNode.Body) case *ast.Import: filename := fieldNode.File.Value newObjs := p.FindTopLevelObjectsInFile(filename, string(fieldNode.Loc().File.DiagnosticFileName)) desugaredObjs = append(desugaredObjs, newObjs...) + case *ast.Binary: + fieldNodes = append(fieldNodes, flattenBinary(fieldNode)...) + case *ast.Self: + filename := fieldNode.LocRange.FileName + rootNode, _, _ := p.vm.ImportAST("", filename) + tmpStack, err := FindNodeByPosition(rootNode, fieldNode.LocRange.Begin) + if err != nil { + return nil, err + } + for !tmpStack.IsEmpty() { + node := tmpStack.Pop() + if castNode, ok := node.(*ast.DesugaredObject); ok { + desugaredObjs = append(desugaredObjs, castNode) + break + } + } } i++ } @@ -234,17 +248,15 @@ func findObjectFieldsInObject(objectNode *ast.DesugaredObject, index string, par return matchingFields } -func findChildDesugaredObject(node ast.Node) *ast.DesugaredObject { +func (p *Processor) findChildDesugaredObjects(node ast.Node) []*ast.DesugaredObject { switch node := node.(type) { case *ast.DesugaredObject: - return node + return []*ast.DesugaredObject{node} case *ast.Binary: - if res := findChildDesugaredObject(node.Left); res != nil { - return res - } - if res := findChildDesugaredObject(node.Right); res != nil { - return res - } + var res []*ast.DesugaredObject + res = append(res, p.findChildDesugaredObjects(node.Left)...) + res = append(res, p.findChildDesugaredObjects(node.Right)...) + return res } return nil } @@ -264,7 +276,7 @@ func (p *Processor) FindVarReference(varNode *ast.Var) (ast.Node, error) { return bind.Body, nil } -func findLHSDesugaredObject(stack *nodestack.NodeStack) (*ast.DesugaredObject, error) { +func (p *Processor) findLHSDesugaredObject(stack *nodestack.NodeStack) (*ast.DesugaredObject, error) { for !stack.IsEmpty() { curr := stack.Pop() switch curr := curr.(type) { @@ -276,8 +288,8 @@ func findLHSDesugaredObject(stack *nodestack.NodeStack) (*ast.DesugaredObject, e case *ast.Var: bind := FindBindByIDViaStack(stack, lhsNode.Id) if bind != nil { - if bindBody := findChildDesugaredObject(bind.Body); bindBody != nil { - return bindBody, nil + if binds := p.findChildDesugaredObjects(bind.Body); len(binds) > 0 { + return binds[0], nil } } } diff --git a/pkg/server/definition_test.go b/pkg/server/definition_test.go index 5789aa3..724ef78 100644 --- a/pkg/server/definition_test.go +++ b/pkg/server/definition_test.go @@ -973,6 +973,21 @@ var definitionTestCases = []definitionTestCase{ }, }}, }, + { + name: "goto builder pattern function", + filename: "testdata/goto-builder-pattern.jsonnet", + position: protocol.Position{Line: 21, Character: 62}, + results: []definitionResult{{ + targetRange: protocol.Range{ + Start: protocol.Position{Line: 16, Character: 6}, + End: protocol.Position{Line: 16, Character: 51}, + }, + targetSelectionRange: protocol.Range{ + Start: protocol.Position{Line: 16, Character: 6}, + End: protocol.Position{Line: 16, Character: 11}, + }, + }}, + }, } func TestDefinition(t *testing.T) { diff --git a/pkg/server/testdata/goto-builder-pattern.jsonnet b/pkg/server/testdata/goto-builder-pattern.jsonnet new file mode 100644 index 0000000..a5d297c --- /dev/null +++ b/pkg/server/testdata/goto-builder-pattern.jsonnet @@ -0,0 +1,23 @@ +{ + util:: { + new():: { + local this = self, + + attr: 'unset1', + attr2: 'unset2', + + withAttr(v):: self { // Intentionally using `self` instead of `this` + attr: v, + }, + + withAttr2(v):: this { // Intentionally using `this` instead of `self` + attr2: v, + }, + + build():: '%s + %s' % [self.attr, this.attr2], + }, + }, + + + test: self.util.new().withAttr('hello').withAttr2('world').build(), +}