Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify only one field is filled out in "enums" #349

Open
ethanfrey opened this issue Sep 5, 2022 · 4 comments
Open

Verify only one field is filled out in "enums" #349

ethanfrey opened this issue Sep 5, 2022 · 4 comments

Comments

@ethanfrey
Copy link
Member

ethanfrey commented Sep 5, 2022

In Rust, we use enums or "union types" to set exactly one field of many possibilities. This is enforced by internal data structures as well as the JSON parser.

In Go, we use a struct with many fields to represent this, like CosmosMsg or QueryRequest (and their sub-types). If No fields or multiple fields are filled out, this may introduce some logical errors later on in the consumer, such as this reported error CosmWasm/wasmd#931 (which never happens when coming from the valid Rust type).

To eliminate this class of error and possible attack surface, we should enforce that these Go structs are actually enums (exactly one field is set). IMO, we should add some "Validate" method to do so, but more importantly, auto-execute the validate method in JSON unmarshalling. JSON unmarshalling catches all the cases where this unvalidated data is imported from an untrusted contract and we should make it safe by default. Exposing that same logic via a "Validate" method is mainly to allow some assertions in unit tests than manually construct some objects.

@ethanfrey
Copy link
Member Author

I think it would look something like this:

type cosmosMsg CosmosMsg

func (m *CosmosMsg) UnmarshalJSON(data []byte) error {
        var raw cosmosMsg
        if err := json.Unmarshal(&raw, data); err != nil {
          return err
        }
        *m = raw
        return m.Validate()
}

(Maybe invalid go code, but you get the idea)

@alpe what do you think?

@alpe
Copy link
Contributor

alpe commented Sep 13, 2022

It would help to have the assumptions on types also covered in Go. We may see other languages or frameworks in the future so that we are on the safe side with this

On the details, a custom unmarshal function does work but calling a ValidateBasic()error right after the unmarshal may be a simpler solution as you would not have to maintain a data type. An interface could be introduced for this

@ethanfrey
Copy link
Member Author

but calling a ValidateBasic()error right after the unmarshal may be a simpler solution as you would not have to maintain a data type.

Slightly simpler to implement, but very easy for the caller to forget. I just suggest modifying the unmarshaller to use standard json and then auto-validate. This would fix all places in wasmd where we might forget

@webmaster128
Copy link
Member

webmaster128 commented Jan 4, 2023

I built a generic UnmarshallEnum which does more or less the same as json.Unmarshal but it ensures only one top level item is set. It can be used like this in a pretty brain-dead copy/paste fashion:

func (c *CosmosMsg) UnmarshalJSON(data []byte) error {
	var d CosmosMsg
	if err := UnmarshallEnum(data, &d); err != nil {
		return err
	}
	*c = d
	return nil
}

The implementation uses reflect and then relies on the default JSON unmarshalling to do the nested decoding.
This would need some more cleanup and testing, especially for mixing pointer and non-pointer in an enum, but I think this is doing the job quite well. This does not try to guarantee anything about instance validation but ensures at JSON level that only one field is set.

package main

import (
	"bytes"
	"encoding/json"
	"fmt"
	"log"
	"reflect"
	"strconv"
	"strings"
)

type BankMsg struct {
	Send *SendMsg `json:"send,omitempty"`
	Burn *BurnMsg `json:"burn,omitempty"`
}

// SendMsg contains instructions for a Cosmos-SDK/SendMsg
// It has a fixed interface here and should be converted into the proper SDK format before dispatching
type SendMsg struct {
	ToAddress string `json:"to_address"`
	Amount    Coins  `json:"amount"`
}

// BurnMsg will burn the given coins from the contract's account.
// There is no Cosmos SDK message that performs this, but it can be done by calling the bank keeper.
// Important if a contract controls significant token supply that must be retired.
type BurnMsg struct {
	Amount Coins `json:"amount"`
}

// Coin is a string representation of the sdk.Coin type (more portable than sdk.Int)
type Coin struct {
	Denom  string `json:"denom"`  // type, eg. "ATOM"
	Amount string `json:"amount"` // string encoing of decimal value, eg. "12.3456"
}

func NewCoin(amount uint64, denom string) Coin {
	return Coin{
		Denom:  denom,
		Amount: strconv.FormatUint(amount, 10),
	}
}

// Coins handles properly serializing empty amounts
type Coins []Coin

func UnmarshallEnum(data []byte, out any) error {

	// Reset struct to zero state
	val := reflect.ValueOf(out).Elem()
	val.Set(reflect.Zero(val.Type()))

	dec := json.NewDecoder(bytes.NewReader(data))

	// read open bracket
	token, err := dec.Token()
	if err != nil {
		log.Fatal("Error reading top level object: ", err)
	}
	// fmt.Printf("(Top level): %T: %v\n", token, token)

	// a map from JSON field name to instance field index
	outFields := make(map[string]int)
	outTypes := make(map[string]reflect.Type)

	for i := 0; i < val.Type().NumField(); i++ {
		field := val.Type().Field(i)
		fieldName := field.Name
		// fmt.Printf("field %d: %+v\n", i, field)

		switch jsonTag := field.Tag.Get("json"); jsonTag {
		case "-":
		case "":
			outFields[fieldName] = i
			outTypes[fieldName] = field.Type
		default:
			parts := strings.Split(jsonTag, ",")
			name := parts[0]
			if name == "" {
				name = fieldName
			}
			outFields[name] = i
			outTypes[name] = field.Type
		}
	}
	// fmt.Printf("Output instance fields: %+v\n", outFields)
	// fmt.Printf("Output instance types: %+v\n", outTypes)

	token, err = dec.Token()
	if err != nil {
		log.Fatal(err)
	}
	// fmt.Printf("(Top level): %T: %v\n", token, token)
	tokenStr := token.(string)
	t, found := outTypes[tokenStr]
	if !found {
		log.Fatal("Found token that does not match field name", token)
	} else {
		switch t.Kind() {
		case reflect.Struct:
			// nothing to do
			// fmt.Printf("  Found struct type %v\n", t)
		case reflect.Pointer:
			// fmt.Printf("  Found pointer type %v\n", t)
			t = t.Elem()
			// fmt.Printf("  Converted to %v\n", t)
		default:
			return fmt.Errorf("Found unsupported struct field kind: %v", t.Kind())
		}

		targetInstancePtr := reflect.New(t)
		err := dec.Decode(targetInstancePtr.Interface())
		if err != nil {
			log.Fatal("  Error decoding inner message: ", err)
		}
		structField := val.Field(outFields[tokenStr])
		// fmt.Printf("  structField: %+v\n", structField)
		// fmt.Printf("  targetInstancePtr: %V\n", targetInstancePtr)
		structField.Set(targetInstancePtr)
		// fmt.Printf("  structField: %+v\n", structField)
	}

	if dec.More() {
		return fmt.Errorf("Found more than one top level key")
	}

	// read closing bracket
	token, err = dec.Token()
	if err != nil {
		log.Fatal("Error closing top level object: ", err)
	}
	fmt.Printf("%T: %v\n", token, token)

	return nil
}

func main() {
	const burn = `{
	  "burn": {
	    "amount": [{
	      "amount": "435",
	      "denom": "uhoh"
	    }]
	  }
	}
`
	const send = `{
  "send": {
	"to_address": "the king",
	"amount": [{
	  "amount": "2233",
	  "denom": "pennies"
	}]
  }
}
`
	const both = `{
  "burn": {
	"amount": [{
	  "amount": "435",
	  "denom": "uhoh"
	}]
  },
  "send": {
	"to_address": "the king",
	"amount": [{
	  "amount": "2233",
	  "denom": "pennies"
	}]
  }
}
`

	var out BankMsg

	err := UnmarshallEnum([]byte(burn), &out)
	if err != nil {
		log.Fatal("Error decoding enum: ", err)
	}
	log.Println(out)

	err = UnmarshallEnum([]byte(send), &out)
	if err != nil {
		log.Fatal("Error decoding enum: ", err)
	}
	log.Println(out)

	err = UnmarshallEnum([]byte(both), &out)
	if err != nil {
		log.Fatal("Error decoding enum: ", err)
	}

	fmt.Printf("Done. Out is\n")
	fmt.Printf("    %+v\n", out)
	fmt.Printf("  = %V\n", out)
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants