From 28a83fd0a68b06462ea6ab32d8b05f30f769051b Mon Sep 17 00:00:00 2001 From: Fantix King Date: Tue, 18 Jun 2024 20:45:15 -0400 Subject: [PATCH] blocking client: fix connect and timeout (#499) * fix ipv6 connection issues * Fix wait_for() in blocking client to raise internal TimeoutError * Add manual name resolve in blocking client This will try to connect to all resolved addresses. Timeout handling is also improved. --------- Co-authored-by: Zachary Juang --- edgedb/blocking_client.py | 73 ++++++++++++++++++++++-------- edgedb/protocol/blocking_proto.pyx | 12 ++--- 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index 97931501..e59fb2c7 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -47,21 +47,57 @@ async def connect_addr(self, addr, timeout): if isinstance(addr, str): # UNIX socket - sock = socket.socket(socket.AF_UNIX) + res_list = [(socket.AF_UNIX, socket.SOCK_STREAM, -1, None, addr)] else: - sock = socket.socket(socket.AF_INET) - - try: - sock.settimeout(timeout) + host, port = addr + try: + # getaddrinfo() doesn't take timeout!! + res_list = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + except socket.gaierror as e: + # All name resolution errors are considered temporary + err = errors.ClientConnectionFailedTemporarilyError(str(e)) + raise err from e + for i, res in enumerate(res_list): + af, socktype, proto, _, sa = res try: - sock.connect(addr) + sock = socket.socket(af, socktype, proto) + except OSError as e: + sock.close() + if i < len(res_list) - 1: + continue + else: + raise con_utils.wrap_error(e) from e + try: + await self._connect_addr(sock, addr, sa, deadline) + except TimeoutError: + raise + except Exception: + if i < len(res_list) - 1: + continue + else: + raise + else: + break - if not isinstance(addr, str): - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError + async def _connect_addr(self, sock, addr, sa, deadline): + try: + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + try: + sock.settimeout(time_left) + sock.connect(sa) + except OSError as e: + raise con_utils.wrap_error(e) from e + if not isinstance(addr, str): + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + try: # Upgrade to TLS sock.settimeout(time_left) try: @@ -74,12 +110,8 @@ async def connect_addr(self, addr, timeout): raise con_utils.wrap_error(e) from e else: con_utils.check_alpn_protocol(sock) - except socket.gaierror as e: - # All name resolution errors are considered temporary - err = errors.ClientConnectionFailedTemporarilyError(str(e)) - raise err from e - except OSError as e: - raise con_utils.wrap_error(e) from e + except OSError as e: + raise con_utils.wrap_error(e) from e time_left = deadline - time.monotonic() if time_left <= 0: @@ -92,9 +124,9 @@ async def connect_addr(self, addr, timeout): proto.set_connection(self) try: - sock.settimeout(time_left) - await proto.connect() - sock.settimeout(None) + await proto.wait_for(proto.connect(), time_left) + except TimeoutError: + raise except OSError as e: raise con_utils.wrap_error(e) from e @@ -133,6 +165,9 @@ async def close(self, timeout=None): await self._protocol.wait_for( self._protocol.wait_for_disconnect(), timeout ) + except TimeoutError: + self.terminate() + raise errors.QueryTimeoutError() except Exception: self.terminate() raise diff --git a/edgedb/protocol/blocking_proto.pyx b/edgedb/protocol/blocking_proto.pyx index 51c52140..ea4c1c16 100644 --- a/edgedb/protocol/blocking_proto.pyx +++ b/edgedb/protocol/blocking_proto.pyx @@ -63,18 +63,14 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocolBackwardsCompatible): async def wait_for_message(self): cdef float timeout if self.deadline > 0: - timeout = self.deadline - time.monotonic() - if timeout <= 0: - self.abort() - raise errors.QueryTimeoutError() while not self.buffer.take_message(): + timeout = self.deadline - time.monotonic() + if timeout <= 0: + self.abort() + raise TimeoutError try: self.sock.settimeout(timeout) data = self.sock.recv(RECV_BUF) - timeout = self.deadline - time.monotonic() - if timeout <= 0: - self.abort() - raise TimeoutError except OSError as e: self._disconnect() raise con_utils.wrap_error(e) from e