diff --git a/pkg/payouts/audit.go b/pkg/payouts/audit.go index 71d3623..819e3f4 100644 --- a/pkg/payouts/audit.go +++ b/pkg/payouts/audit.go @@ -1,9 +1,7 @@ package payouts import ( - "bytes" "context" - stdcsv "encoding/csv" "math/big" "os" "time" @@ -13,6 +11,7 @@ import ( "storj.io/crypto-batch-payment/pkg/eth" "storj.io/crypto-batch-payment/pkg/payer" "storj.io/crypto-batch-payment/pkg/pipelinedb" + "storj.io/crypto-batch-payment/pkg/receipts" "storj.io/crypto-batch-payment/pkg/zksync" "storj.io/crypto-batch-payment/pkg/zksyncera" @@ -80,12 +79,6 @@ func Audit(ctx context.Context, dir string, csvPath string, payerType payer.Type } csvPayouts := FromCSV(rows) - receiptsBuf := new(bytes.Buffer) - receiptsCSV := stdcsv.NewWriter(receiptsBuf) - if err := receiptsCSV.Write([]string{"wallet", "amount", "txhash", "mechanism"}); err != nil { - return nil, err - } - // Load the database sink.ReportStatusf("Loading database...") dbDir, err := dbDirFromCSVPath(dir, csvPath) @@ -207,6 +200,8 @@ func Audit(ctx context.Context, dir string, csvPath string, payerType payer.Type } } + var receipts receipts.Buffer + // For each payout, ensure it belongs to a payout group with a confirmed // transaction. Reconfirm the transaction against the blockchain. sink.ReportStatusf("Checking payouts status...") @@ -215,9 +210,7 @@ func Audit(ctx context.Context, dir string, csvPath string, payerType payer.Type for _, dbPayout := range dbPayouts { if txHash, ok := payoutGroupStatus[dbPayout.PayoutGroupID]; ok { if txHash != "" { - if err := receiptsCSV.Write([]string{dbPayout.Payee.String(), dbPayout.USD.String(), txHash, payerType.String()}); err != nil { - return nil, errs.Wrap(err) - } + receipts.Emit(dbPayout.Payee, dbPayout.USD, txHash, payerType) } continue } @@ -295,9 +288,7 @@ func Audit(ctx context.Context, dir string, csvPath string, payerType payer.Type if confirmedCount > 0 { txHash := confirmed[0].Hash payoutGroupStatus[dbPayout.PayoutGroupID] = txHash - if err := receiptsCSV.Write([]string{dbPayout.Payee.String(), dbPayout.USD.String(), txHash, payerType.String()}); err != nil { - return nil, errs.Wrap(err) - } + receipts.Emit(dbPayout.Payee, dbPayout.USD, txHash, payerType) payoutsConfirmed += numPayouts } @@ -319,9 +310,8 @@ func Audit(ctx context.Context, dir string, csvPath string, payerType payer.Type switch { case receiptsOut == "": case payoutsConfirmed == stats.Total || receiptsForce: - receiptsCSV.Flush() sink.ReportStatusf("Writing receipts to %s...", receiptsOut) - if err := os.WriteFile(receiptsOut, receiptsBuf.Bytes(), 0644); err != nil { + if err := os.WriteFile(receiptsOut, receipts.Finalize(), 0644); err != nil { return nil, errs.Wrap(err) } default: diff --git a/pkg/receipts/buffer.go b/pkg/receipts/buffer.go new file mode 100644 index 0000000..f43c560 --- /dev/null +++ b/pkg/receipts/buffer.go @@ -0,0 +1,36 @@ +package receipts + +import ( + "bytes" + "encoding/csv" + + "github.com/ethereum/go-ethereum/common" + "github.com/shopspring/decimal" + "storj.io/crypto-batch-payment/pkg/payer" +) + +type Buffer struct { + buf bytes.Buffer + csv *csv.Writer +} + +func (b *Buffer) Emit(wallet common.Address, amount decimal.Decimal, txHash string, mechanism payer.Type) { + b.init() + b.write(wallet.String(), amount.String(), txHash, mechanism.String()) +} + +func (b *Buffer) Finalize() []byte { + b.csv.Flush() + return b.buf.Bytes() +} + +func (b *Buffer) init() { + if b.csv == nil { + b.csv = csv.NewWriter(&b.buf) + b.write("wallet", "amount", "txhash", "mechanism") + } +} + +func (b *Buffer) write(c1, c2, c3, c4 string) { + _ = b.csv.Write([]string{c1, c2, c3, c4}) +} diff --git a/pkg/receipts/buffer_test.go b/pkg/receipts/buffer_test.go new file mode 100644 index 0000000..6739b71 --- /dev/null +++ b/pkg/receipts/buffer_test.go @@ -0,0 +1,29 @@ +package receipts_test + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + "storj.io/crypto-batch-payment/pkg/payer" + "storj.io/crypto-batch-payment/pkg/receipts" +) + +func TestBuffer(t *testing.T) { + address1 := common.BytesToAddress(bytes.Repeat([]byte{1}, common.AddressLength)) + address2 := common.BytesToAddress(bytes.Repeat([]byte{2}, common.AddressLength)) + address3 := common.BytesToAddress(bytes.Repeat([]byte{3}, common.AddressLength)) + + var b receipts.Buffer + b.Emit(address1, decimal.NewFromInt(1), "hash1", payer.Eth) + b.Emit(address2, decimal.NewFromInt(2), "hash2", payer.ZkSync) + b.Emit(address3, decimal.NewFromInt(3), "hash3", payer.ZkSyncEra) + receipts := b.Finalize() + require.Equal(t, `wallet,amount,txhash,mechanism +0x0101010101010101010101010101010101010101,1,hash1,eth +0x0202020202020202020202020202020202020202,2,hash2,zksync +0x0303030303030303030303030303030303030303,3,hash3,zksync-era +`, string(receipts)) +}