diff --git a/client/cmd/split.go b/client/cmd/split.go index 3def31f2..98446c6c 100644 --- a/client/cmd/split.go +++ b/client/cmd/split.go @@ -1,6 +1,7 @@ package cmd import ( + "bytes" "context" "encoding/hex" "fmt" @@ -8,24 +9,76 @@ import ( "os" "strings" + "github.com/iden3/go-iden3-crypto/poseidon" "github.com/shopspring/decimal" "github.com/spf13/cobra" "source.quilibrium.com/quilibrium/monorepo/node/protobufs" ) +var parts int +var partAmount string var splitCmd = &cobra.Command{ Use: "split", Short: "Splits a coin into multiple coins", Long: `Splits a coin into multiple coins: - + split ... - + split <--parts PARTS> [--part-amount AMOUNT] + OfCoin - the address of the coin to split Amounts - the sets of amounts to split + + Example - Split a coin into the specified amounts: + $ qclient token coins + 1.000000000000 QUIL (Coin 0x1234) + $ qclient token split 0x1234 0.5 0.25 0.25 + $ qclient token coins + 0.250000000000 QUIL (Coin 0x1111) + 0.250000000000 QUIL (Coin 0x2222) + 0.500000000000 QUIL (Coin 0x3333) + + Example - Split a coin into three parts: + $ qclient token coins + 1.000000000000 QUIL (Coin 0x1234) + $ qclient token split 0x1234 --parts 3 + $ qclient token coins + 0.000000000250 QUIL (Coin 0x1111) + 0.333333333250 QUIL (Coin 0x2222) + 0.333333333250 QUIL (Coin 0x3333) + 0.333333333250 QUIL (Coin 0x4444) + + **Note:** Coin 0x1111 is the remainder. + + Example - Split a coin into two parts using the specified amounts: + $ qclient token coins + 1.000000000000 QUIL (Coin 0x1234) + $ qclient token split 0x1234 --parts 2 --part-amount 0.35 + $ qclient token coins + 0.300000000000 QUIL (Coin 0x1111) + 0.350000000000 QUIL (Coin 0x2222) + 0.350000000000 QUIL (Coin 0x3333) + + **Note:** Coin 0x1111 is the remainder. `, Run: func(cmd *cobra.Command, args []string) { - if len(args) < 3 { - fmt.Println("invalid command") + if len(args) < 3 && parts == 1 { + fmt.Println("did you forget to specify and ?") + os.Exit(1) + } + if len(args) < 1 && parts > 1 { + fmt.Println("did you forget to specify ?") + os.Exit(1) + } + if len(args) > 1 && parts > 1 { + fmt.Println("-p/--parts can't be combined with ") + os.Exit(1) + } + if len(args) > 1 && partAmount != "" { + fmt.Println("-a/--part-amount can't be combined with ") + os.Exit(1) + } + if parts > 100 { + fmt.Println("too many parts, maximum is 100") os.Exit(1) } @@ -40,18 +93,32 @@ var splitCmd = &cobra.Command{ } payload = append(payload, coinaddr...) - conversionFactor, _ := new(big.Int).SetString("1DCD65000", 16) + // Get the amount of the coin to be split + totalAmount := getCoinAmount(coinaddr) + amounts := [][]byte{} - for _, amt := range args[1:] { - amount, err := decimal.NewFromString(amt) + + // Split the coin into the user specified amounts + if parts == 1 { + amounts, payload, err = Split(args[1:], amounts, payload, totalAmount) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + } + + // Split the coin into parts + if parts > 1 && partAmount == "" { + amounts, payload = SplitIntoParts(amounts, payload, totalAmount, parts) + } + + // Split the coin into parts of the user specified amount + if parts > 1 && partAmount != "" { + amounts, payload, err = SplitIntoPartsAmount(amounts, payload, totalAmount, parts, partAmount) if err != nil { - fmt.Println("invalid amount") + fmt.Println(err) os.Exit(1) } - amount = amount.Mul(decimal.NewFromBigInt(conversionFactor, 0)) - amountBytes := amount.BigInt().FillBytes(make([]byte, 32)) - amounts = append(amounts, amountBytes) - payload = append(payload, amountBytes...) } conn, err := GetGRPCClient() @@ -100,5 +167,151 @@ var splitCmd = &cobra.Command{ } func init() { + splitCmd.Flags().IntVarP(&parts, "parts", "p", 1, "number of parts to split the coin into") + splitCmd.Flags().StringVarP(&partAmount, "part-amount", "a", "", "amount of each part") tokenCmd.AddCommand(splitCmd) } + +func Split(args []string, amounts [][]byte, payload []byte, totalAmount *big.Int) ([][]byte, []byte, error) { + conversionFactor, _ := new(big.Int).SetString("1DCD65000", 16) + inputAmount := new(big.Int) + for _, amt := range args { + amount, err := decimal.NewFromString(amt) + if err != nil { + return nil, nil, fmt.Errorf("invalid amount, must be a decimal number like 0.02 or 2") + } + amount = amount.Mul(decimal.NewFromBigInt(conversionFactor, 0)) + inputAmount = inputAmount.Add(inputAmount, amount.BigInt()) + amountBytes := amount.BigInt().FillBytes(make([]byte, 32)) + amounts = append(amounts, amountBytes) + payload = append(payload, amountBytes...) + } + + // Check if the user specified amounts sum to the total amount of the coin + if inputAmount.Cmp(totalAmount) != 0 { + return nil, nil, fmt.Errorf("the specified amounts must sum to the total amount of the coin") + } + return amounts, payload, nil +} + +func SplitIntoParts(amounts [][]byte, payload []byte, totalAmount *big.Int, parts int) ([][]byte, []byte) { + amount := new(big.Int).Div(totalAmount, big.NewInt(int64(parts))) + amountBytes := amount.FillBytes(make([]byte, 32)) + for i := int64(0); i < int64(parts); i++ { + amounts = append(amounts, amountBytes) + payload = append(payload, amountBytes...) + } + + // If there is a remainder, we need to add it as a separate amount + // because the amounts must sum to the original coin amount. + remainder := new(big.Int).Mod(totalAmount, big.NewInt(int64(parts))) + if remainder.Cmp(big.NewInt(0)) != 0 { + remainderBytes := remainder.FillBytes(make([]byte, 32)) + amounts = append(amounts, remainderBytes) + payload = append(payload, remainderBytes...) + } + return amounts, payload +} + +func SplitIntoPartsAmount(amounts [][]byte, payload []byte, totalAmount *big.Int, parts int, partAmount string) ([][]byte, []byte, error) { + conversionFactor, _ := new(big.Int).SetString("1DCD65000", 16) + amount, err := decimal.NewFromString(partAmount) + if err != nil { + return nil, nil, fmt.Errorf("invalid amount, must be a decimal number like 0.02 or 2") + } + amount = amount.Mul(decimal.NewFromBigInt(conversionFactor, 0)) + inputAmount := new(big.Int).Mul(amount.BigInt(), big.NewInt(int64(parts))) + amountBytes := amount.BigInt().FillBytes(make([]byte, 32)) + for i := int64(0); i < int64(parts); i++ { + amounts = append(amounts, amountBytes) + payload = append(payload, amountBytes...) + } + + // If there is a remainder, we need to add it as a separate amount + // because the amounts must sum to the original coin amount. + remainder := new(big.Int).Sub(totalAmount, inputAmount) + if remainder.Cmp(big.NewInt(0)) != 0 { + remainderBytes := remainder.FillBytes(make([]byte, 32)) + amounts = append(amounts, remainderBytes) + payload = append(payload, remainderBytes...) + } + + // Check if the user specified amounts sum to the total amount of the coin + if new(big.Int).Add(inputAmount, new(big.Int).Abs(remainder)).Cmp(totalAmount) != 0 { + return nil, nil, fmt.Errorf("the specified amounts must sum to the total amount of the coin") + } + return amounts, payload, nil +} + +func getCoinAmount(coinaddr []byte) *big.Int { + conn, err := GetGRPCClient() + if err != nil { + panic(err) + } + defer conn.Close() + + client := protobufs.NewNodeServiceClient(conn) + peerId := GetPeerIDFromConfig(NodeConfig) + privKey, err := GetPrivKeyFromConfig(NodeConfig) + if err != nil { + panic(err) + } + + pub, err := privKey.GetPublic().Raw() + if err != nil { + panic(err) + } + + addr, err := poseidon.HashBytes([]byte(peerId)) + if err != nil { + panic(err) + } + + addrBytes := addr.FillBytes(make([]byte, 32)) + resp, err := client.GetTokensByAccount( + context.Background(), + &protobufs.GetTokensByAccountRequest{ + Address: addrBytes, + }, + ) + if err != nil { + panic(err) + } + + if len(resp.Coins) != len(resp.FrameNumbers) { + panic("invalid response from RPC") + } + + altAddr, err := poseidon.HashBytes([]byte(pub)) + if err != nil { + panic(err) + } + + altAddrBytes := altAddr.FillBytes(make([]byte, 32)) + resp2, err := client.GetTokensByAccount( + context.Background(), + &protobufs.GetTokensByAccountRequest{ + Address: altAddrBytes, + }, + ) + if err != nil { + panic(err) + } + + if len(resp.Coins) != len(resp.FrameNumbers) { + panic("invalid response from RPC") + } + + var amount *big.Int + for i, coin := range resp.Coins { + if bytes.Equal(resp.Addresses[i], coinaddr) { + amount = new(big.Int).SetBytes(coin.Amount) + } + } + for i, coin := range resp2.Coins { + if bytes.Equal(resp.Addresses[i], coinaddr) { + amount = new(big.Int).SetBytes(coin.Amount) + } + } + return amount +} diff --git a/client/cmd/split_test.go b/client/cmd/split_test.go new file mode 100644 index 00000000..ff03241d --- /dev/null +++ b/client/cmd/split_test.go @@ -0,0 +1,232 @@ +package cmd_test + +import ( + "encoding/hex" + "math/big" + "reflect" + "strings" + "testing" + + "github.com/shopspring/decimal" + "source.quilibrium.com/quilibrium/monorepo/client/cmd" +) + +func TestSplit(t *testing.T) { + tests := []struct { + name string + args []string + totalAmount string + amounts [][]byte + payload []byte + expectError bool + }{ + { + name: "Valid split - specified amounts", + args: []string{"0x1234", "0.5", "0.25", "0.25"}, + totalAmount: "1.0", + amounts: [][]byte{ + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 238, 107, 40, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 119, 53, 148, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 119, 53, 148, 0}, + }, + payload: []byte{ + 115, 112, 108, 105, 116, + 18, 52, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 238, 107, 40, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 119, 53, 148, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 119, 53, 148, 0, + }, + expectError: false, + }, + { + name: "Invalid split - amounts do not sum to the total amount of the coin", + args: []string{"0x1234", "0.5", "0.25"}, + totalAmount: "1.0", + amounts: [][]byte{}, + payload: []byte{}, + expectError: true, + }, + { + name: "Invalid split - amounts exceed total amount of the coin", + args: []string{"0x1234", "0.5", "1"}, + totalAmount: "1.0", + amounts: [][]byte{}, + payload: []byte{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + payload := []byte("split") + coinaddrHex, _ := strings.CutPrefix(tc.args[0], "0x") + coinaddr, err := hex.DecodeString(coinaddrHex) + if err != nil { + panic(err) + } + payload = append(payload, coinaddr...) + + conversionFactor, _ := new(big.Int).SetString("1DCD65000", 16) + totalAmount, _ := decimal.NewFromString(tc.totalAmount) + totalAmount = totalAmount.Mul(decimal.NewFromBigInt(conversionFactor, 0)) + + amounts := [][]byte{} + + if tc.expectError { + _, _, err = cmd.Split(tc.args[1:], amounts, payload, totalAmount.BigInt()) + if err == nil { + t.Errorf("want error for invalid split, got nil") + } + } else { + amounts, payload, err = cmd.Split(tc.args[1:], amounts, payload, totalAmount.BigInt()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(tc.amounts, amounts) { + t.Errorf("expected amounts: %v, got: %v", tc.amounts, amounts) + } + if !reflect.DeepEqual(tc.payload, payload) { + t.Errorf("expected payloads: %v, got: %v", tc.payload, payload) + } + } + }) + } +} + +func TestSplitParts(t *testing.T) { + tests := []struct { + name string + args []string + parts int + totalAmount string + amounts [][]byte + payload []byte + }{ + { + name: "Valid split - into parts", + args: []string{"0x1234"}, + parts: 3, + totalAmount: "1.0", + amounts: [][]byte{ + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 242, 26, 170}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 242, 26, 170}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 242, 26, 170}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}, + }, + payload: []byte{ + 115, 112, 108, 105, 116, + 18, 52, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 242, 26, 170, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 242, 26, 170, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 242, 26, 170, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + payload := []byte("split") + coinaddrHex, _ := strings.CutPrefix(tc.args[0], "0x") + coinaddr, err := hex.DecodeString(coinaddrHex) + if err != nil { + panic(err) + } + payload = append(payload, coinaddr...) + + conversionFactor, _ := new(big.Int).SetString("1DCD65000", 16) + totalAmount, _ := decimal.NewFromString(tc.totalAmount) + totalAmount = totalAmount.Mul(decimal.NewFromBigInt(conversionFactor, 0)) + + amounts := [][]byte{} + + amounts, payload = cmd.SplitIntoParts(amounts, payload, totalAmount.BigInt(), tc.parts) + if !reflect.DeepEqual(tc.amounts, amounts) { + t.Errorf("expected amounts: %v, got: %v", tc.amounts, amounts) + } + if !reflect.DeepEqual(tc.payload, payload) { + t.Errorf("expected payloads: %v, got: %v", tc.payload, payload) + } + }) + } +} + +func TestSplitIntoPartsAmount(t *testing.T) { + tests := []struct { + name string + args []string + parts int + partAmount string + totalAmount string + amounts [][]byte + payload []byte + expectError bool + }{ + { + name: "Valid split - into parts of specified amount", + args: []string{"0x1234"}, + parts: 2, + partAmount: "0.35", + totalAmount: "1.0", + amounts: [][]byte{ + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 166, 228, 156, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 166, 228, 156, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 143, 13, 24, 0}, + }, + payload: []byte{ + 115, 112, 108, 105, 116, + 18, 52, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 166, 228, 156, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 166, 228, 156, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 143, 13, 24, 0, + }, + expectError: false, + }, + { + name: "Invalid split - amounts exceed total amount of the coin", + args: []string{"0x1234"}, + parts: 3, + partAmount: "0.5", + totalAmount: "1.0", + amounts: [][]byte{}, + payload: []byte{}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + payload := []byte("split") + coinaddrHex, _ := strings.CutPrefix(tc.args[0], "0x") + coinaddr, err := hex.DecodeString(coinaddrHex) + if err != nil { + panic(err) + } + payload = append(payload, coinaddr...) + + conversionFactor, _ := new(big.Int).SetString("1DCD65000", 16) + totalAmount, _ := decimal.NewFromString(tc.totalAmount) + totalAmount = totalAmount.Mul(decimal.NewFromBigInt(conversionFactor, 0)) + + amounts := [][]byte{} + + if tc.expectError { + _, _, err = cmd.SplitIntoPartsAmount(amounts, payload, totalAmount.BigInt(), tc.parts, tc.partAmount) + if err == nil { + t.Errorf("want error for invalid split, got nil") + } + } else { + amounts, payload, err = cmd.SplitIntoPartsAmount(amounts, payload, totalAmount.BigInt(), tc.parts, tc.partAmount) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(tc.amounts, amounts) { + t.Errorf("expected amounts: %v, got: %v", tc.amounts, amounts) + } + if !reflect.DeepEqual(tc.payload, payload) { + t.Errorf("expected payloads: %v, got: %v", tc.payload, payload) + } + } + }) + } +}