Skip to content

Commit

Permalink
Add support for binding Unix sockets in Linux's abstract namespace.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dadeos-Menlo committed Jul 2, 2024
1 parent bdfc017 commit e169be7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
23 changes: 13 additions & 10 deletions tornado/netutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,20 @@ def bind_unix_socket(
# Hurd doesn't support SO_REUSEADDR
raise
sock.setblocking(False)
try:
st = os.stat(file)
except FileNotFoundError:
pass
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
if not file.startswith("\0"):
try:
st = os.stat(file)
except FileNotFoundError:
pass
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
else:
sock.bind(file)
sock.listen(backlog)
return sock

Expand Down
44 changes: 30 additions & 14 deletions tornado/test/httpserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,37 +834,53 @@ def setUp(self):
super().setUp()
self.tmpdir = tempfile.mkdtemp()
self.sockfile = os.path.join(self.tmpdir, "test.sock")
sock = netutil.bind_unix_socket(self.sockfile)
app = Application([("/hello", HelloWorldRequestHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
self.stream = IOStream(socket.socket(socket.AF_UNIX))
self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile))
if sys.platform.startswith("linux"):
self.sockabstract = "\0" + os.path.basename(self.tmpdir)
self.server.add_socket(netutil.bind_unix_socket(self.sockabstract))
self.server.add_socket(netutil.bind_unix_socket(self.sockfile))

def tearDown(self):
self.stream.close()
self.io_loop.run_sync(self.server.close_all_connections)
self.server.stop()
shutil.rmtree(self.tmpdir)
super().tearDown()

@gen_test
def test_unix_socket(self):
self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
response = yield self.stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
header_data = yield self.stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(header_data.decode("latin1"))
body = yield self.stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.sockfile)
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
response = yield stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
header_data = yield stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(header_data.decode("latin1"))
body = yield stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")

@unittest.skipUnless(sys.platform.startswith("linux"), "requires Linux")
@gen_test
def test_unix_socket_abstract(self):
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.sockabstract)
stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
response = yield stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
header_data = yield stream.read_until(b"\r\n\r\n")
headers = HTTPHeaders.parse(header_data.decode("latin1"))
body = yield stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")

@gen_test
def test_unix_socket_bad_request(self):
# Unix sockets don't have remote addresses so they just return an
# empty string.
with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO):
self.stream.write(b"garbage\r\n\r\n")
response = yield self.stream.read_until_close()
with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
stream.connect(self.sockfile)
stream.write(b"garbage\r\n\r\n")
response = yield stream.read_until_close()
self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")


Expand Down

0 comments on commit e169be7

Please sign in to comment.