Skip to content

Commit

Permalink
Merge pull request #5053 from systeminit/jkeiser/bug-690
Browse files Browse the repository at this point in the history
Billing: don't alert when new users race with subscription import
  • Loading branch information
jkeiser authored Dec 10, 2024
2 parents a1da5f7 + c04d9c6 commit 3b007a9
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 61 deletions.
6 changes: 3 additions & 3 deletions component/lambda/functions/billing-set-prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_subscription(self, external_subscription_id: ExternalSubscriptionId):
)

except LagoHTTPError as e:
if e.json.get("status") != 404:
if e.json and e.json.get("status") != 404:
raise
try:
logging.debug(
Expand All @@ -59,7 +59,7 @@ def get_subscription(self, external_subscription_id: ExternalSubscriptionId):
)

except LagoHTTPError as e:
if e.json.get("status") != 404:
if e.json and e.json.get("status") != 404:
raise

try:
Expand All @@ -74,7 +74,7 @@ def get_subscription(self, external_subscription_id: ExternalSubscriptionId):
)

except LagoHTTPError as e:
if e.json.get("status") != 404:
if e.json and e.json.get("status") != 404:
raise

logging.warning(
Expand Down
72 changes: 36 additions & 36 deletions component/lambda/functions/si_lago_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Iterable
from itertools import islice
from typing import (
Any,
Literal,
NewType,
NotRequired,
Expand All @@ -11,10 +12,12 @@
)
from pip._vendor import requests
from pip._vendor.requests.exceptions import HTTPError
from pip._vendor.requests.models import Response
from si_types import IsoTimestamp

import urllib.parse
import logging
import urllib.parse
import sys

ExternalSubscriptionId = NewType("ExternalSubscriptionId", str)

Expand All @@ -29,28 +32,6 @@ class LagoEvent(TypedDict):

T = TypeVar("T")


class LagoErrorResponse(TypedDict):
code: str
error_details: dict[str, dict[str, list[str]]]


def batch(iterable: Iterable[T], n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
return
yield batch


class LagoHTTPError(Exception):
def __init__(self, *args, json, **kwargs):
super().__init__(*args, **kwargs)
self.json = json

class LagoApi:
def __init__(self, lago_api_url: str, lago_api_token: str):
self.lago_api_url = lago_api_url
Expand All @@ -64,14 +45,17 @@ def request(self, method: str, path: str, **kwargs):
headers={"Authorization": f"Bearer {self.lago_api_token}"},
**kwargs,
)

try:
response.raise_for_status()
except HTTPError as e:
raise LagoHTTPError(
f"{e.response.status_code} for {method} {path}: {e.response.text}",
json=e.response.json(),
) from e

message = f"{e.response.status_code} for {method} {path}: {e.response.text}"
try:
json = cast(LagoHTTPError.ResponseJson, e.response.json())
except:
logging.warning("Failed to parse JSON from Lago API response.", exc_info=sys.exc_info())
json = None
raise LagoHTTPError(message, response=e.response, json=json) from e
return response

def get(self, path: str):
Expand Down Expand Up @@ -114,16 +98,12 @@ def upload_events(self, events: Iterable[LagoEvent], *, dry_run = False):
new_events += len(event_batch)
logging.debug(f"Uploaded {len(event_batch)} events.")

# If the batch failed because some events were already uploaded, retry the rest
except LagoHTTPError as e:

# If the batch failed because some events were already uploaded, retry the rest
if [e.json.get("status"), e.json.get("code")] != [
422,
"validation_errors",
]:
if not (e.json and e.json['status'] == 422 and e.json['code'] == "validation_errors"):
raise

for error in e.json["error_details"].values():
for error in e.json['error_details'].values():
if error.get("transaction_id") != ["value_already_exist"]:
raise

Expand All @@ -134,7 +114,7 @@ def upload_events(self, events: Iterable[LagoEvent], *, dry_run = False):
retry_event_batch = [
event
for (i, event) in enumerate(event_batch)
if str(i) not in e.json["error_details"].keys()
if str(i) not in e.json['error_details'].keys()
]

if len(retry_event_batch) > 0:
Expand All @@ -149,6 +129,15 @@ def upload_events(self, events: Iterable[LagoEvent], *, dry_run = False):

return new_events, total_events

class LagoHTTPError(HTTPError):
class ResponseJson(TypedDict):
status: int
code: str
error_details: dict[str, Any]

def __init__(self, *args, json: Optional[ResponseJson], **kwargs):
super().__init__(*args, **kwargs)
self.json = json

class LagoResponseMetadata(TypedDict):
current_page: int
Expand Down Expand Up @@ -410,3 +399,14 @@ class LagoInvoicesResponse(TypedDict):
class LagoSubscriptionsResponse(TypedDict):
subscriptions: list[LagoSubscription]
meta: LagoResponseMetadata


def batch(iterable: Iterable[T], n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
return
yield batch
25 changes: 19 additions & 6 deletions component/lambda/functions/si_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,29 +113,42 @@ class Statement:
def __init__(self, redshift: 'Redshift', statement: 'ExecuteStatementOutputTypeDef'):
self.redshift = redshift
self.statement = statement
self.response = None
self.started_at = time.time()
self.last_report = None

def wait_for_complete(self):
last_report = time.time()
if self.response is not None:
return self.response

def log_status(status: str):
logging.log(logging.INFO,
f"Query status: {status}. (Id={self.statement['Id']}, Elapsed={time.time() - self.started_at}s)"
)
self.last_report = time.time()

while True:
response = self._describe_statement()
status = response["Status"]

match status:
case "FINISHED":
log_status(status)
self.response = response
return response
case "FAILED":
log_status(status)
self.response = response
raise Exception(
f"Query failed: {response['Error']} (Id={self.statement['Id']})"
)
case "ABORTED":
log_status(status)
self.response = response
raise Exception(f"Query aborted (Id={self.statement['Id']})")

if time.time() - last_report >= self.redshift._report_interval_seconds:
last_report = time.time()
logging.log(logging.INFO,
f"Query status: {status}. Waiting {self.redshift._wait_interval_seconds}s for completion... (Id={self.statement['Id']})"
)
if time.time() - (self.last_report or self.started_at) >= self.redshift._report_interval_seconds:
log_status(status)

time.sleep(self.redshift._wait_interval_seconds)

Expand Down
72 changes: 56 additions & 16 deletions component/lambda/functions/workspace_delegations_population.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable
import json
from typing import NotRequired, Optional, TypedDict, cast, overload
from typing import NotRequired, Optional, TypedDict, Union, cast, overload
import time
from datetime import datetime
from si_lambda import SiLambda, SiLambdaEnv
Expand All @@ -20,26 +20,54 @@ def __init__(self, event: WorkspaceDelegationsPopulationEnv):
assert len(self.owner_pks) > 0, "SI_OWNER_PKS must be non-empty. Did you mean to not set it?"

def update_subscriptions(self, current_timestamp: SqlTimestamp):
started_at = time.time()
last_report = started_at

# Query all owner subscriptions and compare to what's in Lago
latest_owner_subscriptions = cast(Iterable['LatestOwnerSubscription'], self.redshift.query("""
latest_owner_subscriptions = self.redshift.query("""
SELECT owner_pk, subscription_id, subscription_start_date, subscription_end_date, plan_code, external_id
FROM workspace_operations.owners
LEFT OUTER JOIN workspace_operations.latest_owner_subscriptions USING (owner_pk)
ORDER BY owner_pk, start_time
"""))
""")

# Removes the outer join "fake" row and returns 0 subscriptions instead
def remove_fake_row(si_subscriptions_iter: Iterable[Union[LatestOwnerSubscription, OwnerWithoutSubscriptions]]):
si_subscriptions = list(si_subscriptions_iter)
result = cast(list[LatestOwnerSubscription], si_subscriptions)
if len(si_subscriptions) == 1 and si_subscriptions[0]['subscription_id'] is None:
result = []
return result

total_subscriptions = latest_owner_subscriptions.wait_for_complete()['ResultRows']

# Start all the inserts at once (if any)
subscription_updates = [
self.update_owner_subscriptions(owner_pk, current_timestamp, list(si_subscriptions), self.get_owner_lago_subscriptions(owner_pk))
for owner_pk, si_subscriptions in groupby(latest_owner_subscriptions, lambda sub: sub['owner_pk'])
if self.owner_pks is None or owner_pk in self.owner_pks
]
processed_subscriptions = 0
subscription_updates = []
for owner_pk, si_subscriptions_iter in groupby(
cast(Iterable[Union[LatestOwnerSubscription, OwnerWithoutSubscriptions]], latest_owner_subscriptions),
lambda sub: sub['owner_pk']
):
si_subscriptions = remove_fake_row(si_subscriptions_iter)

if self.owner_pks is None or owner_pk in self.owner_pks:
lago_subscriptions = self.get_owner_lago_subscriptions(owner_pk)
subscription_updates.append(
self.update_owner_subscriptions(owner_pk, current_timestamp, si_subscriptions, lago_subscriptions))

processed_subscriptions += len(si_subscriptions)
if time.time() - last_report > 5:
logging.info(f"Updating subscriptions: {processed_subscriptions} / {total_subscriptions} subscriptions retrieved from Lago after {time.time() - started_at}s.")
last_report = time.time()

# Return the completed inserts
return [
result.wait_for_complete()['Status']
for result in subscription_updates
if result is not None
results = [
update.wait_for_complete()
for update in subscription_updates
if update is not None
]
logging.info(f"Subscription update complete: {sum([result['ResultRows'] for result in results])} subscriptions updated.")
return [result['Status'] for result in results]

def update_owner_subscriptions(self, owner_pk: OwnerPk, current_timestamp: SqlTimestamp, si_subscriptions: list['LatestOwnerSubscription'], lago_subscriptions: Iterable[LagoSubscription]):
lago_subscriptions = list(lago_subscriptions)
Expand All @@ -55,8 +83,8 @@ def update_owner_subscriptions(self, owner_pk: OwnerPk, current_timestamp: SqlTi
should_update = False
if len(si_subscriptions) == 0:
if len(lago_subscriptions_by_id) != 0:
logging.info(f"New owner {owner_pk}! Adding {lago_subscriptions_by_id.keys()}")
should_update = True
logging.info(f"New owner {owner_pk}! Adding {lago_subscriptions_by_id.keys()}")
elif len(lago_subscriptions_by_id) == 0:
logging.error(f"Owner {owner_pk} has had all subscriptions removed from Lago!")
else:
Expand All @@ -74,8 +102,8 @@ def update_owner_subscriptions(self, owner_pk: OwnerPk, current_timestamp: SqlTi
si_subscription_ids = set([sub['external_id'] for sub in si_subscriptions])
for external_id in lago_subscriptions_by_id.keys():
if external_id not in si_subscription_ids:
logging.info(f"Owner {owner_pk} has a new subscription {external_id} in Lago! Adding to SI.")
should_update = True
logging.info(f"Owner {owner_pk} has a new subscription {external_id} in Lago! Adding to SI.")

if should_update:
return self.start_inserting_owner_subscriptions(owner_pk, lago_subscriptions_by_id.values(), current_timestamp)
Expand Down Expand Up @@ -140,10 +168,15 @@ def insert_missing_workspaces(self, current_timestamp: SqlTimestamp):
LIMIT 50
"""))
]
return {
workspace_id: insert.wait_for_complete()['Status']
results = {
workspace_id: insert.wait_for_complete()
for workspace_id, insert in missing_workspace_inserts
}
logging.info(f"Insert missing workspaces complete: {sum([result['ResultRows'] for result in results.values()])} workspaces inserted.")
return {
workspace_id: result['Status']
for workspace_id, result in results.items()
}

def start_inserting_workspace(self, workspace_id: WorkspaceId, workspace_owner_id: OwnerPk, timestamp):
# Prepare the columns and values for the workspace_owners table
Expand Down Expand Up @@ -185,6 +218,7 @@ def run(self):
}

# Result of the latest_owner_subscriptions query

class LatestOwnerSubscription(TypedDict):
owner_pk: OwnerPk
subscription_id: str
Expand All @@ -193,6 +227,12 @@ class LatestOwnerSubscription(TypedDict):
plan_code: str
external_id: ExternalSubscriptionId

class OwnerWithoutSubscriptions(TypedDict):
owner_pk: OwnerPk
subscription_id: None
plan_code: None
external_id: None


# Convert ISO 8601 timestamp to the required format
@overload
Expand Down

0 comments on commit 3b007a9

Please sign in to comment.