diff --git a/Sources/MultipartKit/MultipartParser+parse.swift b/Sources/MultipartKit/MultipartParser+parse.swift index 98c689c..b303e3e 100644 --- a/Sources/MultipartKit/MultipartParser+parse.swift +++ b/Sources/MultipartKit/MultipartParser+parse.swift @@ -1,16 +1,16 @@ import HTTPTypes extension MultipartParser { - public static func parse(_ data: some Collection, boundary: some Collection) throws -> [MultipartPart>] { + public func parse(_ data: some Collection) throws -> [MultipartPart>] { var output: [MultipartPart>] = [] - var parser = MultipartParser(boundary: boundary) - + var parser = MultipartParser(boundary: self.boundary) + var currentHeaders: HTTPFields? var currentBody = ArraySlice() // Append data to the parser and process the sections parser.append(buffer: data) - + while true { switch parser.read() { case .success(let optionalPart): diff --git a/Sources/MultipartKit/MultipartParser.swift b/Sources/MultipartKit/MultipartParser.swift index 8ac7ab7..0d79f7c 100644 --- a/Sources/MultipartKit/MultipartParser.swift +++ b/Sources/MultipartKit/MultipartParser.swift @@ -21,13 +21,13 @@ public struct MultipartParser { let boundary: ArraySlice private var state: State - + init(boundary: some Collection) { self.boundary = .init([45, 45] + boundary) self.state = .initial } - init(boundary: String) { + public init(boundary: String) { self.boundary = [45, 45] + ArraySlice(boundary.utf8) self.state = .initial } diff --git a/Sources/MultipartKit/MultipartPart.swift b/Sources/MultipartKit/MultipartPart.swift index 9522cec..e5a16f4 100644 --- a/Sources/MultipartKit/MultipartPart.swift +++ b/Sources/MultipartKit/MultipartPart.swift @@ -1,11 +1,18 @@ import HTTPTypes -public struct MultipartPart>: Equatable, Sendable where Body: Sendable & Equatable { - public let headerFields: HTTPFields +public typealias MultipartPartBodyElement = Collection & Equatable & Sendable + +public struct MultipartPart: Equatable, Sendable { + public var headerFields: HTTPFields public var body: Body - + public init(headerFields: HTTPFields, body: Body) { self.headerFields = headerFields self.body = body } + + public var name: String? { + get { self.headerFields.getParameter(.contentDisposition, "name") } + set { self.headerFields.setParameter(.contentDisposition, "name", to: newValue, defaultValue: "form-data") } + } } diff --git a/Sources/MultipartKit/MultipartSerializer.swift b/Sources/MultipartKit/MultipartSerializer.swift index 9db78b7..c918771 100644 --- a/Sources/MultipartKit/MultipartSerializer.swift +++ b/Sources/MultipartKit/MultipartSerializer.swift @@ -1,7 +1,14 @@ /// Serializes `MultipartForm`s to `Data`. /// /// See `MultipartParser` for more information about the multipart encoding. -public enum MultipartSerializer: Sendable { +public struct MultipartSerializer: Sendable { + let boundary: String + + /// Creates a new `MultipartSerializer`. + init(boundary: String) { + self.boundary = boundary + } + /// Serializes the `MultipartForm` to data. /// /// let data = try MultipartSerializer().serialize(parts: [part], boundary: "123") @@ -12,9 +19,9 @@ public enum MultipartSerializer: Sendable { /// - boundary: Multipart boundary to use for encoding. This must not appear anywhere in the encoded data. /// - throws: Any errors that may occur during serialization. /// - returns: `multipart`-encoded `Data`. - public static func serialize(parts: [MultipartPart>], boundary: String) throws -> String { + public func serialize(parts: [MultipartPart>]) throws -> String { var buffer = [UInt8]() - try self.serialize(parts: parts, boundary: boundary, into: &buffer) + try self.serialize(parts: parts, into: &buffer) return String(decoding: buffer, as: UTF8.self) } @@ -29,7 +36,7 @@ public enum MultipartSerializer: Sendable { /// - boundary: Multipart boundary to use for encoding. This must not appear anywhere in the encoded data. /// - buffer: Buffer to write to. /// - throws: Any errors that may occur during serialization. - public static func serialize(parts: [MultipartPart>], boundary: String, into buffer: inout [UInt8]) throws { + public func serialize(parts: [MultipartPart>], into buffer: inout [UInt8]) throws { for part in parts { buffer.append(contentsOf: Array("--\(boundary)\r\n".utf8)) for field in part.headerFields { diff --git a/Sources/MultipartKit/Utilities.swift b/Sources/MultipartKit/Utilities.swift new file mode 100644 index 0000000..ad0727c --- /dev/null +++ b/Sources/MultipartKit/Utilities.swift @@ -0,0 +1,50 @@ +import Foundation +import HTTPTypes + +extension HTTPFields { + func getParameter(_ name: HTTPField.Name, _ key: String) -> String? { + headerParts(name: name)? + .filter { $0.contains("\(key)=") } + .first? + .split(separator: "=") + .last? + .trimmingCharacters(in: .quotes) + } + + mutating func setParameter( + _ name: HTTPField.Name, + _ key: String, + to value: String?, + defaultValue: String + ) { + var current: [String] + + if let existing = self.headerParts(name: name) { + current = existing.filter { !$0.hasPrefix("\(key)=") } + } else { + current = [defaultValue] + } + + if let value = value { + current.append("\(key)=\"\(value)\"") + } + + let new = current.joined(separator: "; ").trimmingCharacters(in: .whitespaces) + + self[name] = new + } + + func headerParts(name: HTTPField.Name) -> [String]? { + self[name] + .flatMap { + $0.split(separator: ";") + .map { $0.trimmingCharacters(in: .whitespaces) } + } + } +} + +extension CharacterSet { + static var quotes: CharacterSet { + return .init(charactersIn: #""'"#) + } +} diff --git a/Tests/MultipartKitTests/ParserTests.swift b/Tests/MultipartKitTests/ParserTests.swift index 59ad4c1..fcabc31 100644 --- a/Tests/MultipartKitTests/ParserTests.swift +++ b/Tests/MultipartKitTests/ParserTests.swift @@ -126,7 +126,7 @@ struct ParserTests { } } } - + @Test("Parse Synchronously") func parseSynchronously() async throws { let boundary = "boundary123" @@ -138,14 +138,17 @@ struct ParserTests { 123e4567-e89b-12d3-a456-426655440000\r \(boundary)-- """ - - let parts = try MultipartParser.parse([UInt8](message.utf8), boundary: [UInt8](boundary.utf8)) - + + let parts = try MultipartParser(boundary: boundary) + .parse([UInt8](message.utf8)) + #expect(parts.count == 1) - #expect(parts[0].headerFields == .init([ - .init(name: .contentDisposition, value: "form-data; name=\"id\""), - .init(name: .contentType, value: "text/plain"), - ])) + #expect( + parts[0].headerFields + == .init([ + .init(name: .contentDisposition, value: "form-data; name=\"id\""), + .init(name: .contentType, value: "text/plain"), + ])) #expect(parts[0].body == ArraySlice("123e4567-e89b-12d3-a456-426655440000".utf8)) } diff --git a/Tests/MultipartKitTests/SerializerTests.swift b/Tests/MultipartKitTests/SerializerTests.swift index 2ae08de..0a9ecd0 100644 --- a/Tests/MultipartKitTests/SerializerTests.swift +++ b/Tests/MultipartKitTests/SerializerTests.swift @@ -13,9 +13,9 @@ struct SerializerTests { .init(name: .contentType, value: "text/plain"), ]), body: ArraySlice("Hello, world!".utf8) - ), + ) ] - + let serialized = try MultipartSerializer.serialize(parts: example, boundary: "boundary123") } }