Skip to content

Commit

Permalink
Merge pull request #2915 from darkdrag00nv2/range_type_check_paramete…
Browse files Browse the repository at this point in the history
…rized_type_recursive

Make checkParameterizedTypeIsInstantiated recursive
  • Loading branch information
SupunS authored Dec 11, 2023
2 parents 5d74897 + f82c07f commit 2a2c540
Show file tree
Hide file tree
Showing 5 changed files with 939 additions and 120 deletions.
45 changes: 1 addition & 44 deletions runtime/sema/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,7 @@ func (checker *Checker) checkTypeAnnotation(typeAnnotation TypeAnnotation, pos a
}

checker.checkInvalidInterfaceAsType(typeAnnotation.Type, pos)
checker.checkParameterizedTypeIsInstantiated(typeAnnotation.Type, pos)
typeAnnotation.Type.CheckInstantiated(pos, checker.memoryGauge, checker.report)
}

func (checker *Checker) checkInvalidInterfaceAsType(ty Type, pos ast.HasPosition) {
Expand All @@ -2068,49 +2068,6 @@ func (checker *Checker) checkInvalidInterfaceAsType(ty Type, pos ast.HasPosition
}
}

func (checker *Checker) checkParameterizedTypeIsInstantiated(ty Type, pos ast.HasPosition) {
parameterizedType, ok := ty.(ParameterizedType)
if !ok {
return
}

typeArgs := parameterizedType.TypeArguments()
typeParameters := parameterizedType.TypeParameters()

typeArgumentCount := len(typeArgs)
typeParameterCount := len(typeParameters)

if typeArgumentCount != typeParameterCount {
checker.report(
&InvalidTypeArgumentCountError{
TypeParameterCount: typeParameterCount,
TypeArgumentCount: typeArgumentCount,
Range: ast.NewRange(
checker.memoryGauge,
pos.StartPosition(),
pos.EndPosition(checker.memoryGauge),
),
},
)
}

// Ensure that each non-optional typeparameter is non-nil.
for index, typeParam := range typeParameters {
if !typeParam.Optional && typeArgs[index] == nil {
checker.report(
&MissingTypeArgumentError{
TypeArgumentName: typeParam.Name,
Range: ast.NewRange(
checker.memoryGauge,
pos.StartPosition(),
pos.EndPosition(checker.memoryGauge),
),
},
)
}
}
}

func (checker *Checker) ValueActivationDepth() int {
return checker.valueActivations.Depth()
}
Expand Down
4 changes: 4 additions & 0 deletions runtime/sema/simple_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ func (t *SimpleType) CompositeKind() common.CompositeKind {
return common.CompositeKindStructure
}
}

func (t *SimpleType) CheckInstantiated(_ ast.HasPosition, _ common.MemoryGauge, _ func(err error)) {
// NO-OP
}
125 changes: 125 additions & 0 deletions runtime/sema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ type Type interface {
Resolve(typeArguments *TypeParameterTypeOrderedMap) Type

GetMembers() map[string]MemberResolver

CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error))
}

// ValueIndexableType is a type which can be indexed into using a value
Expand Down Expand Up @@ -293,6 +295,34 @@ func MustInstantiate(t ParameterizedType, typeArguments ...Type) Type {
)
}

func CheckParameterizedTypeInstantiated(
t ParameterizedType,
pos ast.HasPosition,
memoryGauge common.MemoryGauge,
report func(err error),
) {
typeArgs := t.TypeArguments()
typeParameters := t.TypeParameters()

// The check for the argument and parameter count already happens in the checker, so we skip that here.

// Ensure that each non-optional typeparameter is non-nil.
for index, typeParam := range typeParameters {
if !typeParam.Optional && typeArgs[index] == nil {
report(
&MissingTypeArgumentError{
TypeArgumentName: typeParam.Name,
Range: ast.NewRange(
memoryGauge,
pos.StartPosition(),
pos.EndPosition(memoryGauge),
),
},
)
}
}
}

// TypeAnnotation

type TypeAnnotation struct {
Expand Down Expand Up @@ -694,6 +724,10 @@ func (t *OptionalType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type
}
}

func (t *OptionalType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
t.Type.CheckInstantiated(pos, memoryGauge, report)
}

const optionalTypeMapFunctionDocString = `
Returns an optional of the result of calling the given function
with the value of this optional when it is not nil.
Expand Down Expand Up @@ -903,6 +937,10 @@ func (t *GenericType) GetMembers() map[string]MemberResolver {
return withBuiltinMembers(t, nil)
}

func (t *GenericType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
t.TypeParameter.TypeBound.CheckInstantiated(pos, memoryGauge, report)
}

// IntegerRangedType

type IntegerRangedType interface {
Expand Down Expand Up @@ -1176,6 +1214,10 @@ func (t *NumericType) IsSuperType() bool {
return t.isSuperType
}

func (*NumericType) CheckInstantiated(_ ast.HasPosition, _ common.MemoryGauge, _ func(err error)) {
// NO-OP
}

// FixedPointNumericType represents all the types in the fixed-point range.
type FixedPointNumericType struct {
maxFractional *big.Int
Expand Down Expand Up @@ -1369,6 +1411,10 @@ func (t *FixedPointNumericType) IsSuperType() bool {
return t.isSuperType
}

func (*FixedPointNumericType) CheckInstantiated(_ ast.HasPosition, _ common.MemoryGauge, _ func(err error)) {
// NO-OP
}

// Numeric types

var (
Expand Down Expand Up @@ -2551,6 +2597,10 @@ func (t *VariableSizedType) Resolve(typeArguments *TypeParameterTypeOrderedMap)
}
}

func (t *VariableSizedType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
t.ElementType(false).CheckInstantiated(pos, memoryGauge, report)
}

// ConstantSizedType is a constant sized array type
type ConstantSizedType struct {
Type Type
Expand Down Expand Up @@ -2707,6 +2757,10 @@ func (t *ConstantSizedType) Resolve(typeArguments *TypeParameterTypeOrderedMap)
}
}

func (t *ConstantSizedType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
t.ElementType(false).CheckInstantiated(pos, memoryGauge, report)
}

// Parameter

func formatParameter(spaces bool, label, identifier, typeAnnotation string) string {
Expand Down Expand Up @@ -3372,6 +3426,18 @@ func (t *FunctionType) initializeMemberResolvers() {
})
}

func (t *FunctionType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
for _, tyParam := range t.TypeParameters {
tyParam.TypeBound.CheckInstantiated(pos, memoryGauge, report)
}

for _, param := range t.Parameters {
param.TypeAnnotation.Type.CheckInstantiated(pos, memoryGauge, report)
}

t.ReturnTypeAnnotation.Type.CheckInstantiated(pos, memoryGauge, report)
}

type ArgumentExpressionsCheck func(
checker *Checker,
argumentExpressions []ast.Expression,
Expand Down Expand Up @@ -4362,6 +4428,24 @@ func (t *CompositeType) SetNestedType(name string, nestedType ContainedType) {
nestedType.SetContainerType(t)
}

func (t *CompositeType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
if t.EnumRawType != nil {
t.EnumRawType.CheckInstantiated(pos, memoryGauge, report)
}

if t.baseType != nil {
t.baseType.CheckInstantiated(pos, memoryGauge, report)
}

for _, typ := range t.ImplicitTypeRequirementConformances {
typ.CheckInstantiated(pos, memoryGauge, report)
}

for _, typ := range t.ExplicitInterfaceConformances {
typ.CheckInstantiated(pos, memoryGauge, report)
}
}

// Member

type Member struct {
Expand Down Expand Up @@ -4848,6 +4932,12 @@ func (t *InterfaceType) FieldPosition(name string, declaration *ast.InterfaceDec
return declaration.Members.FieldPosition(name, declaration.CompositeKind)
}

func (t *InterfaceType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
for _, param := range t.InitializerParameters {
param.TypeAnnotation.Type.CheckInstantiated(pos, memoryGauge, report)
}
}

// DictionaryType consists of the key and value type
// for all key-value pairs in the dictionary:
// All keys have to be a subtype of the key type,
Expand Down Expand Up @@ -4977,6 +5067,11 @@ func (t *DictionaryType) RewriteWithRestrictedTypes() (Type, bool) {
}
}

func (t *DictionaryType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
t.KeyType.CheckInstantiated(pos, memoryGauge, report)
t.ValueType.CheckInstantiated(pos, memoryGauge, report)
}

const dictionaryTypeContainsKeyFunctionDocString = `
Returns true if the given key is in the dictionary
`
Expand Down Expand Up @@ -5448,6 +5543,10 @@ func (t *InclusiveRangeType) TypeArguments() []Type {
}
}

func (t *InclusiveRangeType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
CheckParameterizedTypeInstantiated(t, pos, memoryGauge, report)
}

var inclusiveRangeTypeParameter = &TypeParameter{
Name: "T",
TypeBound: IntegerType,
Expand Down Expand Up @@ -5806,6 +5905,10 @@ func (t *ReferenceType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type
}
}

func (t *ReferenceType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
t.Type.CheckInstantiated(pos, memoryGauge, report)
}

const AddressTypeName = "Address"

// AddressType represents the address type
Expand Down Expand Up @@ -5901,6 +6004,10 @@ func (t *AddressType) Resolve(_ *TypeParameterTypeOrderedMap) Type {
return t
}

func (*AddressType) CheckInstantiated(_ ast.HasPosition, _ common.MemoryGauge, _ func(err error)) {
// NO-OP
}

const AddressTypeToBytesFunctionName = `toBytes`

var AddressTypeToBytesFunctionType = &FunctionType{
Expand Down Expand Up @@ -6816,6 +6923,16 @@ func (t *TransactionType) Resolve(_ *TypeParameterTypeOrderedMap) Type {
return t
}

func (t *TransactionType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
for _, param := range t.PrepareParameters {
param.TypeAnnotation.Type.CheckInstantiated(pos, memoryGauge, report)
}

for _, param := range t.Parameters {
param.TypeAnnotation.Type.CheckInstantiated(pos, memoryGauge, report)
}
}

// RestrictedType
//
// No restrictions implies the type is fully restricted,
Expand Down Expand Up @@ -7121,6 +7238,10 @@ func (t *RestrictedType) IsValidIndexingType(ty Type) bool {
attachmentType.IsResourceType() == t.IsResourceType()
}

func (t *RestrictedType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
t.Type.CheckInstantiated(pos, memoryGauge, report)
}

// CapabilityType

type CapabilityType struct {
Expand Down Expand Up @@ -7324,6 +7445,10 @@ func (t *CapabilityType) TypeArguments() []Type {
}
}

func (t *CapabilityType) CheckInstantiated(pos ast.HasPosition, memoryGauge common.MemoryGauge, report func(err error)) {
CheckParameterizedTypeInstantiated(t, pos, memoryGauge, report)
}

func CapabilityTypeBorrowFunctionType(borrowType Type) *FunctionType {

var typeParameters []*TypeParameter
Expand Down
22 changes: 12 additions & 10 deletions runtime/tests/checker/range_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,15 @@ func TestInclusiveRangeNonLeafIntegerTypes(t *testing.T) {

t.Parallel()

baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation)
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

options := ParseAndCheckOptions{
Config: &sema.Config{
BaseValueActivation: baseValueActivation,
},
newOptions := func() ParseAndCheckOptions {
baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation)
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

return ParseAndCheckOptions{
Config: &sema.Config{
BaseValueActivation: baseValueActivation,
},
}
}

test := func(t *testing.T, ty sema.Type) {
Expand All @@ -421,7 +423,7 @@ func TestInclusiveRangeNonLeafIntegerTypes(t *testing.T) {
let a: %[1]s = 0
let b: %[1]s = 10
var range = InclusiveRange<%[1]s>(a, b)
`, ty), options)
`, ty), newOptions())

errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
Expand All @@ -434,7 +436,7 @@ func TestInclusiveRangeNonLeafIntegerTypes(t *testing.T) {
let a: %[1]s = 0
let b: %[1]s = 10
var range: InclusiveRange<%[1]s> = InclusiveRange<%[1]s>(a, b)
`, ty), options)
`, ty), newOptions())

// One for the invocation and another for the type.
errs := RequireCheckerErrors(t, err, 2)
Expand All @@ -448,7 +450,7 @@ func TestInclusiveRangeNonLeafIntegerTypes(t *testing.T) {
_, err := ParseAndCheckWithOptions(t, fmt.Sprintf(`
let a: InclusiveRange<Int> = InclusiveRange(0, 10)
let b: InclusiveRange<%s> = a
`, ty), options)
`, ty), newOptions())

errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
Expand Down
Loading

0 comments on commit 2a2c540

Please sign in to comment.