From 077b6bd8cf73a0b2324fd4ddd63bc86d759035bd Mon Sep 17 00:00:00 2001 From: Dylan Martin Date: Wed, 7 Aug 2024 13:45:13 -0400 Subject: [PATCH] fix(decide): strip `data=` string from base64 encoded `/decide` requests before decoding them (#24229) --- posthog/test/test_utils.py | 40 ++++++++++++++++++++++++++++++++++++++ posthog/utils.py | 24 ++++++++++++++++------- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/posthog/test/test_utils.py b/posthog/test/test_utils.py index 0504b20941233..ca77b19615aaf 100644 --- a/posthog/test/test_utils.py +++ b/posthog/test/test_utils.py @@ -1,4 +1,6 @@ +import base64 from datetime import datetime +import json from unittest.mock import call, patch from zoneinfo import ZoneInfo @@ -18,6 +20,7 @@ from posthog.utils import ( PotentialSecurityProblemException, absolute_uri, + base64_decode, flatten, format_query_params_absolute_url, get_available_timezones_with_offsets, @@ -459,6 +462,43 @@ def test_can_get_period_to_compare_when_interval_is_day(self) -> None: ) == (datetime(2021, 2, 27, 0, 0), datetime(2021, 12, 31, 23, 59, 59, 999999)) +class TestUtilities(TestCase): + def test_base64_decode(self): + # Test with a simple string + simple_string = "Hello, World!" + encoded = base64.b64encode(simple_string.encode("utf-8")).decode("ascii") + self.assertEqual(base64_decode(encoded), simple_string) + + # Test with bytes input + bytes_input = b"SGVsbG8sIFdvcmxkIQ==" + self.assertEqual(base64_decode(bytes_input), simple_string) + + # Test with Unicode characters + unicode_string = "こんにちは、世界!" + unicode_encoded = base64.b64encode(unicode_string.encode("utf-8")).decode("ascii") + self.assertEqual(base64_decode(unicode_encoded), unicode_string) + + # Test with emojis + emoji_string = "Hello 👋 World 🌍!" + emoji_encoded = base64.b64encode(emoji_string.encode("utf-8")).decode("ascii") + self.assertEqual(base64_decode(emoji_encoded), emoji_string) + + # Test with padding characters removed + no_padding = "SGVsbG8sIFdvcmxkIQ" + self.assertEqual(base64_decode(no_padding), simple_string) + + # Test with real URL encoded data + # from: https://posthog.sentry.io/issues/5680826999/ + encoded_data = b"data=eyJ0b2tlbiI6InBoY191eEl4QmhLQ2NVZll0d1NoTmhlRVMyNTJBak45b0pYNzZmcElybTV3cWpmIiwiZGlzdGluY3RfaWQiOiIwMTkxMjliNi1kNTQwLTczZjUtYjY3YS1kODI3MTEzOWFmYTYiLCJncm91cHMiOnt9fQ%3D%3D" + + decoded = base64_decode(encoded_data) + decoded_json = json.loads(decoded) + + self.assertEqual(decoded_json["token"], "phc_uxIxBhKCcUfYtwShNheES252AjN9oJX76fpIrm5wqjf") + self.assertEqual(decoded_json["distinct_id"], "019129b6-d540-73f5-b67a-d8271139afa6") + self.assertEqual(decoded_json["groups"], {}) + + class TestFlatten(TestCase): def test_flatten_lots_of_depth(self): assert list(flatten([1, [2, 3], [[4], [5, [6, 7]]]])) == [1, 2, 3, 4, 5, 6, 7] diff --git a/posthog/utils.py b/posthog/utils.py index 760bab359381f..49399c0f4a205 100644 --- a/posthog/utils.py +++ b/posthog/utils.py @@ -24,7 +24,7 @@ cast, ) from collections.abc import Generator, Mapping -from urllib.parse import urljoin, urlparse +from urllib.parse import unquote, urljoin, urlparse from zoneinfo import ZoneInfo import lzstring @@ -601,12 +601,22 @@ def base64_decode(data): """ Decodes base64 bytes into string taking into account necessary transformations to match client libraries. """ - if not isinstance(data, str): - data = data.decode() - - data = base64.b64decode(data.replace(" ", "+") + "===") - - return data.decode("utf8", "surrogatepass").encode("utf-16", "surrogatepass") + if isinstance(data, str): + data = data.encode("ascii") + + # Check if the data is URL-encoded + if data.startswith(b"data="): + data = unquote(data.decode("ascii")).split("=", 1)[1] + data = data.encode("ascii") + + # Remove any whitespace and add padding if necessary + data = data.replace(b" ", b"") + missing_padding = len(data) % 4 + if missing_padding: + data += b"=" * (4 - missing_padding) + + decoded = base64.b64decode(data) + return decoded.decode("utf-8", "surrogatepass") def decompress(data: Any, compression: str):