From 474ed09607c6b122d88bddfb9cbb94f625502649 Mon Sep 17 00:00:00 2001 From: Zherphy <1123678689@qq.com> Date: Wed, 11 Dec 2024 17:17:25 +0800 Subject: [PATCH] add: add server_test.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加单元测试 Test_server_handleBatch --- server/server_test.go | 178 +++++++++++++++++++++++++++++++++++------- 1 file changed, 151 insertions(+), 27 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index eba780a..af40dd7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,17 +1,19 @@ package server import ( + "context" "encoding/base64" "errors" "fmt" + "github.com/go-chi/chi" "github.com/huaweicloud/huaweicloud-sdk-go-obs/obs" "github.com/metalogical/BigFiles/auth" "github.com/metalogical/BigFiles/batch" "net/http" "net/http/httptest" - "net/url" "reflect" "regexp" + "strings" "testing" "time" ) @@ -27,10 +29,10 @@ type ServerInfo struct { var serverInfo = ServerInfo{ ttl: time.Hour, - isAuthorized: auth.GiteeAuth(), bucket: "Bucket", prefix: "Prefix", cdnDomain: "CDNDomain", + isAuthorized: auth.GiteeAuth(), } func TestNew(t *testing.T) { @@ -253,7 +255,6 @@ func Test_server_dealWithAuthError(t *testing.T) { args args wantErr bool }{ - // TODO: Add test cases. { name: "deal with auth without username and password", fields: serverInfo, @@ -306,11 +307,26 @@ func Test_server_downloadObject(t *testing.T) { out *batch.Object } tests := []struct { - name string - fields ServerInfo - args args + name string + fields ServerInfo + args args + wantErr bool }{ - // TODO: Add test cases. + { + name: "download object failed", + fields: serverInfo, + args: args{ + in: &batch.RequestObject{ + OID: "123456789", + Size: 100, + }, + out: &batch.Object{ + OID: "123456789", + Size: 100, + }, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -322,6 +338,19 @@ func Test_server_downloadObject(t *testing.T) { cdnDomain: tt.fields.cdnDomain, isAuthorized: tt.fields.isAuthorized, } + defer func() { + if r := recover(); r != nil { + // 如果捕获到了panic,检查错误信息是否符合预期 + _, ok := r.(error) + if ok && tt.wantErr { + return + } else { + t.Errorf("unexpected panic value or wantErr mismatch") + } + } else if tt.wantErr { + t.Errorf("expected panic but none occurred") + } + }() s.downloadObject(tt.args.in, tt.args.out) }) } @@ -332,12 +361,26 @@ func Test_server_generateDownloadUrl(t *testing.T) { getObjectInput *obs.CreateSignedUrlInput } tests := []struct { - name string - fields ServerInfo - args args - want *url.URL + name string + fields ServerInfo + args args + wantErr bool }{ // TODO: Add test cases. + { + name: "generate download url", + fields: serverInfo, + args: args{ + getObjectInput: &obs.CreateSignedUrlInput{ + Method: obs.HttpMethodGet, + Bucket: serverInfo.bucket, + Key: "123456789", + Expires: int(serverInfo.ttl / time.Second), + Headers: map[string]string{contentType: "application/octet-stream"}, + }, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -349,8 +392,21 @@ func Test_server_generateDownloadUrl(t *testing.T) { cdnDomain: tt.fields.cdnDomain, isAuthorized: tt.fields.isAuthorized, } - if got := s.generateDownloadUrl(tt.args.getObjectInput); !reflect.DeepEqual(got, tt.want) { - t.Errorf("generateDownloadUrl() = %v, want %v", got, tt.want) + defer func() { + if r := recover(); r != nil { + // 如果捕获到了panic,检查错误信息是否符合预期 + _, ok := r.(error) + if ok && tt.wantErr { + return + } else { + t.Errorf("unexpected panic value or wantErr mismatch") + } + } else if tt.wantErr { + t.Errorf("expected panic but none occurred") + } + }() + if got := s.generateDownloadUrl(tt.args.getObjectInput); got != nil { + t.Errorf("generateDownloadUrl() = %v", got) } }) } @@ -361,13 +417,19 @@ func Test_server_getObjectMetadataInput(t *testing.T) { key string } tests := []struct { - name string - fields ServerInfo - args args - wantOutput *obs.GetObjectMetadataOutput - wantErr bool + name string + fields ServerInfo + args args + wantErr bool }{ - // TODO: Add test cases. + { + name: "getObjectMetadataInput success", + fields: serverInfo, + args: args{ + key: "123456789", + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -379,14 +441,24 @@ func Test_server_getObjectMetadataInput(t *testing.T) { cdnDomain: tt.fields.cdnDomain, isAuthorized: tt.fields.isAuthorized, } - gotOutput, err := s.getObjectMetadataInput(tt.args.key) + defer func() { + if r := recover(); r != nil { + // 如果捕获到了panic,检查错误信息是否符合预期 + _, ok := r.(error) + if ok && tt.wantErr { + return + } else { + t.Errorf("unexpected panic value or wantErr mismatch") + } + } else if tt.wantErr { + t.Errorf("expected panic but none occurred") + } + }() + _, err := s.getObjectMetadataInput(tt.args.key) if (err != nil) != tt.wantErr { t.Errorf("getObjectMetadataInput() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(gotOutput, tt.wantOutput) { - t.Errorf("getObjectMetadataInput() gotOutput = %v, want %v", gotOutput, tt.wantOutput) - } }) } } @@ -396,12 +468,49 @@ func Test_server_handleBatch(t *testing.T) { w http.ResponseWriter r *http.Request } + requestBodyText := `{ + "operation": "download", + "objects": [ + { + "oid": "123456", + "Size": 100 + } + ] + }` + requestBody := strings.NewReader(requestBodyText) + owner := "test_owner" + repo := "test_repo" + // 创建一个带有路径参数的请求路径,这里将owner作为路径参数添加到URL中 + requestPath := "/{owner}/{repo}/objects/batch" + req := httptest.NewRequest(http.MethodGet, requestPath, requestBody) + ctx := chi.NewRouteContext() + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, ctx)) + ctx.URLParams.Add("owner", owner) + ctx.URLParams.Add("repo", repo) + validatecfg.ownerRegexp, _ = regexp.Compile("^[a-zA-Z]([-_.]?[a-zA-Z0-9]+)*$") + validatecfg.reponameRegexp, _ = regexp.Compile("^[a-zA-Z0-9_.-]{1,189}[a-zA-Z0-9]$") tests := []struct { - name string - fields ServerInfo - args args + name string + fields ServerInfo + args args + wantErr bool }{ - // TODO: Add test cases. + { + name: "server handleBatch success with nil requestBody", + fields: serverInfo, + args: args{ + r: httptest.NewRequest(http.MethodGet, "/owner/repo/objects/batch", nil), + }, + wantErr: false, + }, + { + name: "server handleBatch success", + fields: serverInfo, + args: args{ + r: req, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -413,6 +522,21 @@ func Test_server_handleBatch(t *testing.T) { cdnDomain: tt.fields.cdnDomain, isAuthorized: tt.fields.isAuthorized, } + w := httptest.NewRecorder() + tt.args.w = w + defer func() { + if r := recover(); r != nil { + // 如果捕获到了panic,检查错误信息是否符合预期 + _, ok := r.(error) + if ok && tt.wantErr { + return + } else { + t.Errorf("unexpected panic value or wantErr mismatch") + } + } else if tt.wantErr { + t.Errorf("expected panic but none occurred") + } + }() s.handleBatch(tt.args.w, tt.args.r) }) }