diff --git a/src/Rfc6455ClientFactory.php b/src/Rfc6455ClientFactory.php index b5821f8..b93bddd 100644 --- a/src/Rfc6455ClientFactory.php +++ b/src/Rfc6455ClientFactory.php @@ -8,7 +8,7 @@ use Amp\Http\Server\Request; use Amp\Http\Server\Response; use Amp\Socket\Socket; -use Amp\Websocket\Compression\WebsocketCompressionContextFactory; +use Amp\Websocket\Compression\WebsocketCompressionContext; use Amp\Websocket\ConstantRateLimit; use Amp\Websocket\Parser\Rfc6455ParserFactory; use Amp\Websocket\Parser\WebsocketParserFactory; @@ -24,13 +24,10 @@ final class Rfc6455ClientFactory implements WebsocketClientFactory use ForbidSerialization; /** - * @param WebsocketCompressionContextFactory|null $compressionContextFactory Deprecated. This argument is unused. - * Compression is not supported in v3.x but will be in v4.x. * @param WebsocketHeartbeatQueue|null $heartbeatQueue Use null to disable automatic heartbeats (pings). * @param WebsocketRateLimit|null $rateLimit Use null to disable client rate limits. */ public function __construct( - private readonly ?WebsocketCompressionContextFactory $compressionContextFactory = null, private readonly ?WebsocketHeartbeatQueue $heartbeatQueue = new PeriodicHeartbeatQueue(), private readonly ?WebsocketRateLimit $rateLimit = new ConstantRateLimit(), private readonly WebsocketParserFactory $parserFactory = new Rfc6455ParserFactory(), @@ -43,6 +40,7 @@ public function createClient( Request $request, Response $response, Socket $socket, + ?WebsocketCompressionContext $compressionContext, ): WebsocketClient { if ($socket instanceof ResourceStream) { $socketResource = $socket->getResource(); @@ -69,6 +67,7 @@ public function createClient( socket: $socket, masked: false, parserFactory: $this->parserFactory, + compressionContext: $compressionContext, heartbeatQueue: $this->heartbeatQueue, rateLimit: $this->rateLimit, frameSplitThreshold: $this->frameSplitThreshold, diff --git a/src/Rfc7692CompressionNegotiator.php b/src/Rfc7692CompressionNegotiator.php new file mode 100644 index 0000000..305f3d9 --- /dev/null +++ b/src/Rfc7692CompressionNegotiator.php @@ -0,0 +1,34 @@ +compressionContextFactory = new Rfc7692CompressionFactory(); + } + + public function negotiateCompression(Request $request, Response $response): ?WebsocketCompressionContext + { + $extensions = Http\splitHeader($request, 'sec-websocket-extensions') ?? []; + foreach ($extensions as $extension) { + if ($compressionContext = $this->compressionContextFactory->fromClientHeader($extension, $headerLine)) { + /** @psalm-suppress PossiblyNullArgument */ + $response->setHeader('sec-websocket-extensions', $headerLine); + + return $compressionContext; + } + } + + return null; + } +} diff --git a/src/Websocket.php b/src/Websocket.php index c2ef9f0..e195975 100644 --- a/src/Websocket.php +++ b/src/Websocket.php @@ -11,6 +11,7 @@ use Amp\Http\Server\Request; use Amp\Http\Server\RequestHandler; use Amp\Http\Server\Response; +use Amp\Websocket\Compression\WebsocketCompressionContext; use Amp\Websocket\WebsocketClient; use Amp\Websocket\WebsocketCloseCode; use Amp\Websocket\WebsocketClosedException; @@ -32,6 +33,7 @@ public function __construct( private readonly PsrLogger $logger, private readonly WebsocketAcceptor $acceptor, private readonly WebsocketClientHandler $clientHandler, + private readonly ?WebsocketCompressionNegotiator $compressionNegotiator = null, private readonly WebsocketClientFactory $clientFactory = new Rfc6455ClientFactory(), ) { /** @psalm-suppress PropertyTypeCoercion */ @@ -45,27 +47,40 @@ public function handleRequest(Request $request): Response $response = $this->acceptor->handleHandshake($request); if ($response->getStatus() !== HttpStatus::SWITCHING_PROTOCOLS) { - $response->removeHeader('sec-websocket-accept'); - $response->setHeader('connection', 'close'); + return $this->modifyNonUpgradeResponse($response); + } + + $compressionContext = $this->compressionNegotiator?->negotiateCompression($request, $response); - return $response; + if ($response->getStatus() !== HttpStatus::SWITCHING_PROTOCOLS) { + return $this->modifyNonUpgradeResponse($response); } $response->upgrade(fn (UpgradedSocket $socket) => $this->reapClient( socket: $socket, request: $request, response: $response, + compressionContext: $compressionContext, )); return $response; } + private function modifyNonUpgradeResponse(Response $response): Response + { + $response->removeHeader('sec-websocket-accept'); + $response->setHeader('connection', 'close'); + + return $response; + } + private function reapClient( UpgradedSocket $socket, Request $request, Response $response, + ?WebsocketCompressionContext $compressionContext, ): void { - $client = $this->clientFactory->createClient($request, $response, $socket); + $client = $this->clientFactory->createClient($request, $response, $socket, $compressionContext); /** @psalm-suppress RedundantCondition */ \assert($this->logger->debug(\sprintf( diff --git a/src/WebsocketClientFactory.php b/src/WebsocketClientFactory.php index 7dfbc49..cb846a0 100644 --- a/src/WebsocketClientFactory.php +++ b/src/WebsocketClientFactory.php @@ -5,12 +5,18 @@ use Amp\Http\Server\Request; use Amp\Http\Server\Response; use Amp\Socket\Socket; +use Amp\Websocket\Compression\WebsocketCompressionContext; use Amp\Websocket\WebsocketClient; interface WebsocketClientFactory { /** - * Creates a Client object based on the given Request. + * Creates a {@see WebsocketClient} object based on the given Request. */ - public function createClient(Request $request, Response $response, Socket $socket): WebsocketClient; + public function createClient( + Request $request, + Response $response, + Socket $socket, + ?WebsocketCompressionContext $compressionContext, + ): WebsocketClient; } diff --git a/src/WebsocketCompressionNegotiator.php b/src/WebsocketCompressionNegotiator.php new file mode 100644 index 0000000..a182647 --- /dev/null +++ b/src/WebsocketCompressionNegotiator.php @@ -0,0 +1,17 @@ +createMock(WebsocketClientFactory::class); - $factory->method('createClient') + $clientFactory = $this->createMock(WebsocketClientFactory::class); + $clientFactory->method('createClient') ->willReturn($client); $deferred = new DeferredFuture; $webserver = $this->createWebsocketServer( - $factory, + $clientFactory, function (WebsocketGateway $gateway, WebsocketClient $client) use ($onConnect, $deferred): void { $deferred->complete($onConnect($gateway, $client)); } @@ -66,7 +66,7 @@ function (WebsocketGateway $gateway, WebsocketClient $client) use ($onConnect, $ * @param \Closure(WebsocketGateway, WebsocketClient):void $clientHandler */ protected function createWebsocketServer( - WebsocketClientFactory $factory, + WebsocketClientFactory $clientFactory, \Closure $clientHandler, WebsocketGateway $gateway = new WebsocketClientGateway(), ): SocketHttpServer { @@ -90,7 +90,7 @@ public function handleClient(WebsocketClient $client, Request $request, Response ($this->clientHandler)($this->gateway, $client); } }, - clientFactory: $factory, + clientFactory: $clientFactory, ); $httpServer->expose(new Socket\InternetAddress('127.0.0.1', 0)); @@ -124,6 +124,7 @@ public function testHandshake(Request $request, int $status, array $expectedHead logger: $logger, acceptor: $acceptor, clientHandler: $this->createMock(WebsocketClientHandler::class), + compressionNegotiator: new Rfc7692CompressionNegotiator(), ); $server->start($websocket, $this->createMock(ErrorHandler::class)); @@ -218,6 +219,33 @@ public function provideHandshakes(): iterable HttpStatus::BAD_REQUEST, ["sec-websocket-version" => ["13"]], ]; + + // 8 ----- compression: valid header ------------------------------------------------------> + $request = $this->createRequest(); + $request->setHeader("sec-websocket-extensions", "permessage-deflate; client_max_window_bits"); + yield 'With Valid Compression' => [ + $request, + HttpStatus::SWITCHING_PROTOCOLS, + [ + "upgrade" => ["websocket"], + "connection" => ["upgrade"], + "sec-websocket-accept" => ["HSmrc0sMlYUkAGmm5OPpG2HaGWk="], + "sec-websocket-extensions" => ["permessage-deflate; client_max_window_bits=15"], + ], + ]; + + // 9 ----- compression: invalid header ----------------------------------------------------> + $request = $this->createRequest(); + $request->setHeader("sec-websocket-extensions", "permessage-deflate; client_max_window_bits=8;"); + yield 'With Invalid Compression' => [ + $request, + HttpStatus::SWITCHING_PROTOCOLS, + [ + "upgrade" => ["websocket"], + "connection" => ["upgrade"], + "sec-websocket-accept" => ["HSmrc0sMlYUkAGmm5OPpG2HaGWk="], + ], + ]; } public function testBroadcast(): void