diff --git a/internal/elements.go b/internal/elements.go index db7d960..7fb6d3f 100644 --- a/internal/elements.go +++ b/internal/elements.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/url" + "regexp" "strconv" "strings" "time" @@ -377,10 +378,19 @@ type GetETag struct { type ETag string func (etag *ETag) UnmarshalText(b []byte) error { - s, err := strconv.Unquote(string(b)) + s := string(b) + shouldUnquote, err := regexp.MatchString("^(['\"\\x60])", s) if err != nil { - return fmt.Errorf("webdav: failed to unquote ETag: %v", err) + return fmt.Errorf("webdav: unquote check failed: %v", err) } + + if shouldUnquote { + s, err = strconv.Unquote(s) + if err != nil { + return fmt.Errorf("webdav: failed to unquote ETag: %v", err) + } + } + *etag = ETag(s) return nil } diff --git a/internal/elements_test.go b/internal/elements_test.go index 73a6ef5..3f47c7b 100644 --- a/internal/elements_test.go +++ b/internal/elements_test.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "encoding/binary" "encoding/xml" "strings" "testing" @@ -63,3 +64,69 @@ func TestTimeRoundTrip(t *testing.T) { t.Fatalf("invalid round-trip:\ngot= %s\nwant=%s", got, want) } } + +func TestETag_UnmarshalText(t *testing.T) { + type args struct { + b []byte + } + tests := []struct { + name string + etag ETag + args args + wantErr bool + }{ + { + name: "should unmarshal unquoted string", + etag: "", + args: args{ + b: []byte("hello world"), + }, + wantErr: false, + }, + { + name: "should unmarshal double quoted string", + etag: "", + args: args{ + b: []byte("\"hello world\""), + }, + wantErr: false, + }, + { + name: "shouldn't unmarshal single quoted string", + etag: "", + args: args{ + b: []byte("'hello world'"), + }, + wantErr: true, + }, + { + name: "should unmarshal backward quoted string", + etag: "", + args: args{ + b: []byte("`hello world`"), + }, + wantErr: false, + }, + { + name: "should unmarshal int", + etag: "", + args: args{ + b: func() []byte { + buf := new(bytes.Buffer) + num := 162392347123 + binary.Write(buf, binary.LittleEndian, num) + return buf.Bytes() + }(), + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.etag.UnmarshalText(tt.args.b); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}