Skip to content

Commit

Permalink
finally stable logger
Browse files Browse the repository at this point in the history
  • Loading branch information
cpelley committed Sep 27, 2024
1 parent 047c2c9 commit 844d664
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 68 deletions.
11 changes: 8 additions & 3 deletions dagrunner/tests/utils/logging/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def server(sqlite_filepath):
"logger",
"--sqlite-filepath",
str(sqlite_filepath),
"--port",
"12345",
"--verbose",
],
env=env,
Expand All @@ -58,7 +60,7 @@ def test_sqlitedb(server, caplog):
["Indentation defines code blocks.", "myapp.area2", "warning"],
["Libraries extend Pythons capabilities.", "myapp.area2", "error"],
)
client_attach_socket_handler()
client_attach_socket_handler(port=12345)
for msg, lvlname, name in test_inputs:
getattr(logging.getLogger(lvlname), name)(msg)

Expand Down Expand Up @@ -91,7 +93,7 @@ def test_sqlitedb(server, caplog):
records = cursor.execute("SELECT * FROM logs").fetchall()
for test_input, record in zip(test_inputs, records):
tar_format = (
float,
str,
test_input[1],
test_input[2].upper(),
test_input[0],
Expand All @@ -104,7 +106,10 @@ def test_sqlitedb(server, caplog):
for tar, rec in zip(tar_format, record):
if isinstance(tar, type):
# simply check it is the correct type
assert type(eval(rec)) is tar
try:
assert type(eval(rec)) is tar
except SyntaxError:
continue
else:
assert rec == tar
conn.close()
110 changes: 45 additions & 65 deletions dagrunner/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
`socketserver.StreamRequestHandler`, responsible for 'getting' log records.
"""

import datetime
import logging
import logging.handlers
import os
import pickle
import queue
import socket
import socketserver
import sqlite3
Expand Down Expand Up @@ -73,7 +73,6 @@ def client_attach_socket_handler(
class LogRecordStreamHandler(socketserver.StreamRequestHandler):
"""
Handler for a streaming logging request.
Specialisation of the `socketserver.StreamRequestHandler` class to handle log
records and customise logging events.
"""
Expand All @@ -94,12 +93,7 @@ def handle(self):
chunk = chunk + self.connection.recv(slen - len(chunk))
obj = self.unpickle(chunk)
record = logging.makeLogRecord(obj)
# Modify record to include hostname
record.hostname = socket.gethostname()
# Push log record to the queue for database writing
if self.server.log_queue is not None:
self.server.log_queue.put(record)

self.handle_log_record(record)

def unpickle(self, data):
Expand All @@ -122,10 +116,7 @@ def handle_log_record(self, record):

class LogRecordSocketReceiver(socketserver.ThreadingTCPServer):
"""
Simple TCP socket-based logging receiver.
Specialisation of the `socketserver.ThreadingTCPServer` class to handle
log records.
Simple TCP socket-based logging receiver suitable for testing.
"""

allow_reuse_address = True
Expand All @@ -135,45 +126,40 @@ def __init__(
host="localhost",
port=logging.handlers.DEFAULT_TCP_LOGGING_PORT,
handler=LogRecordStreamHandler,
log_queue=None,
):
socketserver.ThreadingTCPServer.__init__(self, (host, port), handler)
self.abort = 0
self.timeout = 1
self.logname = None
self.log_queue = log_queue # Store the reference to the log queue

def serve_until_stopped(self, queue_handler=None):
def serve_until_stopped(self):
import select

abort = 0
while not abort:
rd, wr, ex = select.select([self.socket.fileno()], [], [], self.timeout)
if rd:
self.handle_request()
if queue_handler:
queue_handler.write(self.log_queue)
abort = self.abort
if queue_handler:
queue_handler.write(self.log_queue) # Ensure all records are written
queue_handler.close()

def stop(self):
self.abort = 1 # Set abort flag to stop the server loop
self.server_close() # Close the server socket

class SQLiteHandler(logging.Handler):
"""
Custom logging handler to write log messages to an SQLite database.
"""

class SQLiteQueueHandler:
def __init__(self, sqfile="logs.sqlite", verbose=False):
def __init__(self, sqfile="logs.sqlite"):
logging.Handler.__init__(self)
self._sqfile = sqfile
self._conn = None
self._verbose = verbose
self._debug = False
sqlite3.enable_callback_tracebacks(self._debug)

def write_table(self, cursor):
if self._verbose:
print(f"Writing sqlite file table: {self._sqfile}")
self._create_table()

def _create_table(self):
"""
Creates a table to store the logs if it doesn't exist.
"""
conn = sqlite3.connect(self._sqfile)
cursor = conn.cursor()
print(f"Writing sqlite file table: {self._sqfile}")
cursor.execute("""
CREATE TABLE IF NOT EXISTS logs (
created TEXT,
Expand All @@ -185,31 +171,27 @@ def write_table(self, cursor):
thread TEXT
)
""") # Create the 'logs' table if it doesn't exist
conn.commit()
cursor.close()
conn.close()

def write(self, log_queue):
if self._conn is None:
# SQLite objects created in a thread can only be used in that same thread
# for flexibility we create a new connection here.
self._conn = sqlite3.connect(self._sqfile)
cursor = self._conn.cursor()
self.write_table(cursor)
else:
# NOTE: cursors are not thread-safe
cursor = self._conn.cursor()

if self._verbose:
print(f"Writing row to sqlite file: {self._sqfile}")
while not log_queue.empty():
record = log_queue.get()
if self._verbose:
print("Dequeued item:", record)
def emit(self, record):
"""
Emit a log record, and insert it into the database.
"""
try:
conn = sqlite3.connect(self._sqfile)
cursor = conn.cursor()
print("Dequeued item:", record)
cursor.execute(
"\n"
"INSERT INTO logs "
"(created, name, level, message, hostname, process, thread)\n"
"VALUES (?, ?, ?, ?, ?, ?, ?)\n",
(
record.created,
datetime.datetime.fromtimestamp(record.created).strftime(
"%Y-%m-%d %H:%M:%S"
),
record.name,
record.levelname,
record.getMessage(),
Expand All @@ -218,12 +200,17 @@ def write(self, log_queue):
record.thread,
),
)
self._conn.commit() # Commit the transaction after all writes
cursor.close()
conn.commit()
cursor.close()
conn.close()
except sqlite3.Error as e:
print(f"SQLite error: {e}")

def close(self):
if self._conn:
self._conn.close()
"""
Ensure the database connection is closed cleanly.
"""
super().close()


class CustomFormatter(logging.Formatter):
Expand Down Expand Up @@ -263,22 +250,15 @@ def start_logging_server(
datefmt="%Y-%m-%dT%H:%M:%S", # Date in ISO 8601 format
)

log_queue = queue.Queue()

sqlitequeue = None
tcpserver = LogRecordSocketReceiver(host=host, port=port)
if sqlite_filepath:
sqlitequeue = SQLiteQueueHandler(sqfile=sqlite_filepath, verbose=verbose)

tcpserver = LogRecordSocketReceiver(
host=host,
port=port,
log_queue=log_queue,
)
sqlite_handler = SQLiteHandler(sqlite_filepath)
logging.getLogger("").addHandler(sqlite_handler)
print(
"About to start TCP server...\n",
f"HOST: {host}; PORT: {port}; PID: {os.getpid()}; SQLITE: {sqlite_filepath}\n",
)
tcpserver.serve_until_stopped(queue_handler=sqlitequeue)
tcpserver.serve_until_stopped()


def main():
Expand Down

0 comments on commit 844d664

Please sign in to comment.