Skip to content

Commit

Permalink
Merge pull request #47 from ambiata/fix/allow-pydantic-2
Browse files Browse the repository at this point in the history
[#patch] Allow Pydantic 2.x
  • Loading branch information
dashambiata authored Oct 31, 2024
2 parents 18f2298 + 867598d commit afd6975
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 78 deletions.
2 changes: 1 addition & 1 deletion atmosphere/custom_activity/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def format_prediction_payload_response(
You can format the prediction the way you want based
on the information returned by default
"""
return default_prediction_response
return default_prediction_response.model_dump()

def get_exclusion_rule_conditions(self) -> ExclusionRuleConditionListResponse:
"""
Expand Down
14 changes: 7 additions & 7 deletions atmosphere/custom_activity/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Versions(BaseModelForbiddingExtraFields):

class PredictionResponsePayloadFormat(BaseModel):
name: str
description: Optional[str]
description: Optional[str] = None


class PredictionResponsePayloadFormatListResponse(BaseModel):
Expand All @@ -35,15 +35,15 @@ class Phase(BaseModel):
description: str
status: str
start_date: datetime
end_date: Optional[datetime]
end_date: Optional[datetime] = None


class Activity(BaseModel):
id: UUID
name: str
description: str
start_date: datetime
end_date: Optional[datetime]
end_date: Optional[datetime] = None
status: str
current_category: str
endpoint: str
Expand All @@ -54,16 +54,16 @@ class Process(BaseModel):
name: str
description: str
category: str
start_date: Optional[datetime]
end_date: Optional[datetime]
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
status: str


class Method(BaseModel):
id: UUID
name: str
method_type: str
config: Optional[dict]
config: Optional[dict] = None


class Action(BaseModel):
Expand All @@ -80,7 +80,7 @@ class AllocationLog(BaseModel):
class InferenceInfo(BaseModel):
activity: Activity
current_process: Process
method: Optional[Method]
method: Optional[Method] = None
action: Action
allocation: AllocationLog

Expand Down
187 changes: 130 additions & 57 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
pydantic = ">=1.10.18,<2.0"
pydantic = ">=2.0,<3.0"
apscheduler = ">=3.10.1,<4.0.0"
fastapi = ">=0.115.2,<1.0.0"
httpx = ">=0.27.2,<1.0.0"
Expand Down
2 changes: 1 addition & 1 deletion scripts/test.ps1
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pytest --cov=atmosphere --cov-report=term-missing "${@}"
pytest -rP -vv --cov=atmosphere --cov-report=term-missing "${@}"
4 changes: 2 additions & 2 deletions tests/custom_activity/activity_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def send_mock_prediction_request(
return (
requests.post(
url_prediction_endpoint,
json=ExpectedModel(a="a", b=3).dict(),
json=ExpectedModel(a="a", b=3).model_dump(),
timeout=5,
),
prediction_extra_info,
Expand All @@ -55,5 +55,5 @@ def send_mock_outcome_request(
) -> Response:
assert info_from_prediction == prediction_extra_info
return requests.post(
url_outcome_endpoint, json=ExpectedModel(a="b", b=2).dict(), timeout=5
url_outcome_endpoint, json=ExpectedModel(a="b", b=2).model_dump(), timeout=5
)
10 changes: 5 additions & 5 deletions tests/custom_activity/api/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _assert_204(response):

def test_validate_prediction(client: TestClient, example) -> None:
response = client.post(
"/validate-prediction-request", json=example.good_prediction.dict()
"/validate-prediction-request", json=example.good_prediction.model_dump()
)
_assert_204(response)

Expand All @@ -27,7 +27,7 @@ def test_validate_prediction_not_valid(client: TestClient, example) -> None:

def test_validate_outcome(client: TestClient, example) -> None:
response = client.post(
"/validate-outcome-request", json=example.good_prediction.dict()
"/validate-outcome-request", json=example.good_prediction.model_dump()
)
_assert_204(response)

Expand All @@ -37,7 +37,7 @@ def test_validate_outcome_not_valid(client: TestClient, example) -> None:


def test_compute_rewards(client: TestClient, example) -> None:
response = client.post("/compute-reward", json=example.good_prediction.dict())
response = client.post("/compute-reward", json=example.good_prediction.model_dump())
assert response.status_code == 200
# Raise an exception if not if the model does not validate the payload
compute_reward_response = ComputeRewardResponse.parse_obj(response.json())
Expand All @@ -63,13 +63,13 @@ def _failed_validation(client, path, example):
assert response.status_code == 422

# Extra field
data = example.good_prediction.dict()
data = example.good_prediction.model_dump()
data["c"] = 2
response = client.post(path, json=data)
assert response.status_code == 422

# Wrong type
data = example.good_prediction.dict()
data = example.good_prediction.model_dump()
data["b"] = "def"
response = client.post(path, json=data)
assert response.status_code == 422
9 changes: 5 additions & 4 deletions tests/custom_activity/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def fixture_default_prediction():
"id": "c22d6e7a-fd33-4f60-8e5a-0a0ab9bb1da1",
"name": "kana activity",
"description": "Kana mock activity",
"start_date": "2020-11-06T03:17:29.600944+00:00",
"start_date": "2020-11-06T03:17:29.600944Z",
"end_date": None,
"status": "running",
"current_category": "deployment",
Expand All @@ -26,7 +26,7 @@ def fixture_default_prediction():
"name": "kana",
"description": "Kana mock",
"category": "deployment",
"start_date": "2020-11-06T03:17:29.600944+00:00",
"start_date": "2020-11-06T03:17:29.600944Z",
"end_date": None,
"status": "running",
},
Expand Down Expand Up @@ -84,8 +84,9 @@ def test_get_prediction_response_payload_formats(test_client: TestClient):
def test_format_prediction_payload_response(
test_client: TestClient, default_prediction: dict
):
resp = test_client.post(
"/format-prediction-payload-response",
resp = test_client.request(
method="post",
url="/format-prediction-payload-response",
params={"payload_format": "test"},
json=default_prediction,
)
Expand Down

0 comments on commit afd6975

Please sign in to comment.