diff --git a/rbdeal/deal_repair.go b/rbdeal/deal_repair.go index 804c2cb..f9541da 100644 --- a/rbdeal/deal_repair.go +++ b/rbdeal/deal_repair.go @@ -259,8 +259,13 @@ func (r *ribs) fetchGroupHttp(ctx context.Context, workerID int, group ribs2.Gro } }() + var repairTxIdleTimeout = 20 * time.Second + + rc := r.repairFetchCounters.Get(group) + rw := ributil.NewRateEnforcingReader(resp.Body, rc, repairTxIdleTimeout) + cc := new(ributil.DataCidWriter) - commdReader := io.TeeReader(resp.Body, cc) + commdReader := io.TeeReader(rw, cc) _, err = io.Copy(f, commdReader) done() diff --git a/rbdeal/ribs.go b/rbdeal/ribs.go index 29c96f5..d8a5ab5 100644 --- a/rbdeal/ribs.go +++ b/rbdeal/ribs.go @@ -108,10 +108,6 @@ type ribs struct { rateCounters *ributil.RateCounters[peer.ID] - repairDir string - repairStats map[int]*iface.RepairJob // workerid -> repair job - repairStatsLk sync.Mutex - /* car upload offload (S3) */ s3 *s3.S3 @@ -137,6 +133,13 @@ type ribs struct { /* retrieval checker */ rckToDo, rckStarted, rckSuccess, rckFail, rckSuccessAll, rckFailAll atomic.Int64 + + /* repair */ + repairDir string + repairStats map[int]*iface.RepairJob // workerid -> repair job + repairStatsLk sync.Mutex + + repairFetchCounters *ributil.RateCounters[iface.GroupKey] } func (r *ribs) Wallet() iface.Wallet { @@ -197,6 +200,8 @@ func Open(root string, opts ...OpenOption) (iface.RIBS, error) { marketWatchClosed: make(chan struct{}), moreDealsLocks: map[iface.GroupKey]struct{}{}, + + repairFetchCounters: ributil.NewRateCounters[iface.GroupKey](ributil.MinAvgGlobalLogPeerRate(float64(minTransferMbps), float64(linkSpeedMbps))), } rp, err := newRetrievalProvider(context.TODO(), r) diff --git a/ributil/minratewriter.go b/ributil/minratewriter.go index 06f45f1..e9978a2 100644 --- a/ributil/minratewriter.go +++ b/ributil/minratewriter.go @@ -167,20 +167,16 @@ func (rew *RateEnforcingWriter) Write(p []byte) (int, error) { rew.writeError = xerrors.Errorf("write rate over past %s is too slow: %w", rew.windowDuration, checkErr) return 0, rew.writeError } - - // Set write deadline - if w, ok := rew.w.(interface{ SetWriteDeadline(time.Time) error }); ok { - _ = w.SetWriteDeadline(now.Add(rew.windowDuration * 2)) - } } else if rew.lastSpeedCheck.IsZero() { // Set last speed check time and transferred bytes snapshot rew.lastSpeedCheck = now rew.bytesTransferredSnap = rew.rc.transferred.Load() - // Set write deadline - if w, ok := rew.w.(interface{ SetWriteDeadline(time.Time) error }); ok { - _ = w.SetWriteDeadline(now.Add(rew.windowDuration * 2)) - } + } + + // Set write deadline + if w, ok := rew.w.(interface{ SetWriteDeadline(time.Time) error }); ok { + _ = w.SetWriteDeadline(now.Add(rew.windowDuration * 2)) } n, err := rew.w.Write(p) @@ -197,3 +193,75 @@ func (rew *RateEnforcingWriter) Done() { rew.rc.Release() } } + +type RateEnforcingReader struct { + r io.Reader + + readError error + + rc *RateCounter + + bytesTransferredSnap int64 + lastSpeedCheck time.Time + windowDuration time.Duration +} + +func NewRateEnforcingReader(r io.Reader, rc *RateCounter, windowDuration time.Duration) *RateEnforcingReader { + return &RateEnforcingReader{ + r: r, + rc: rc, + windowDuration: windowDuration, + } +} + +func (rer *RateEnforcingReader) Read(p []byte) (int, error) { + if rer.readError != nil { + return 0, rer.readError + } + + now := time.Now() + + if !rer.lastSpeedCheck.IsZero() && now.Sub(rer.lastSpeedCheck) >= rer.windowDuration { + elapsedTime := now.Sub(rer.lastSpeedCheck) + + checkErr := rer.rc.Check(func() error { + ctrTransferred := rer.rc.transferred.Load() + transferredInWindow := ctrTransferred - rer.bytesTransferredSnap + + rer.bytesTransferredSnap = ctrTransferred + rer.lastSpeedCheck = now + + transferSpeedMbps := float64(transferredInWindow*8) / 1e6 / elapsedTime.Seconds() + + return rer.rc.rateFunc(transferSpeedMbps, rer.rc.transfers.Load(), rer.rc.globalTransfers.Load()) + }) + + if checkErr != nil { + rer.readError = xerrors.Errorf("read rate over past %s is too slow: %w", rer.windowDuration, checkErr) + return 0, rer.readError + } + } else if rer.lastSpeedCheck.IsZero() { + // Initialize last speed check time and transferred bytes snapshot + rer.lastSpeedCheck = now + rer.bytesTransferredSnap = rer.rc.transferred.Load() + } + + // Set read deadline + if w, ok := rer.r.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = w.SetReadDeadline(now.Add(rer.windowDuration * 2)) + } + + n, err := rer.r.Read(p) + rer.rc.transferred.Add(int64(n)) + return n, err +} + +func (rer *RateEnforcingReader) ReadError() error { + return rer.readError +} + +func (rer *RateEnforcingReader) Done() { + if rer.readError == nil { + rer.rc.Release() + } +} diff --git a/ributil/minratewriter_test.go b/ributil/minratewriter_test.go index d7647b7..5d6d90a 100644 --- a/ributil/minratewriter_test.go +++ b/ributil/minratewriter_test.go @@ -80,3 +80,82 @@ func (d *deadlineWriter) SetWriteDeadline(t time.Time) error { d.writeDeadline = t return nil } + +func TestRateEnforcingReader(t *testing.T) { + t.Run("should read without error when rate is above minimum", func(t *testing.T) { + data := make([]byte, 1024) + buf := bytes.NewBuffer(data) + + rc := NewRateCounters[int](MinAvgGlobalLogPeerRate(1024, 1000)).Get(0) + rer := NewRateEnforcingReader(buf, rc, 50*time.Millisecond) + defer rer.Done() + + readData := make([]byte, 1024) + time.Sleep(50 * time.Millisecond) + n, err := rer.Read(readData) + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if n != len(data) { + t.Fatalf("expected to read %d bytes, read %d", len(data), n) + } + }) + + t.Run("should read with error when rate is below minimum", func(t *testing.T) { + data := make([]byte, 1024) + buf := bytes.NewBuffer(data) + + rc := NewRateCounters[int](MinAvgGlobalLogPeerRate(1024, 1000)).Get(0) + rer := NewRateEnforcingReader(buf, rc, 50*time.Millisecond) + defer rer.Done() + + readData := make([]byte, 1024) + _, err := rer.Read(readData) + if err != nil { + t.Fatal(err) + } + + time.Sleep(60 * time.Millisecond) // Increase the sleep duration to make sure the rate is below the minimum + n, err := rer.Read(readData) + t.Log(err) + if !errors.Is(err, rer.readError) { + t.Fatalf("expected error, got: %v", err) + } + if n != 0 { + t.Fatalf("expected to read 0 bytes, read %d", n) + } + }) + + t.Run("should support SetReadDeadline on the underlying reader", func(t *testing.T) { + var buf deadlineReader + buf.buf = bytes.NewBuffer(make([]byte, 2000)) + + rc := NewRateCounters[int](MinAvgGlobalLogPeerRate(1024, 1000)).Get(0) + rer := NewRateEnforcingReader(&buf, rc, 50*time.Millisecond) + defer rer.Done() + + data := make([]byte, 1024) + _, err := rer.Read(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if buf.readDeadline.IsZero() { + t.Fatal("expected read deadline to be set") + } + }) +} + +type deadlineReader struct { + buf *bytes.Buffer + readDeadline time.Time +} + +func (d *deadlineReader) Read(p []byte) (n int, err error) { + return d.buf.Read(p) +} + +func (d *deadlineReader) SetReadDeadline(t time.Time) error { + d.readDeadline = t + return nil +}