diff --git a/entity.go b/entity.go index e2bd80c4..a4cad9d2 100644 --- a/entity.go +++ b/entity.go @@ -16,6 +16,10 @@ type Entity struct { mediaType string mediaParams map[string]string + + originalBody io.Reader + encoding string + charset string } // New makes a new message with the provided header and body. The entity's @@ -25,8 +29,10 @@ type Entity struct { // error that verifies IsUnknownCharset, but also returns an Entity that can // be read. func New(header Header, body io.Reader) (*Entity, error) { - var err error + originalBody := body + var encoding, charset string + var err error mediaType, mediaParams, _ := header.ContentType() // QUIRK: RFC 2045 section 6.4 specifies that multipart messages can't have @@ -35,8 +41,8 @@ func New(header Header, body io.Reader) (*Entity, error) { // e.g. "quoted-printable". So we just ignore it for multipart. // See https://github.com/emersion/go-message/issues/48 if !strings.HasPrefix(mediaType, "multipart/") { - enc := header.Get("Content-Transfer-Encoding") - if decoded, encErr := encodingReader(enc, body); encErr != nil { + encoding = header.Get("Content-Transfer-Encoding") + if decoded, encErr := encodingReader(encoding, body); encErr != nil { err = unknownEncodingError{encErr} } else { body = decoded @@ -45,8 +51,9 @@ func New(header Header, body io.Reader) (*Entity, error) { // RFC 2046 section 4.1.2: charset only applies to text/* if strings.HasPrefix(mediaType, "text/") { - if ch, ok := mediaParams["charset"]; ok { - if converted, charsetErr := charsetReader(ch, body); charsetErr != nil { + var ok bool + if charset, ok = mediaParams["charset"]; ok { + if converted, charsetErr := charsetReader(charset, body); charsetErr != nil { err = unknownCharsetError{charsetErr} } else { body = converted @@ -55,10 +62,13 @@ func New(header Header, body io.Reader) (*Entity, error) { } return &Entity{ - Header: header, - Body: body, - mediaType: mediaType, - mediaParams: mediaParams, + Header: header, + Body: body, + mediaType: mediaType, + mediaParams: mediaParams, + originalBody: originalBody, + encoding: encoding, + charset: charset, }, err } @@ -111,6 +121,8 @@ func (e *Entity) writeBodyTo(w *Writer) error { var err error if mb, ok := e.Body.(*multipartBody); ok { err = mb.writeBodyTo(w) + } else if w.encoding == e.encoding && w.charset == e.charset { + _, err = io.Copy(w.rawWriter, e.originalBody) } else { _, err = io.Copy(w, e.Body) } diff --git a/entity_test.go b/entity_test.go index 775b61ce..b0d8c9cf 100644 --- a/entity_test.go +++ b/entity_test.go @@ -150,6 +150,24 @@ func TestRead_single(t *testing.T) { } } +func TestEntity_WriteTo_original(t *testing.T) { + e := testMakeEntity() + + var b bytes.Buffer + if err := e.WriteTo(&b); err != nil { + t.Fatal("Expected no error while writing entity, got", err) + } + + expected := "Content-Transfer-Encoding: base64\r\n" + + "Content-Type: text/plain; charset=US-ASCII\r\n" + + "\r\n" + + "Y2Mgc2F2YQ==" + + if s := b.String(); s != expected { + t.Errorf("Expected written entity to be:\n%s\nbut got:\n%s", expected, s) + } +} + func TestEntity_WriteTo_decode(t *testing.T) { e := testMakeEntity() diff --git a/writer.go b/writer.go index 21c3cd3d..3f377915 100644 --- a/writer.go +++ b/writer.go @@ -21,12 +21,16 @@ type Writer struct { w io.Writer c io.Closer mw *textproto.MultipartWriter + + rawWriter io.Writer + encoding string + charset string } // createWriter creates a new Writer writing to w with the provided header. // Nothing is written to w when it is called. header is modified in-place. func createWriter(w io.Writer, header *Header) (*Writer, error) { - ww := &Writer{w: w} + ww := &Writer{w: w, rawWriter: w} mediaType, mediaParams, _ := header.ContentType() if strings.HasPrefix(mediaType, "multipart/") { @@ -46,7 +50,8 @@ func createWriter(w io.Writer, header *Header) (*Writer, error) { header.Del("Content-Transfer-Encoding") } else { - wc, err := encodingWriter(header.Get("Content-Transfer-Encoding"), ww.w) + ww.encoding = header.Get("Content-Transfer-Encoding") + wc, err := encodingWriter(ww.encoding, ww.w) if err != nil { return nil, err } @@ -54,12 +59,13 @@ func createWriter(w io.Writer, header *Header) (*Writer, error) { ww.c = wc } - switch strings.ToLower(mediaParams["charset"]) { + ww.charset = mediaParams["charset"] + switch strings.ToLower(ww.charset) { case "", "us-ascii", "utf-8": // This is OK default: // Anything else is invalid - return nil, fmt.Errorf("unhandled charset %q", mediaParams["charset"]) + return nil, fmt.Errorf("unhandled charset %q", ww.charset) } return ww, nil