Skip to content

Commit

Permalink
Allow simple usage of ast.Call in table joins
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilbert09 committed Dec 31, 2024
1 parent 3aa25f3 commit d286216
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
18 changes: 12 additions & 6 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,20 @@ def define_mappings(warehouse: dict[str, Table], get_table: Callable):
joining_table = database.get_table(join.joining_table_name)

field = parse_expr(join.source_table_key)
if not isinstance(field, ast.Field):
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node")
from_field = field.chain
if isinstance(field, ast.Field):
from_field = field.chain
elif isinstance(field, ast.Call) and isinstance(field.args[0], ast.Field):
from_field = field.args[0].chain
else:
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node")

field = parse_expr(join.joining_table_key)
if not isinstance(field, ast.Field):
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node")
to_field = field.chain
if isinstance(field, ast.Field):
to_field = field.chain
elif isinstance(field, ast.Call) and isinstance(field.args[0], ast.Field):
to_field = field.args[0].chain
else:
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node")

source_table.fields[join.field_name] = LazyJoin(
from_field=from_field,
Expand Down
2 changes: 1 addition & 1 deletion posthog/warehouse/api/test/test_view_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_create_saved_query_join_key_function(self):
"field_name": "some_field",
},
)
self.assertEqual(response.status_code, 400, response.content)
self.assertEqual(response.status_code, 201, response.content)

def test_update_with_configuration(self):
join = DataWarehouseJoin.objects.create(
Expand Down
6 changes: 3 additions & 3 deletions posthog/warehouse/api/view_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.hogql.ast import Field
from posthog.hogql.ast import Field, Call
from posthog.hogql.database.database import create_hogql_database
from posthog.hogql.parser import parse_expr
from posthog.warehouse.models import DataWarehouseJoin
Expand Down Expand Up @@ -71,8 +71,8 @@ def _validate_join_key(self, join_key: Optional[str], table: Optional[str], team
raise serializers.ValidationError(f"Invalid table: {table}")

node = parse_expr(join_key)
if not isinstance(node, Field):
raise serializers.ValidationError(f"Join key {join_key} must be a table field - no function calls allowed")
if not isinstance(node, Field) and not (isinstance(node, Call) and isinstance(node.args[0], Field)):
raise serializers.ValidationError(f"Join key {join_key} must be a table field")

return

Expand Down
16 changes: 11 additions & 5 deletions posthog/warehouse/models/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,20 @@ def _join_function(
raise ResolutionError(f"No fields requested from {join_to_add.to_table}")

left = parse_expr(_source_table_key)
if not isinstance(left, ast.Field):
if isinstance(left, ast.Field):
left.chain = [join_to_add.from_table, *left.chain]
elif isinstance(left, ast.Call) and isinstance(left.args[0], ast.Field):
left.args[0].chain = [join_to_add.from_table, *left.args[0].chain]
else:
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node")
left.chain = [join_to_add.from_table, *left.chain]

right = parse_expr(_joining_table_key)
if not isinstance(right, ast.Field):
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field node")
right.chain = [join_to_add.to_table, *right.chain]
if isinstance(right, ast.Field):
right.chain = [join_to_add.to_table, *right.chain]
elif isinstance(right, ast.Call) and isinstance(right.args[0], ast.Field):
right.args[0].chain = [join_to_add.to_table, *right.args[0].chain]
else:
raise ResolutionError("Data Warehouse Join HogQL expression should be a Field or Call node")

join_expr = ast.JoinExpr(
table=ast.SelectQuery(
Expand Down

0 comments on commit d286216

Please sign in to comment.