From 5ba23d8be8a7d401d6c5dda0d7b5038e83f301cd Mon Sep 17 00:00:00 2001 From: Shinku Date: Thu, 30 Nov 2023 01:06:06 +0800 Subject: [PATCH] Try to reuse request body on Retry When calling Request.Retry, and request has a body (e.g. POST form or file upload), try to rewind it. If it's not seekable, return ErrRetryBodyUnSeekable error. --- colly.go | 12 +++++++- colly_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++++++ request.go | 3 ++ 3 files changed, 93 insertions(+), 1 deletion(-) diff --git a/colly.go b/colly.go index 4a957d6c..451b6766 100644 --- a/colly.go +++ b/colly.go @@ -233,6 +233,8 @@ var ( ErrQueueFull = errors.New("Queue MaxSize reached") // ErrMaxRequests is the error returned when exceeding max requests ErrMaxRequests = errors.New("Max Requests limit reached") + // ErrRetryBodyUnseekable is the error when retry with not seekable body + ErrRetryBodyUnseekable = errors.New("Retry Body Unseekable") ) var envMap = map[string]func(*Collector, string){ @@ -629,6 +631,13 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c if _, ok := hdr["User-Agent"]; !ok { hdr.Set("User-Agent", c.UserAgent) } + if seeker, ok := requestData.(io.ReadSeeker); ok { + _, err := seeker.Seek(0, io.SeekStart) + if err != nil { + return err + } + } + req, err := http.NewRequest(method, parsedURL.String(), requestData) if err != nil { return err @@ -1440,7 +1449,8 @@ func createMultipartReader(boundary string, data map[string][]byte) io.Reader { buffer.WriteString("\n") } buffer.WriteString(dashBoundary + "--\n\n") - return buffer + return bytes.NewReader(buffer.Bytes()) + } // randomBoundary was borrowed from diff --git a/colly_test.go b/colly_test.go index e330fc2e..2382ecb1 100644 --- a/colly_test.go +++ b/colly_test.go @@ -1703,3 +1703,82 @@ func requireSessionCookieAuthPage(handler http.Handler) http.Handler { handler.ServeHTTP(w, r) }) } + +func TestCollectorPostRetry(t *testing.T) { + ts := newTestServer() + defer ts.Close() + + postValue := "hello" + c := NewCollector() + try := false + c.OnResponse(func(r *Response) { + if r.Ctx.Get("notFirst") == "" { + r.Ctx.Put("notFirst", "first") + _ = r.Request.Retry() + return + } + if postValue != string(r.Body) { + t.Error("Failed to send data with POST") + } + try = true + }) + + c.Post(ts.URL+"/login", map[string]string{ + "name": postValue, + }) + if !try { + t.Error("OnResponse Retry was not called") + } +} +func TestCollectorGetRetry(t *testing.T) { + ts := newTestServer() + defer ts.Close() + try := false + + c := NewCollector() + + c.OnResponse(func(r *Response) { + if r.Ctx.Get("notFirst") == "" { + r.Ctx.Put("notFirst", "first") + _ = r.Request.Retry() + return + } + if !bytes.Equal(r.Body, serverIndexResponse) { + t.Error("Response body does not match with the original content") + } + try = true + }) + + c.Visit(ts.URL) + if !try { + t.Error("OnResponse Retry was not called") + } +} + +func TestCollectorPostRetryUnseekable(t *testing.T) { + ts := newTestServer() + defer ts.Close() + try := false + postValue := "hello" + c := NewCollector() + + c.OnResponse(func(r *Response) { + if postValue != string(r.Body) { + t.Error("Failed to send data with POST") + } + + if r.Ctx.Get("notFirst") == "" { + r.Ctx.Put("notFirst", "first") + err := r.Request.Retry() + if !errors.Is(err, ErrRetryBodyUnseekable) { + t.Errorf("Unexpected error Type ErrRetryBodyUnseekable : %v", err) + } + return + } + try = true + }) + c.Request("POST", ts.URL+"/login", bytes.NewBuffer([]byte("name="+postValue)), nil, nil) + if try { + t.Error("OnResponse Retry was called but BodyUnseekable") + } +} diff --git a/request.go b/request.go index c2c15c76..5c80e2bb 100644 --- a/request.go +++ b/request.go @@ -152,6 +152,9 @@ func (r *Request) PostMultipart(URL string, requestData map[string][]byte) error // Retry submits HTTP request again with the same parameters func (r *Request) Retry() error { r.Headers.Del("Cookie") + if _, ok := r.Body.(io.ReadSeeker); r.Body != nil && !ok { + return ErrRetryBodyUnseekable + } return r.collector.scrape(r.URL.String(), r.Method, r.Depth, r.Body, r.Ctx, *r.Headers, false) }