Skip to content

Commit

Permalink
fix(ipgraph): strictly check the input length (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
kingster-will authored Dec 11, 2024
1 parent 3c2478a commit 9cd6d0e
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions core/vm/ipgraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ func (c *ipGraph) addParentIp(input []byte, evm *EVM, ipGraphAddress common.Addr
parentCount := new(big.Int).SetBytes(getData(input, 64, 32))
log.Info("addParentIp", "parentCount", parentCount)

if len(input) < int(96+parentCount.Uint64()*32) {
return nil, fmt.Errorf("input too short for parent IPs")
if len(input) != int(96+parentCount.Uint64()*32) {
return nil, fmt.Errorf("input length does not match parent count")
}

for i := 0; i < int(parentCount.Uint64()); i++ {
Expand All @@ -184,7 +184,7 @@ func (c *ipGraph) addParentIp(input []byte, evm *EVM, ipGraphAddress common.Addr
}

func (c *ipGraph) hasParentIp(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
if len(input) < 64 {
if len(input) != 64 {
return nil, fmt.Errorf("input too short for hasParentIp")
}
ipId := common.BytesToAddress(input[0:32])
Expand All @@ -209,8 +209,8 @@ func (c *ipGraph) hasParentIp(input []byte, evm *EVM, ipGraphAddress common.Addr

func (c *ipGraph) getParentIps(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
log.Info("getParentIps", "input", input)
if len(input) < 32 {
return nil, fmt.Errorf("input too short for getParentIps")
if len(input) != 32 {
return nil, fmt.Errorf("inputs too short for getParentIps")
}
ipId := common.BytesToAddress(input[0:32])

Expand All @@ -233,7 +233,7 @@ func (c *ipGraph) getParentIps(input []byte, evm *EVM, ipGraphAddress common.Add

func (c *ipGraph) getParentIpsCount(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
log.Info("getParentIpsCount", "input", input)
if len(input) < 32 {
if len(input) != 32 {
return nil, fmt.Errorf("input too short for getParentIpsCount")
}
ipId := common.BytesToAddress(input[0:32])
Expand All @@ -247,7 +247,7 @@ func (c *ipGraph) getParentIpsCount(input []byte, evm *EVM, ipGraphAddress commo

func (c *ipGraph) getAncestorIps(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
log.Info("getAncestorIps", "input", input)
if len(input) < 32 {
if len(input) != 32 {
return nil, fmt.Errorf("input too short for getAncestorIps")
}
ipId := common.BytesToAddress(input[0:32])
Expand All @@ -269,7 +269,7 @@ func (c *ipGraph) getAncestorIps(input []byte, evm *EVM, ipGraphAddress common.A

func (c *ipGraph) getAncestorIpsCount(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
log.Info("getAncestorIpsCount", "input", input)
if len(input) < 32 {
if len(input) != 32 {
return nil, fmt.Errorf("input too short for getAncestorIpsCount")
}
ipId := common.BytesToAddress(input[0:32])
Expand All @@ -281,7 +281,7 @@ func (c *ipGraph) getAncestorIpsCount(input []byte, evm *EVM, ipGraphAddress com
}

func (c *ipGraph) hasAncestorIp(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
if len(input) < 64 {
if len(input) != 64 {
return nil, fmt.Errorf("input too short for hasAncestorIp")
}
ipId := common.BytesToAddress(input[0:32])
Expand Down Expand Up @@ -338,7 +338,7 @@ func (c *ipGraph) setRoyalty(input []byte, evm *EVM, ipGraphAddress common.Addre
return nil, fmt.Errorf("setRoyalty can only be called with CALL, not %v", evm.currentPrecompileCallType)
}

if len(input) < 96 {
if len(input) != 128 {
return nil, fmt.Errorf("input too short for setRoyalty")
}
ipId := common.BytesToAddress(input[0:32])
Expand All @@ -360,7 +360,7 @@ func (c *ipGraph) setRoyalty(input []byte, evm *EVM, ipGraphAddress common.Addre

func (c *ipGraph) getRoyalty(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
log.Info("getRoyalty", "ipGraphAddress", ipGraphAddress, "input", input)
if len(input) < 64 {
if len(input) != 96 {
return nil, fmt.Errorf("input too short for getRoyalty")
}
ipId := common.BytesToAddress(input[0:32])
Expand Down Expand Up @@ -519,7 +519,7 @@ func (c *ipGraph) topologicalSort(ipId, ancestorIpId common.Address, evm *EVM, i
func (c *ipGraph) getRoyaltyStack(input []byte, evm *EVM, ipGraphAddress common.Address) ([]byte, error) {
log.Info("getRoyaltyStack", "ipGraphAddress", ipGraphAddress, "input", input)
totalRoyalty := big.NewInt(0)
if len(input) < 32 {
if len(input) != 64 {
return nil, fmt.Errorf("input too short for getRoyaltyStack")
}
ipId := common.BytesToAddress(input[0:32])
Expand Down

0 comments on commit 9cd6d0e

Please sign in to comment.