Skip to content

Commit

Permalink
chore(python): upgrade pydantic and datamodel-code-generator (#17477)
Browse files Browse the repository at this point in the history
  • Loading branch information
thmsobrmlr authored Sep 18, 2023
1 parent 1de6d5c commit 26f29f1
Show file tree
Hide file tree
Showing 16 changed files with 385 additions and 370 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"build:esbuild": "node frontend/build.mjs",
"schema:build": "pnpm run schema:build:json && pnpm run schema:build:python",
"schema:build:json": "ts-json-schema-generator -f tsconfig.json --path 'frontend/src/*.ts' --type 'QuerySchema' --no-type-check > frontend/src/queries/schema.json && prettier --write frontend/src/queries/schema.json",
"schema:build:python": "datamodel-codegen --collapse-root-models --disable-timestamp --use-subclass-enum --input frontend/src/queries/schema.json --input-file-type jsonschema --output posthog/schema.py && black posthog/schema.py",
"schema:build:python": "datamodel-codegen --collapse-root-models --disable-timestamp --use-one-literal-as-default --use-default-kwarg --use-subclass-enum --input frontend/src/queries/schema.json --input-file-type jsonschema --output posthog/schema.py --output-model-type pydantic_v2.BaseModel && black posthog/schema.py",
"grammar:build": "cd posthog/hogql/grammar && antlr -Dlanguage=Python3 HogQLLexer.g4 && antlr -visitor -no-listener -Dlanguage=Python3 HogQLParser.g4",
"packages:build": "pnpm packages:build:apps-common && pnpm packages:build:lemon-ui",
"packages:build:apps-common": "cd frontend/@posthog/apps-common && pnpm i && pnpm build",
Expand Down
2 changes: 1 addition & 1 deletion posthog/api/insight.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def parse(self, stream, media_type=None, parser_context=None):
try:
query = data.get("query", None)
if query:
schema.Model.parse_obj(query)
schema.Model.model_validate(query)
except Exception as error:
raise ParseError(detail=str(error))
else:
Expand Down
8 changes: 4 additions & 4 deletions posthog/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class QuerySchemaParser(JSONParser):
@staticmethod
def validate_query(data) -> Dict:
try:
schema.Model.parse_obj(data)
schema.Model.model_validate(data)
# currently we have to return data not the parsed Model
# because pydantic doesn't know to discriminate on 'kind'
# if we can get this correctly typed we can return the parsed model
Expand Down Expand Up @@ -204,11 +204,11 @@ def process_query(team: Team, query_json: Dict, default_limit: Optional[int] = N
tag_queries(query=query_json)

if query_kind == "EventsQuery":
events_query = EventsQuery.parse_obj(query_json)
events_query = EventsQuery.model_validate(query_json)
events_response = run_events_query(query=events_query, team=team, default_limit=default_limit)
return _unwrap_pydantic_dict(events_response)
elif query_kind == "HogQLQuery":
hogql_query = HogQLQuery.parse_obj(query_json)
hogql_query = HogQLQuery.model_validate(query_json)
hogql_response = execute_hogql_query(
query_type="HogQLQuery",
query=hogql_query.query,
Expand All @@ -218,7 +218,7 @@ def process_query(team: Team, query_json: Dict, default_limit: Optional[int] = N
)
return _unwrap_pydantic_dict(hogql_response)
elif query_kind == "HogQLMetadata":
metadata_query = HogQLMetadata.parse_obj(query_json)
metadata_query = HogQLMetadata.model_validate(query_json)
metadata_response = get_hogql_metadata(query=metadata_query, team=team)
return _unwrap_pydantic_dict(metadata_response)
elif query_kind == "LifecycleQuery":
Expand Down
7 changes: 3 additions & 4 deletions posthog/api/test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_full_hogql_query(self):
with freeze_time("2020-01-10 12:14:00"):
query = HogQLQuery(query="select event, distinct_id, properties.key from events order by timestamp")
api_response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": query.dict()}).json()
query.response = HogQLQueryResponse.parse_obj(api_response)
query.response = HogQLQueryResponse.model_validate(api_response)

self.assertEqual(query.response.results and len(query.response.results), 4)
self.assertEqual(
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_invalid_query_kind(self):
assert api_response.status_code == 400
assert api_response.json()["code"] == "parse_error"
assert "validation errors for Model" in api_response.json()["detail"]
assert "type=value_error.const; given=Tomato Soup" in api_response.json()["detail"]
assert "type=literal_error, input_value='Tomato Soup'" in api_response.json()["detail"]

@snapshot_clickhouse_queries
def test_full_hogql_query_view(self):
Expand All @@ -498,7 +498,6 @@ def test_full_hogql_query_view(self):
flush_persons_and_events()

with freeze_time("2020-01-10 12:14:00"):

self.client.post(
f"/api/projects/{self.team.id}/warehouse_saved_queries/",
{
Expand All @@ -511,7 +510,7 @@ def test_full_hogql_query_view(self):
)
query = HogQLQuery(query="select * from event_view")
api_response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": query.dict()}).json()
query.response = HogQLQueryResponse.parse_obj(api_response)
query.response = HogQLQueryResponse.model_validate(api_response)

self.assertEqual(query.response.results and len(query.response.results), 4)
self.assertEqual(
Expand Down
5 changes: 2 additions & 3 deletions posthog/hogql/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import date, datetime
from typing import Optional, Literal, TypeAlias, Tuple, List
from uuid import UUID
from pydantic import BaseModel, Extra
from pydantic import ConfigDict, BaseModel

ConstantDataType: TypeAlias = Literal[
"int", "float", "str", "bool", "array", "tuple", "date", "datetime", "uuid", "unknown"
Expand All @@ -24,8 +24,7 @@

# Settings applied on top of all HogQL queries.
class HogQLSettings(BaseModel):
class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")

readonly: Optional[int] = 2
max_execution_time: Optional[int] = 60
Expand Down
11 changes: 5 additions & 6 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Literal, Optional, TypedDict
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from pydantic import BaseModel, Extra
from pydantic import ConfigDict, BaseModel

from posthog.hogql.database.models import (
FieldTraverser,
Expand Down Expand Up @@ -33,8 +33,7 @@


class Database(BaseModel):
class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

# Users can query from the tables below
events: EventsTable = EventsTable()
Expand All @@ -58,7 +57,7 @@ class Config:
numbers: NumbersTable = NumbersTable()

# clunky: keep table names in sync with above
_table_names: List[str] = [
_table_names: ClassVar[List[str]] = [
"events",
"groups",
"person",
Expand Down Expand Up @@ -182,7 +181,7 @@ class SerializedField(_SerializedFieldBase, total=False):
def serialize_database(database: Database) -> Dict[str, List[SerializedField]]:
tables: Dict[str, List[SerializedField]] = {}

for table_key in database.__fields__.keys():
for table_key in database.model_fields.keys():
field_input: Dict[str, Any] = {}
table = getattr(database, table_key, None)
if isinstance(table, FunctionCallTable):
Expand Down
21 changes: 7 additions & 14 deletions posthog/hogql/database/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from pydantic import BaseModel, Extra
from pydantic import ConfigDict, BaseModel

from posthog.hogql.errors import HogQLException, NotImplementedException

Expand All @@ -16,8 +16,7 @@ class DatabaseField(FieldOrTable):
Base class for a field in a database table.
"""

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")

name: str
array: Optional[bool] = None
Expand Down Expand Up @@ -57,17 +56,14 @@ class BooleanDatabaseField(DatabaseField):


class FieldTraverser(FieldOrTable):
class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")

chain: List[str]


class Table(FieldOrTable):
fields: Dict[str, FieldOrTable]

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")

def has_field(self, name: str) -> bool:
return name in self.fields
Expand Down Expand Up @@ -102,8 +98,7 @@ def get_asterisk(self):


class LazyJoin(FieldOrTable):
class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")

join_function: Callable[[str, str, Dict[str, Any]], Any]
join_table: Table
Expand All @@ -115,8 +110,7 @@ class LazyTable(Table):
A table that is replaced with a subquery returned from `lazy_select(requested_fields: Dict[name, chain])`
"""

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")

def lazy_select(self, requested_fields: Dict[str, List[str]]) -> Any:
raise NotImplementedException("LazyTable.lazy_select not overridden")
Expand All @@ -127,8 +121,7 @@ class VirtualTable(Table):
A nested table that reuses the parent for storage. E.g. events.person.* fields with PoE enabled.
"""

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class FunctionCallTable(Table):
Expand Down
8 changes: 4 additions & 4 deletions posthog/hogql/database/schema/numbers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Optional

from posthog.hogql.database.models import (
IntegerDatabaseField,
Expand All @@ -14,9 +14,9 @@
class NumbersTable(FunctionCallTable):
fields: Dict[str, FieldOrTable] = NUMBERS_TABLE_FIELDS

name = "numbers"
min_args = 1
max_args = 2
name: str = "numbers"
min_args: Optional[int] = 1
max_args: Optional[int] = 2

def to_printed_clickhouse(self, context):
return "numbers"
Expand Down
6 changes: 4 additions & 2 deletions posthog/hogql/test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ class TestMetadata(ClickhouseTestMixin, APIBaseTest):
maxDiff = None

def _expr(self, query: str) -> HogQLMetadataResponse:
return get_hogql_metadata(query=HogQLMetadata(expr=query), team=self.team)
return get_hogql_metadata(query=HogQLMetadata(kind="HogQLMetadata", expr=query, response=None), team=self.team)

def _select(self, query: str) -> HogQLMetadataResponse:
return get_hogql_metadata(query=HogQLMetadata(select=query), team=self.team)
return get_hogql_metadata(
query=HogQLMetadata(kind="HogQLMetadata", select=query, response=None), team=self.team
)

def test_metadata_valid_expr_select(self):
metadata = self._expr("select 1")
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,7 @@ def test_hogql_query_filters(self):
)
query = "SELECT event, distinct_id from events WHERE distinct_id={distinct_id} and {filters}"
filters = HogQLFilters(
properties=[EventPropertyFilter(key="index", operator="exact", value=4, type="event")]
properties=[EventPropertyFilter(key="index", operator="exact", value="4", type="event")]
)
placeholders = {"distinct_id": ast.Constant(value=random_uuid)}
response = execute_hogql_query(query, team=self.team, filters=filters, placeholders=placeholders)
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql_queries/lifecycle_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, query: LifecycleQuery | Dict[str, Any], team: Team, timings:
if isinstance(query, LifecycleQuery):
self.query = query
else:
self.query = LifecycleQuery.parse_obj(query)
self.query = LifecycleQuery.model_validate(query)

def to_query(self) -> ast.SelectQuery:
placeholders = {
Expand Down
Loading

0 comments on commit 26f29f1

Please sign in to comment.