From fefbbf4c081b8bca17e75dab0ca1addf90689f0d Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 6 Jun 2024 23:07:27 +0530 Subject: [PATCH 01/26] Don't make use of transaction.atomic in handle_subscription_updated --- routers/account.py | 1 - routers/stripe.py | 1 - 2 files changed, 2 deletions(-) diff --git a/routers/account.py b/routers/account.py index 4df7afcf1..be7300485 100644 --- a/routers/account.py +++ b/routers/account.py @@ -229,7 +229,6 @@ def account_page_wrapper(request: Request, current_tab: TabData): yield -@transaction.atomic def paypal_handle_subscription_updated(subscription: paypal.Subscription): logger.info("Subscription updated") diff --git a/routers/stripe.py b/routers/stripe.py index 4ee3ef12a..e6df7123f 100644 --- a/routers/stripe.py +++ b/routers/stripe.py @@ -114,7 +114,6 @@ def _handle_checkout_session_completed(uid: str, session_data): ) -@transaction.atomic def _handle_subscription_updated(uid: str, subscription_data): logger.info("Subscription updated") product = stripe.Product.retrieve(subscription_data.plan.product) From e35d78f316ca4014f1392d5de49a6d9494ba8625 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 6 Jun 2024 23:08:09 +0530 Subject: [PATCH 02/26] paypal.Subscription.cancel: ignore when sub is already cancelled / expired --- daras_ai_v2/paypal.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/daras_ai_v2/paypal.py b/daras_ai_v2/paypal.py index 4c5229b98..eae5042d4 100644 --- a/daras_ai_v2/paypal.py +++ b/daras_ai_v2/paypal.py @@ -155,6 +155,8 @@ class Subscription(PaypalResource): billing_info: BillingInfo | None def cancel(self, *, reason: str = "cancellation_requested") -> None: + if self.status in ["CANCELLED", "EXPIRED"]: + return r = requests.post( str(self.get_resource_url() / "cancel"), headers=get_default_headers(), From 46e34b13c831add5c920fa2965f6b68b3ad54cc0 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 6 Jun 2024 23:09:02 +0530 Subject: [PATCH 03/26] Fix paypal next invoice time when it is None --- payments/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/payments/models.py b/payments/models.py index cca4e1cbd..d55f15b53 100644 --- a/payments/models.py +++ b/payments/models.py @@ -143,7 +143,10 @@ def get_next_invoice_timestamp(self) -> float | None: return period_end elif self.payment_provider == PaymentProvider.PAYPAL: subscription = paypal.Subscription.retrieve(self.external_id) - if not subscription.billing_info: + if ( + not subscription.billing_info + or not subscription.billing_info.next_billing_time + ): return None return subscription.billing_info.next_billing_time.timestamp() else: From 965b040637c57ea439bf09a3fac448541fad0a43 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 7 Jun 2024 12:19:06 +0530 Subject: [PATCH 04/26] Payment processing page: run paypal subscription update in thread --- routers/account.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/routers/account.py b/routers/account.py index be7300485..3ff2335a0 100644 --- a/routers/account.py +++ b/routers/account.py @@ -35,18 +35,19 @@ def payment_processing_route( request: Request, provider: str = None, subscription_id: str = None ): - subtext = None waiting_time_sec = 3 + subtext = None if provider == "paypal": - if sub_id := subscription_id: - sub = paypal.Subscription.retrieve(sub_id) - paypal_handle_subscription_updated(sub) + if (sub_id := subscription_id) and st.run_in_thread( + threaded_paypal_handle_subscription_updated, args=[sub_id] + ): + waiting_time_sec = 0 else: + waiting_time_sec = 30 subtext = ( "PayPal transactions take up to a minute to reflect in your account" ) - waiting_time_sec = 30 with page_wrapper(request, className="m-auto"): with st.center(): @@ -261,3 +262,12 @@ def paypal_handle_subscription_updated(subscription: paypal.Subscription): user.subscription.full_clean() user.subscription.save() user.save(update_fields=["subscription"]) + + +def threaded_paypal_handle_subscription_updated(subscription_id: str) -> bool: + """ + Always returns True when completed (for use in st.run_in_thread()) + """ + subscription = paypal.Subscription.retrieve(subscription_id) + paypal_handle_subscription_updated(subscription) + return True From 43c250f3ecefad7298deb06cc51746074d442f77 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 9 Jun 2024 19:56:37 +0530 Subject: [PATCH 05/26] sentry_sdk: upgrade to 1.45, add loguru extra to capture logged info --- poetry.lock | 31 +++++++++++++++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9382b7274..dbdbef4a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2917,6 +2917,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -4464,6 +4474,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4471,8 +4482,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4489,6 +4507,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4496,6 +4515,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -5103,17 +5123,18 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo [[package]] name = "sentry-sdk" -version = "1.34.0" +version = "1.45.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.34.0.tar.gz", hash = "sha256:e5d0d2b25931d88fa10986da59d941ac6037f742ab6ff2fce4143a27981d60c3"}, - {file = "sentry_sdk-1.34.0-py2.py3-none-any.whl", hash = "sha256:76dd087f38062ac6c1e30ed6feb533ee0037ff9e709974802db7b5dbf2e5db21"}, + {file = "sentry-sdk-1.45.0.tar.gz", hash = "sha256:509aa9678c0512344ca886281766c2e538682f8acfa50fd8d405f8c417ad0625"}, + {file = "sentry_sdk-1.45.0-py2.py3-none-any.whl", hash = "sha256:1ce29e30240cc289a027011103a8c83885b15ef2f316a60bcc7c5300afa144f1"}, ] [package.dependencies] certifi = "*" +loguru = {version = ">=0.5", optional = true, markers = "extra == \"loguru\""} urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""} [package.extras] @@ -5123,6 +5144,7 @@ asyncpg = ["asyncpg (>=0.23)"] beam = ["apache-beam (>=2.12)"] bottle = ["bottle (>=0.12.13)"] celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] chalice = ["chalice (>=1.16.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] django = ["django (>=1.8)"] @@ -5133,6 +5155,7 @@ grpcio = ["grpcio (>=1.21.1)"] httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] @@ -6384,4 +6407,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a4efb36ab8d78f27caa79189d06a8b977f20990654f9d5b4a096fe654465e3c5" +content-hash = "6ec610d060dbd9b6cd5e5926076d584404f1c268f61fe2f02e7825f9422a2b01" diff --git a/pyproject.toml b/pyproject.toml index 0e08956d2..463e6d97e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ stripe = "^5.0.0" python-multipart = "^0.0.5" html-sanitizer = "^1.9.3" plotly = "^5.11.0" -sentry-sdk = "^1.12.0" httpx = "^0.23.1" pyquery = "^1.4.3" redis = "^4.5.1" @@ -86,6 +85,7 @@ emoji = "^2.10.1" pyvespa = "^0.39.0" anthropic = "^0.25.5" azure-cognitiveservices-speech = "^1.37.0" +sentry-sdk = {version = "1.45.0", extras = ["loguru"]} [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" From 61e5da4f4faccd6b3c7ad76e1b065b03210b40dc Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:04:44 +0530 Subject: [PATCH 06/26] Refactor auto-recharge functionality with exceptions To allow for blocking and non-blocking versions that handle exceptions differently --- celeryapp/tasks.py | 13 ++--- daras_ai_v2/auto_recharge.py | 64 ----------------------- payments/auto_recharge.py | 98 ++++++++++++++++++++++++++++++++++++ payments/tasks.py | 93 +++++++++++++++++++++++----------- 4 files changed, 164 insertions(+), 104 deletions(-) delete mode 100644 daras_ai_v2/auto_recharge.py create mode 100644 payments/auto_recharge.py diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index f535b6ff5..b5f23e2e2 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -18,15 +18,14 @@ from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings -from daras_ai_v2.auto_recharge import auto_recharge_user from daras_ai_v2.base import StateKeys, BasePage from daras_ai_v2.exceptions import UserError -from daras_ai_v2.redis_cache import redis_lock from daras_ai_v2.send_email import send_email_via_postmark, send_low_balance_email from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params from gooeysite.bg_db_conn import db_middleware, next_db_safe +from payments.tasks import run_auto_recharge_async @app.task @@ -143,7 +142,9 @@ def save(done=False): save(done=True) if not is_api_call: send_email_on_completion(page, sr) + run_low_balance_email_check(uid) + run_auto_recharge_async.apply(kwargs={"uid": uid}) def err_msg_for_exc(e: Exception): @@ -253,11 +254,3 @@ def send_integration_attempt_email(*, user_id: int, platform: Platform, run_url: subject=f"{user.display_name} Attempted to Connect to {platform.label}", html_body=html_body, ) - - -@app.task -def auto_recharge(*, user_id: int): - redis_lock_key = f"gooey/auto_recharge/{user_id}" - with redis_lock(redis_lock_key): - user = AppUser.objects.get(id=user_id) - auto_recharge_user(user) diff --git a/daras_ai_v2/auto_recharge.py b/daras_ai_v2/auto_recharge.py deleted file mode 100644 index 15c47e42f..000000000 --- a/daras_ai_v2/auto_recharge.py +++ /dev/null @@ -1,64 +0,0 @@ -from loguru import logger - -from app_users.models import AppUser, PaymentProvider -from payments.tasks import send_email_budget_reached, send_email_auto_recharge_failed - - -def auto_recharge_user(user: AppUser): - if not user_should_auto_recharge(user): - logger.info(f"User doesn't need to auto-recharge: {user=}") - return - - dollars_spent = user.get_dollars_spent_this_month() - if ( - dollars_spent + user.subscription.auto_recharge_topup_amount - > user.subscription.monthly_spending_budget - ): - if not user.subscription.has_sent_monthly_budget_email_this_month(): - send_email_budget_reached.delay(user.id) - logger.info(f"User has reached the monthly budget: {user=}, {dollars_spent=}") - return - - match user.subscription.payment_provider: - case PaymentProvider.STRIPE: - customer = user.search_stripe_customer() - if not customer: - logger.error(f"User doesn't have a stripe customer: {user=}") - return - - try: - invoice = user.subscription.stripe_get_or_create_auto_invoice( - amount_in_dollars=user.subscription.auto_recharge_topup_amount, - metadata_key="auto_recharge", - ) - - if invoice.status == "open": - pm = user.subscription.stripe_get_default_payment_method() - invoice.pay(payment_method=pm) - logger.info( - f"Payment attempted for auto recharge invoice: {user=}, {invoice=}" - ) - elif invoice.status == "paid": - logger.info( - f"Auto recharge invoice already paid recently: {user=}, {invoice=}" - ) - except Exception as e: - logger.error( - f"Error while auto-recharging user: {user=}, {e=}, {invoice=}" - ) - send_email_auto_recharge_failed.delay(user.id) - - case PaymentProvider.PAYPAL: - logger.error(f"Auto-recharge not supported for PayPal: {user=}") - return - - -def user_should_auto_recharge(user: AppUser): - """ - whether an auto recharge should be attempted for the user - """ - return ( - user.subscription - and user.subscription.auto_recharge_enabled - and user.balance < user.subscription.auto_recharge_balance_threshold - ) diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py new file mode 100644 index 000000000..0480fbeb3 --- /dev/null +++ b/payments/auto_recharge.py @@ -0,0 +1,98 @@ +from datetime import datetime, timedelta, timezone + +from app_users.models import AppUser, PaymentProvider +from daras_ai_v2 import settings +from daras_ai_v2.redis_cache import redis_lock + + +class AutoRechargeException(Exception): + pass + + +class MonthlyBudgetReachedException(AutoRechargeException): + def __init__(self, *args, budget: int, spending: float, **kwargs): + super().__init__(*args, **kwargs) + self.budget = budget + self.spending = spending + + +class PaymentFailedException(AutoRechargeException): + pass + + +class AutoRechargeCooldownException(AutoRechargeException): + pass + + +def auto_recharge_user(uid: str): + with redis_lock(f"gooey/auto_recharge_user/{uid}"): + user: AppUser = AppUser.objects.get(uid=uid) + if should_attempt_auto_recharge(user): + _auto_recharge_user(user) + + +def _auto_recharge_user(user: AppUser): + """ + Returns whether a charge was attempted + """ + from payments.webhooks import StripeWebhookHandler + + assert ( + user.subscription.payment_provider == PaymentProvider.STRIPE + ), "Auto recharge is only supported with Stripe" + + # check for monthly budget + dollars_spent = user.get_dollars_spent_this_month() + if ( + dollars_spent + user.subscription.auto_recharge_topup_amount + > user.subscription.monthly_spending_budget + ): + raise MonthlyBudgetReachedException( + "Performing this top-up would exceed your monthly recharge budget", + budget=user.subscription.monthly_spending_budget, + spending=dollars_spent, + ) + + # create invoice or get a recent one (recent = created in the last `settings.AUTO_RECHARGE_COOLDOWN_SECONDS` seconds) + cooldown_period_start = datetime.now(timezone.utc) - timedelta( + seconds=settings.AUTO_RECHARGE_COOLDOWN_SECONDS + ) + try: + invoice = user.subscription.stripe_get_or_create_auto_invoice( + amount_in_dollars=user.subscription.auto_recharge_topup_amount, + metadata_key="auto_recharge", + created_after=cooldown_period_start, + ) + except Exception as e: + raise PaymentFailedException("Failed to create auto-recharge invoice") from e + + # recent invoice was already paid + if invoice.status == "paid": + raise AutoRechargeCooldownException( + "An auto recharge invoice was paid recently" + ) + + # get default payment method and attempt payment + assert invoice.status == "open" # sanity check + pm = user.subscription.stripe_get_default_payment_method() + + try: + invoice_data = invoice.pay(payment_method=pm) + except Exception as e: + raise PaymentFailedException( + "Payment failed when attempting to auto-recharge" + ) from e + else: + assert invoice_data.paid + StripeWebhookHandler.handle_invoice_paid( + uid=user.uid, invoice_data=invoice_data + ) + + +def should_attempt_auto_recharge(user: AppUser): + return ( + user.subscription + and user.subscription.auto_recharge_enabled + and user.subscription.payment_provider + and user.balance < user.subscription.auto_recharge_balance_threshold + ) diff --git a/payments/tasks.py b/payments/tasks.py index d5177c8e3..ffd2816da 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -7,73 +7,106 @@ from daras_ai_v2.fastapi_tricks import get_route_url from daras_ai_v2.send_email import send_email_via_postmark from daras_ai_v2.settings import templates +from payments.auto_recharge import ( + AutoRechargeCooldownException, + MonthlyBudgetReachedException, + PaymentFailedException, + auto_recharge_user, +) @app.task -def send_email_budget_reached(user_id: int): +def send_monthly_spending_notification_email(*, uid: str): from routers.account import account_route - user = AppUser.objects.get(id=user_id) + user = AppUser.objects.get(uid=uid) if not user.email: + logger.error(f"User doesn't have an email: {user=}") return - email_body = templates.get_template("monthly_budget_reached_email.html").render( - user=user, - account_url=get_route_url(account_route), - ) + threshold = user.subscription.monthly_spending_notification_threshold + send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=user.email, - subject="[Gooey.AI] Monthly Budget Reached", - html_body=email_body, + subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}", + html_body=templates.get_template( + "monthly_spending_notification_threshold_email.html" + ).render( + user=user, + account_url=get_route_url(account_route), + ), ) - user.subscription.monthly_budget_email_sent_at = timezone.now() - user.subscription.save(update_fields=["monthly_budget_email_sent_at"]) + # IMPORTANT: always use update_fields=... / select_for_update when updating + # subscription info. We don't want to overwrite other changes made to + # subscription during the same time + user.subscription.monthly_spending_notification_sent_at = timezone.now() + user.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) @app.task -def send_email_auto_recharge_failed(user_id: int): +def run_auto_recharge_async(*, uid: str): + try: + auto_recharge_user(uid) + except AutoRechargeCooldownException as e: + logger.info( + f"Rejected auto-recharge because auto-recharge is in cooldown period for user" + f"{uid=}, {e=}" + ) + return + except MonthlyBudgetReachedException as e: + send_monthly_budget_reached_email(uid=uid) + logger.info( + f"Rejected auto-recharge because user has reached monthly budget" + f"{uid=}, spending=${e.spending}, budget=${e.budget}" + ) + return + except (PaymentFailedException, Exception) as e: + send_auto_recharge_failed_email(uid=uid) + logger.exception("Payment failed when attempting to auto-recharge", uid=uid) + return + + +def send_monthly_budget_reached_email(*, uid: str): from routers.account import account_route - user = AppUser.objects.get(id=user_id) + user = AppUser.objects.get(uid=uid) if not user.email: return - email_body = templates.get_template("auto_recharge_failed_email.html").render( + email_body = templates.get_template("monthly_budget_reached_email.html").render( user=user, account_url=get_route_url(account_route), ) send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=user.email, - subject="[Gooey.AI] Auto-Recharge failed", + subject="[Gooey.AI] Monthly Budget Reached", html_body=email_body, ) + # IMPORTANT: always use update_fields=... when updating subscription + # info. We don't want to overwrite other changes made to subscription + # during the same time + user.subscription.monthly_budget_email_sent_at = timezone.now() + user.subscription.save(update_fields=["monthly_budget_email_sent_at"]) -@app.task -def send_monthly_spending_notification_email(user_id: int): + +def send_auto_recharge_failed_email(*, uid: str): from routers.account import account_route - user = AppUser.objects.get(id=user_id) + user = AppUser.objects.get(uid=uid) if not user.email: - logger.error(f"User doesn't have an email: {user=}") return - threshold = user.subscription.monthly_spending_notification_threshold - + email_body = templates.get_template("auto_recharge_failed_email.html").render( + user=user, + account_url=get_route_url(account_route), + ) send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=user.email, - subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}", - html_body=templates.get_template( - "monthly_spending_notification_threshold_email.html" - ).render( - user=user, - account_url=get_route_url(account_route), - ), + subject="[Gooey.AI] Auto-Recharge failed", + html_body=email_body, ) - - user.subscription.monthly_spending_notification_sent_at = timezone.now() - user.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) From 558505c42d3db2a40c4d99aac86dd6abef63dcd9 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:06:06 +0530 Subject: [PATCH 07/26] Refactor payments webhook handling code into payments/webhooks.py --- payments/webhooks.py | 261 +++++++++++++++++++++++++++++++++++++++++++ routers/paypal.py | 112 +++++-------------- routers/stripe.py | 123 ++------------------ 3 files changed, 297 insertions(+), 199 deletions(-) create mode 100644 payments/webhooks.py diff --git a/payments/webhooks.py b/payments/webhooks.py new file mode 100644 index 000000000..62dbbb8e5 --- /dev/null +++ b/payments/webhooks.py @@ -0,0 +1,261 @@ +from urllib.parse import quote_plus + +import stripe +from django.db import transaction +from loguru import logger +from requests.models import HTTPError + +from .models import Subscription +from .plans import PricingPlan +from .tasks import send_monthly_spending_notification_email +from app_users.models import AppUser, PaymentProvider +from daras_ai_v2 import paypal + + +class WebhookHandler: + PROVIDER: PaymentProvider + + @classmethod + def _after_payment_completed( + cls, *, uid: str, invoice_id: str, credits: int, charged_amount: int + ): + user = AppUser.objects.get_or_create_from_uid(uid)[0] + user.add_balance( + payment_provider=cls.PROVIDER, + invoice_id=invoice_id, + amount=credits, + charged_amount=charged_amount, + ) + if not user.is_paying: + user.is_paying = True + user.save(update_fields=["is_paying"]) + if ( + user.subscription + and user.subscription.should_send_monthly_spending_notification() + ): + send_monthly_spending_notification_email.delay(kwargs={"uid": uid}) + + @classmethod + @transaction.atomic + def _after_subscription_updated(cls, *, uid: str, sub_id: str, plan: PricingPlan): + if not is_sub_active(provider=cls.PROVIDER, sub_id=sub_id): + # subscription is not in an active state, just ignore + logger.info( + "Subscription is not active. Ignoring event", + provider=cls.PROVIDER, + sub_id=sub_id, + ) + return + + user = AppUser.objects.get_or_create_from_uid(uid)[0] + # select_for_update: we want to lock the row until we are done reading & + # updating the subscription + # + # anther transaction shouldn't update the subscription in the meantime + user: AppUser = AppUser.objects.select_for_update().get(pk=user.pk) + if not user.subscription: + # new subscription + logger.info("Creating new subscription for user", uid=uid) + user.subscription = Subscription.objects.get_or_create( + payment_provider=cls.PROVIDER, + external_id=sub_id, + defaults={"plan": plan.db_value}, + )[0] + user.subscription.plan = plan.db_value + + elif is_same_sub(user.subscription, provider=cls.PROVIDER, sub_id=sub_id): + if user.subscription.plan == plan.db_value: + # same subscription exists with the same plan in DB + logger.info("Nothing to do") + return + else: + # provider & sub_id is same, but plan is different. so we update only the plan + logger.info("Updating plan for user", uid=uid) + user.subscription.plan = plan.db_value + + else: + logger.critical( + "Invalid state: last subscription was not cleared for user", uid=uid + ) + + # we have a different existing subscription in DB + # this is invalid state! we should cancel the subscription if it is active + if is_sub_active( + provider=user.subscription.payment_provider, + sub_id=user.subscription.external_id, + ): + logger.critical( + "Found existing active subscription for user. Cancelling that...", + uid=uid, + provider=user.subscription.get_payment_provider_display(), + sub_id=user.subscription.external_id, + ) + user.subscription.cancel() + + logger.info("Creating new subscription for user", uid=uid) + user.subscription = Subscription( + payment_provider=cls.PROVIDER, plan=plan, external_id=sub_id + ) + + user.subscription.full_clean() + user.subscription.save() + user.save(update_fields=["subscription"]) + + @classmethod + def _after_subscription_cancelled(cls, uid: str, sub_id: str): + user = AppUser.objects.get_or_create_from_uid(uid=uid)[0] + if user.subscription and is_same_sub( + user.subscription, provider=cls.PROVIDER, sub_id=sub_id + ): + user.subscription = None + user.save(update_fields=["subscription"]) + + +class PaypalWebhookHandler(WebhookHandler): + PROVIDER = PaymentProvider.PAYPAL + + @classmethod + def handle_sale_completed(cls, sale: paypal.Sale): + if not sale.billing_agreement_id: + logger.info(f"sale {sale} is not a subscription sale... skipping") + return + + pp_sub = paypal.Subscription.retrieve(sale.billing_agreement_id) + assert pp_sub.custom_id, "pp_sub is missing uid" + assert pp_sub.plan_id, "pp_sub is missing plan ID" + + plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) + assert plan, f"Plan {pp_sub.plan_id} not found" + + charged_dollars = int(float(sale.amount.total)) # convert to dollars + if charged_dollars != plan.monthly_charge: + # log so that we can investigate, and record the payment as usual + logger.critical( + f"paypal: charged amount ${charged_dollars} does not match plan's monthly charge ${plan.monthly_charge}" + ) + + uid = pp_sub.custom_id + cls._after_payment_completed( + uid=uid, + invoice_id=sale.id, + credits=plan.credits, + charged_amount=charged_dollars * 100, + ) + + @classmethod + def handle_subscription_updated(cls, pp_sub: paypal.Subscription): + logger.info(f"Paypal subscription updated {pp_sub.id}") + + assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" + assert pp_sub.plan_id, f"PayPal subscription {pp_sub.id} is missing plan ID" + + plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) + assert plan, f"Plan {pp_sub.plan_id} not found" + + cls._after_subscription_updated( + uid=pp_sub.custom_id, sub_id=pp_sub.id, plan=plan + ) + + @classmethod + def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): + assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" + cls._after_subscription_cancelled(uid=pp_sub.custom_id, sub_id=pp_sub.id) + + +class StripeWebhookHandler(WebhookHandler): + PROVIDER = PaymentProvider.STRIPE + + @classmethod + def handle_invoice_paid(cls, uid: str, invoice_data): + invoice_id = invoice_data.id + line_items = stripe.Invoice._static_request( + "get", + "/v1/invoices/{invoice}/lines".format(invoice=quote_plus(invoice_id)), + ) + + cls._after_payment_completed( + uid=uid, + invoice_id=invoice_id, + credits=line_items.data[0].quantity, + charged_amount=line_items.data[0].amount, + ) + + @classmethod + def handle_checkout_session_completed(cls, uid: str, session_data): + if setup_intent_id := session_data.get("setup_intent") is None: + # not a setup mode checkout -- do nothing + return + setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) + + # subscription_id was passed to metadata when creating the session + sub_id = setup_intent.metadata["subscription_id"] + assert ( + sub_id + ), f"subscription_id is missing in setup_intent metadata {setup_intent}" + + if is_sub_active(provider=PaymentProvider.STRIPE, sub_id=sub_id): + stripe.Subscription.modify( + sub_id, default_payment_method=setup_intent.payment_method + ) + + @classmethod + def handle_subscription_updated(cls, uid: str, stripe_sub): + logger.info(f"Stripe subscription updated: {stripe_sub.id}") + + assert stripe_sub.plan, f"Stripe subscription {stripe_sub.id} is missing plan" + assert ( + stripe_sub.plan.product + ), f"Stripe subscription {stripe_sub.id} is missing product" + + product = stripe.Product.retrieve(stripe_sub.plan.product) + plan = PricingPlan.get_by_stripe_product(product) + if not plan: + raise Exception( + f"PricingPlan not found for product {stripe_sub.plan.product}" + ) + + cls._after_subscription_updated(uid=uid, sub_id=stripe_sub.id, plan=plan) + + @classmethod + def handle_subscription_cancelled(cls, uid: str, stripe_sub): + logger.info(f"Stripe subscription cancelled: {stripe_sub.id}") + + cls._after_subscription_cancelled(uid=uid, sub_id=stripe_sub.id) + + +def is_same_sub( + subscription: Subscription, *, provider: PaymentProvider, sub_id: str +) -> bool: + return ( + subscription.payment_provider == provider and subscription.external_id == sub_id + ) + + +def is_sub_active(*, provider: PaymentProvider, sub_id: str) -> bool: + match provider: + case PaymentProvider.PAYPAL: + try: + sub = paypal.Subscription.retrieve(sub_id) + except HTTPError as e: + if e.response.status_code != 404: + # if not 404, it likely means there is a bug in our code... + # we want to know about it, but not break the end user experience + logger.exception(f"Unexpected PayPal error for sub: {sub_id}") + return False + + return sub.status == "ACTIVE" + + case PaymentProvider.STRIPE: + try: + sub = stripe.Subscription.retrieve(sub_id) + except stripe.error.InvalidRequestError as e: + if e.http_status != 404: + # if not 404, it likely means there is a bug in our code... + # we want to know about it, but not break the end user experience + logger.exception(f"Unexpected Stripe error for sub: {sub_id}") + return False + except stripe.error.StripeError as e: + logger.exception(f"Unexpected Stripe error for sub: {sub_id}") + return False + + return sub.status == "active" diff --git a/routers/paypal.py b/routers/paypal.py index 2d540122a..178fb8979 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -17,12 +17,8 @@ from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fastapi_tricks import fastapi_request_json, get_route_url from payments.models import PricingPlan -from payments.tasks import send_monthly_spending_notification_email -from routers.account import ( - paypal_handle_subscription_updated, - payment_processing_route, - account_route, -) +from payments.webhooks import PaypalWebhookHandler +from routers.account import payment_processing_route, account_route router = APIRouter() @@ -150,7 +146,27 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json): return JSONResponse(content=jsonable_encoder(pp_subscription), status_code=200) -@router.post("/__/paypal/webhook/") +# Capture payment for the created order to complete the transaction. +# @see https://developer.paypal.com/docs/api/orders/v2/#orders_capture +@router.post("/__/paypal/orders/{order_id}/capture/") +def capture_order(order_id: str): + response = requests.post( + str(furl(settings.PAYPAL_BASE) / f"v2/checkout/orders/{order_id}/capture"), + headers={ + "Content-Type": "application/json", + "Authorization": paypal.generate_auth_header(), + # Uncomment one of these to force an error for negative testing (in sandbox mode only). Documentation: + # https://developer.paypal.com/tools/sandbox/negative-testing/request-headers/ + # "PayPal-Mock-Response": '{"mock_application_codes": "INSTRUMENT_DECLINED"}' + # "PayPal-Mock-Response": '{"mock_application_codes": "TRANSACTION_REFUSED"}' + # "PayPal-Mock-Response": '{"mock_application_codes": "INTERNAL_SERVER_ERROR"}' + }, + ) + _handle_invoice_paid(order_id) + return JSONResponse(response.json(), response.status_code) + + +@router.post("/__/paypal/webhook") def webhook(request: Request, payload: dict = fastapi_request_json): if not paypal.verify_webhook_event(payload, headers=request.headers): logger.error("Invalid PayPal webhook signature") @@ -166,40 +182,20 @@ def webhook(request: Request, payload: dict = fastapi_request_json): match event.event_type: case "PAYMENT.SALE.COMPLETED": - event = SaleCompletedEvent.parse_obj(event) - _handle_sale_completed(event) + sale = SaleCompletedEvent.parse_obj(event).resource + PaypalWebhookHandler.handle_sale_completed(sale) case "BILLING.SUBSCRIPTION.ACTIVATED" | "BILLING.SUBSCRIPTION.UPDATED": - event = SubscriptionEvent.parse_obj(event) - paypal_handle_subscription_updated(event.resource) + subscription = SubscriptionEvent.parse_obj(event).resource + PaypalWebhookHandler.handle_subscription_updated(subscription) case "BILLING.SUBSCRIPTION.CANCELLED" | "BILLING.SUBSCRIPTION.EXPIRED": - event = SubscriptionEvent.parse_obj(event) - _handle_subscription_cancelled(event.resource) + subscription = SubscriptionEvent.parse_obj(event).resource + PaypalWebhookHandler.handle_subscription_cancelled(subscription) case _: logger.error(f"Unhandled PayPal webhook event: {event.event_type}") return JSONResponse({}, status_code=200) -# Capture payment for the created order to complete the transaction. -# @see https://developer.paypal.com/docs/api/orders/v2/#orders_capture -@router.post("/__/paypal/orders/{order_id}/capture/") -def capture_order(order_id: str): - response = requests.post( - str(furl(settings.PAYPAL_BASE) / f"v2/checkout/orders/{order_id}/capture"), - headers={ - "Content-Type": "application/json", - "Authorization": paypal.generate_auth_header(), - # Uncomment one of these to force an error for negative testing (in sandbox mode only). Documentation: - # https://developer.paypal.com/tools/sandbox/negative-testing/request-headers/ - # "PayPal-Mock-Response": '{"mock_application_codes": "INSTRUMENT_DECLINED"}' - # "PayPal-Mock-Response": '{"mock_application_codes": "TRANSACTION_REFUSED"}' - # "PayPal-Mock-Response": '{"mock_application_codes": "INTERNAL_SERVER_ERROR"}' - }, - ) - _handle_invoice_paid(order_id) - return JSONResponse(response.json(), response.status_code) - - def _handle_invoice_paid(order_id: str): response = requests.get( str(furl(settings.PAYPAL_BASE) / f"v2/checkout/orders/{order_id}"), @@ -219,53 +215,3 @@ def _handle_invoice_paid(order_id: str): if not user.is_paying: user.is_paying = True user.save(update_fields=["is_paying"]) - - -def _handle_sale_completed(event: SaleCompletedEvent): - sale = event.resource - if not sale.billing_agreement_id: - logger.warning(f"Sale {sale.id} is missing subscription ID") - return - - pp_subscription = paypal.Subscription.retrieve(sale.billing_agreement_id) - if not pp_subscription.custom_id: - logger.error(f"Subscription {pp_subscription.id} is missing custom ID") - return - - assert pp_subscription.plan_id, "Subscription is missing plan ID" - plan = PricingPlan.get_by_paypal_plan_id(pp_subscription.plan_id) - if not plan: - logger.error(f"Invalid plan ID: {pp_subscription.plan_id}") - return - - if float(sale.amount.total) == float(plan.monthly_charge): - new_credits = plan.credits - else: - new_credits = int(float(sale.amount.total) * settings.ADDON_CREDITS_PER_DOLLAR) - - user = AppUser.objects.get(uid=pp_subscription.custom_id) - user.add_balance( - payment_provider=PaymentProvider.PAYPAL, - invoice_id=sale.id, - amount=new_credits, - charged_amount=int(float(sale.amount.total) * 100), - ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) - if ( - user.subscription - and user.subscription.should_send_monthly_spending_notification() - ): - send_monthly_spending_notification_email.delay(user.id) - - -def _handle_subscription_cancelled(subscription: paypal.Subscription): - user = AppUser.objects.get(uid=subscription.custom_id) - if ( - user.subscription - and user.subscription.payment_provider == PaymentProvider.PAYPAL - and user.subscription.external_id == subscription.id - ): - user.subscription = None - user.save() diff --git a/routers/stripe.py b/routers/stripe.py index e6df7123f..d6ab35cb3 100644 --- a/routers/stripe.py +++ b/routers/stripe.py @@ -1,35 +1,14 @@ -from urllib.parse import quote_plus - import stripe -from django.db import transaction from fastapi import APIRouter, Request -from fastapi.responses import JSONResponse, RedirectResponse -from loguru import logger +from fastapi.responses import JSONResponse -from app_users.models import AppUser from daras_ai_v2 import settings -from daras_ai_v2.fastapi_tricks import ( - fastapi_request_body, - get_route_url, -) -from payments.models import PaymentProvider, Subscription -from payments.plans import PricingPlan -from payments.tasks import send_monthly_spending_notification_email -from routers.account import account_route +from daras_ai_v2.fastapi_tricks import fastapi_request_body +from payments.webhooks import StripeWebhookHandler router = APIRouter() -@router.post("/__/stripe/create-portal-session") -def customer_portal(request: Request): - customer = request.user.get_or_create_stripe_customer() - portal_session = stripe.billing_portal.Session.create( - customer=customer, - return_url=get_route_url(account_route), - ) - return RedirectResponse(portal_session.url, status_code=303) - - @router.post("/__/stripe/webhook") def webhook_received(request: Request, payload: bytes = fastapi_request_body): # Retrieve the event by verifying the signature using the raw body and secret if webhook signing is configured. @@ -58,100 +37,12 @@ def webhook_received(request: Request, payload: bytes = fastapi_request_body): # Get the type of webhook event sent - used to check the status of PaymentIntents. match event["type"]: case "invoice.paid": - handle_invoice_paid(uid, data) + StripeWebhookHandler.handle_invoice_paid(uid, data) case "checkout.session.completed": - _handle_checkout_session_completed(uid, data) + StripeWebhookHandler.handle_checkout_session_completed(uid, data) case "customer.subscription.created" | "customer.subscription.updated": - _handle_subscription_updated(uid, data) + StripeWebhookHandler.handle_subscription_updated(uid, data) case "customer.subscription.deleted": - _handle_subscription_cancelled(uid, data) + StripeWebhookHandler.handle_subscription_cancelled(uid, data) return JSONResponse({"status": "success"}) - - -def handle_invoice_paid(uid: str, invoice_data): - invoice_id = invoice_data.id - line_items = stripe.Invoice._static_request( - "get", - "/v1/invoices/{invoice}/lines".format(invoice=quote_plus(invoice_id)), - ) - user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance( - payment_provider=PaymentProvider.STRIPE, - invoice_id=invoice_id, - amount=line_items.data[0].quantity, - charged_amount=line_items.data[0].amount, - ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) - if ( - user.subscription - and user.subscription.should_send_monthly_spending_notification() - ): - send_monthly_spending_notification_email.delay(user.id) - - -def _handle_checkout_session_completed(uid: str, session_data): - setup_intent_id = session_data.get("setup_intent") - if not setup_intent_id: - # not a setup mode checkout - return - - # set default payment method - user = AppUser.objects.get_or_create_from_uid(uid)[0] - setup_intent = stripe.SetupIntent.retrieve(setup_intent_id) - subscription_id = setup_intent.metadata.get("subscription_id") - if not ( - user.subscription.payment_provider == PaymentProvider.STRIPE - and user.subscription.external_id == subscription_id - ): - logger.error(f"Subscription {subscription_id} not found for user {user}") - return - - stripe.Subscription.modify( - subscription_id, default_payment_method=setup_intent.payment_method - ) - - -def _handle_subscription_updated(uid: str, subscription_data): - logger.info("Subscription updated") - product = stripe.Product.retrieve(subscription_data.plan.product) - plan = PricingPlan.get_by_stripe_product(product) - if not plan: - raise Exception( - f"PricingPlan not found for product {subscription_data.plan.product}" - ) - - if subscription_data.get("status") != "active": - logger.warning(f"Subscription {subscription_data.id} is not active") - return - - user = AppUser.objects.get_or_create_from_uid(uid)[0] - if user.subscription and ( - user.subscription.payment_provider != PaymentProvider.STRIPE - or user.subscription.external_id != subscription_data.id - ): - logger.warning( - f"User {user} has different existing subscription {user.subscription}. Cancelling that..." - ) - user.subscription.cancel() - user.subscription.delete() - elif not user.subscription: - user.subscription = Subscription() - - user.subscription.plan = plan.db_value - user.subscription.payment_provider = PaymentProvider.STRIPE - user.subscription.external_id = subscription_data.id - - user.subscription.full_clean() - user.subscription.save() - user.save(update_fields=["subscription"]) - - -def _handle_subscription_cancelled(uid: str, subscription_data): - subscription = Subscription.objects.get_by_stripe_subscription_id( - subscription_data.id - ) - logger.info(f"Subscription {subscription} cancelled. Deleting it...") - subscription.delete() From cdcc8ac2aef384d62e69f0c6aa79f66461175941 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:07:41 +0530 Subject: [PATCH 08/26] Add BasePage.run_with_auto_recharge --- daras_ai_v2/base.py | 44 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 9276f7c35..7bd968504 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -11,6 +11,7 @@ from time import sleep from types import SimpleNamespace +from loguru import logger import sentry_sdk from django.db.models import Sum from django.utils import timezone @@ -24,6 +25,7 @@ ) from starlette.requests import Request +from daras_ai_v2.exceptions import UserError import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.models import ( @@ -37,7 +39,6 @@ from daras_ai.text_format import format_number_with_suffix from daras_ai_v2 import settings, urls from daras_ai_v2.api_examples_widget import api_example_generator -from daras_ai_v2.auto_recharge import user_should_auto_recharge from daras_ai_v2.breadcrumbs import render_breadcrumbs, get_title_breadcrumbs from daras_ai_v2.copy_to_clipboard_button_widget import ( copy_to_clipboard_button, @@ -71,6 +72,11 @@ from gooey_ui.components.pills import pill from gooey_ui.pubsub import realtime_pull from routers.account import AccountTabs +from payments.auto_recharge import ( + AutoRechargeException, + auto_recharge_user, + should_attempt_auto_recharge, +) from routers.root import RecipeTabs DEFAULT_META_IMG = ( @@ -1355,6 +1361,29 @@ def _render_help(self): def render_usage_guide(self): raise NotImplementedError + def run_with_auto_recharge(self, state: dict) -> typing.Iterator[str | None]: + if not self.check_credits() and should_attempt_auto_recharge(self.request.user): + yield "Low balance detected. Recharging..." + try: + auto_recharge_user(uid=self.request.user.uid) + except AutoRechargeException as e: + # raise this error only if another auto-recharge + # procedure didn't complete successfully + self.request.user.refresh_from_db() + if not self.check_credits(): + raise UserError(str(e)) from e + else: + self.request.user.refresh_from_db() + + if not self.check_credits(): + example_id, run_id, uid = extract_query_params(gooey_get_query_params()) + error_msg = self.generate_credit_error_message( + example_id=example_id, run_id=run_id, uid=uid + ) + raise UserError(error_msg) + + yield from self.run(state) + def run(self, state: dict) -> typing.Iterator[str | None]: # initialize request and response request = self.RequestModel.parse_obj(state) @@ -1505,8 +1534,6 @@ def estimate_run_duration(self) -> int | None: pass def on_submit(self): - from celeryapp.tasks import auto_recharge - try: example_id, run_id, uid = self.create_new_run(enable_rate_limits=True) except RateLimitExceeded as e: @@ -1514,10 +1541,10 @@ def on_submit(self): st.session_state[StateKeys.error_msg] = e.detail.get("error", "") return - if user_should_auto_recharge(self.request.user): - auto_recharge.delay(user_id=self.request.user.id) - - if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits(): + if not self.check_credits() and not should_attempt_auto_recharge( + self.request.user + ): + # insufficient balance for this run and auto-recharge isn't setup st.session_state[StateKeys.run_status] = None st.session_state[StateKeys.error_msg] = self.generate_credit_error_message( example_id, run_id, uid @@ -1978,6 +2005,9 @@ def run_as_api_tab(self): manage_api_keys(self.request.user) def check_credits(self) -> bool: + if not settings.CREDITS_TO_DEDUCT_PER_RUN: + return True + assert self.request, "request must be set to check credits" assert self.request.user, "request.user must be set to check credits" return self.request.user.balance >= self.get_price_roundoff(st.session_state) From c42ffae204f1312cdd1047e306e56bbbdfce4298 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:08:11 +0530 Subject: [PATCH 09/26] recipe runner task: s/page.run/page.run_with_auto_recharge --- celeryapp/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index b5f23e2e2..0fb7b966e 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -96,13 +96,13 @@ def save(done=False): page.dump_state_to_sr(st.session_state | output, sr) try: - gen = page.run(st.session_state) + gen = page.run_with_auto_recharge(st.session_state) save() while True: # record time start_time = time() try: - # advance the generator (to further progress of run()) + # advance the generator (to further progress of run_with_auto_recharge()) yield_val = next_db_safe(gen) # increment total time taken after every iteration run_time += time() - start_time From d2314476721e223efcf68354cb3b5d82b2074a48 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 17:27:52 +0530 Subject: [PATCH 10/26] api: don't auto-recharge before the run (we do that afterwards) --- celeryapp/tasks.py | 4 ++-- payments/tasks.py | 18 ++++++++++++------ routers/api.py | 9 ++------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 0fb7b966e..97c4ce056 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -25,7 +25,7 @@ from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params from gooeysite.bg_db_conn import db_middleware, next_db_safe -from payments.tasks import run_auto_recharge_async +from payments.tasks import run_auto_recharge_gracefully @app.task @@ -144,7 +144,7 @@ def save(done=False): send_email_on_completion(page, sr) run_low_balance_email_check(uid) - run_auto_recharge_async.apply(kwargs={"uid": uid}) + run_auto_recharge_gracefully(uid) def err_msg_for_exc(e: Exception): diff --git a/payments/tasks.py b/payments/tasks.py index ffd2816da..eb122c4db 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -16,7 +16,7 @@ @app.task -def send_monthly_spending_notification_email(*, uid: str): +def send_monthly_spending_notification_email(uid: str): from routers.account import account_route user = AppUser.objects.get(uid=uid) @@ -46,7 +46,13 @@ def send_monthly_spending_notification_email(*, uid: str): @app.task -def run_auto_recharge_async(*, uid: str): +def run_auto_recharge_gracefully(uid: str): + """ + Wrapper over auto_recharge_user, that handles exceptions so that it can: + - log exceptions + - send emails when auto-recharge fails + - not retry if this is run as a background task + """ try: auto_recharge_user(uid) except AutoRechargeCooldownException as e: @@ -56,19 +62,19 @@ def run_auto_recharge_async(*, uid: str): ) return except MonthlyBudgetReachedException as e: - send_monthly_budget_reached_email(uid=uid) + send_monthly_budget_reached_email(uid) logger.info( f"Rejected auto-recharge because user has reached monthly budget" f"{uid=}, spending=${e.spending}, budget=${e.budget}" ) return except (PaymentFailedException, Exception) as e: - send_auto_recharge_failed_email(uid=uid) + send_auto_recharge_failed_email(uid) logger.exception("Payment failed when attempting to auto-recharge", uid=uid) return -def send_monthly_budget_reached_email(*, uid: str): +def send_monthly_budget_reached_email(uid: str): from routers.account import account_route user = AppUser.objects.get(uid=uid) @@ -93,7 +99,7 @@ def send_monthly_budget_reached_email(*, uid: str): user.subscription.save(update_fields=["monthly_budget_email_sent_at"]) -def send_auto_recharge_failed_email(*, uid: str): +def send_auto_recharge_failed_email(uid: str): from routers.account import account_route user = AppUser.objects.get(uid=uid) diff --git a/routers/api.py b/routers/api.py index 979af1c3a..1002218e3 100644 --- a/routers/api.py +++ b/routers/api.py @@ -19,8 +19,6 @@ from starlette.datastructures import UploadFile from starlette.requests import Request -from celeryapp.tasks import auto_recharge -from daras_ai_v2.auto_recharge import user_should_auto_recharge import gooey_ui as st from app_users.models import AppUser from auth.token_authentication import api_auth_header @@ -34,8 +32,8 @@ RecipeRunState, ) from daras_ai_v2.fastapi_tricks import fastapi_request_form -from daras_ai_v2.ratelimits import ensure_rate_limits from gooeysite.bg_db_conn import get_celery_result_db_safe +from payments.auto_recharge import should_attempt_auto_recharge from routers.account import AccountTabs app = APIRouter() @@ -369,11 +367,8 @@ def submit_api_call( st.set_session_state(state) st.set_query_params(query_params) - if user_should_auto_recharge(self.request.user): - auto_recharge.delay(user_id=self.request.user.id) - # check the balance - if settings.CREDITS_TO_DEDUCT_PER_RUN and not self.check_credits(): + if not self.check_credits() and not should_attempt_auto_recharge(self.request.user): account_url = furl(settings.APP_BASE_URL) / AccountTabs.billing.url_path raise HTTPException( status_code=402, From 5f48995228eee8b0b43189b07fcf626fecafec9a Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 17:34:02 +0530 Subject: [PATCH 11/26] Remove subscription-change logic from routers/account.py --- payments/webhooks.py | 2 +- routers/account.py | 69 +++++++++++++------------------------------- 2 files changed, 21 insertions(+), 50 deletions(-) diff --git a/payments/webhooks.py b/payments/webhooks.py index 62dbbb8e5..8ad45f46a 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -33,7 +33,7 @@ def _after_payment_completed( user.subscription and user.subscription.should_send_monthly_spending_notification() ): - send_monthly_spending_notification_email.delay(kwargs={"uid": uid}) + send_monthly_spending_notification_email.delay(uid) @classmethod @transaction.atomic diff --git a/routers/account.py b/routers/account.py index 3ff2335a0..1d99d3d0a 100644 --- a/routers/account.py +++ b/routers/account.py @@ -2,29 +2,24 @@ from contextlib import contextmanager from enum import Enum -from django.db import transaction from fastapi import APIRouter from fastapi.requests import Request from furl import furl from loguru import logger +from requests.models import HTTPError import gooey_ui as st -from app_users.models import AppUser, PaymentProvider from bots.models import PublishedRun, PublishedRunVisibility, Workflow from daras_ai_v2 import icons, paypal from daras_ai_v2.base import RedirectException from daras_ai_v2.billing import billing_page -from daras_ai_v2.fastapi_tricks import ( - get_route_path, - get_route_url, -) +from daras_ai_v2.fastapi_tricks import get_route_path, get_route_url from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import raw_build_meta_tags from daras_ai_v2.profiles import edit_user_profile_page from gooey_ui.components.pills import pill -from payments.models import Subscription -from payments.plans import PricingPlan +from payments.webhooks import PaypalWebhookHandler from routers.root import page_wrapper, get_og_url_path app = APIRouter() @@ -33,17 +28,21 @@ @app.post("/payment-processing/") @st.route def payment_processing_route( - request: Request, provider: str = None, subscription_id: str = None + request: Request, provider: str | None = None, subscription_id: str | None = None ): waiting_time_sec = 3 subtext = None if provider == "paypal": - if (sub_id := subscription_id) and st.run_in_thread( - threaded_paypal_handle_subscription_updated, args=[sub_id] - ): + success = st.run_in_thread( + threaded_paypal_handle_subscription_updated, + args=[subscription_id], + ) + if success: + # immediately redirect waiting_time_sec = 0 else: + # either failed or still running. in either case, wait 30s before redirecting waiting_time_sec = 30 subtext = ( "PayPal transactions take up to a minute to reflect in your account" @@ -57,7 +56,9 @@ def payment_processing_route( style=dict(height="3rem", width="3rem"), ) st.write("# Processing payment...") - st.caption(subtext) + + if subtext: + st.caption(subtext) st.js( # language=JavaScript @@ -230,44 +231,14 @@ def account_page_wrapper(request: Request, current_tab: TabData): yield -def paypal_handle_subscription_updated(subscription: paypal.Subscription): - logger.info("Subscription updated") - - plan = PricingPlan.get_by_paypal_plan_id(subscription.plan_id) - if not plan: - logger.error(f"Invalid plan ID: {subscription.plan_id}") - return - - if not subscription.status == "ACTIVE": - logger.warning(f"Subscription {subscription.id} is not active") - return - - user = AppUser.objects.get(uid=subscription.custom_id) - if user.subscription and ( - user.subscription.payment_provider != PaymentProvider.PAYPAL - or user.subscription.external_id != subscription.id - ): - logger.warning( - f"User {user} has different existing subscription {user.subscription}. Cancelling that..." - ) - user.subscription.cancel() - user.subscription.delete() - elif not user.subscription: - user.subscription = Subscription() - - user.subscription.plan = plan.db_value - user.subscription.payment_provider = PaymentProvider.PAYPAL - user.subscription.external_id = subscription.id - - user.subscription.full_clean() - user.subscription.save() - user.save(update_fields=["subscription"]) - - def threaded_paypal_handle_subscription_updated(subscription_id: str) -> bool: """ Always returns True when completed (for use in st.run_in_thread()) """ - subscription = paypal.Subscription.retrieve(subscription_id) - paypal_handle_subscription_updated(subscription) + try: + subscription = paypal.Subscription.retrieve(subscription_id) + PaypalWebhookHandler.handle_subscription_updated(subscription) + except HTTPError: + logger.exception(f"Unexpected PayPal error for sub: {subscription_id}") + return False return True From 7f227db464a16a5ad1d1428c33866c104672b685 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 17:38:19 +0530 Subject: [PATCH 12/26] billing_page: fix same-key bug in buttons, refactor for clarity --- daras_ai_v2/billing.py | 97 +++++++++++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 30 deletions(-) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 353cee0ec..dbf1bf6a1 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -1,3 +1,5 @@ +from typing import Literal + import stripe from django.core.exceptions import ValidationError @@ -17,6 +19,9 @@ rounded_border = "w-100 border shadow-sm rounded py-4 px-3" +PlanActionLabel = Literal["Upgrade", "Downgrade", "Contact Us", "Your Plan"] + + def billing_page(user: AppUser): render_payments_setup() @@ -126,11 +131,11 @@ def render_all_plans(user: AppUser): st.write("## All Plans") plans_div = st.div(className="mb-1") - if user.subscription: - payment_provider = None + if user.subscription and user.subscription.payment_provider: + selected_payment_provider = None else: with st.div(): - payment_provider = PaymentProvider[ + selected_payment_provider = PaymentProvider[ payment_provider_radio() or PaymentProvider.STRIPE.name ] @@ -144,7 +149,9 @@ def _render_plan(plan: PricingPlan): className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}" ): _render_plan_details(plan) - _render_plan_action_button(user, plan, current_plan, payment_provider) + _render_plan_action_button( + user, plan, current_plan, selected_payment_provider + ) with plans_div: grid_layout(4, all_plans, _render_plan, separator=False) @@ -188,30 +195,44 @@ def _render_plan_action_button( className=btn_classes + " btn btn-theme btn-primary", ): st.html("Contact Us") - elif current_plan is not PricingPlan.ENTERPRISE: - update_subscription_button( - user=user, - plan=plan, - current_plan=current_plan, - className=btn_classes, - payment_provider=payment_provider, - ) + elif user.subscription and not user.subscription.payment_provider: + # don't show upgrade/downgrade buttons for enterprise customers + # assumption: anyone without a payment provider attached is admin/enterprise + return + else: + if plan.credits > current_plan.credits: + label, btn_type = ("Upgrade", "primary") + else: + label, btn_type = ("Downgrade", "secondary") + + if user.subscription and user.subscription.payment_provider: + # subscription exists, show upgrade/downgrade button + _render_update_subscription_button( + label, + user=user, + current_plan=current_plan, + plan=plan, + className=f"{btn_classes} btn btn-theme btn-{btn_type}", + ) + else: + assert payment_provider is not None # for sanity + _render_create_subscription_button( + label, + btn_type=btn_type, + user=user, + plan=plan, + payment_provider=payment_provider, + ) -def update_subscription_button( +def _render_create_subscription_button( + label: PlanActionLabel, *, + btn_type: str, user: AppUser, - current_plan: PricingPlan, plan: PricingPlan, - className: str = "", - payment_provider: PaymentProvider | None = None, + payment_provider: PaymentProvider, ): - if plan.credits > current_plan.credits: - label, btn_type = ("Upgrade", "primary") - else: - label, btn_type = ("Downgrade", "secondary") - className += f" btn btn-theme btn-{btn_type}" - match payment_provider: case PaymentProvider.STRIPE: render_stripe_subscription_button( @@ -219,13 +240,26 @@ def update_subscription_button( ) case PaymentProvider.PAYPAL: render_paypal_subscription_button(plan=plan) - case _ if label == "Downgrade": + + +def _render_update_subscription_button( + label: PlanActionLabel, + *, + user: AppUser, + current_plan: PricingPlan, + plan: PricingPlan, + className: str = "", +): + match label: + case "Downgrade": downgrade_modal = Modal( "Confirm downgrade", - key=f"downgrade-plan-modal-{plan.key}", + key=f"change-sub-{plan.key}-modal", ) if st.button( - label, className=className, key=f"downgrade-button-{plan.key}" + label, + className=className, + key=f"change-sub-{plan.key}-modal-open-btn", ): downgrade_modal.open() @@ -404,7 +438,10 @@ def render_stripe_subscription_button( st.write("Stripe subscription not available") return - if st.button(label, type=btn_type): + # IMPORTANT: key=... is needed here to maintain uniqueness + # of buttons with the same label. otherwise, all buttons + # will be the same to the server + if st.button(label, key=f"sub-new-{plan.key}", type=btn_type): create_stripe_checkout_session(user=user, plan=plan) @@ -412,8 +449,8 @@ def create_stripe_checkout_session(user: AppUser, plan: PricingPlan): from routers.account import account_route from routers.account import payment_processing_route - if user.subscription and user.subscription.plan == plan.db_value: - # already subscribed to the same plan + if user.subscription: + # already subscribed to some plan return metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key} @@ -463,7 +500,7 @@ def render_payment_information(user: AppUser): provider = PaymentProvider(user.subscription.payment_provider) st.write(provider.label) with col3: - if st.button(f"{icons.edit} Edit", type="link"): + if st.button(f"{icons.edit} Edit", type="link", key="manage-sub"): raise RedirectException(user.subscription.get_external_management_url()) pm_summary = st.run_in_thread( @@ -482,7 +519,7 @@ def render_payment_information(user: AppUser): unsafe_allow_html=True, ) with col3: - if st.button(f"{icons.edit} Edit", type="link"): + if st.button(f"{icons.edit} Edit", type="link", key="change-pm"): change_payment_method(user) if pm_summary.billing_email: From b12dc0c08e33d8fe92431a423516bc16fd1f96a4 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 17:39:19 +0530 Subject: [PATCH 13/26] Fix Subscription model: allow null-values for auto-recharge config in DB to make it easier to use for admins & enterprise users --- ...ription_auto_recharge_balance_threshold.py | 18 ++++++++++++++++++ payments/models.py | 19 ++++++++++++------- 2 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 payments/migrations/0004_alter_subscription_auto_recharge_balance_threshold.py diff --git a/payments/migrations/0004_alter_subscription_auto_recharge_balance_threshold.py b/payments/migrations/0004_alter_subscription_auto_recharge_balance_threshold.py new file mode 100644 index 000000000..d5a1e922a --- /dev/null +++ b/payments/migrations/0004_alter_subscription_auto_recharge_balance_threshold.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-06-10 09:21 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0003_alter_subscription_external_id_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='subscription', + name='auto_recharge_balance_threshold', + field=models.IntegerField(blank=True, null=True), + ), + ] diff --git a/payments/models.py b/payments/models.py index cfb9c50c1..a8db57945 100644 --- a/payments/models.py +++ b/payments/models.py @@ -43,19 +43,23 @@ class Subscription(models.Model): blank=True, ) auto_recharge_enabled = models.BooleanField(default=True) - auto_recharge_balance_threshold = models.IntegerField() auto_recharge_topup_amount = models.IntegerField( - default=settings.ADDON_AMOUNT_CHOICES[0], + default=settings.ADDON_AMOUNT_CHOICES[0] ) + auto_recharge_balance_threshold = models.IntegerField( + null=True, blank=True + ) # dynamic default (see: full_clean) monthly_spending_budget = models.IntegerField( null=True, blank=True, help_text="In USD, pause auto-recharge just before the spending exceeds this amount in a calendar month", + # dynamic default (see: full_clean) ) monthly_spending_notification_threshold = models.IntegerField( null=True, blank=True, help_text="In USD, send an email when spending crosses this threshold in a calendar month", + # dynamic default (see: full_clean) ) monthly_spending_notification_sent_at = models.DateTimeField(null=True, blank=True) @@ -194,14 +198,15 @@ def stripe_get_or_create_auto_invoice( self, amount_in_dollars: int, metadata_key: str, + created_after: datetime, ) -> stripe.Invoice: """ Fetches the relevant invoice, or creates one if it doesn't exist. This is the fallback order: - - Fetch an open invoice with metadata_key in the metadata - - Fetch a $metadata_key invoice that was recently paid - - Create an invoice with amount=amount_in_dollars and $metadata_key + - An open invoice that has `metadata_key` set + - A paid invoice that has `metadata_key` set and was created after `created_after` + - A new invoice to charge `amount_in_dollars` USD (with `metadata_key` set to True) """ customer_id = self.stripe_get_customer_id() invoices = stripe.Invoice.list( @@ -259,7 +264,7 @@ def stripe_get_customer_id(self) -> str: raise ValueError("Invalid Payment Provider") def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool: - from routers.stripe import handle_invoice_paid + from payments.webhooks import StripeWebhookHandler invoice = self.stripe_create_auto_invoice( amount_in_dollars=amount_in_dollars, @@ -271,7 +276,7 @@ def stripe_attempt_addon_purchase(self, amount_in_dollars: int) -> bool: invoice = invoice.pay(payment_method=pm) if not invoice.paid: return False - handle_invoice_paid(self.user.uid, invoice) + StripeWebhookHandler.handle_invoice_paid(self.user.uid, invoice) return True def get_external_management_url(self) -> str: From 83cb1cf3ae4c18e1ac879f49f4ba591f970802a7 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 17:43:13 +0530 Subject: [PATCH 14/26] Add missing migration for blank=True on AppUser.subscription --- .../0017_alter_appuser_subscription.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 app_users/migrations/0017_alter_appuser_subscription.py diff --git a/app_users/migrations/0017_alter_appuser_subscription.py b/app_users/migrations/0017_alter_appuser_subscription.py new file mode 100644 index 000000000..9d0aab4c5 --- /dev/null +++ b/app_users/migrations/0017_alter_appuser_subscription.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.7 on 2024-06-10 09:21 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0004_alter_subscription_auto_recharge_balance_threshold'), + ('app_users', '0016_appuser_disable_rate_limits'), + ] + + operations = [ + migrations.AlterField( + model_name='appuser', + name='subscription', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='user', to='payments.subscription'), + ), + ] From 45a30194bd3ae148c3222beb43b787d02528030c Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:10:28 +0530 Subject: [PATCH 15/26] stripe auto invoice: use default setting for AUTO_RECHARGE_COOLDOWN without parametrizing --- payments/auto_recharge.py | 5 ----- payments/models.py | 7 +++---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py index 0480fbeb3..4228c2bce 100644 --- a/payments/auto_recharge.py +++ b/payments/auto_recharge.py @@ -53,15 +53,10 @@ def _auto_recharge_user(user: AppUser): spending=dollars_spent, ) - # create invoice or get a recent one (recent = created in the last `settings.AUTO_RECHARGE_COOLDOWN_SECONDS` seconds) - cooldown_period_start = datetime.now(timezone.utc) - timedelta( - seconds=settings.AUTO_RECHARGE_COOLDOWN_SECONDS - ) try: invoice = user.subscription.stripe_get_or_create_auto_invoice( amount_in_dollars=user.subscription.auto_recharge_topup_amount, metadata_key="auto_recharge", - created_after=cooldown_period_start, ) except Exception as e: raise PaymentFailedException("Failed to create auto-recharge invoice") from e diff --git a/payments/models.py b/payments/models.py index a8db57945..fa86a22cd 100644 --- a/payments/models.py +++ b/payments/models.py @@ -198,15 +198,14 @@ def stripe_get_or_create_auto_invoice( self, amount_in_dollars: int, metadata_key: str, - created_after: datetime, ) -> stripe.Invoice: """ Fetches the relevant invoice, or creates one if it doesn't exist. This is the fallback order: - - An open invoice that has `metadata_key` set - - A paid invoice that has `metadata_key` set and was created after `created_after` - - A new invoice to charge `amount_in_dollars` USD (with `metadata_key` set to True) + - Fetch an open invoice that has `metadata_key` set + - Fetch a `metadata_key` invoice that was recently paid + - Create an invoice with amount=`amount_in_dollars` and `metadata_key` set to true """ customer_id = self.stripe_get_customer_id() invoices = stripe.Invoice.list( From 156a363d98b31c3dcbb9b292056ea0df95801101 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:21:36 +0530 Subject: [PATCH 16/26] Remove unnecessary cosmetic changes --- daras_ai_v2/base.py | 1 - daras_ai_v2/billing.py | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 7bd968504..ed27b8d33 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -11,7 +11,6 @@ from time import sleep from types import SimpleNamespace -from loguru import logger import sentry_sdk from django.db.models import Sum from django.utils import timezone diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index dbf1bf6a1..bf18e3ca7 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -254,12 +254,10 @@ def _render_update_subscription_button( case "Downgrade": downgrade_modal = Modal( "Confirm downgrade", - key=f"change-sub-{plan.key}-modal", + key=f"downgrade-plan-modal-{plan.key}", ) if st.button( - label, - className=className, - key=f"change-sub-{plan.key}-modal-open-btn", + label, className=className, key=f"downgrade-button-{plan.key}" ): downgrade_modal.open() From bcbb02fe12b7863b7053d1067fabe7280589b459 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 11 Jun 2024 17:16:41 +0530 Subject: [PATCH 17/26] Fix bug: proration behavior Earlier, prorations were on, and someone who purchases a subscription today will be charged the prorated amount + the next month's charge in the next invoice. This commit fixes that behaviour by charging the full subscription amount on the same day. The billing cycle will also change to create the next invoice 1 month from the same day. --- daras_ai_v2/billing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index bf18e3ca7..56013ba3c 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -323,6 +323,10 @@ def change_subscription(user: AppUser, new_plan: PricingPlan): metadata={ settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: new_plan.key, }, + # charge the full new amount today, without prorations + # see: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time + billing_cycle_anchor="now", + proration_behavior="none", ) raise RedirectException( get_route_url(payment_processing_route), status_code=303 From b5bc34b7c802125b96a5da07a81a4d8c8bd684cf Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 11 Jul 2024 19:22:40 +0530 Subject: [PATCH 18/26] rename gui_runner -> runner_task optimize run time of runner_task: check for credits only once in entire call chain, separate post run stuff into a separate task fix run complete email for empty prompts --- bots/admin.py | 2 +- bots/models.py | 27 +++++++-- celeryapp/tasks.py | 83 ++++++++++++++++---------- daras_ai_v2/base.py | 97 +++++++++++++------------------ daras_ai_v2/exceptions.py | 48 ++++++++++++++- daras_ai_v2/settings.py | 2 +- payments/auto_recharge.py | 58 +++++++++++++----- payments/tasks.py | 41 +------------ routers/api.py | 89 +++++++++++++++++----------- server.py | 10 +++- templates/run_complete_email.html | 24 +++----- 11 files changed, 280 insertions(+), 201 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index 82da0aab2..e154f03b2 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -31,7 +31,7 @@ Workflow, ) from bots.tasks import create_personal_channels_for_all_members -from celeryapp.tasks import gui_runner +from celeryapp.tasks import runner_task from daras_ai_v2.fastapi_tricks import get_route_url from gooeysite.custom_actions import export_to_excel, export_to_csv from gooeysite.custom_filters import ( diff --git a/bots/models.py b/bots/models.py index 6ff6674ff..757766a4c 100644 --- a/bots/models.py +++ b/bots/models.py @@ -212,10 +212,24 @@ class SavedRun(models.Model): state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder) - error_msg = models.TextField(default="", blank=True) + error_msg = models.TextField( + default="", + blank=True, + help_text="The error message. If this is not set, the run is deemed successful.", + ) run_time = models.DurationField(default=datetime.timedelta, blank=True) run_status = models.TextField(default="", blank=True) + error_code = models.IntegerField( + null=True, + default=None, + blank=True, + help_text="The HTTP status code of the error. If this is not set, 500 is assumed.", + ) + error_type = models.TextField( + default="", blank=True, help_text="The exception type" + ) + hidden = models.BooleanField(default=False) is_flagged = models.BooleanField(default=False) @@ -282,9 +296,12 @@ def __str__(self): def parent_published_run(self) -> typing.Optional["PublishedRun"]: return self.parent_version and self.parent_version.published_run - def get_app_url(self): + def get_app_url(self, query_params: dict = None): return Workflow(self.workflow).page_cls.app_url( - example_id=self.example_id, run_id=self.run_id, uid=self.uid + example_id=self.example_id, + run_id=self.run_id, + uid=self.uid, + query_params=query_params, ) def to_dict(self) -> dict: @@ -1624,9 +1641,9 @@ def duplicate( visibility=visibility, ) - def get_app_url(self): + def get_app_url(self, query_params: dict = None): return Workflow(self.workflow).page_cls.app_url( - example_id=self.published_run_id + example_id=self.published_run_id, query_params=query_params ) def add_version( diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index e949228e2..2e30e4379 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -14,7 +14,7 @@ import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.admin_links import change_obj_url -from bots.models import SavedRun, Platform +from bots.models import SavedRun, Platform, Workflow from celeryapp.celeryconfig import app from daras_ai.image_input import truncate_text_words from daras_ai_v2 import settings @@ -24,22 +24,24 @@ from daras_ai_v2.settings import templates from gooey_ui.pubsub import realtime_push from gooey_ui.state import set_query_params -from gooeysite.bg_db_conn import db_middleware, next_db_safe -from payments.tasks import run_auto_recharge_gracefully from gooeysite.bg_db_conn import db_middleware +from payments.auto_recharge import ( + should_attempt_auto_recharge, + run_auto_recharge_gracefully, +) DEFAULT_RUN_STATUS = "Running..." @app.task -def gui_runner( +def runner_task( *, page_cls: typing.Type[BasePage], user_id: int, run_id: str, uid: str, channel: str, -): +) -> int: start_time = time() error_msg = None @@ -89,36 +91,50 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False save_on_step() for val in page.main(sr, st.session_state): save_on_step(val) + # render errors nicely except Exception as e: - if isinstance(e, HTTPException) and e.status_code == 402: - error_msg = page.generate_credit_error_message(run_id, uid) - try: - raise UserError(error_msg) from e - except UserError as e: - sentry_sdk.capture_exception(e, level=e.sentry_level) + if isinstance(e, UserError): + sentry_level = e.sentry_level else: - if isinstance(e, UserError): - sentry_level = e.sentry_level - else: - sentry_level = "error" - traceback.print_exc() - sentry_sdk.capture_exception(e, level=sentry_level) - error_msg = err_msg_for_exc(e) + sentry_level = "error" + traceback.print_exc() + sentry_sdk.capture_exception(e, level=sentry_level) + error_msg = err_msg_for_exc(e) + sr.error_type = type(e).__qualname__ + sr.error_code = getattr(e, "status_code", None) + # run completed successfully, deduct credits else: sr.transaction, sr.price = page.deduct_credits(st.session_state) + + # save everything, mark run as completed finally: save_on_step(done=True) - if not sr.is_api_call: - send_email_on_completion(page, sr) - run_low_balance_email_check(user) - run_auto_recharge_gracefully(uid) + return sr.id + + +@app.task +def post_runner_tasks(saved_run_id: int): + sr = SavedRun.objects.get(id=saved_run_id) + user = AppUser.objects.get(uid=sr.uid) + + if not sr.is_api_call: + send_email_on_completion(sr) + + if should_attempt_auto_recharge(user): + run_auto_recharge_gracefully(user) + + run_low_balance_email_check(user) def err_msg_for_exc(e: Exception): - if isinstance(e, requests.HTTPError): + if isinstance(e, UserError): + return e.message + elif isinstance(e, HTTPException): + return f"(HTTP {e.status_code}) {e.detail})" + elif isinstance(e, requests.HTTPError): response: requests.Response = e.response try: err_body = response.json() @@ -135,10 +151,6 @@ def err_msg_for_exc(e: Exception): return f"(GPU) {err_type}: {err_str}" err_str = str(err_body) return f"(HTTP {response.status_code}) {html.escape(err_str[:1000])}" - elif isinstance(e, HTTPException): - return f"(HTTP {e.status_code}) {e.detail})" - elif isinstance(e, UserError): - return e.message else: return f"{type(e).__name__}: {e}" @@ -179,7 +191,7 @@ def run_low_balance_email_check(user: AppUser): user.save(update_fields=["low_balance_email_sent_at"]) -def send_email_on_completion(page: BasePage, sr: SavedRun): +def send_email_on_completion(sr: SavedRun): run_time_sec = sr.run_time.total_seconds() if ( run_time_sec <= settings.SEND_RUN_EMAIL_AFTER_SEC @@ -191,9 +203,16 @@ def send_email_on_completion(page: BasePage, sr: SavedRun): ) if not to_address: return - prompt = (page.preview_input(sr.state) or "").strip() - title = (sr.state.get("__title") or page.title).strip() - subject = f"🌻 “{truncate_text_words(prompt, maxlen=50)}” {title} is done" + + workflow = Workflow(sr.workflow) + page_cls = workflow.page_cls + prompt = (page_cls.preview_input(sr.state) or "").strip().replace("\n", " ") + recipe_title = page_cls.get_recipe_title() + + subject = ( + f"🌻 “{truncate_text_words(prompt, maxlen=50) or 'Run'}” {recipe_title} is done" + ) + send_email_via_postmark( from_address=settings.SUPPORT_EMAIL, to_address=to_address, @@ -202,7 +221,7 @@ def send_email_on_completion(page: BasePage, sr: SavedRun): run_time_sec=round(run_time_sec), app_url=sr.get_app_url(), prompt=prompt, - title=title, + recipe_title=recipe_title, ), message_stream="gooey-ai-workflows", ) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 3b7c5e24d..6a3d602ea 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -25,7 +25,6 @@ ) from starlette.requests import Request -from daras_ai_v2.exceptions import UserError import gooey_ui as st from app_users.models import AppUser, AppUserTransaction from bots.models import ( @@ -49,6 +48,7 @@ from daras_ai_v2.db import ( ANONYMOUS_USER_COOKIE, ) +from daras_ai_v2.exceptions import InsufficientCredits from daras_ai_v2.fastapi_tricks import get_route_path from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.html_spinner_widget import html_spinner @@ -83,12 +83,11 @@ from gooey_ui.components.modal import Modal from gooey_ui.components.pills import pill from gooey_ui.pubsub import realtime_pull -from routers.account import AccountTabs from payments.auto_recharge import ( - AutoRechargeException, - auto_recharge_user, should_attempt_auto_recharge, + run_auto_recharge_gracefully, ) +from routers.account import AccountTabs from routers.root import RecipeTabs DEFAULT_META_IMG = ( @@ -1415,30 +1414,9 @@ def _render_help(self): def render_usage_guide(self): raise NotImplementedError - def run_with_auto_recharge(self, state: dict) -> typing.Iterator[str | None]: - if not self.check_credits() and should_attempt_auto_recharge(self.request.user): - yield "Low balance detected. Recharging..." - try: - auto_recharge_user(uid=self.request.user.uid) - except AutoRechargeException as e: - # raise this error only if another auto-recharge - # procedure didn't complete successfully - self.request.user.refresh_from_db() - if not self.check_credits(): - raise UserError(str(e)) from e - else: - self.request.user.refresh_from_db() - - if not self.check_credits(): - example_id, run_id, uid = extract_query_params(gooey_get_query_params()) - error_msg = self.generate_credit_error_message( - example_id=example_id, run_id=run_id, uid=uid - ) - raise UserError(error_msg) - - yield from self.run(state) - def main(self, sr: SavedRun, state: dict) -> typing.Iterator[str | None]: + yield from self.ensure_credits_and_auto_recharge(sr, state) + yield from call_recipe_functions( saved_run=sr, current_user=self.request.user, @@ -1465,15 +1443,12 @@ def run(self, state: dict) -> typing.Iterator[str | None]: response = self.ResponseModel.construct() # run the recipe - gen = self.run_v2(request, response) - while True: - try: - val = next(gen) - except StopIteration: - break - finally: + try: + for val in self.run_v2(request, response): state.update(response.dict(exclude_unset=True)) - yield val + yield val + finally: + state.update(response.dict(exclude_unset=True)) # validate the response if successful self.ResponseModel.validate(response) @@ -1634,15 +1609,7 @@ def on_submit(self): st.session_state[StateKeys.error_msg] = e.detail.get("error", "") return - if not self.check_credits() and not should_attempt_auto_recharge( - self.request.user - ): - # insufficient balance for this run and auto-recharge isn't setup - sr.run_status = "" - sr.error_msg = self.generate_credit_error_message(sr.run_id, sr.uid) - sr.save(update_fields=["run_status", "error_msg"]) - else: - self.call_runner_task(sr) + self.call_runner_task(sr) raise RedirectException(self.app_url(run_id=sr.run_id, uid=sr.uid)) @@ -1715,15 +1682,19 @@ def dump_state_to_sr(self, state: dict, sr: SavedRun): ) def call_runner_task(self, sr: SavedRun): - from celeryapp.tasks import gui_runner - - return gui_runner.delay( - page_cls=self.__class__, - user_id=self.request.user.id, - run_id=sr.run_id, - uid=sr.uid, - channel=self.realtime_channel_name(sr.run_id, sr.uid), + from celeryapp.tasks import runner_task, post_runner_tasks + + chain = ( + runner_task.s( + page_cls=self.__class__, + user_id=self.request.user.id, + run_id=sr.run_id, + uid=sr.uid, + channel=self.realtime_channel_name(sr.run_id, sr.uid), + ) + | post_runner_tasks.s() ) + return chain.apply_async() @classmethod def realtime_channel_name(cls, run_id, uid): @@ -2099,13 +2070,27 @@ def run_as_api_tab(self): manage_api_keys(self.request.user) - def check_credits(self) -> bool: + def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict): if not settings.CREDITS_TO_DEDUCT_PER_RUN: - return True - + return assert self.request, "request must be set to check credits" assert self.request.user, "request.user must be set to check credits" - return self.request.user.balance >= self.get_price_roundoff(st.session_state) + + user = self.request.user + price = self.get_price_roundoff(state) + + if user.balance >= price: + return + + if should_attempt_auto_recharge(user): + yield "Low balance detected. Recharging..." + run_auto_recharge_gracefully(user) + user.refresh_from_db() + + if user.balance >= price: + return + + raise InsufficientCredits(self.request.user, sr) def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: assert ( diff --git a/daras_ai_v2/exceptions.py b/daras_ai_v2/exceptions.py index 78e3f053f..02acc4ec6 100644 --- a/daras_ai_v2/exceptions.py +++ b/daras_ai_v2/exceptions.py @@ -3,8 +3,16 @@ import typing import requests +from furl import furl from loguru import logger from requests import HTTPError +from starlette.status import HTTP_402_PAYMENT_REQUIRED + +from daras_ai_v2 import settings + +if typing.TYPE_CHECKING: + from bots.models import SavedRun + from bots.models import AppUser def raise_for_status(resp: requests.Response, is_user_url: bool = False): @@ -47,9 +55,12 @@ def _response_preview(resp: requests.Response) -> bytes: class UserError(Exception): - def __init__(self, message: str, sentry_level: str = "info"): + def __init__( + self, message: str, sentry_level: str = "info", status_code: int = None + ): self.message = message self.sentry_level = sentry_level + self.status_code = status_code super().__init__(message) @@ -57,6 +68,41 @@ class GPUError(UserError): pass +class InsufficientCredits(UserError): + def __init__(self, user: "AppUser", sr: "SavedRun"): + from daras_ai_v2.base import SUBMIT_AFTER_LOGIN_Q + + account_url = furl(settings.APP_BASE_URL) / "account/" + if user.is_anonymous: + account_url.query.params["next"] = sr.get_app_url( + query_params={SUBMIT_AFTER_LOGIN_Q: "1"}, + ) + # language=HTML + message = f""" +

+Doh! Please login to run more Gooey.AI workflows. +

+ +You’ll receive {settings.LOGIN_USER_FREE_CREDITS} Credits when you sign up via your phone #, Google, Apple or GitHub account +and can purchase more for $1/100 Credits. +""" + else: + # language=HTML + message = f""" +

+Doh! You’re out of Gooey.AI credits. +

+ +

+Please buy more to run more workflows. +

+ +We’re always on discord if you’ve got any questions. +""" + + super().__init__(message, status_code=HTTP_402_PAYMENT_REQUIRED) + + FFMPEG_ERR_MSG = ( "Unsupported File Format\n\n" "We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. " diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 58878a9af..2b363732b 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -263,7 +263,7 @@ ADMIN_EMAILS = config("ADMIN_EMAILS", cast=Csv(), default="") SUPPORT_EMAIL = "Gooey.AI Support " SALES_EMAIL = "Gooey.AI Sales " -SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 60) +SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 5) DISALLOWED_TITLE_SLUGS = config("DISALLOWED_TITLE_SLUGS", cast=Csv(), default="") + [ # tab names diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py index 4228c2bce..1e3572929 100644 --- a/payments/auto_recharge.py +++ b/payments/auto_recharge.py @@ -1,8 +1,14 @@ -from datetime import datetime, timedelta, timezone +import traceback + +import sentry_sdk +from loguru import logger from app_users.models import AppUser, PaymentProvider -from daras_ai_v2 import settings from daras_ai_v2.redis_cache import redis_lock +from payments.tasks import ( + send_monthly_budget_reached_email, + send_auto_recharge_failed_email, +) class AutoRechargeException(Exception): @@ -24,11 +30,42 @@ class AutoRechargeCooldownException(AutoRechargeException): pass -def auto_recharge_user(uid: str): - with redis_lock(f"gooey/auto_recharge_user/{uid}"): - user: AppUser = AppUser.objects.get(uid=uid) - if should_attempt_auto_recharge(user): +def should_attempt_auto_recharge(user: AppUser): + return ( + user.subscription + and user.subscription.auto_recharge_enabled + and user.subscription.payment_provider + and user.balance < user.subscription.auto_recharge_balance_threshold + ) + + +def run_auto_recharge_gracefully(user: AppUser): + """ + Wrapper over _auto_recharge_user, that handles exceptions so that it can: + - log exceptions + - send emails when auto-recharge fails + - not retry if this is run as a background task + + Meant to be used in conjunction with should_attempt_auto_recharge + """ + try: + with redis_lock(f"gooey/auto_recharge_user/v1/{user.uid}"): _auto_recharge_user(user) + except AutoRechargeCooldownException as e: + logger.info( + f"Rejected auto-recharge because auto-recharge is in cooldown period for user" + f"{user=}, {e=}" + ) + except MonthlyBudgetReachedException as e: + send_monthly_budget_reached_email(user) + logger.info( + f"Rejected auto-recharge because user has reached monthly budget" + f"{user=}, spending=${e.spending}, budget=${e.budget}" + ) + except Exception as e: + traceback.print_exc() + sentry_sdk.capture_exception(e) + send_auto_recharge_failed_email(user) def _auto_recharge_user(user: AppUser): @@ -82,12 +119,3 @@ def _auto_recharge_user(user: AppUser): StripeWebhookHandler.handle_invoice_paid( uid=user.uid, invoice_data=invoice_data ) - - -def should_attempt_auto_recharge(user: AppUser): - return ( - user.subscription - and user.subscription.auto_recharge_enabled - and user.subscription.payment_provider - and user.balance < user.subscription.auto_recharge_balance_threshold - ) diff --git a/payments/tasks.py b/payments/tasks.py index eb122c4db..252ac505b 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -7,12 +7,6 @@ from daras_ai_v2.fastapi_tricks import get_route_url from daras_ai_v2.send_email import send_email_via_postmark from daras_ai_v2.settings import templates -from payments.auto_recharge import ( - AutoRechargeCooldownException, - MonthlyBudgetReachedException, - PaymentFailedException, - auto_recharge_user, -) @app.task @@ -45,39 +39,9 @@ def send_monthly_spending_notification_email(uid: str): user.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) -@app.task -def run_auto_recharge_gracefully(uid: str): - """ - Wrapper over auto_recharge_user, that handles exceptions so that it can: - - log exceptions - - send emails when auto-recharge fails - - not retry if this is run as a background task - """ - try: - auto_recharge_user(uid) - except AutoRechargeCooldownException as e: - logger.info( - f"Rejected auto-recharge because auto-recharge is in cooldown period for user" - f"{uid=}, {e=}" - ) - return - except MonthlyBudgetReachedException as e: - send_monthly_budget_reached_email(uid) - logger.info( - f"Rejected auto-recharge because user has reached monthly budget" - f"{uid=}, spending=${e.spending}, budget=${e.budget}" - ) - return - except (PaymentFailedException, Exception) as e: - send_auto_recharge_failed_email(uid) - logger.exception("Payment failed when attempting to auto-recharge", uid=uid) - return - - -def send_monthly_budget_reached_email(uid: str): +def send_monthly_budget_reached_email(user: AppUser): from routers.account import account_route - user = AppUser.objects.get(uid=uid) if not user.email: return @@ -99,10 +63,9 @@ def send_monthly_budget_reached_email(uid: str): user.subscription.save(update_fields=["monthly_budget_email_sent_at"]) -def send_auto_recharge_failed_email(uid: str): +def send_auto_recharge_failed_email(user: AppUser): from routers.account import account_route - user = AppUser.objects.get(uid=uid) if not user.email: return diff --git a/routers/api.py b/routers/api.py index 9b2aff2b3..d2e6dc4ee 100644 --- a/routers/api.py +++ b/routers/api.py @@ -5,11 +5,13 @@ import typing from types import SimpleNamespace +import line_profiler from fastapi import APIRouter from fastapi import Depends from fastapi import Form from fastapi import HTTPException from fastapi import Response +from fastapi.exceptions import RequestValidationError from furl import furl from pydantic import BaseModel, Field from pydantic import ValidationError @@ -18,25 +20,28 @@ from starlette.datastructures import FormData from starlette.datastructures import UploadFile from starlette.requests import Request +from starlette.status import ( + HTTP_402_PAYMENT_REQUIRED, + HTTP_429_TOO_MANY_REQUESTS, + HTTP_500_INTERNAL_SERVER_ERROR, + HTTP_400_BAD_REQUEST, +) import gooey_ui as st from app_users.models import AppUser from auth.token_authentication import api_auth_header from bots.models import RetentionPolicy -from celeryapp.tasks import auto_recharge from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages -from daras_ai_v2.auto_recharge import user_should_auto_recharge from daras_ai_v2.base import ( BasePage, RecipeRunState, ) from daras_ai_v2.fastapi_tricks import fastapi_request_form +from daras_ai_v2.polygon_fitter import line from functions.models import CalledFunctionResponse from gooeysite.bg_db_conn import get_celery_result_db_safe -from payments.auto_recharge import should_attempt_auto_recharge -from routers.account import AccountTabs app = APIRouter() @@ -93,7 +98,7 @@ class AsyncApiResponseModelV3(BaseResponseModelV3): class AsyncStatusResponseModelV3(BaseResponseModelV3, typing.Generic[O]): - run_time_sec: int = Field(description="Total run time in seconds") + run_time_sec: float = Field(description="Total run time in seconds") status: RecipeRunState = Field(description="Status of the run") detail: str = Field( description="Details about the status of the run as a human readable string" @@ -130,14 +135,17 @@ def script_to_api(page_cls: typing.Type[BasePage]): ) common_errs = { - 402: {"model": GenericErrorResponse}, - 429: {"model": GenericErrorResponse}, + HTTP_402_PAYMENT_REQUIRED: {"model": GenericErrorResponse}, + HTTP_429_TOO_MANY_REQUESTS: {"model": GenericErrorResponse}, } @app.post( os.path.join(endpoint, ""), response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + **common_errs, + }, operation_id=page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v2 sync)", @@ -145,7 +153,10 @@ def script_to_api(page_cls: typing.Type[BasePage]): @app.post( endpoint, response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + **common_errs, + }, include_in_schema=False, ) def run_api_json( @@ -164,13 +175,21 @@ def run_api_json( @app.post( os.path.join(endpoint, "form/"), response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) @app.post( os.path.join(endpoint, "form"), response_model=response_model, - responses={500: {"model": FailedReponseModelV2}, **common_errs}, + responses={ + HTTP_500_INTERNAL_SERVER_ERROR: {"model": FailedReponseModelV2}, + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) def run_api_form( @@ -224,16 +243,22 @@ def run_api_json_async( @app.post( os.path.join(endpoint, "async/form/"), response_model=response_model, - responses=common_errs, + responses={ + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) @app.post( os.path.join(endpoint, "async/form"), response_model=response_model, - responses=common_errs, + responses={ + HTTP_400_BAD_REQUEST: {"model": GenericErrorResponse}, + **common_errs, + }, include_in_schema=False, ) - def run_api_form( + def run_api_form_async( request: Request, response: Response, user: AppUser = Depends(api_auth_header), @@ -279,9 +304,10 @@ def get_run_status( "created_at": sr.created_at.isoformat(), "run_time_sec": sr.run_time.total_seconds(), } - if sr.error_msg: + if sr.error_code: + raise HTTPException(sr.error_code, detail=ret | {"error": sr.error_msg}) + elif sr.error_msg: ret |= {"status": "failed", "detail": sr.error_msg} - return ret else: status = self.get_run_state(sr.to_dict()) ret |= {"detail": sr.run_status or "", "status": status} @@ -311,7 +337,10 @@ def _parse_form_data( try: is_str = request_model.schema()["properties"][key]["type"] == "string" except KeyError: - raise HTTPException(status_code=400, detail=f'Inavlid file field "{key}"') + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=dict(error=f'Inavlid file field "{key}"'), + ) if is_str: page_request_data[key] = urls[0] else: @@ -320,7 +349,7 @@ def _parse_form_data( try: page_request = request_model.parse_obj(page_request_data) except ValidationError as e: - raise HTTPException(status_code=422, detail=e.errors()) + raise RequestValidationError(e.errors(), body=page_request_data) return page_request @@ -374,21 +403,15 @@ def submit_api_call( st.set_session_state(state) st.set_query_params(query_params) - # check the balance - if not self.check_credits() and not should_attempt_auto_recharge(self.request.user): - account_url = furl(settings.APP_BASE_URL) / AccountTabs.billing.url_path - raise HTTPException( - status_code=402, - detail=dict( - error=f"Doh! You need to purchase additional credits to run more Gooey.AI recipes: {account_url}" - ), - ) # create a new run - sr = self.create_new_run( - enable_rate_limits=enable_rate_limits, - is_api_call=True, - retention_policy=retention_policy or RetentionPolicy.keep, - ) + try: + sr = self.create_new_run( + enable_rate_limits=enable_rate_limits, + is_api_call=True, + retention_policy=retention_policy or RetentionPolicy.keep, + ) + except ValidationError as e: + raise RequestValidationError(e.errors(), body=request_body) # submit the task result = self.call_runner_task(sr) return self, result, sr.run_id, sr.uid @@ -427,7 +450,7 @@ def build_api_response( # check for errors if sr.error_msg: raise HTTPException( - status_code=500, + status_code=sr.error_code or HTTP_500_INTERNAL_SERVER_ERROR, detail={ "id": run_id, "url": web_url, diff --git a/server.py b/server.py index 2d8cc5630..c20b82846 100644 --- a/server.py +++ b/server.py @@ -1,10 +1,14 @@ from fastapi.exception_handlers import ( - request_validation_exception_handler, http_exception_handler, + request_validation_exception_handler, ) from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException from starlette.requests import Request +from starlette.status import ( + HTTP_404_NOT_FOUND, + HTTP_405_METHOD_NOT_ALLOWED, +) from daras_ai_v2.pydantic_validation import convert_errors from daras_ai_v2.settings import templates @@ -113,8 +117,8 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE return await request_validation_exception_handler(request, exc) -@app.exception_handler(404) -@app.exception_handler(405) +@app.exception_handler(HTTP_404_NOT_FOUND) +@app.exception_handler(HTTP_405_METHOD_NOT_ALLOWED) async def not_found_exception_handler(request: Request, exc: HTTPException): if not request.headers.get("accept", "").startswith("text/html"): return await http_exception_handler(request, exc) diff --git a/templates/run_complete_email.html b/templates/run_complete_email.html index 613ce3903..c7ae03bb0 100644 --- a/templates/run_complete_email.html +++ b/templates/run_complete_email.html @@ -1,20 +1,14 @@ -

- Your {{ title }} Gooey.AI run completed in {{ run_time_sec }} seconds. -

-

- View output here: {{ app_url }} -

-

- Your prompt: {{ prompt }} -

+

Your {{ recipe_title }} Gooey.AI run completed in {{ run_time_sec }} seconds.

+

View output here: {{ app_url }}

-

- We can’t wait to see what you build with Gooey! -

+{% if prompt %} +

Your prompt: {{ prompt }}

+{% endif %} + +

We can’t wait to see what you build with Gooey!

- Cheers, -
+ Cheers,
The Gooey.AI Team

-{{ "{{{ pm:unsubscribe }}}" }} +{{ "{{{ pm:unsubscribe }}}" }} From 4dbd145577c3c8497983104e48b96708782101e3 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sat, 13 Jul 2024 04:21:22 +0530 Subject: [PATCH 19/26] fix tests --- ...error_code_savedrun_error_type_and_more.py | 28 +++++++++++++++++++ conftest.py | 12 ++++++-- routers/api.py | 2 -- tests/test_apis.py | 6 ++-- tests/test_integrations_api.py | 2 +- 5 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py diff --git a/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py new file mode 100644 index 000000000..44033b955 --- /dev/null +++ b/bots/migrations/0077_savedrun_error_code_savedrun_error_type_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 4.2.7 on 2024-07-12 19:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('bots', '0076_alter_workflowmetadata_default_image_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='savedrun', + name='error_code', + field=models.IntegerField(blank=True, default=None, help_text='The HTTP status code of the error. If this is not set, 500 is assumed.', null=True), + ), + migrations.AddField( + model_name='savedrun', + name='error_type', + field=models.TextField(blank=True, default='', help_text='The exception type'), + ), + migrations.AlterField( + model_name='savedrun', + name='error_msg', + field=models.TextField(blank=True, default='', help_text='The error message. If this is not set, the run is deemed successful.'), + ), + ] diff --git a/conftest.py b/conftest.py index 96fb837f1..a38c6a11a 100644 --- a/conftest.py +++ b/conftest.py @@ -51,16 +51,17 @@ def force_authentication(): @pytest.fixture -def mock_gui_runner(): +def mock_celery_tasks(): with ( - patch("celeryapp.tasks.gui_runner", _mock_gui_runner), + patch("celeryapp.tasks.runner_task", _mock_runner_task), + patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks), patch("daras_ai_v2.bots.realtime_subscribe", _mock_realtime_subscribe), ): yield @app.task -def _mock_gui_runner( +def _mock_runner_task( *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs ): sr = page_cls.run_doc_sr(run_id, uid) @@ -70,6 +71,11 @@ def _mock_gui_runner( _mock_realtime_push(channel, sr.to_dict()) +@app.task +def _mock_post_runner_tasks(*args, **kwargs): + pass + + def _mock_realtime_push(channel, value): redis_qs[channel].put(value) diff --git a/routers/api.py b/routers/api.py index d2e6dc4ee..30d834943 100644 --- a/routers/api.py +++ b/routers/api.py @@ -5,7 +5,6 @@ import typing from types import SimpleNamespace -import line_profiler from fastapi import APIRouter from fastapi import Depends from fastapi import Form @@ -39,7 +38,6 @@ RecipeRunState, ) from daras_ai_v2.fastapi_tricks import fastapi_request_form -from daras_ai_v2.polygon_fitter import line from functions.models import CalledFunctionResponse from gooeysite.bg_db_conn import get_celery_result_db_safe diff --git a/tests/test_apis.py b/tests/test_apis.py index bd3a915fb..fa897eb83 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -15,7 +15,7 @@ @pytest.mark.django_db -def test_apis_sync(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_sync(mock_celery_tasks, force_authentication, threadpool_subtest): for page_cls in all_test_pages: threadpool_subtest(_test_api_sync, page_cls) @@ -32,7 +32,7 @@ def _test_api_sync(page_cls: typing.Type[BasePage]): @pytest.mark.django_db -def test_apis_async(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_async(mock_celery_tasks, force_authentication, threadpool_subtest): for page_cls in all_test_pages: threadpool_subtest(_test_api_async, page_cls) @@ -65,7 +65,7 @@ def _test_api_async(page_cls: typing.Type[BasePage]): @pytest.mark.django_db -def test_apis_examples(mock_gui_runner, force_authentication, threadpool_subtest): +def test_apis_examples(mock_celery_tasks, force_authentication, threadpool_subtest): qs = ( PublishedRun.objects.exclude(is_approved_example=False) .exclude(published_run_id="") diff --git a/tests/test_integrations_api.py b/tests/test_integrations_api.py index c6f11c9d8..398fdd52a 100644 --- a/tests/test_integrations_api.py +++ b/tests/test_integrations_api.py @@ -11,7 +11,7 @@ @pytest.mark.django_db -def test_send_msg_streaming(mock_gui_runner, force_authentication): +def test_send_msg_streaming(mock_celery_tasks, force_authentication): r = client.post( "/v3/integrations/stream/", json={ From 9b012e1b8dc7a57baca026937dde9c06d8b23322 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 14 Jul 2024 13:06:40 +0530 Subject: [PATCH 20/26] record extra data about transactions --- app_users/admin.py | 9 +- ...nsaction_plan_appusertransaction_reason.py | 59 +++++ app_users/models.py | 85 ++++-- daras_ai_v2/base.py | 1 - daras_ai_v2/billing.py | 19 +- daras_ai_v2/settings.py | 3 +- payments/auto_recharge.py | 4 +- .../0005_alter_subscription_plan.py | 18 ++ payments/models.py | 3 +- payments/plans.py | 4 +- payments/tasks.py | 4 +- payments/webhooks.py | 247 ++++++++++-------- poetry.lock | 20 +- pyproject.toml | 2 +- routers/paypal.py | 43 ++- routers/stripe.py | 3 + 16 files changed, 334 insertions(+), 190 deletions(-) create mode 100644 app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py create mode 100644 payments/migrations/0005_alter_subscription_plan.py diff --git a/app_users/admin.py b/app_users/admin.py index da50b57e9..121ca45a3 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -192,19 +192,24 @@ class AppUserTransactionAdmin(admin.ModelAdmin): "invoice_id", "user", "amount", + "dollar_amount", "end_balance", "payment_provider", - "dollar_amount", + "reason", + "plan", "created_at", ] readonly_fields = ["created_at"] list_filter = [ - "created_at", + "reason", ("payment_provider", admin.EmptyFieldListFilter), "payment_provider", + "plan", + "created_at", ] inlines = [SavedRunInline] ordering = ["-created_at"] + search_fields = ["invoice_id"] @admin.display(description="Charged Amount") def dollar_amount(self, obj: models.AppUserTransaction): diff --git a/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py b/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py new file mode 100644 index 000000000..2891795fd --- /dev/null +++ b/app_users/migrations/0018_appusertransaction_plan_appusertransaction_reason.py @@ -0,0 +1,59 @@ +# Generated by Django 4.2.7 on 2024-07-14 20:51 + +from django.db import migrations, models + + +def forwards_func(apps, schema_editor): + from payments.plans import PricingPlan + from app_users.models import TransactionReason + + # We get the model from the versioned app registry; + # if we directly import it, it'll be the wrong version + AppUserTransaction = apps.get_model("app_users", "AppUserTransaction") + db_alias = schema_editor.connection.alias + objects = AppUserTransaction.objects.using(db_alias) + + for transaction in objects.all(): + if transaction.amount <= 0: + transaction.reason = TransactionReason.DEDUCT + else: + # For old transactions, we didn't have a subscription field. + # It just so happened that all monthly subscriptions we offered had + # different amounts from the one-time purchases. + # This uses that heuristic to determine whether a transaction + # was a subscription payment or a one-time purchase. + transaction.reason = TransactionReason.ADDON + for plan in PricingPlan: + if ( + transaction.amount == plan.credits + and transaction.charged_amount == plan.monthly_charge * 100 + ): + transaction.plan = plan.db_value + transaction.reason = TransactionReason.SUBSCRIBE + transaction.save(update_fields=["reason", "plan"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0017_alter_appuser_subscription'), + ] + + operations = [ + migrations.AddField( + model_name='appusertransaction', + name='plan', + field=models.IntegerField(blank=True, choices=[(1, 'Basic Plan'), (2, 'Premium Plan'), (3, 'Starter'), (4, 'Creator'), (5, 'Business'), (6, 'Enterprise / Agency')], default=None, help_text="User's plan at the time of this transaction.", null=True), + ), + migrations.AddField( + model_name='appusertransaction', + name='reason', + field=models.IntegerField(choices=[(1, 'Deduct'), (2, 'Addon'), (3, 'Subscribe'), (4, 'Sub-Create'), (5, 'Sub-Cycle'), (6, 'Sub-Update'), (7, 'Auto-Recharge')], default=0, help_text='The reason for this transaction.

Deduct: Credits deducted due to a run.
Addon: User purchased an add-on.
Subscribe: Applies to subscriptions where no distinction was made between create, update and cycle.
Sub-Create: A subscription was created.
Sub-Cycle: A subscription advanced into a new period.
Sub-Update: A subscription was updated.
Auto-Recharge: Credits auto-recharged due to low balance.'), + ), + migrations.RunPython(forwards_func, migrations.RunPython.noop), + migrations.AlterField( + model_name='appusertransaction', + name='reason', + field=models.IntegerField(choices=[(1, 'Deduct'), (2, 'Addon'), (3, 'Subscribe'), (4, 'Sub-Create'), (5, 'Sub-Cycle'), (6, 'Sub-Update'), (7, 'Auto-Recharge')], help_text='The reason for this transaction.

Deduct: Credits deducted due to a run.
Addon: User purchased an add-on.
Subscribe: Applies to subscriptions where no distinction was made between create, update and cycle.
Sub-Create: A subscription was created.
Sub-Cycle: A subscription advanced into a new period.
Sub-Update: A subscription was updated.
Auto-Recharge: Credits auto-recharged due to low balance.'), + ), + ] diff --git a/app_users/models.py b/app_users/models.py index f4b29f490..380bd92be 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -172,6 +172,7 @@ def add_balance( user: AppUser = AppUser.objects.select_for_update().get(pk=self.pk) user.balance += amount user.save(update_fields=["balance"]) + kwargs.setdefault("plan", user.subscription and user.subscription.plan) return AppUserTransaction.objects.create( user=self, invoice_id=invoice_id, @@ -273,6 +274,18 @@ def get_dollars_spent_this_month(self) -> float: return (cents_spent or 0) / 100 +class TransactionReason(models.IntegerChoices): + DEDUCT = 1, "Deduct" + ADDON = 2, "Addon" + + SUBSCRIBE = 3, "Subscribe" + SUBSCRIPTION_CREATE = 4, "Sub-Create" + SUBSCRIPTION_CYCLE = 5, "Sub-Cycle" + SUBSCRIPTION_UPDATE = 6, "Sub-Update" + + AUTO_RECHARGE = 7, "Auto-Recharge" + + class AppUserTransaction(models.Model): user = models.ForeignKey( "AppUser", on_delete=models.CASCADE, related_name="transactions" @@ -307,6 +320,25 @@ class AppUserTransaction(models.Model): default=0, ) + reason = models.IntegerField( + choices=TransactionReason.choices, + help_text="The reason for this transaction.

" + f"{TransactionReason.DEDUCT.label}: Credits deducted due to a run.
" + f"{TransactionReason.ADDON.label}: User purchased an add-on.
" + f"{TransactionReason.SUBSCRIBE.label}: Applies to subscriptions where no distinction was made between create, update and cycle.
" + f"{TransactionReason.SUBSCRIPTION_CREATE.label}: A subscription was created.
" + f"{TransactionReason.SUBSCRIPTION_CYCLE.label}: A subscription advanced into a new period.
" + f"{TransactionReason.SUBSCRIPTION_UPDATE.label}: A subscription was updated.
" + f"{TransactionReason.AUTO_RECHARGE.label}: Credits auto-recharged due to low balance.", + ) + plan = models.IntegerField( + choices=PricingPlan.db_choices(), + help_text="User's plan at the time of this transaction.", + null=True, + blank=True, + default=None, + ) + created_at = models.DateTimeField(editable=False, blank=True, default=timezone.now) class Meta: @@ -320,32 +352,29 @@ class Meta: def __str__(self): return f"{self.invoice_id} ({self.amount})" - def get_subscription_plan(self) -> PricingPlan | None: - """ - It just so happened that all monthly subscriptions we offered had - different amounts from the one-time purchases. - This uses that heuristic to determine whether a transaction - was a subscription payment or a one-time purchase. - - TODO: Implement this more robustly - """ - if self.amount <= 0: - # credits deducted - return None - - for plan in PricingPlan: - if ( - self.amount == plan.credits - and self.charged_amount == plan.monthly_charge * 100 + def save(self, *args, **kwargs): + if self.reason is None: + if self.amount <= 0: + self.reason = TransactionReason.DEDUCT + else: + self.reason = TransactionReason.ADDON + super().save(*args, **kwargs) + + def reason_note(self) -> str: + match self.reason: + case ( + TransactionReason.SUBSCRIPTION_CREATE + | TransactionReason.SUBSCRIPTION_CYCLE + | TransactionReason.SUBSCRIPTION_UPDATE + | TransactionReason.SUBSCRIBE ): - return plan - - return None - - def note(self) -> str: - if self.amount <= 0: - return "" - elif plan := self.get_subscription_plan(): - return f"Subscription payment: {plan.title} (+{self.amount:,} credits)" - else: - return f"Addon purchase (+{self.amount:,} credits)" + ret = "Subscription payment" + if self.plan: + ret += f": {PricingPlan.from_db_value(self.plan).title}" + return ret + case TransactionReason.AUTO_RECHARGE: + return "Auto recharge" + case TransactionReason.ADDON: + return "Addon purchase" + case TransactionReason.DEDUCT: + return "Run deduction" diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 6a3d602ea..a447b4252 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -145,7 +145,6 @@ class BasePage: class RequestModel(BaseModel): functions: list[RecipeFunction] | None = Field( - None, title="🧩 Functions", ) variables: dict[str, typing.Any] = Field( diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index df2d90e71..43085fdde 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -579,16 +579,17 @@ def render_billing_history(user: AppUser, limit: int = 50): st.write("## Billing History", className="d-block") st.table( pd.DataFrame.from_records( - columns=[""] * 3, - data=[ - [ - txn.created_at.strftime("%m/%d/%Y"), - txn.note(), - f"${txn.charged_amount / 100:,.2f}", - ] + [ + { + "Date": txn.created_at.strftime("%m/%d/%Y"), + "Description": txn.reason_note(), + "Amount": f"-${txn.charged_amount / 100:,.2f}", + "Credits": f"+{txn.amount:,}", + "Balance": f"{txn.end_balance:,}", + } for txn in txns[:limit] - ], - ) + ] + ), ) if txns.count() > limit: st.caption(f"Showing only the most recent {limit} transactions.") diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 2b363732b..9b48a1c2e 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -295,9 +295,10 @@ LOW_BALANCE_EMAIL_ENABLED = config("LOW_BALANCE_EMAIL_ENABLED", True, cast=bool) STRIPE_SECRET_KEY = config("STRIPE_SECRET_KEY", None) +STRIPE_ENDPOINT_SECRET = config("STRIPE_ENDPOINT_SECRET", None) stripe.api_key = STRIPE_SECRET_KEY + STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: str = "subscription_key" -STRIPE_ENDPOINT_SECRET = config("STRIPE_ENDPOINT_SECRET", None) STRIPE_ADDON_PRODUCT_NAME = config( "STRIPE_ADDON_PRODUCT_NAME", "Gooey.AI Add-on Credits" ) diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py index 1e3572929..1aa414cfb 100644 --- a/payments/auto_recharge.py +++ b/payments/auto_recharge.py @@ -116,6 +116,4 @@ def _auto_recharge_user(user: AppUser): ) from e else: assert invoice_data.paid - StripeWebhookHandler.handle_invoice_paid( - uid=user.uid, invoice_data=invoice_data - ) + StripeWebhookHandler.handle_invoice_paid(uid=user.uid, invoice=invoice_data) diff --git a/payments/migrations/0005_alter_subscription_plan.py b/payments/migrations/0005_alter_subscription_plan.py new file mode 100644 index 000000000..bfb74dfe3 --- /dev/null +++ b/payments/migrations/0005_alter_subscription_plan.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-07-14 08:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0004_alter_subscription_auto_recharge_balance_threshold'), + ] + + operations = [ + migrations.AlterField( + model_name='subscription', + name='plan', + field=models.IntegerField(choices=[(1, 'Basic Plan'), (2, 'Premium Plan'), (3, 'Starter'), (4, 'Creator'), (5, 'Business'), (6, 'Enterprise / Agency')]), + ), + ] diff --git a/payments/models.py b/payments/models.py index fa86a22cd..3ce730793 100644 --- a/payments/models.py +++ b/payments/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time import typing import stripe @@ -221,7 +222,7 @@ def stripe_get_or_create_auto_invoice( for inv in invoices: if ( inv.status == "paid" - and timezone.now().timestamp() - inv.created + and abs(time.time() - inv.created) < settings.AUTO_RECHARGE_COOLDOWN_SECONDS ): return inv diff --git a/payments/plans.py b/payments/plans.py index b2dd6fcdb..badfd2c1c 100644 --- a/payments/plans.py +++ b/payments/plans.py @@ -211,7 +211,7 @@ def __lt__(self, other: PricingPlan) -> bool: @classmethod def db_choices(cls): - return [(plan.db_value, plan.name) for plan in cls] + return [(plan.db_value, plan.title) for plan in cls] @classmethod def from_sub(cls, sub: "Subscription") -> PricingPlan: @@ -240,7 +240,7 @@ def get_by_paypal_plan_id(cls, plan_id: str) -> PricingPlan | None: return plan @classmethod - def get_by_key(cls, key: str): + def get_by_key(cls, key: str) -> PricingPlan | None: for plan in cls: if plan.key == key: return plan diff --git a/payments/tasks.py b/payments/tasks.py index 252ac505b..6c8b046d5 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -10,10 +10,10 @@ @app.task -def send_monthly_spending_notification_email(uid: str): +def send_monthly_spending_notification_email(user_id: int): from routers.account import account_route - user = AppUser.objects.get(uid=uid) + user = AppUser.objects.get(id=user_id) if not user.email: logger.error(f"User doesn't have an email: {user=}") return diff --git a/payments/webhooks.py b/payments/webhooks.py index 8ad45f46a..8efc91492 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -1,117 +1,122 @@ -from urllib.parse import quote_plus - import stripe from django.db import transaction from loguru import logger from requests.models import HTTPError +from app_users.models import AppUser, PaymentProvider, TransactionReason +from daras_ai_v2 import paypal from .models import Subscription from .plans import PricingPlan from .tasks import send_monthly_spending_notification_email -from app_users.models import AppUser, PaymentProvider -from daras_ai_v2 import paypal -class WebhookHandler: - PROVIDER: PaymentProvider +def add_balance_for_payment( + *, + uid: str, + amount: int, + invoice_id: str, + payment_provider: PaymentProvider, + charged_amount: int, + **kwargs, +): + user = AppUser.objects.get_or_create_from_uid(uid)[0] + user.add_balance( + amount=amount, + invoice_id=invoice_id, + charged_amount=charged_amount, + payment_provider=payment_provider, + **kwargs, + ) - @classmethod - def _after_payment_completed( - cls, *, uid: str, invoice_id: str, credits: int, charged_amount: int + if not user.is_paying: + user.is_paying = True + user.save(update_fields=["is_paying"]) + + if ( + user.subscription + and user.subscription.should_send_monthly_spending_notification() ): - user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance( - payment_provider=cls.PROVIDER, - invoice_id=invoice_id, - amount=credits, - charged_amount=charged_amount, + send_monthly_spending_notification_email.delay(user.id) + + +@transaction.atomic +def _after_subscription_updated( + *, provider: PaymentProvider, uid: str, sub_id: str, plan: PricingPlan +): + if not is_sub_active(provider=provider, sub_id=sub_id): + # subscription is not in an active state, just ignore + logger.info( + "Subscription is not active. Ignoring event", + provider=provider, + sub_id=sub_id, ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) - if ( - user.subscription - and user.subscription.should_send_monthly_spending_notification() - ): - send_monthly_spending_notification_email.delay(uid) - - @classmethod - @transaction.atomic - def _after_subscription_updated(cls, *, uid: str, sub_id: str, plan: PricingPlan): - if not is_sub_active(provider=cls.PROVIDER, sub_id=sub_id): - # subscription is not in an active state, just ignore - logger.info( - "Subscription is not active. Ignoring event", - provider=cls.PROVIDER, - sub_id=sub_id, - ) + return + + user = AppUser.objects.get_or_create_from_uid(uid)[0] + # select_for_update: we want to lock the row until we are done reading & + # updating the subscription + # + # anther transaction shouldn't update the subscription in the meantime + user: AppUser = AppUser.objects.select_for_update().get(pk=user.pk) + if not user.subscription: + # new subscription + logger.info("Creating new subscription for user", uid=uid) + user.subscription = Subscription.objects.get_or_create( + payment_provider=provider, + external_id=sub_id, + defaults={"plan": plan.db_value}, + )[0] + user.subscription.plan = plan.db_value + + elif is_same_sub(user.subscription, provider=provider, sub_id=sub_id): + if user.subscription.plan == plan.db_value: + # same subscription exists with the same plan in DB + logger.info("Nothing to do") return - - user = AppUser.objects.get_or_create_from_uid(uid)[0] - # select_for_update: we want to lock the row until we are done reading & - # updating the subscription - # - # anther transaction shouldn't update the subscription in the meantime - user: AppUser = AppUser.objects.select_for_update().get(pk=user.pk) - if not user.subscription: - # new subscription - logger.info("Creating new subscription for user", uid=uid) - user.subscription = Subscription.objects.get_or_create( - payment_provider=cls.PROVIDER, - external_id=sub_id, - defaults={"plan": plan.db_value}, - )[0] + else: + # provider & sub_id is same, but plan is different. so we update only the plan + logger.info("Updating plan for user", uid=uid) user.subscription.plan = plan.db_value - elif is_same_sub(user.subscription, provider=cls.PROVIDER, sub_id=sub_id): - if user.subscription.plan == plan.db_value: - # same subscription exists with the same plan in DB - logger.info("Nothing to do") - return - else: - # provider & sub_id is same, but plan is different. so we update only the plan - logger.info("Updating plan for user", uid=uid) - user.subscription.plan = plan.db_value + else: + logger.critical( + "Invalid state: last subscription was not cleared for user", uid=uid + ) - else: + # we have a different existing subscription in DB + # this is invalid state! we should cancel the subscription if it is active + if is_sub_active( + provider=user.subscription.payment_provider, + sub_id=user.subscription.external_id, + ): logger.critical( - "Invalid state: last subscription was not cleared for user", uid=uid - ) - - # we have a different existing subscription in DB - # this is invalid state! we should cancel the subscription if it is active - if is_sub_active( - provider=user.subscription.payment_provider, + "Found existing active subscription for user. Cancelling that...", + uid=uid, + provider=user.subscription.get_payment_provider_display(), sub_id=user.subscription.external_id, - ): - logger.critical( - "Found existing active subscription for user. Cancelling that...", - uid=uid, - provider=user.subscription.get_payment_provider_display(), - sub_id=user.subscription.external_id, - ) - user.subscription.cancel() - - logger.info("Creating new subscription for user", uid=uid) - user.subscription = Subscription( - payment_provider=cls.PROVIDER, plan=plan, external_id=sub_id ) + user.subscription.cancel() - user.subscription.full_clean() - user.subscription.save() - user.save(update_fields=["subscription"]) + logger.info("Creating new subscription for user", uid=uid) + user.subscription = Subscription( + payment_provider=provider, plan=plan, external_id=sub_id + ) - @classmethod - def _after_subscription_cancelled(cls, uid: str, sub_id: str): - user = AppUser.objects.get_or_create_from_uid(uid=uid)[0] - if user.subscription and is_same_sub( - user.subscription, provider=cls.PROVIDER, sub_id=sub_id - ): - user.subscription = None - user.save(update_fields=["subscription"]) + user.subscription.full_clean() + user.subscription.save() + user.save(update_fields=["subscription"]) + + +def _after_subscription_cancelled(*, provider: PaymentProvider, uid: str, sub_id: str): + user = AppUser.objects.get_or_create_from_uid(uid=uid)[0] + if user.subscription and is_same_sub( + user.subscription, provider=provider, sub_id=sub_id + ): + user.subscription = None + user.save(update_fields=["subscription"]) -class PaypalWebhookHandler(WebhookHandler): +class PaypalWebhookHandler: PROVIDER = PaymentProvider.PAYPAL @classmethod @@ -135,11 +140,14 @@ def handle_sale_completed(cls, sale: paypal.Sale): ) uid = pp_sub.custom_id - cls._after_payment_completed( + add_balance_for_payment( uid=uid, + amount=plan.credits, invoice_id=sale.id, - credits=plan.credits, + payment_provider=cls.PROVIDER, charged_amount=charged_dollars * 100, + reason=TransactionReason.SUBSCRIBE, + plan=plan.db_value, ) @classmethod @@ -152,32 +160,49 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) assert plan, f"Plan {pp_sub.plan_id} not found" - cls._after_subscription_updated( - uid=pp_sub.custom_id, sub_id=pp_sub.id, plan=plan + _after_subscription_updated( + provider=cls.PROVIDER, uid=pp_sub.custom_id, sub_id=pp_sub.id, plan=plan ) @classmethod def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" - cls._after_subscription_cancelled(uid=pp_sub.custom_id, sub_id=pp_sub.id) + _after_subscription_cancelled( + provider=cls.PROVIDER, uid=pp_sub.custom_id, sub_id=pp_sub.id + ) -class StripeWebhookHandler(WebhookHandler): +class StripeWebhookHandler: PROVIDER = PaymentProvider.STRIPE @classmethod - def handle_invoice_paid(cls, uid: str, invoice_data): - invoice_id = invoice_data.id - line_items = stripe.Invoice._static_request( - "get", - "/v1/invoices/{invoice}/lines".format(invoice=quote_plus(invoice_id)), - ) - - cls._after_payment_completed( + def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): + kwargs = {} + if invoice.subscription: + kwargs["plan"] = PricingPlan.get_by_key( + invoice.subscription_details.metadata.get("subscription_key") + ).db_value + match invoice.billing_reason: + case "subscription_create": + reason = TransactionReason.SUBSCRIPTION_CREATE + case "subscription_cycle": + reason = TransactionReason.SUBSCRIPTION_CYCLE + case "subscription_update": + reason = TransactionReason.SUBSCRIPTION_UPDATE + case _: + reason = TransactionReason.SUBSCRIBE + elif invoice.metadata and invoice.metadata.get("auto_recharge"): + reason = TransactionReason.AUTO_RECHARGE + else: + reason = TransactionReason.ADDON + add_balance_for_payment( uid=uid, - invoice_id=invoice_id, - credits=line_items.data[0].quantity, - charged_amount=line_items.data[0].amount, + amount=invoice.lines.data[0].quantity, + invoice_id=invoice.id, + payment_provider=cls.PROVIDER, + charged_amount=invoice.lines.data[0].amount, + reason=reason, + **kwargs, ) @classmethod @@ -214,13 +239,17 @@ def handle_subscription_updated(cls, uid: str, stripe_sub): f"PricingPlan not found for product {stripe_sub.plan.product}" ) - cls._after_subscription_updated(uid=uid, sub_id=stripe_sub.id, plan=plan) + _after_subscription_updated( + provider=cls.PROVIDER, uid=uid, sub_id=stripe_sub.id, plan=plan + ) @classmethod def handle_subscription_cancelled(cls, uid: str, stripe_sub): logger.info(f"Stripe subscription cancelled: {stripe_sub.id}") - cls._after_subscription_cancelled(uid=uid, sub_id=stripe_sub.id) + _after_subscription_cancelled( + provider=cls.PROVIDER, uid=uid, sub_id=stripe_sub.id + ) def is_same_sub( diff --git a/poetry.lock b/poetry.lock index c63dd4849..9d7b61cb3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5122,17 +5122,18 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo [[package]] name = "sentry-sdk" -version = "1.34.0" +version = "1.45.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.34.0.tar.gz", hash = "sha256:e5d0d2b25931d88fa10986da59d941ac6037f742ab6ff2fce4143a27981d60c3"}, - {file = "sentry_sdk-1.34.0-py2.py3-none-any.whl", hash = "sha256:76dd087f38062ac6c1e30ed6feb533ee0037ff9e709974802db7b5dbf2e5db21"}, + {file = "sentry-sdk-1.45.0.tar.gz", hash = "sha256:509aa9678c0512344ca886281766c2e538682f8acfa50fd8d405f8c417ad0625"}, + {file = "sentry_sdk-1.45.0-py2.py3-none-any.whl", hash = "sha256:1ce29e30240cc289a027011103a8c83885b15ef2f316a60bcc7c5300afa144f1"}, ] [package.dependencies] certifi = "*" +loguru = {version = ">=0.5", optional = true, markers = "extra == \"loguru\""} urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""} [package.extras] @@ -5142,6 +5143,7 @@ asyncpg = ["asyncpg (>=0.23)"] beam = ["apache-beam (>=2.12)"] bottle = ["bottle (>=0.12.13)"] celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] chalice = ["chalice (>=1.16.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] django = ["django (>=1.8)"] @@ -5152,6 +5154,7 @@ grpcio = ["grpcio (>=1.21.1)"] httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] @@ -5445,17 +5448,18 @@ snowflake = ["snowflake-connector-python (>=2.8.0)", "snowflake-snowpark-python [[package]] name = "stripe" -version = "5.5.0" +version = "10.3.0" description = "Python bindings for the Stripe API" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.6" files = [ - {file = "stripe-5.5.0-py2.py3-none-any.whl", hash = "sha256:b4947da66dbb3de8969004ba6398f9a019c6b1b3ffe6aa88d5b07ac560a52b28"}, - {file = "stripe-5.5.0.tar.gz", hash = "sha256:04a9732b37a46228ecf0e496163a3edd93596b0e6200029fbc48911638627e19"}, + {file = "stripe-10.3.0-py2.py3-none-any.whl", hash = "sha256:95aa10d34e325cb6a19784412d6196621442c278b0c9cd3fe7be2a7ef180c2f8"}, + {file = "stripe-10.3.0.tar.gz", hash = "sha256:56515faf0cbee82f27d9b066403988a107301fc80767500be9789a25d65f2bae"}, ] [package.dependencies] requests = {version = ">=2.20", markers = "python_version >= \"3.0\""} +typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""} [[package]] name = "tabulate" @@ -6442,4 +6446,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "cc79f3b414323945ade371a12c4071eb50b9988715f0d094e4e9ef34008c3fe2" +content-hash = "2777d2a014b924fe8a7c2dfe63ebccbb55b5572abea44515898ca8e7fd7a17b0" diff --git a/pyproject.toml b/pyproject.toml index c0dc40590..a470e162b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ google-cloud-texttospeech = "^2.12.1" Wand = "^0.6.10" readability-lxml = "^0.8.1" transformers = "^4.24.0" -stripe = "^5.0.0" +stripe = "^10.3.0" python-multipart = "^0.0.5" html-sanitizer = "^1.9.3" plotly = "^5.11.0" diff --git a/routers/paypal.py b/routers/paypal.py index 178fb8979..28f6ac48f 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -12,12 +12,12 @@ from loguru import logger from pydantic import BaseModel -from app_users.models import AppUser, PaymentProvider +from app_users.models import PaymentProvider, TransactionReason from daras_ai_v2 import paypal, settings from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fastapi_tricks import fastapi_request_json, get_route_url from payments.models import PricingPlan -from payments.webhooks import PaypalWebhookHandler +from payments.webhooks import PaypalWebhookHandler, add_balance_for_payment from routers.account import payment_processing_route, account_route router = APIRouter() @@ -166,6 +166,24 @@ def capture_order(order_id: str): return JSONResponse(response.json(), response.status_code) +def _handle_invoice_paid(order_id: str): + response = requests.get( + str(furl(settings.PAYPAL_BASE) / f"v2/checkout/orders/{order_id}"), + headers={"Authorization": paypal.generate_auth_header()}, + ) + raise_for_status(response) + order = response.json() + purchase_unit = order["purchase_units"][0] + uid = purchase_unit["payments"]["captures"][0]["custom_id"] + add_balance_for_payment( + uid=uid, + amount=int(purchase_unit["items"][0]["quantity"]), + invoice_id=order_id, + payment_provider=PaymentProvider.PAYPAL, + charged_amount=int(float(purchase_unit["amount"]["value"]) * 100), + ) + + @router.post("/__/paypal/webhook") def webhook(request: Request, payload: dict = fastapi_request_json): if not paypal.verify_webhook_event(payload, headers=request.headers): @@ -194,24 +212,3 @@ def webhook(request: Request, payload: dict = fastapi_request_json): logger.error(f"Unhandled PayPal webhook event: {event.event_type}") return JSONResponse({}, status_code=200) - - -def _handle_invoice_paid(order_id: str): - response = requests.get( - str(furl(settings.PAYPAL_BASE) / f"v2/checkout/orders/{order_id}"), - headers={"Authorization": paypal.generate_auth_header()}, - ) - raise_for_status(response) - order = response.json() - purchase_unit = order["purchase_units"][0] - uid = purchase_unit["payments"]["captures"][0]["custom_id"] - user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance( - payment_provider=PaymentProvider.PAYPAL, - invoice_id=order_id, - amount=int(purchase_unit["items"][0]["quantity"]), - charged_amount=int(float(purchase_unit["amount"]["value"]) * 100), - ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) diff --git a/routers/stripe.py b/routers/stripe.py index d6ab35cb3..3d481d47b 100644 --- a/routers/stripe.py +++ b/routers/stripe.py @@ -1,6 +1,7 @@ import stripe from fastapi import APIRouter, Request from fastapi.responses import JSONResponse +from loguru import logger from daras_ai_v2 import settings from daras_ai_v2.fastapi_tricks import fastapi_request_body @@ -34,6 +35,8 @@ def webhook_received(request: Request, payload: bytes = fastapi_request_body): status_code=400, ) + logger.info(f"Received event: {event['type']}") + # Get the type of webhook event sent - used to check the status of PaymentIntents. match event["type"]: case "invoice.paid": From 9ede7d8e3e1f0bd47aead1f5e8504e30cd252489 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 14 Jul 2024 18:17:15 +0530 Subject: [PATCH 21/26] fix: record the correct txn id for paypal --- routers/paypal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/routers/paypal.py b/routers/paypal.py index 28f6ac48f..9bb31cf37 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -174,11 +174,11 @@ def _handle_invoice_paid(order_id: str): raise_for_status(response) order = response.json() purchase_unit = order["purchase_units"][0] - uid = purchase_unit["payments"]["captures"][0]["custom_id"] + payment_capture = purchase_unit["payments"]["captures"][0] add_balance_for_payment( - uid=uid, + uid=payment_capture["custom_id"], amount=int(purchase_unit["items"][0]["quantity"]), - invoice_id=order_id, + invoice_id=payment_capture["id"], payment_provider=PaymentProvider.PAYPAL, charged_amount=int(float(purchase_unit["amount"]["value"]) * 100), ) From 1edae2af9b7a1ea24ff6078a485a5849251ec70d Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 14 Jul 2024 18:17:23 +0530 Subject: [PATCH 22/26] re-enable one time payment buttons --- daras_ai_v2/billing.py | 100 +++++++++++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 34 deletions(-) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 43085fdde..94e68cbf8 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -15,6 +15,7 @@ from gooey_ui.components.pills import pill from payments.models import PaymentMethodSummary from payments.plans import PricingPlan +from scripts.migrate_existing_subscriptions import available_subscriptions rounded_border = "w-100 border shadow-sm rounded py-4 px-3" @@ -32,14 +33,15 @@ def billing_page(user: AppUser): render_credit_balance(user) with st.div(className="my-5"): - render_all_plans(user) + selected_payment_provider = render_all_plans(user) + + with st.div(className="my-5"): + render_addon_section(user, selected_payment_provider) if user.subscription and user.subscription.payment_provider: if user.subscription.payment_provider == PaymentProvider.STRIPE: with st.div(className="my-5"): render_auto_recharge_section(user) - with st.div(className="my-5"): - render_addon_section(user) with st.div(className="my-5"): render_payment_information(user) @@ -120,7 +122,7 @@ def render_credit_balance(user: AppUser): ) -def render_all_plans(user: AppUser): +def render_all_plans(user: AppUser) -> PaymentProvider: current_plan = ( PricingPlan.from_sub(user.subscription) if user.subscription @@ -159,6 +161,8 @@ def _render_plan(plan: PricingPlan): with st.div(className="my-2 d-flex justify-content-center"): st.caption(f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**") + return selected_payment_provider + def _render_plan_details(plan: PricingPlan): with st.div(className="flex-grow-1"): @@ -370,44 +374,72 @@ def payment_provider_radio(**props) -> str | None: ) -def render_addon_section(user: AppUser): - assert user.subscription - - st.write("# Purchase More Credits") +def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider): + if user.subscription: + st.write("# Purchase More Credits") + else: + st.write("# Purchase Credits") st.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") - provider = PaymentProvider(user.subscription.payment_provider) + if user.subscription: + provider = PaymentProvider(user.subscription.payment_provider) + else: + provider = selected_payment_provider match provider: case PaymentProvider.STRIPE: - for amount in settings.ADDON_AMOUNT_CHOICES: - render_stripe_addon_button(amount, user=user) + render_stripe_addon_buttons(user) case PaymentProvider.PAYPAL: - for amount in settings.ADDON_AMOUNT_CHOICES: - render_paypal_addon_button(amount) - st.div( - id="paypal-addon-buttons", - className="mt-2", - style={"width": "fit-content"}, - ) - st.div(id="paypal-result-message") + render_paypal_addon_buttons() -def render_paypal_addon_button(amount: int): - st.html( - f""" - - """ +def render_paypal_addon_buttons(): + selected_amt = st.horizontal_radio( + "", + settings.ADDON_AMOUNT_CHOICES, + format_func=lambda amt: f"${amt:,}", + checked_by_default=False, + ) + if selected_amt: + st.js( + f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})" + ) + st.div( + id="paypal-addon-buttons", + className="mt-2", + style={"width": "fit-content"}, ) + st.div(id="paypal-result-message") + +def render_stripe_addon_buttons(user: AppUser): + for dollat_amt in settings.ADDON_AMOUNT_CHOICES: + render_stripe_addon_button(dollat_amt, user) -def render_stripe_addon_button(amount: int, user: AppUser): - confirm_purchase_modal = Modal("Confirm Purchase", key=f"confirm-purchase-{amount}") - if st.button(f"${amount:,}", type="primary"): - confirm_purchase_modal.open() + +def render_stripe_addon_button(dollat_amt: int, user: AppUser): + confirm_purchase_modal = Modal( + "Confirm Purchase", key=f"confirm-purchase-{dollat_amt}" + ) + if st.button(f"${dollat_amt:,}", type="primary"): + if user.subscription: + confirm_purchase_modal.open() + else: + from routers.account import account_route + from routers.account import payment_processing_route + + line_item = available_subscriptions["addon"]["stripe"].copy() + line_item["quantity"] = dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR + + checkout_session = stripe.checkout.Session.create( + line_items=[line_item], + mode="payment", + success_url=get_route_url(payment_processing_route), + cancel_url=get_route_url(account_route), + customer=user.get_or_create_stripe_customer(), + invoice_creation={"enabled": True}, + allow_promotion_codes=True, + ) + raise RedirectException(checkout_session.url, status_code=303) if not confirm_purchase_modal.is_open(): return @@ -415,7 +447,7 @@ def render_stripe_addon_button(amount: int, user: AppUser): st.write( f""" Please confirm your purchase: - **{amount * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${amount}**. + **{dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollat_amt}**. """, className="py-4 d-block text-center", ) @@ -423,7 +455,7 @@ def render_stripe_addon_button(amount: int, user: AppUser): if st.session_state.get("--confirm-purchase"): success = st.run_in_thread( user.subscription.stripe_attempt_addon_purchase, - args=[amount], + args=[dollat_amt], placeholder="Processing payment...", ) if success is None: From 2de5fc52abceb79d449c1da5edeedc8a65ecd514 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 14 Jul 2024 18:45:40 +0530 Subject: [PATCH 23/26] fix: dont charge the user when downgrading the subscription --- app_users/admin.py | 10 +++++++++- app_users/models.py | 13 +++++++++++++ daras_ai_v2/billing.py | 13 ++++++++----- routers/paypal.py | 4 +--- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/app_users/admin.py b/app_users/admin.py index 121ca45a3..caa61d223 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -199,7 +199,7 @@ class AppUserTransactionAdmin(admin.ModelAdmin): "plan", "created_at", ] - readonly_fields = ["created_at"] + readonly_fields = ["view_payment_provider_url", "created_at"] list_filter = [ "reason", ("payment_provider", admin.EmptyFieldListFilter), @@ -217,6 +217,14 @@ def dollar_amount(self, obj: models.AppUserTransaction): return return f"${obj.charged_amount / 100}" + @admin.display(description="Payment Provider URL") + def view_payment_provider_url(self, txn: models.AppUserTransaction): + url = txn.payment_provider_url() + if url: + return open_in_new_tab(url, label=url) + else: + raise txn.DoesNotExist + @admin.register(LogEntry) class LogEntryAdmin(admin.ModelAdmin): diff --git a/app_users/models.py b/app_users/models.py index 380bd92be..1e1016520 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -4,6 +4,7 @@ from django.db.models import Sum from django.utils import timezone from firebase_admin import auth +from furl import furl from phonenumber_field.modelfields import PhoneNumberField from bots.custom_fields import CustomURLField, StrippedTextField @@ -378,3 +379,15 @@ def reason_note(self) -> str: return "Addon purchase" case TransactionReason.DEDUCT: return "Run deduction" + + def payment_provider_url(self) -> str | None: + match self.payment_provider: + case PaymentProvider.STRIPE: + return str( + furl("https://dashboard.stripe.com/invoices/") / self.invoice_id + ) + case PaymentProvider.PAYPAL: + return str( + furl("https://www.paypal.com/unifiedtransactions/details/payment/") + / self.invoice_id + ) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 94e68cbf8..a0e91453e 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -297,7 +297,12 @@ def _render_update_subscription_button( downgrade_modal.close() case _: if st.button(label, className=className, key=key): - change_subscription(user, plan) + change_subscription( + user, + plan, + # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time + billing_cycle_anchor="now", + ) def fmt_price(plan: PricingPlan) -> str: @@ -307,7 +312,7 @@ def fmt_price(plan: PricingPlan) -> str: return "Free" -def change_subscription(user: AppUser, new_plan: PricingPlan): +def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): from routers.account import account_route from routers.account import payment_processing_route @@ -338,9 +343,7 @@ def change_subscription(user: AppUser, new_plan: PricingPlan): metadata={ settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: new_plan.key, }, - # charge the full new amount today, without prorations - # see: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time - billing_cycle_anchor="now", + **kwargs, proration_behavior="none", ) raise RedirectException( diff --git a/routers/paypal.py b/routers/paypal.py index 9bb31cf37..c59668fc1 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -208,7 +208,5 @@ def webhook(request: Request, payload: dict = fastapi_request_json): case "BILLING.SUBSCRIPTION.CANCELLED" | "BILLING.SUBSCRIPTION.EXPIRED": subscription = SubscriptionEvent.parse_obj(event).resource PaypalWebhookHandler.handle_subscription_cancelled(subscription) - case _: - logger.error(f"Unhandled PayPal webhook event: {event.event_type}") - return JSONResponse({}, status_code=200) + return JSONResponse({}) From 6bb0ac2f7ee90fe3b7923074def0167891af547a Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 14 Jul 2024 23:23:22 +0530 Subject: [PATCH 24/26] refactor --- daras_ai_v2/billing.py | 39 +++++++++++++++++++++------------------ tests/test_checkout.py | 4 ++-- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index a0e91453e..c9bf26847 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -427,22 +427,7 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser): if user.subscription: confirm_purchase_modal.open() else: - from routers.account import account_route - from routers.account import payment_processing_route - - line_item = available_subscriptions["addon"]["stripe"].copy() - line_item["quantity"] = dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR - - checkout_session = stripe.checkout.Session.create( - line_items=[line_item], - mode="payment", - success_url=get_route_url(payment_processing_route), - cancel_url=get_route_url(account_route), - customer=user.get_or_create_stripe_customer(), - invoice_creation={"enabled": True}, - allow_promotion_codes=True, - ) - raise RedirectException(checkout_session.url, status_code=303) + stripe_addon_checkout_redirect(user, dollat_amt) if not confirm_purchase_modal.is_open(): return @@ -475,6 +460,24 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser): st.button("Buy", type="primary", key="--confirm-purchase") +def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): + from routers.account import account_route + from routers.account import payment_processing_route + + line_item = available_subscriptions["addon"]["stripe"].copy() + line_item["quantity"] = dollat_amt * settings.ADDON_CREDITS_PER_DOLLAR + checkout_session = stripe.checkout.Session.create( + line_items=[line_item], + mode="payment", + success_url=get_route_url(payment_processing_route), + cancel_url=get_route_url(account_route), + customer=user.get_or_create_stripe_customer(), + invoice_creation={"enabled": True}, + allow_promotion_codes=True, + ) + raise RedirectException(checkout_session.url, status_code=303) + + def render_stripe_subscription_button( *, label: str, @@ -491,10 +494,10 @@ def render_stripe_subscription_button( # of buttons with the same label. otherwise, all buttons # will be the same to the server if st.button(label, key=key, type=btn_type): - create_stripe_checkout_session(user=user, plan=plan) + stripe_subscription_checkout_redirect(user=user, plan=plan) -def create_stripe_checkout_session(user: AppUser, plan: PricingPlan): +def stripe_subscription_checkout_redirect(user: AppUser, plan: PricingPlan): from routers.account import account_route from routers.account import payment_processing_route diff --git a/tests/test_checkout.py b/tests/test_checkout.py index 4e412543a..19a988543 100644 --- a/tests/test_checkout.py +++ b/tests/test_checkout.py @@ -3,7 +3,7 @@ from app_users.models import AppUser from daras_ai_v2 import settings -from daras_ai_v2.billing import create_stripe_checkout_session +from daras_ai_v2.billing import stripe_subscription_checkout_redirect from gooey_ui import RedirectException from payments.plans import PricingPlan from server import app @@ -20,4 +20,4 @@ def test_create_checkout_session( return with pytest.raises(RedirectException): - create_stripe_checkout_session(force_authentication, plan) + stripe_subscription_checkout_redirect(force_authentication, plan) From b9df617ca3187a7e6ac3182fb85605ea50a6dfc7 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sun, 14 Jul 2024 23:24:04 +0530 Subject: [PATCH 25/26] allow saved payment options in stripe addon checkout --- daras_ai_v2/billing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index c9bf26847..709fde926 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -474,6 +474,9 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int): customer=user.get_or_create_stripe_customer(), invoice_creation={"enabled": True}, allow_promotion_codes=True, + saved_payment_method_options={ + "payment_method_save": "enabled", + }, ) raise RedirectException(checkout_session.url, status_code=303) From e6b99558cf5889733a9df98af3ff606e36e29ff5 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 15 Jul 2024 00:38:13 +0530 Subject: [PATCH 26/26] remove network calls inside transaction.atomic() --- payments/webhooks.py | 253 +++++++++++++++++-------------------------- 1 file changed, 98 insertions(+), 155 deletions(-) diff --git a/payments/webhooks.py b/payments/webhooks.py index 8efc91492..b30cae120 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -1,7 +1,6 @@ import stripe from django.db import transaction from loguru import logger -from requests.models import HTTPError from app_users.models import AppUser, PaymentProvider, TransactionReason from daras_ai_v2 import paypal @@ -10,112 +9,6 @@ from .tasks import send_monthly_spending_notification_email -def add_balance_for_payment( - *, - uid: str, - amount: int, - invoice_id: str, - payment_provider: PaymentProvider, - charged_amount: int, - **kwargs, -): - user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance( - amount=amount, - invoice_id=invoice_id, - charged_amount=charged_amount, - payment_provider=payment_provider, - **kwargs, - ) - - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) - - if ( - user.subscription - and user.subscription.should_send_monthly_spending_notification() - ): - send_monthly_spending_notification_email.delay(user.id) - - -@transaction.atomic -def _after_subscription_updated( - *, provider: PaymentProvider, uid: str, sub_id: str, plan: PricingPlan -): - if not is_sub_active(provider=provider, sub_id=sub_id): - # subscription is not in an active state, just ignore - logger.info( - "Subscription is not active. Ignoring event", - provider=provider, - sub_id=sub_id, - ) - return - - user = AppUser.objects.get_or_create_from_uid(uid)[0] - # select_for_update: we want to lock the row until we are done reading & - # updating the subscription - # - # anther transaction shouldn't update the subscription in the meantime - user: AppUser = AppUser.objects.select_for_update().get(pk=user.pk) - if not user.subscription: - # new subscription - logger.info("Creating new subscription for user", uid=uid) - user.subscription = Subscription.objects.get_or_create( - payment_provider=provider, - external_id=sub_id, - defaults={"plan": plan.db_value}, - )[0] - user.subscription.plan = plan.db_value - - elif is_same_sub(user.subscription, provider=provider, sub_id=sub_id): - if user.subscription.plan == plan.db_value: - # same subscription exists with the same plan in DB - logger.info("Nothing to do") - return - else: - # provider & sub_id is same, but plan is different. so we update only the plan - logger.info("Updating plan for user", uid=uid) - user.subscription.plan = plan.db_value - - else: - logger.critical( - "Invalid state: last subscription was not cleared for user", uid=uid - ) - - # we have a different existing subscription in DB - # this is invalid state! we should cancel the subscription if it is active - if is_sub_active( - provider=user.subscription.payment_provider, - sub_id=user.subscription.external_id, - ): - logger.critical( - "Found existing active subscription for user. Cancelling that...", - uid=uid, - provider=user.subscription.get_payment_provider_display(), - sub_id=user.subscription.external_id, - ) - user.subscription.cancel() - - logger.info("Creating new subscription for user", uid=uid) - user.subscription = Subscription( - payment_provider=provider, plan=plan, external_id=sub_id - ) - - user.subscription.full_clean() - user.subscription.save() - user.save(update_fields=["subscription"]) - - -def _after_subscription_cancelled(*, provider: PaymentProvider, uid: str, sub_id: str): - user = AppUser.objects.get_or_create_from_uid(uid=uid)[0] - if user.subscription and is_same_sub( - user.subscription, provider=provider, sub_id=sub_id - ): - user.subscription = None - user.save(update_fields=["subscription"]) - - class PaypalWebhookHandler: PROVIDER = PaymentProvider.PAYPAL @@ -158,17 +51,26 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): assert pp_sub.plan_id, f"PayPal subscription {pp_sub.id} is missing plan ID" plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) - assert plan, f"Plan {pp_sub.plan_id} not found" + assert plan, f"Plan with id={pp_sub.plan_id} not found" + + if pp_sub.status.lower() != "active": + logger.info( + "Subscription is not active. Ignoring event", subscription=pp_sub + ) + return - _after_subscription_updated( - provider=cls.PROVIDER, uid=pp_sub.custom_id, sub_id=pp_sub.id, plan=plan + _set_user_subscription( + provider=cls.PROVIDER, + plan=plan, + uid=pp_sub.custom_id, + external_id=pp_sub.id, ) @classmethod def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" - _after_subscription_cancelled( - provider=cls.PROVIDER, uid=pp_sub.custom_id, sub_id=pp_sub.id + _remove_subscription_for_user( + provider=cls.PROVIDER, uid=pp_sub.custom_id, external_id=pp_sub.id ) @@ -218,13 +120,12 @@ def handle_checkout_session_completed(cls, uid: str, session_data): sub_id ), f"subscription_id is missing in setup_intent metadata {setup_intent}" - if is_sub_active(provider=PaymentProvider.STRIPE, sub_id=sub_id): - stripe.Subscription.modify( - sub_id, default_payment_method=setup_intent.payment_method - ) + stripe.Subscription.modify( + sub_id, default_payment_method=setup_intent.payment_method + ) @classmethod - def handle_subscription_updated(cls, uid: str, stripe_sub): + def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): logger.info(f"Stripe subscription updated: {stripe_sub.id}") assert stripe_sub.plan, f"Stripe subscription {stripe_sub.id} is missing plan" @@ -239,52 +140,94 @@ def handle_subscription_updated(cls, uid: str, stripe_sub): f"PricingPlan not found for product {stripe_sub.plan.product}" ) - _after_subscription_updated( - provider=cls.PROVIDER, uid=uid, sub_id=stripe_sub.id, plan=plan + if stripe_sub.status.lower() != "active": + logger.info( + "Subscription is not active. Ignoring event", subscription=stripe_sub + ) + return + + _set_user_subscription( + provider=cls.PROVIDER, + plan=plan, + uid=uid, + external_id=stripe_sub.id, ) @classmethod def handle_subscription_cancelled(cls, uid: str, stripe_sub): logger.info(f"Stripe subscription cancelled: {stripe_sub.id}") - - _after_subscription_cancelled( - provider=cls.PROVIDER, uid=uid, sub_id=stripe_sub.id + _remove_subscription_for_user( + provider=cls.PROVIDER, uid=uid, external_id=stripe_sub.id ) -def is_same_sub( - subscription: Subscription, *, provider: PaymentProvider, sub_id: str -) -> bool: - return ( - subscription.payment_provider == provider and subscription.external_id == sub_id +def add_balance_for_payment( + *, + uid: str, + amount: int, + invoice_id: str, + payment_provider: PaymentProvider, + charged_amount: int, + **kwargs, +): + user = AppUser.objects.get_or_create_from_uid(uid)[0] + user.add_balance( + amount=amount, + invoice_id=invoice_id, + charged_amount=charged_amount, + payment_provider=payment_provider, + **kwargs, ) + if not user.is_paying: + user.is_paying = True + user.save(update_fields=["is_paying"]) -def is_sub_active(*, provider: PaymentProvider, sub_id: str) -> bool: - match provider: - case PaymentProvider.PAYPAL: - try: - sub = paypal.Subscription.retrieve(sub_id) - except HTTPError as e: - if e.response.status_code != 404: - # if not 404, it likely means there is a bug in our code... - # we want to know about it, but not break the end user experience - logger.exception(f"Unexpected PayPal error for sub: {sub_id}") - return False - - return sub.status == "ACTIVE" - - case PaymentProvider.STRIPE: - try: - sub = stripe.Subscription.retrieve(sub_id) - except stripe.error.InvalidRequestError as e: - if e.http_status != 404: - # if not 404, it likely means there is a bug in our code... - # we want to know about it, but not break the end user experience - logger.exception(f"Unexpected Stripe error for sub: {sub_id}") - return False - except stripe.error.StripeError as e: - logger.exception(f"Unexpected Stripe error for sub: {sub_id}") - return False - - return sub.status == "active" + if ( + user.subscription + and user.subscription.should_send_monthly_spending_notification() + ): + send_monthly_spending_notification_email.delay(user.id) + + +def _set_user_subscription( + *, provider: PaymentProvider, plan: PricingPlan, uid: str, external_id: str +): + with transaction.atomic(): + subscription, created = Subscription.objects.get_or_create( + payment_provider=provider, + external_id=external_id, + defaults=dict(plan=plan.db_value), + ) + subscription.plan = plan.db_value + subscription.full_clean() + subscription.save() + + user = AppUser.objects.get_or_create_from_uid(uid)[0] + existing = user.subscription + + user.subscription = subscription + user.save(update_fields=["subscription"]) + + if not existing: + return + + # cancel existing subscription if it's not the same as the new one + if existing.external_id != external_id: + existing.cancel() + + # delete old db record if it exists + if existing.id != subscription.id: + existing.delete() + + +def _remove_subscription_for_user( + *, uid: str, provider: PaymentProvider, external_id: str +): + AppUser.objects.filter( + uid=uid, + subscription__payment_provider=provider, + subscription__external_id=external_id, + ).update( + subscription=None, + )