Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Oct 29, 2024
1 parent 1c28826 commit 953f3ec
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 69 deletions.
3 changes: 1 addition & 2 deletions osm/schemas/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Annotated, Optional, Union

import pandas as pd
from bson import ObjectId
from pydantic import (
BaseModel,
BeforeValidator,
Expand Down Expand Up @@ -107,7 +106,7 @@ class Invocation(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
json_encoders={datetime.datetime: lambda dt: dt.isoformat(), ObjectId: str},
json_encoders={datetime.datetime: lambda dt: dt.isoformat()},
)
# class Settings:
# keep_nulls = False
Expand Down
124 changes: 67 additions & 57 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,91 +3,101 @@

import nest_asyncio
import pytest
from bson import ObjectId

from osm.schemas import Client, Component, Invocation, RtransparentMetrics, Work
from osm.schemas import Invocation

nest_asyncio.apply()
SIMPLE_INV = Invocation(
work=Work(),
client=Client(compute_context_id=1),
)
SAMPLE_INV = Invocation(
work=Work(filename="test_file"),
client=Client(compute_context_id="1"),
osm_version="1.0.0",
user_comment="Test comment",
components=[Component(name="TestComponent", version="1.0", sample=b"long text")],
funder=["Test Funder"],
rtransparent_metrics=RtransparentMetrics(
is_open_code=True, is_open_data=False, funding_text="long text"
),
data_tags=["test_tag"],
created_at=datetime.datetime.now(datetime.UTC).replace(microsecond=0),
)
SIMPLE_DOC = {"work": {}, "client": {"compute_context_id": 1}}


@pytest.mark.asyncio
async def test_swagger_ui(client):
response = await client.get("/docs")
assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert "swagger-ui-dist" in response.text
assert (
"oauth2RedirectUrl: window.location.origin + '/docs/oauth2-redirect'"
in response.text
)


@pytest.mark.parametrize("_", [x for x in range(3)])
@pytest.mark.asyncio
async def test_db_upload(db, _):
asyncio.run(db.invocations.insert_one(SIMPLE_INV.model_dump(mode="json")))
saved_inv = asyncio.run(db.invocations.find_one())
assert saved_inv is not None
assert saved_inv["client"]["compute_context_id"] == 1
SAMPLE_DOC = {
"work": {"filename": "test_file"},
"client": {"compute_context_id": 1},
"osm_version": "1.0.0",
"user_comment": "Test comment",
"components": [
{"name": "TestComponent", "version": "1.0", "sample": "bG9uZyB0ZXh0"}
],
"funder": ["Test Funder"],
"rtransparent_metrics": {
"is_open_code": True,
"is_open_data": False,
"funding_text": "long text",
},
"data_tags": ["test_tag"],
"created_at": datetime.datetime.now(datetime.UTC)
.replace(microsecond=0)
.isoformat(),
}


@pytest.mark.asyncio
async def test_upload_path(client, db):
# from requests import post
# response = post("/upload/", json=SAMPLE_DOC.model_dump(mode="json", exclude_none=True))
# breakpoint()
response = asyncio.run(
client.put(
client.post(
"/upload/",
json=SAMPLE_INV.model_dump(mode="json", exclude_none=True),
json=SAMPLE_DOC,
)
)

assert response.status_code == 200, response.text
assert response.status_code == 201, response.text

# Retrieve the saved invocation to verify it was saved correctly
record = asyncio.run(db.invocations.find_one())
record = asyncio.run(
db.invocations.find_one({"_id": ObjectId(response.json()["id"])})
)
breakpoint()
assert record is not None, "Saved invocation not found"
saved_invocation = Invocation(**record)
assert saved_invocation["id"] == SAMPLE_INV.id, "IDs do not match"
assert saved_invocation.id is not None, "IDs should be created automatically"
breakpoint()
assert isinstance(saved_invocation.id, str), "ID should be a string"
assert (
saved_invocation["work"]["filename"] == SAMPLE_INV.work.filename
saved_invocation.work.filename == SAMPLE_DOC.work.filename
), "Work data does not match"
assert (
saved_invocation["client"]["compute_context_id"]
== SAMPLE_INV.client.compute_context_id
saved_invocation.client.compute_context_id
== SAMPLE_DOC.client.compute_context_id
), "Client data does not match"
assert (
saved_invocation["osm_version"] == SAMPLE_INV.osm_version
saved_invocation.osm_version == SAMPLE_DOC.osm_version
), "OSM version does not match"
assert (
saved_invocation["user_comment"] == SAMPLE_INV.user_comment
saved_invocation.user_comment == SAMPLE_DOC.user_comment
), "User comment does not match"
assert (
saved_invocation["components"][0]["name"] == SAMPLE_INV.components[0].name
saved_invocation.components[0].name == SAMPLE_DOC.components[0].name
), "Component name does not match"
assert (
saved_invocation["components"][0]["version"] == SAMPLE_INV.components[0].version
saved_invocation.components[0].version == SAMPLE_DOC.components[0].version
), "Component version does not match"
assert saved_invocation["funder"] == SAMPLE_INV.funder, "Funder does not match"
assert (
saved_invocation["data_tags"] == SAMPLE_INV.data_tags
), "Data tags do not match"
assert saved_invocation.funder == SAMPLE_DOC.funder, "Funder does not match"
assert saved_invocation.data_tags == SAMPLE_DOC.data_tags, "Data tags do not match"
# TODO: fix tzinfo issue
breakpoint()

assert (
saved_invocation["created_at"] == SAMPLE_INV.created_at
saved_invocation.created_at == SAMPLE_DOC.created_at
), "Created at does not match"


@pytest.mark.parametrize("doc", [SIMPLE_DOC, SAMPLE_DOC])
@pytest.mark.asyncio
async def test_db_upload(db, doc):
res = asyncio.run(db.invocations.insert_one(doc))
saved_inv = asyncio.run(db.invocations.find_one({"_id": res.inserted_id}))
assert saved_inv is not None
assert saved_inv["client"]["compute_context_id"] == 1


@pytest.mark.asyncio
async def test_swagger_ui(client):
response = await client.get("/docs")
assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert "swagger-ui-dist" in response.text
assert (
"oauth2RedirectUrl: window.location.origin + '/docs/oauth2-redirect'"
in response.text
)
35 changes: 25 additions & 10 deletions web/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from contextlib import asynccontextmanager

import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from bson import ObjectId
from fastapi import Body, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi.responses import JSONResponse

from osm import db as osm_db
Expand All @@ -22,17 +23,26 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)


@app.put("/upload/", response_model=Invocation)
@app.post(
"/upload/",
response_description="Upload the OSM client output",
response_model=Invocation,
status_code=status.HTTP_201_CREATED,
response_model_by_alias=False,
)
async def upload_invocation(
invocation: Invocation,
invocation: Invocation = Body(...),
):
db = osm_db.get_mongo_db()
result = await db.invocations.insert_one(invocation.model_dump(mode="json"))
result = await db.invocations.insert_one(
invocation.model_dump(by_alias=True, exclude=["id"])
)
inserted_invocation = await db.invocations.find_one({"_id": result.inserted_id})
return Invocation(**{k: v for k, v in inserted_invocation.items() if k != "id"})
return inserted_invocation
# return Invocation(**{k: v for k, v in inserted_invocation.items() if k != "id"})


@app.put("/payload_error/", response_model=PayloadError)
@app.post("/payload_error/", response_model=PayloadError)
async def upload_failed_payload_construction(payload_error: PayloadError):
db = osm_db.get_mongo_db()
result = await db.payload_errors.insert_one(payload_error.model_dump(mode="json"))
Expand All @@ -44,15 +54,15 @@ async def upload_failed_payload_construction(payload_error: PayloadError):
)


@app.put("/quarantine/", response_model=Quarantine)
@app.post("/quarantine/", response_model=Quarantine)
async def upload_failed_invocation(quarantine: Quarantine):
db = osm_db.get_mongo_db()
result = await db.quarantines.insert_one(quarantine.model_dump(mode="json"))
inserted_quarantine = await db.quarantines.find_one({"_id": result.inserted_id})
return Quarantine(**{k: v for k, v in inserted_quarantine.items() if k != "id"})


@app.put("/quarantine2/")
@app.post("/quarantine2/")
async def upload_failed_quarantine(
file: UploadFile = File(...), error_message: str = Form(...)
):
Expand All @@ -65,9 +75,14 @@ async def upload_failed_quarantine(
)


@app.get("/invocations/{id}", response_model=Invocation)
@app.get(
"/invocations/{id}",
response_description="Get a single record by its ID",
response_model=Invocation,
response_model_by_alias=False,
)
async def get_invocation_by_id(id: str):
invocation = await Invocation.find_one(id)
invocation = await Invocation.find_one({"_id": ObjectId(id)})
if invocation is None:
raise HTTPException(404)
return invocation
Expand Down

0 comments on commit 953f3ec

Please sign in to comment.