From a6cbf6eb08146dc94e8b1034ae5c97795d8cd733 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Tue, 28 May 2024 16:53:59 -0400 Subject: [PATCH] Add manual name resolve in blocking client This will try to connect to all resolved addresses. Timeout handling is also improved. --- edgedb/blocking_client.py | 72 +++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 21 deletions(-) diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index 1fa9b33e..e59fb2c7 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -47,23 +47,57 @@ async def connect_addr(self, addr, timeout): if isinstance(addr, str): # UNIX socket - sock = socket.socket(socket.AF_UNIX) + res_list = [(socket.AF_UNIX, socket.SOCK_STREAM, -1, None, addr)] else: - addr_info = socket.getaddrinfo(addr[0], addr[1], type=socket.SOCK_STREAM)[0] - addr = addr_info[4] - sock = socket.socket(addr_info[0], addr_info[1]) - - try: - sock.settimeout(timeout) + host, port = addr + try: + # getaddrinfo() doesn't take timeout!! + res_list = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + except socket.gaierror as e: + # All name resolution errors are considered temporary + err = errors.ClientConnectionFailedTemporarilyError(str(e)) + raise err from e + for i, res in enumerate(res_list): + af, socktype, proto, _, sa = res try: - sock.connect(addr) + sock = socket.socket(af, socktype, proto) + except OSError as e: + sock.close() + if i < len(res_list) - 1: + continue + else: + raise con_utils.wrap_error(e) from e + try: + await self._connect_addr(sock, addr, sa, deadline) + except TimeoutError: + raise + except Exception: + if i < len(res_list) - 1: + continue + else: + raise + else: + break - if not isinstance(addr, str): - time_left = deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError + async def _connect_addr(self, sock, addr, sa, deadline): + try: + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + try: + sock.settimeout(time_left) + sock.connect(sa) + except OSError as e: + raise con_utils.wrap_error(e) from e + if not isinstance(addr, str): + time_left = deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + try: # Upgrade to TLS sock.settimeout(time_left) try: @@ -76,12 +110,8 @@ async def connect_addr(self, addr, timeout): raise con_utils.wrap_error(e) from e else: con_utils.check_alpn_protocol(sock) - except socket.gaierror as e: - # All name resolution errors are considered temporary - err = errors.ClientConnectionFailedTemporarilyError(str(e)) - raise err from e - except OSError as e: - raise con_utils.wrap_error(e) from e + except OSError as e: + raise con_utils.wrap_error(e) from e time_left = deadline - time.monotonic() if time_left <= 0: @@ -94,9 +124,9 @@ async def connect_addr(self, addr, timeout): proto.set_connection(self) try: - sock.settimeout(time_left) - await proto.connect() - sock.settimeout(None) + await proto.wait_for(proto.connect(), time_left) + except TimeoutError: + raise except OSError as e: raise con_utils.wrap_error(e) from e