diff --git a/progress/example_test.go b/progress/example_test.go new file mode 100644 index 00000000..2f7754dd --- /dev/null +++ b/progress/example_test.go @@ -0,0 +1,83 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress_test + +import ( + "crypto/rand" + "fmt" + "io" + + "oras.land/oras-go/v2/progress" +) + +// ExampleTrackReader demonstrates how to track the transmission progress of a +// reader. +func ExampleTrackReader() { + // Set up a progress tracker. + total := int64(11) + tracker := progress.TrackerFunc(func(status progress.Status, err error) error { + if err != nil { + fmt.Printf("Error: %v\n", err) + return nil + } + switch status.State { + case progress.StateInitialized: + fmt.Println("Start reading content") + case progress.StateTransmitting: + fmt.Printf("Progress: %d/%d bytes\n", status.Offset, total) + case progress.StateTransmitted: + fmt.Println("Finish reading content") + default: + // Ignore other states. + } + return nil + }) + // Close takes no effect for TrackerFunc but should be called for general + // Tracker implementations. + defer tracker.Close() + + // Wrap a reader of a random content generator with the progress tracker. + r := io.LimitReader(rand.Reader, total) + rc := progress.TrackReader(tracker, r) + + // Start tracking the transmission. + if err := progress.Start(tracker); err != nil { + panic(err) + } + + // Read from the random content generator and discard the content, while + // tracking the progress. + // Note: io.Discard is wrapped with a io.MultiWriter for dropping + // the io.ReadFrom interface for demonstration purposes. + buf := make([]byte, 3) + w := io.MultiWriter(io.Discard) + if _, err := io.CopyBuffer(w, rc, buf); err != nil { + panic(err) + } + + // Finish tracking the transmission. + if err := progress.Done(tracker); err != nil { + panic(err) + } + + // Output: + // Start reading content + // Progress: 3/11 bytes + // Progress: 6/11 bytes + // Progress: 9/11 bytes + // Progress: 11/11 bytes + // Finish reading content +} diff --git a/progress/manager.go b/progress/manager.go new file mode 100644 index 00000000..439b90d1 --- /dev/null +++ b/progress/manager.go @@ -0,0 +1,48 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package progress tracks the status of descriptors being processed. +package progress + +import ( + "io" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +// Manager tracks the progress of multiple descriptors. +type Manager interface { + io.Closer + + // Track starts tracking the progress of a descriptor. + Track(desc ocispec.Descriptor) (Tracker, error) +} + +// ManagerFunc is an adapter to allow the use of ordinary functions as Managers. +// If f is a function with the appropriate signature, ManagerFunc(f) is a +// [Manager] that calls f. +type ManagerFunc func(desc ocispec.Descriptor, status Status, err error) error + +// Close closes the manager. +func (f ManagerFunc) Close() error { + return nil +} + +// Track starts tracking the progress of a descriptor. +func (f ManagerFunc) Track(desc ocispec.Descriptor) (Tracker, error) { + return TrackerFunc(func(status Status, err error) error { + return f(desc, status, err) + }), nil +} diff --git a/progress/status.go b/progress/status.go new file mode 100644 index 00000000..e6c4d1cb --- /dev/null +++ b/progress/status.go @@ -0,0 +1,40 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress + +// State represents the state of a descriptor. +type State int + +// Registered states. +const ( + StateUnknown State = iota // unknown state + StateInitialized // progress initialized + StateTransmitting // transmitting content + StateTransmitted // content transmitted + StateExists // content exists + StateSkipped // content skipped + StateMounted // content mounted +) + +// Status represents the status of a descriptor. +type Status struct { + // State represents the state of the descriptor. + State State + + // Offset represents the current offset of the descriptor. + // Offset is discarded if set to a negative value. + Offset int64 +} diff --git a/progress/tracker.go b/progress/tracker.go new file mode 100644 index 00000000..0431a767 --- /dev/null +++ b/progress/tracker.go @@ -0,0 +1,166 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress + +import "io" + +// Tracker updates the status of a descriptor. +type Tracker interface { + io.Closer + + // Update updates the status of the descriptor. + Update(status Status) error + + // Fail marks the descriptor as failed. + // Fail should return nil on successful failure marking. + Fail(err error) error +} + +// TrackerFunc is an adapter to allow the use of ordinary functions as Trackers. +// If f is a function with the appropriate signature, TrackerFunc(f) is a +// [Tracker] that calls f. +type TrackerFunc func(status Status, err error) error + +// Close closes the tracker. +func (f TrackerFunc) Close() error { + return nil +} + +// Update updates the status of the descriptor. +func (f TrackerFunc) Update(status Status) error { + return f(status, nil) +} + +// Fail marks the descriptor as failed. +func (f TrackerFunc) Fail(err error) error { + return f(Status{}, err) +} + +// Start starts tracking the transmission. +func Start(t Tracker) error { + return t.Update(Status{ + State: StateInitialized, + Offset: -1, + }) +} + +// Done marks the transmission as complete. +// Done should be called after the transmission is complete. +// Note: Reading all content from the reader does not imply the transmission is +// complete. +func Done(t Tracker) error { + return t.Update(Status{ + State: StateTransmitted, + Offset: -1, + }) +} + +// TrackReader bind a reader with a tracker. +func TrackReader(t Tracker, r io.Reader) io.Reader { + rt := readTracker{ + base: r, + tracker: t, + } + if _, ok := r.(io.WriterTo); ok { + return &readTrackerWriteTo{rt} + } + return &rt +} + +// readTracker tracks the transmission based on the read operation. +type readTracker struct { + base io.Reader + tracker Tracker + offset int64 +} + +// Read reads from the base reader and updates the status. +// On partial read, the tracker treats it as two reads: a successful read with +// status update and a failed read with failure report. +func (rt *readTracker) Read(p []byte) (int, error) { + n, err := rt.base.Read(p) + rt.offset += int64(n) + if n > 0 { + if updateErr := rt.tracker.Update(Status{ + State: StateTransmitting, + Offset: rt.offset, + }); updateErr != nil { + err = updateErr + } + } + if err != nil && err != io.EOF { + if failErr := rt.tracker.Fail(err); failErr != nil { + return n, failErr + } + } + return n, err +} + +// readTrackerWriteTo is readTracker with WriteTo support. +type readTrackerWriteTo struct { + readTracker +} + +// WriteTo writes to the base writer and updates the status. +// On partial write, the tracker treats it as two writes: a successful write +// with status update and a failed write with failure report. +func (rt *readTrackerWriteTo) WriteTo(w io.Writer) (int64, error) { + wt := &writeTracker{ + base: w, + tracker: rt.tracker, + offset: rt.offset, + } + n, err := rt.base.(io.WriterTo).WriteTo(wt) + rt.offset = wt.offset + if err != nil && wt.trackerErr == nil { + if failErr := rt.tracker.Fail(err); failErr != nil { + return n, failErr + } + } + return n, err +} + +// writeTracker tracks the transmission based on the write operation. +type writeTracker struct { + base io.Writer + tracker Tracker + offset int64 + trackerErr error +} + +// Write writes to the base writer and updates the status. +// On partial write, the tracker treats it as two writes: a successful write +// with status update and a failed write with failure report. +func (wt *writeTracker) Write(p []byte) (int, error) { + n, err := wt.base.Write(p) + wt.offset += int64(n) + if n > 0 { + if updateErr := wt.tracker.Update(Status{ + State: StateTransmitting, + Offset: wt.offset, + }); updateErr != nil { + wt.trackerErr = updateErr + err = updateErr + } + } + if err != nil { + if failErr := wt.tracker.Fail(err); failErr != nil { + wt.trackerErr = failErr + return n, failErr + } + } + return n, err +} diff --git a/progress/tracker_test.go b/progress/tracker_test.go new file mode 100644 index 00000000..f4190cc0 --- /dev/null +++ b/progress/tracker_test.go @@ -0,0 +1,414 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package progress + +import ( + "bytes" + "errors" + "io" + "testing" +) + +func TestTrackerFunc_Close(t *testing.T) { + var f TrackerFunc + if err := f.Close(); err != nil { + t.Errorf("TrackerFunc.Close() error = %v, wantErr false", err) + } +} + +func TestTrackerFunc_Update(t *testing.T) { + wantStatus := Status{ + State: StateTransmitted, + Offset: 42, + } + var wantErr error + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return wantErr + }) + + if err := tracker.Update(wantStatus); err != wantErr { + t.Errorf("TrackerFunc.Update() error = %v, want %v", err, wantErr) + } + + wantErr = errors.New("fail to track") + if err := tracker.Update(wantStatus); err != wantErr { + t.Errorf("TrackerFunc.Update() error = %v, want %v", err, wantErr) + } +} + +func TestTrackerFunc_Fail(t *testing.T) { + reportErr := errors.New("fail to process") + var wantStatus Status + var wantErr error + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != reportErr { + t.Errorf("TrackerFunc err = %v, want %v", err, reportErr) + } + return wantErr + }) + + if err := tracker.Fail(reportErr); err != wantErr { + t.Errorf("TrackerFunc.Fail() error = %v, want %v", err, wantErr) + } + + wantErr = errors.New("fail to track") + if err := tracker.Fail(reportErr); err != wantErr { + t.Errorf("TrackerFunc.Fail() error = %v, want %v", err, wantErr) + } +} + +func TestStart(t *testing.T) { + tests := []struct { + name string + t Tracker + wantErr bool + }{ + { + name: "successful report initialization", + t: TrackerFunc(func(status Status, err error) error { + if status.State != StateInitialized { + t.Errorf("expected state to be StateInitialized, got %v", status.State) + } + return nil + }), + }, + { + name: "fail to report initialization", + t: TrackerFunc(func(status Status, err error) error { + return errors.New("fail to track") + }), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Start(tt.t); (err != nil) != tt.wantErr { + t.Errorf("Start() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDone(t *testing.T) { + tests := []struct { + name string + t Tracker + wantErr bool + }{ + { + name: "successful report initialization", + t: TrackerFunc(func(status Status, err error) error { + if status.State != StateTransmitted { + t.Errorf("expected state to be StateTransmitted, got %v", status.State) + } + return nil + }), + }, + { + name: "fail to report initialization", + t: TrackerFunc(func(status Status, err error) error { + return errors.New("fail to track") + }), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Done(tt.t); (err != nil) != tt.wantErr { + t.Errorf("Done() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTrackReader(t *testing.T) { + const bufSize = 6 + content := []byte("hello world") + t.Run("track io.Reader", func(t *testing.T) { + var wantStatus Status + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return nil + }) + var reader io.Reader = bytes.NewReader(content) + reader = io.LimitReader(reader, int64(len(content))) // remove the io.WriterTo interface + gotReader := TrackReader(tracker, reader) + if _, ok := gotReader.(*readTracker); !ok { + t.Fatalf("TrackReader() = %v, want *readTracker", gotReader) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: bufSize, + } + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if n != bufSize { + t.Fatalf("TrackReader() n = %v, want %v", n, bufSize) + } + if want := content[:bufSize]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: int64(len(content)), + } + n, err = gotReader.Read(buf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if want := len(content) - bufSize; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + buf = buf[:n] + if want := content[bufSize:]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("track io.Reader + io.WriterTo", func(t *testing.T) { + var wantStatus Status + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return nil + }) + var reader io.Reader = bytes.NewReader(content) + gotReader := TrackReader(tracker, reader) + if _, ok := gotReader.(*readTrackerWriteTo); !ok { + t.Fatalf("TrackReader() = %v, want *readTrackerWriteTo", gotReader) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: bufSize, + } + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if n != bufSize { + t.Fatalf("TrackReader() n = %v, want %v", n, bufSize) + } + if want := content[:bufSize]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: int64(len(content)), + } + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if want := len(content) - bufSize; wn != int64(want) { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := content[bufSize:]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("empty io.Reader", func(t *testing.T) { + tracker := TrackerFunc(func(status Status, err error) error { + t.Errorf("TrackerFunc should not be called for empty read") + return nil + }) + gotReader := TrackReader(tracker, bytes.NewReader(nil)) + + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if want := io.EOF; err != want { + t.Fatalf("TrackReader() error = %v, want %v", err, want) + } + if want := 0; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != nil { + t.Fatalf("TrackReader() error = %v, want nil", err) + } + if want := int64(0); wn != want { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := []byte{}; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("report failure", func(t *testing.T) { + var wantStatus Status + wantErr := errors.New("fail to track") + trackerMockStage := 0 + tracker := TrackerFunc(func(status Status, err error) error { + defer func() { + trackerMockStage++ + }() + switch trackerMockStage { + case 0: + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != nil { + t.Errorf("TrackerFunc err = %v, want nil", err) + } + return wantErr + case 1: + var emptyStatus Status + if wantStatus := emptyStatus; status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != wantErr { + t.Errorf("TrackerFunc err = %v, want %v", err, wantErr) + } + return nil + default: + t.Errorf("TrackerFunc should not be called") + return nil + } + }) + gotReader := TrackReader(tracker, bytes.NewReader(content)) + + wantStatus = Status{ + State: StateTransmitting, + Offset: bufSize, + } + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if n != bufSize { + t.Fatalf("TrackReader() n = %v, want %v", n, bufSize) + } + if want := content[:bufSize]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + wantStatus = Status{ + State: StateTransmitting, + Offset: int64(len(content)), + } + trackerMockStage = 0 + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if want := len(content) - bufSize; wn != int64(want) { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := content[bufSize:]; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) + + t.Run("process failure", func(t *testing.T) { + reportErr := io.ErrClosedPipe + var wantStatus Status + var wantErr error + tracker := TrackerFunc(func(status Status, err error) error { + if status != wantStatus { + t.Errorf("TrackerFunc status = %v, want %v", status, wantStatus) + } + if err != reportErr { + t.Errorf("TrackerFunc err = %v, want %v", err, reportErr) + } + return wantErr + }) + pipeReader, pipeWriter := io.Pipe() + pipeReader.Close() + pipeWriter.Close() + gotReader := TrackReader(tracker, pipeReader) + + buf := make([]byte, bufSize) + n, err := gotReader.Read(buf) + if err != reportErr { + t.Fatalf("TrackReader() error = %v, want %v", err, reportErr) + } + if want := 0; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + + wantErr = errors.New("fail to track") + n, err = gotReader.Read(buf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if want := 0; n != want { + t.Fatalf("TrackReader() n = %v, want %v", n, want) + } + + gotReader = TrackReader(tracker, io.MultiReader(pipeReader)) // wrap io.WriteTo + wantErr = nil + writeBuf := bytes.NewBuffer(nil) + wn, err := gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != reportErr { + t.Fatalf("TrackReader() error = %v, want %v", err, reportErr) + } + if want := int64(0); wn != want { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := []byte{}; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + + gotReader = TrackReader(tracker, io.MultiReader(pipeReader)) // wrap io.WriteTo + wantErr = errors.New("fail to track") + wn, err = gotReader.(io.WriterTo).WriteTo(writeBuf) + if err != wantErr { + t.Fatalf("TrackReader() error = %v, want %v", err, wantErr) + } + if want := int64(0); wn != want { + t.Fatalf("TrackReader() n = %v, want %v", wn, want) + } + buf = writeBuf.Bytes() + if want := []byte{}; !bytes.Equal(buf, want) { + t.Fatalf("TrackReader() buf = %v, want %v", buf, want) + } + }) +}