Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds log handler #190

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ server = [
'uvicorn[standard]',
'wtforms',
'requests==2.25.0',
'openpyxl'
'openpyxl',
'python-logging-loki'
]

[tool.setuptools.packages.find]
Expand Down
49 changes: 49 additions & 0 deletions src/aind_data_transfer_service/log_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Module to handle setting up logger"""

import logging
from typing import Literal, Optional

from logging_loki import LokiHandler
from pydantic import Field
from pydantic_settings import BaseSettings


class LoggingConfigs(BaseSettings):
"""Configs for logger"""

env_name: Optional[str] = Field(
default=None, description="Can be used to help tag logging source."
)
loki_uri: Optional[str] = Field(
default=None, description="URI of Loki logging server."
)
log_level: Literal[
"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
] = Field(default="INFO", description="Log level")

@property
def app_name(self):
"""Build app name from configs"""
package_name = __package__
base_name = package_name.split(".")[0].replace("_", "-")
app_name = (
base_name
if self.env_name is None
else f"{base_name}-{self.env_name}"
)
return app_name


def get_logger(log_configs: LoggingConfigs) -> logging.Logger:
"""Return a logger that can be used to log messages."""
level = logging.getLevelName(log_configs.log_level)
logger = logging.getLogger(__name__)
logger.setLevel(level)
if log_configs.loki_uri is not None:
handler = LokiHandler(
url=log_configs.loki_uri,
version="1",
tags={"application": log_configs.app_name},
)
logger.addHandler(handler)
return logger
42 changes: 32 additions & 10 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import csv
import io
import json
import logging
import os
from asyncio import sleep
from pathlib import PurePosixPath
Expand All @@ -29,6 +28,7 @@
)
from aind_data_transfer_service.hpc.client import HpcClient, HpcClientConfigs
from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger
from aind_data_transfer_service.models import (
AirflowDagRun,
AirflowDagRunRequestParameters,
Expand Down Expand Up @@ -64,10 +64,16 @@
# AIND_AIRFLOW_SERVICE_JOBS_URL
# AIND_AIRFLOW_SERVICE_PASSWORD
# AIND_AIRFLOW_SERVICE_USER
# LOKI_URI
# ENV_NAME
# LOG_LEVEL

logger = get_logger(log_configs=LoggingConfigs())


async def validate_csv(request: Request):
"""Validate a csv or xlsx file. Return parsed contents as json."""
logger.info("Received request to validate csv")
async with request.form() as form:
basic_jobs = []
errors = []
Expand Down Expand Up @@ -169,13 +175,25 @@ async def validate_csv_legacy(request: Request):

async def submit_jobs(request: Request):
"""Post BasicJobConfigs raw json to hpc server to process."""
logger.info("Received request to submit jobs")
content = await request.json()
try:
model = SubmitJobRequest.model_validate_json(json.dumps(content))
full_content = json.loads(
model.model_dump_json(warnings=False, exclude_none=True)
)
# TODO: Replace with httpx async client
logger.info(
f"Valid request detected. Sending list of jobs. "
f"Job Type: {model.job_type}"
)
total_jobs = len(model.upload_jobs)
for job_index, job in enumerate(model.upload_jobs, 1):
logger.info(
f"{job.s3_prefix} sending to airflow. "
f"{job_index} of {total_jobs}."
)

response = requests.post(
url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
auth=(
Expand All @@ -193,6 +211,7 @@ async def submit_jobs(request: Request):
)

except ValidationError as e:
logger.warning(f"There were validation errors processing {content}")
return JSONResponse(
status_code=406,
content={
Expand All @@ -201,6 +220,7 @@ async def submit_jobs(request: Request):
},
)
except Exception as e:
logger.exception("Internal Server Error.")
return JSONResponse(
status_code=500,
content={
Expand Down Expand Up @@ -256,7 +276,7 @@ async def submit_basic_jobs(request: Request):
# Add pause to stagger job requests to the hpc
await sleep(0.2)
except Exception as e:
logging.error(f"{e.__class__.__name__}{e.args}")
logger.error(f"{e.__class__.__name__}{e.args}")
hpc_errors.append(
f"Error processing "
f"{hpc_job.basic_upload_job_configs.s3_prefix}"
Expand Down Expand Up @@ -382,7 +402,7 @@ async def submit_hpc_jobs(request: Request): # noqa: C901
# Add pause to stagger job requests to the hpc
await sleep(0.2)
except Exception as e:
logging.error(repr(e))
logger.error(repr(e))
hpc_errors.append(f"Error processing " f"{hpc_job_def.name}")
message = (
"There were errors submitting jobs to the hpc."
Expand Down Expand Up @@ -456,12 +476,14 @@ async def get_job_status_list(request: Request):
"errors": [response_jobs.json()],
}
except ValidationError as e:
logging.error(e)
logger.warning(
f"There was a validation error process job_status list: {e}"
)
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
logger.exception("Unable to retrieve job status list from airflow")
status_code = 500
message = "Unable to retrieve job status list from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
Expand Down Expand Up @@ -516,12 +538,12 @@ async def get_tasks_list(request: Request):
"errors": [response_tasks.json()],
}
except ValidationError as e:
logging.error(e)
logger.warning(f"There was a validation error process task_list: {e}")
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
logger.exception("Unable to retrieve job tasks list from airflow")
status_code = 500
message = "Unable to retrieve job tasks list from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
Expand Down Expand Up @@ -565,12 +587,12 @@ async def get_task_logs(request: Request):
"errors": [response_logs.json()],
}
except ValidationError as e:
logging.error(e)
logger.warning(f"Error validating request parameters: {e}")
status_code = 406
message = "Error validating request parameters"
data = {"errors": json.loads(e.json())}
except Exception as e:
logging.error(e)
logger.exception("Unable to retrieve job task_list from airflow")
status_code = 500
message = "Unable to retrieve task logs from airflow"
data = {"errors": [f"{e.__class__.__name__}{e.args}"]}
Expand Down Expand Up @@ -707,7 +729,7 @@ async def download_job_template(_: Request):
status_code=200,
)
except Exception as e:
logging.error(e)
logger.exception("Error creating job template")
return JSONResponse(
content={
"message": "Error creating job template",
Expand Down
43 changes: 43 additions & 0 deletions tests/test_log_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Tests methods in log_handler module"""

import unittest
from unittest.mock import MagicMock, call, patch

from aind_data_transfer_service.log_handler import LoggingConfigs, get_logger


class TestLoggingConfigs(unittest.TestCase):
"""Tests LoggingConfigs class"""

def test_app_name(self):
"""Tests app_name property"""

configs = LoggingConfigs(env_name="test", loki_uri=None)
self.assertEqual("aind-data-transfer-service-test", configs.app_name)

@patch("logging.getLogger")
@patch("aind_data_transfer_service.log_handler.LokiHandler")
def test_get_logger(
self, mock_loki_handler: MagicMock, mock_get_logger: MagicMock
):
"""Tests get_logger method"""

mock_get_logger.return_value = MagicMock()
configs = LoggingConfigs(
env_name="test", loki_uri="example.com", log_level="WARNING"
)
_ = get_logger(log_configs=configs)
mock_loki_handler.assert_called_once_with(
url="example.com",
version="1",
tags={"application": "aind-data-transfer-service-test"},
)
mock_get_logger.assert_has_calls(
[call("aind_data_transfer_service.log_handler")]
)
mock_get_logger.return_value.setLevel.assert_called_once_with(30)
mock_get_logger.return_value.addHandler.assert_called_once()


if __name__ == "__main__":
unittest.main()
29 changes: 16 additions & 13 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_submit_jobs(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_jobs_server_error(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_submit_jobs_server_error(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_jobs_malformed_json(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -450,7 +450,7 @@ def test_submit_hpc_jobs_open_data(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_hpc_jobs_error(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -497,7 +497,7 @@ def test_submit_hpc_jobs_error(
@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
@patch("logging.error")
@patch("logging.Logger.error")
def test_submit_hpc_jobs_server_error(
self,
mock_log_error: MagicMock,
Expand Down Expand Up @@ -664,10 +664,10 @@ def test_get_job_status_list_query_params(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.get")
@patch("logging.error")
@patch("logging.Logger.warning")
def test_get_job_status_list_validation_error(
self,
mock_log_error: MagicMock,
mock_log_warning: MagicMock,
mock_get,
):
"""Tests get_job_status_list when query_params are invalid."""
Expand All @@ -693,7 +693,7 @@ def test_get_job_status_list_validation_error(
response_content["message"],
"Error validating request parameters",
)
mock_log_error.assert_called()
mock_log_warning.assert_called()
mock_get.assert_not_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
Expand Down Expand Up @@ -745,7 +745,7 @@ def test_get_job_status_list_dag_run_id(
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.error")
@patch("logging.Logger.exception")
@patch("requests.get")
def test_get_job_status_list_error(
self,
Expand Down Expand Up @@ -990,7 +990,7 @@ def test_get_tasks_list_query_params(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.get")
@patch("logging.error")
@patch("logging.Logger.warning")
def test_get_tasks_list_validation_error(
self,
mock_log_error: MagicMock,
Expand All @@ -1014,7 +1014,7 @@ def test_get_tasks_list_validation_error(
mock_get.assert_not_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.error")
@patch("logging.Logger.exception")
@patch("requests.get")
def test_get_tasks_list_error(
self,
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def test_get_task_logs_query_params(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.get")
@patch("logging.error")
@patch("logging.Logger.warning")
def test_get_task_logs_validation_error(
self,
mock_log_error: MagicMock,
Expand All @@ -1106,7 +1106,7 @@ def test_get_task_logs_validation_error(
mock_get.assert_not_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("logging.error")
@patch("logging.Logger.exception")
@patch("requests.get")
def test_get_task_logs_error(
self,
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def test_download_job_template(self):
self.assertEqual(200, response.status_code)

@patch("aind_data_transfer_service.server.JobUploadTemplate")
@patch("logging.error")
@patch("logging.Logger.exception")
def test_download_invalid_job_template(
self, mock_log_error: MagicMock, mock_job_template: MagicMock
):
Expand Down Expand Up @@ -1441,8 +1441,10 @@ def test_submit_v1_jobs_200(

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.post")
@patch("logging.Logger.exception")
def test_submit_v1_jobs_500(
self,
mock_log_exception: MagicMock,
mock_post: MagicMock,
):
"""Tests submit jobs 500 response."""
Expand Down Expand Up @@ -1484,6 +1486,7 @@ def test_submit_v1_jobs_500(
url="/api/v1/submit_jobs", json=request_json
)
self.assertEqual(500, submit_job_response.status_code)
mock_log_exception.assert_called()

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.post")
Expand Down
Loading