From b532ed9e31872edfa4edb02738cf508749277249 Mon Sep 17 00:00:00 2001 From: Robbie Date: Mon, 25 Mar 2024 13:45:30 +0000 Subject: [PATCH] Add logic to handle empty strings in channel type --- posthog/hogql/database/schema/channel_type.py | 17 ++++++++++------- .../database/schema/test/test_channel_type.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/posthog/hogql/database/schema/channel_type.py b/posthog/hogql/database/schema/channel_type.py index 5dee575fc59a3..4954cc5be2b29 100644 --- a/posthog/hogql/database/schema/channel_type.py +++ b/posthog/hogql/database/schema/channel_type.py @@ -62,6 +62,9 @@ def create_channel_type_expr( gclid: ast.Expr, gad_source: ast.Expr, ) -> ast.Expr: + def wrap_with_null_if_empty(expr: ast.Expr) -> ast.Expr: + return ast.Call(name="nullIf", args=[expr, ast.Constant(value="")]) + return parse_expr( """ multiIf( @@ -95,8 +98,8 @@ def create_channel_type_expr( ( {referring_domain} = '$direct' - AND ({medium} IS NULL OR {medium} = '') - AND ({source} IS NULL OR {source} IN ('', '(direct)', 'direct')) + AND ({medium} IS NULL) + AND ({source} IS NULL OR {source} IN ('(direct)', 'direct')) ), 'Direct', @@ -122,11 +125,11 @@ def create_channel_type_expr( )""", start=None, placeholders={ - "campaign": campaign, - "medium": medium, - "source": source, + "campaign": wrap_with_null_if_empty(campaign), + "medium": wrap_with_null_if_empty(medium), + "source": wrap_with_null_if_empty(source), "referring_domain": referring_domain, - "gclid": gclid, - "gad_source": gad_source, + "gclid": wrap_with_null_if_empty(gclid), + "gad_source": wrap_with_null_if_empty(gad_source), }, ) diff --git a/posthog/hogql/database/schema/test/test_channel_type.py b/posthog/hogql/database/schema/test/test_channel_type.py index 89e026ff3aed0..10cd4ea4ae009 100644 --- a/posthog/hogql/database/schema/test/test_channel_type.py +++ b/posthog/hogql/database/schema/test/test_channel_type.py @@ -106,6 +106,21 @@ def test_direct(self): ), ) + def test_direct_empty_string(self): + self.assertEqual( + "Direct", + self._get_initial_channel_type( + { + "$initial_referring_domain": "$direct", + "$initial_utm_source": "", + "$initial_utm_medium": "", + "$initial_utm_campaign": "", + "$initial_gclid": "", + "$initial_gad_source": "", + } + ), + ) + def test_cross_network(self): self.assertEqual( "Cross Network",