Skip to content

Commit

Permalink
feat: adds log handler
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 committed Dec 3, 2024
1 parent 121544f commit 95f6c35
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 24 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ server = [
'uvicorn[standard]',
'wtforms',
'requests==2.25.0',
'openpyxl'
'openpyxl',
'python-logging-loki'
]

[tool.setuptools.packages.find]
Expand Down
49 changes: 49 additions & 0 deletions src/aind_data_transfer_service/log_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Module to handle setting up logger"""

import logging
from typing import Literal, Optional

from logging_loki import LokiHandler
from pydantic import Field
from pydantic_settings import BaseSettings


class LoggingConfigs(BaseSettings):
"""Configs for logger"""

env_name: Optional[str] = Field(
default=None, description="Can be used to help tag logging source."
)
loki_uri: Optional[str] = Field(
default=None, description="URI of Loki logging server."
)
log_level: Literal[
"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
] = Field(default="INFO", description="Log level")

@property
def app_name(self):
"""Build app name from configs"""
package_name = __package__
base_name = package_name.split(".")[0].replace("_", "-")
app_name = (
base_name
if self.env_name is None
else f"{base_name}-{self.env_name}"
)
return app_name


def get_logger(log_configs: LoggingConfigs) -> logging.Logger:
"""Return a logger that can be used to log messages."""
level = logging.getLevelName(log_configs.log_level)
logger = logging.getLogger(__name__)
logger.setLevel(level)
if log_configs.loki_uri is not None:
handler = LokiHandler(
url=log_configs.loki_uri,
version="1",
tags={"application": log_configs.app_name},
)
logger.addHandler(handler)
return logger
42 changes: 32 additions & 10 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import csv
import io
import json
import logging
import os
from asyncio import sleep
from pathlib import PurePosixPath
Expand All @@ -29,6 +28,7 @@
)
from aind_data_transfer_service.hpc.client import HpcClient, HpcClientConfigs
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger
from aind_data_transfer_service.models import (
AirflowDagRun,
AirflowDagRunRequestParameters,
Expand Down Expand Up @@ -64,10 +64,16 @@
# AIND_AIRFLOW_SERVICE_JOBS_URL
# AIND_AIRFLOW_SERVICE_PASSWORD
# AIND_AIRFLOW_SERVICE_USER
# LOKI_URI
# ENV_NAME
# LOG_LEVEL

logger = get_logger(log_configs=LoggingConfigs())


async def validate_csv(request: Request):
"""Validate a csv or xlsx file. Return parsed contents as json."""
logger.info("Received request to validate csv")
async with request.form() as form:
basic_jobs = []
errors = []
Expand Down Expand Up @@ -169,13 +175,25 @@ async def validate_csv_legacy(request: Request):

async def submit_jobs(request: Request):
"""Post BasicJobConfigs raw json to hpc server to process."""
logger.info("Received request to submit jobs")
content = await request.json()
try:
model = SubmitJobRequest.model_validate_json(json.dumps(content))
full_content = json.loads(
model.model_dump_json(warnings=False, exclude_none=True)
)
# TODO: Replace with httpx async client
logger.info(
f"Valid request detected. Sending list of jobs. "
f"Job Type: {model.job_type}"
)
total_jobs = len(model.upload_jobs)
for job_index, job in enumerate(model.upload_jobs, 1):
logger.info(
f"{job.s3_prefix} sending to airflow. "
f"{job_index} of {total_jobs}."
)

response = requests.post(
url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
auth=(
Expand All @@ -193,6 +211,7 @@ async def submit_jobs(request: Request):
)

except ValidationError as e:
logger.warning(f"There were validation errors processing {content}")
return JSONResponse(
status_code=406,
content={
Expand All @@ -201,6 +220,7 @@ async def submit_jobs(request: Request):
},
)
except Exception as e:
logger.exception("Internal Server Error.")
return JSONResponse(
status_code=500,
content={
Expand Down Expand Up @@ -256,7 +276,7 @@ async def submit_basic_jobs(request: Request):
# Add pause to stagger job requests to the hpc
await sleep(0.2)
except Exception as e:
logging.error(f"{e.__class__.__name__}{e.args}")
logger.error(f"{e.__class__.__name__}{e.args}")
hpc_errors.append(
f"Error processing "
f"{hpc_job.basic_upload_job_configs.s3_prefix}"
Expand Down Expand Up @@ -382,7 +402,7 @@ async def submit_hpc_jobs(request: Request): # noqa: C901
# Add pause to stagger job requests to the hpc
await sleep(0.2)
except Exception as e:
logging.error(repr(e))
logger.error(repr(e))
hpc_errors.append(f"Error processing " f"{hpc_job_def.name}")
message = (
"There were errors submitting jobs to the hpc."
Expand Down Expand Up @@ -456,12 +476,14 @@ async def get_job_status_list(request: Request):
"errors": [response_jobs.json()],
}
except ValidationError as e:
logging.error(e)
logger.warning(
f"There was a validation error process job_status list: {e}"
)
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
logger.exception("Unable to retrieve job status list from airflow")
status_code = 500
message = "Unable to retrieve job status list from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
Expand Down Expand Up @@ -516,12 +538,12 @@ async def get_tasks_list(request: Request):
"errors": [response_tasks.json()],
}
except ValidationError as e:
logging.error(e)
logger.warning(f"There was a validation error process task_list: {e}")
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
logger.exception("Unable to retrieve job tasks list from airflow")
status_code = 500
message = "Unable to retrieve job tasks list from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
Expand Down Expand Up @@ -565,12 +587,12 @@ async def get_task_logs(request: Request):
"errors": [response_logs.json()],
}
except ValidationError as e:
logging.error(e)
logger.warning(f"Error validating request parameters: {e}")
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
logger.exception("Unable to retrieve job task_list from airflow")
status_code = 500
message = "Unable to retrieve task logs from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
Expand Down Expand Up @@ -707,7 +729,7 @@ async def download_job_template(_: Request):
status_code=200,
)
except Exception as e:
logging.error(e)
logger.exception("Error creating job template")
return JSONResponse(
content={
"message": "Error creating job template",
Expand Down
43 changes: 43 additions & 0 deletions tests/test_log_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Tests methods in log_handler module"""

import unittest
from unittest.mock import MagicMock, call, patch

from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger


class TestLoggingConfigs(unittest.TestCase):
"""Tests LoggingConfigs class"""

def test_app_name(self):
"""Tests app_name property"""

configs = LoggingConfigs(env_name="test", loki_uri=None)
self.assertEqual("aind-data-transfer-service-test", configs.app_name)

@patch("logging.getLogger")
@patch("aind_data_transfer_service.log_handler.LokiHandler")
def test_get_logger(
self, mock_loki_handler: MagicMock, mock_get_logger: MagicMock
):
"""Tests get_logger method"""

mock_get_logger.return_value = MagicMock()
configs = LoggingConfigs(
env_name="test", loki_uri="example.com", log_level="WARNING"
)
_ = get_logger(log_configs=configs)
mock_loki_handler.assert_called_once_with(
url="example.com",
version="1",
tags={"application": "aind-data-transfer-service-test"},
)
mock_get_logger.assert_has_calls(
[call("aind_data_transfer_service.log_handler")]
)
mock_get_logger.return_value.setLevel.assert_called_once_with(30)
mock_get_logger.return_value.addHandler.assert_called_once()


if __name__ == "__main__":
unittest.main()
29 changes: 16 additions & 13 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_submit_jobs(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_jobs_server_error(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_submit_jobs_server_error(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_jobs_malformed_json(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -450,7 +450,7 @@ def test_submit_hpc_jobs_open_data(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_hpc_jobs_error(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -497,7 +497,7 @@ def test_submit_hpc_jobs_error(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_hpc_jobs_server_error(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -664,10 +664,10 @@ def test_get_job_status_list_query_params(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.get")
@patch("logging.error")
@patch("logging.Logger.warning")
def test_get_job_status_list_validation_error(
self,
mock_log_error: MagicMock,
mock_log_warning: MagicMock,
mock_get,
):
"""Tests get_job_status_list when query_params are invalid."""
Expand All @@ -693,7 +693,7 @@ def test_get_job_status_list_validation_error(
response_content["message"],
"Error validating request parameters",
)
mock_log_error.assert_called()
mock_log_warning.assert_called()
mock_get.assert_not_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
Expand Down Expand Up @@ -745,7 +745,7 @@ def test_get_job_status_list_dag_run_id(
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.error")
@patch("logging.Logger.exception")
@patch("requests.get")
def test_get_job_status_list_error(
self,
Expand Down Expand Up @@ -990,7 +990,7 @@ def test_get_tasks_list_query_params(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.get")
@patch("logging.error")
@patch("logging.Logger.warning")
def test_get_tasks_list_validation_error(
self,
mock_log_error: MagicMock,
Expand All @@ -1014,7 +1014,7 @@ def test_get_tasks_list_validation_error(
mock_get.assert_not_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.error")
@patch("logging.Logger.exception")
@patch("requests.get")
def test_get_tasks_list_error(
self,
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def test_get_task_logs_query_params(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.get")
@patch("logging.error")
@patch("logging.Logger.warning")
def test_get_task_logs_validation_error(
self,
mock_log_error: MagicMock,
Expand All @@ -1106,7 +1106,7 @@ def test_get_task_logs_validation_error(
mock_get.assert_not_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.error")
@patch("logging.Logger.exception")
@patch("requests.get")
def test_get_task_logs_error(
self,
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def test_download_job_template(self):
self.assertEqual(200, response.status_code)

@patch("aind_data_transfer_service.server.JobUploadTemplate")
@patch("logging.error")
@patch("logging.Logger.exception")
def test_download_invalid_job_template(
self, mock_log_error: MagicMock, mock_job_template: MagicMock
):
Expand Down Expand Up @@ -1441,8 +1441,10 @@ def test_submit_v1_jobs_200(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.post")
@patch("logging.Logger.exception")
def test_submit_v1_jobs_500(
self,
mock_log_exception: MagicMock,
mock_post: MagicMock,
):
"""Tests submit jobs 500 response."""
Expand Down Expand Up @@ -1484,6 +1486,7 @@ def test_submit_v1_jobs_500(
url="/api/v1/submit_jobs", json=request_json
)
self.assertEqual(500, submit_job_response.status_code)
mock_log_exception.assert_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.post")
Expand Down

0 comments on commit 95f6c35

Please sign in to comment.