Skip to content

Commit

Permalink
fix(decide): strip data= string from base64 encoded /decide reque…
Browse files Browse the repository at this point in the history
…sts before decoding them (#24229)
  • Loading branch information
dmarticus authored Aug 7, 2024
1 parent 892d0a2 commit 077b6bd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
40 changes: 40 additions & 0 deletions posthog/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
from datetime import datetime
import json
from unittest.mock import call, patch
from zoneinfo import ZoneInfo

Expand All @@ -18,6 +20,7 @@
from posthog.utils import (
PotentialSecurityProblemException,
absolute_uri,
base64_decode,
flatten,
format_query_params_absolute_url,
get_available_timezones_with_offsets,
Expand Down Expand Up @@ -459,6 +462,43 @@ def test_can_get_period_to_compare_when_interval_is_day(self) -> None:
) == (datetime(2021, 2, 27, 0, 0), datetime(2021, 12, 31, 23, 59, 59, 999999))


class TestUtilities(TestCase):
def test_base64_decode(self):
# Test with a simple string
simple_string = "Hello, World!"
encoded = base64.b64encode(simple_string.encode("utf-8")).decode("ascii")
self.assertEqual(base64_decode(encoded), simple_string)

# Test with bytes input
bytes_input = b"SGVsbG8sIFdvcmxkIQ=="
self.assertEqual(base64_decode(bytes_input), simple_string)

# Test with Unicode characters
unicode_string = "こんにちは、世界!"
unicode_encoded = base64.b64encode(unicode_string.encode("utf-8")).decode("ascii")
self.assertEqual(base64_decode(unicode_encoded), unicode_string)

# Test with emojis
emoji_string = "Hello 👋 World 🌍!"
emoji_encoded = base64.b64encode(emoji_string.encode("utf-8")).decode("ascii")
self.assertEqual(base64_decode(emoji_encoded), emoji_string)

# Test with padding characters removed
no_padding = "SGVsbG8sIFdvcmxkIQ"
self.assertEqual(base64_decode(no_padding), simple_string)

# Test with real URL encoded data
# from: https://posthog.sentry.io/issues/5680826999/
encoded_data = b"data=eyJ0b2tlbiI6InBoY191eEl4QmhLQ2NVZll0d1NoTmhlRVMyNTJBak45b0pYNzZmcElybTV3cWpmIiwiZGlzdGluY3RfaWQiOiIwMTkxMjliNi1kNTQwLTczZjUtYjY3YS1kODI3MTEzOWFmYTYiLCJncm91cHMiOnt9fQ%3D%3D"

decoded = base64_decode(encoded_data)
decoded_json = json.loads(decoded)

self.assertEqual(decoded_json["token"], "phc_uxIxBhKCcUfYtwShNheES252AjN9oJX76fpIrm5wqjf")
self.assertEqual(decoded_json["distinct_id"], "019129b6-d540-73f5-b67a-d8271139afa6")
self.assertEqual(decoded_json["groups"], {})


class TestFlatten(TestCase):
def test_flatten_lots_of_depth(self):
assert list(flatten([1, [2, 3], [[4], [5, [6, 7]]]])) == [1, 2, 3, 4, 5, 6, 7]
Expand Down
24 changes: 17 additions & 7 deletions posthog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
cast,
)
from collections.abc import Generator, Mapping
from urllib.parse import urljoin, urlparse
from urllib.parse import unquote, urljoin, urlparse
from zoneinfo import ZoneInfo

import lzstring
Expand Down Expand Up @@ -601,12 +601,22 @@ def base64_decode(data):
"""
Decodes base64 bytes into string taking into account necessary transformations to match client libraries.
"""
if not isinstance(data, str):
data = data.decode()

data = base64.b64decode(data.replace(" ", "+") + "===")

return data.decode("utf8", "surrogatepass").encode("utf-16", "surrogatepass")
if isinstance(data, str):
data = data.encode("ascii")

# Check if the data is URL-encoded
if data.startswith(b"data="):
data = unquote(data.decode("ascii")).split("=", 1)[1]
data = data.encode("ascii")

# Remove any whitespace and add padding if necessary
data = data.replace(b" ", b"")
missing_padding = len(data) % 4
if missing_padding:
data += b"=" * (4 - missing_padding)

decoded = base64.b64decode(data)
return decoded.decode("utf-8", "surrogatepass")


def decompress(data: Any, compression: str):
Expand Down

0 comments on commit 077b6bd

Please sign in to comment.