diff --git a/pinax/stripe/actions/plans.py b/pinax/stripe/actions/plans.py index b304a5110..48d9c2911 100644 --- a/pinax/stripe/actions/plans.py +++ b/pinax/stripe/actions/plans.py @@ -29,7 +29,9 @@ def sync_plan(plan, event=None): "name": plan["name"], "statement_descriptor": plan["statement_descriptor"] or "", "trial_period_days": plan["trial_period_days"], - "metadata": plan["metadata"] + "metadata": plan["metadata"], + "billing_scheme": plan["billing_scheme"], + "tiers_mode": plan["tiers_mode"] } obj, created = models.Plan.objects.get_or_create( @@ -37,3 +39,14 @@ def sync_plan(plan, event=None): defaults=defaults ) utils.update_with_defaults(obj, defaults, created) + + if plan["tiers"]: + obj.tiers.all().delete() # delete all tiers, since they don't have ids in Stripe + for tier in plan["tiers"]: + tier_obj = models.Tier.objects.create( + plan=obj, + amount=utils.convert_amount_for_db(tier["amount"], plan["currency"]), + flat_amount=utils.convert_amount_for_db(tier["flat_amount"], plan["currency"]), + up_to=tier["up_to"] + ) + obj.tiers.add(tier_obj) diff --git a/pinax/stripe/managers.py b/pinax/stripe/managers.py index bfc849926..a4396ddfd 100644 --- a/pinax/stripe/managers.py +++ b/pinax/stripe/managers.py @@ -71,3 +71,43 @@ def paid_totals_for(self, year, month): total_amount=models.Sum("amount"), total_refunded=models.Sum("amount_refunded") ) + + +class TieredPricingManager(models.Manager): + + TIERS_MODE_VOLUME = "volume" + TIERS_MODE_GRADUATED = "graduated" + TIERS_MODES = (TIERS_MODE_VOLUME, TIERS_MODE_GRADUATED) + + def closed_tiers(self, plan): + return self.filter(plan=plan, up_to__isnull=False).order_by("up_to") + + def open_tiers(self, plan): + return self.filter(plan=plan, up_to__isnull=True) + + def all_tiers(self, plan): + return list(self.closed_tiers(plan)) + list(self.open_tiers(plan)) + + def calculate_final_cost(self, plan, quantity, mode): + if mode not in self.TIERS_MODES: + raise Exception("Received wrong type of mode ({})".format(mode)) + + all_tiers = self.all_tiers(plan) + cost = 0 + if mode == self.TIERS_MODE_VOLUME: + applicable_tiers = list(filter(lambda t: not t.up_to or quantity <= t.up_to, all_tiers)) + tier = applicable_tiers[0] if applicable_tiers else all_tiers[-1] + cost = tier.calculate_cost(quantity) + + if mode == self.TIERS_MODE_GRADUATED: + quantity_billed = 0 + idx = 0 + while quantity > 0: + tier = all_tiers[idx] + quantity_to_bill = min(quantity, tier.up_to - quantity_billed) if tier.up_to else quantity + cost += tier.calculate_cost(quantity_to_bill) + quantity -= quantity_to_bill + quantity_billed += quantity_to_bill + idx += 1 + + return cost diff --git a/pinax/stripe/migrations/0015_auto_20190203_0949.py b/pinax/stripe/migrations/0015_auto_20190203_0949.py new file mode 100644 index 000000000..0b0a2bf6d --- /dev/null +++ b/pinax/stripe/migrations/0015_auto_20190203_0949.py @@ -0,0 +1,34 @@ +# Generated by Django 2.2a1 on 2019-02-03 14:49 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('pinax_stripe', '0014_auto_20180413_1959'), + ] + + operations = [ + migrations.AddField( + model_name='plan', + name='billing_scheme', + field=models.CharField(default='per_unit', max_length=15), + ), + migrations.AddField( + model_name='plan', + name='tiers_mode', + field=models.CharField(blank=True, max_length=15, null=True), + ), + migrations.CreateModel( + name='Tier', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('amount', models.DecimalField(decimal_places=2, max_digits=9)), + ('flat_amount', models.DecimalField(decimal_places=2, max_digits=9)), + ('up_to', models.IntegerField(blank=True, null=True)), + ('plan', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tiers', to='pinax_stripe.Plan')), + ], + ), + ] diff --git a/pinax/stripe/models.py b/pinax/stripe/models.py index 584f34e56..3c9bcdc32 100644 --- a/pinax/stripe/models.py +++ b/pinax/stripe/models.py @@ -14,7 +14,7 @@ from jsonfield.fields import JSONField from .conf import settings -from .managers import ChargeManager, CustomerManager +from .managers import ChargeManager, CustomerManager, TieredPricingManager from .utils import CURRENCY_SYMBOLS @@ -76,6 +76,9 @@ def stripe_account_stripe_id(self): @python_2_unicode_compatible class Plan(UniquePerAccountStripeObject): + BILLING_SCHEME_PER_UNIT = "per_unit" + BILLING_SCHEME_TIERED = "tiered" + amount = models.DecimalField(decimal_places=2, max_digits=9) currency = models.CharField(max_length=15, blank=False) interval = models.CharField(max_length=15) @@ -84,6 +87,8 @@ class Plan(UniquePerAccountStripeObject): statement_descriptor = models.TextField(blank=True) trial_period_days = models.IntegerField(null=True, blank=True) metadata = JSONField(null=True, blank=True) + billing_scheme = models.CharField(max_length=15, default=BILLING_SCHEME_PER_UNIT) + tiers_mode = models.CharField(max_length=15, null=True, blank=True) def __str__(self): return "{} ({}{})".format(self.name, CURRENCY_SYMBOLS.get(self.currency, ""), self.amount) @@ -107,6 +112,14 @@ def stripe_plan(self): stripe_account=self.stripe_account_stripe_id, ) + def calculate_total_amount(self, quantity): + if self.billing_scheme == self.BILLING_SCHEME_PER_UNIT: + return self.amount * quantity + elif self.billing_scheme == self.BILLING_SCHEME_TIERED: + return Tier.pricing.calculate_final_cost(self, quantity, self.tiers_mode) + else: + raise Exception("The Plan ({}) received the wrong type of billing_scheme ({})".format(self.name, self.billing_scheme)) + @python_2_unicode_compatible class Coupon(StripeObject): @@ -379,7 +392,7 @@ def stripe_subscription(self): @property def total_amount(self): - return self.plan.amount * self.quantity + return self.plan.calculate_total_amount(self.quantity) def plan_display(self): return self.plan.name @@ -653,3 +666,17 @@ def stripe_bankaccount(self): return self.account.stripe_account.external_accounts.retrieve( self.stripe_id ) + + +class Tier(models.Model): + + plan = models.ForeignKey(Plan, related_name="tiers", on_delete=models.CASCADE) + amount = models.DecimalField(decimal_places=2, max_digits=9) + flat_amount = models.DecimalField(decimal_places=2, max_digits=9) + up_to = models.IntegerField(blank=True, null=True) + + objects = models.Manager() + pricing = TieredPricingManager() + + def calculate_cost(self, quantity): + return (self.amount * quantity) + self.flat_amount diff --git a/pinax/stripe/tests/__init__.py b/pinax/stripe/tests/__init__.py index be404df1a..f5a90faab 100644 --- a/pinax/stripe/tests/__init__.py +++ b/pinax/stripe/tests/__init__.py @@ -146,7 +146,10 @@ "currency": "usd", "created": 1498573686, "name": "Pro Plan", - "metadata": {} + "metadata": {}, + "billing_scheme": "per_unit", + "tiers_mode": None, + "tiers": None } }, "type": "plan.updated", diff --git a/pinax/stripe/tests/test_actions.py b/pinax/stripe/tests/test_actions.py index 4d5dc1b40..1813ba186 100644 --- a/pinax/stripe/tests/test_actions.py +++ b/pinax/stripe/tests/test_actions.py @@ -1212,7 +1212,10 @@ def test_sync_plans(self, PlanAutoPagerMock): "metadata": {}, "name": "The Pro Plan", "statement_descriptor": "ALTMAN", - "trial_period_days": 3 + "trial_period_days": 3, + "billing_scheme": "per_unit", + "tiers_mode": None, + "tiers": None }, { "id": "simple1", @@ -1226,12 +1229,45 @@ def test_sync_plans(self, PlanAutoPagerMock): "metadata": {}, "name": "The Simple Plan", "statement_descriptor": "ALTMAN", - "trial_period_days": 3 + "trial_period_days": 3, + "billing_scheme": "per_unit", + "tiers_mode": None, + "tiers": None + }, + { + "id": "tiered1", + "object": "plan", + "amount": None, + "created": 1448121054, + "currency": "usd", + "interval": "month", + "interval_count": 1, + "livemode": False, + "metadata": {}, + "name": "The Simple Plan", + "statement_descriptor": "ALTMAN", + "trial_period_days": 3, + "billing_scheme": "tiered", + "tiers_mode": "test", + "tiers": [ + { + "amount": None, + "flat_amount": 14900, + "up_to": 100 + }, + { + "amount": 100, + "flat_amount": None, + "up_to": None + } + ], }, ] plans.sync_plans() - self.assertTrue(Plan.objects.all().count(), 2) + self.assertEqual(Plan.objects.all().count(), len(PlanAutoPagerMock.return_value)) self.assertEqual(Plan.objects.get(stripe_id="simple1").amount, decimal.Decimal("9.99")) + self.assertTrue(Plan.objects.filter(stripe_id="tiered1").exists()) + self.assertEqual(Plan.objects.get(stripe_id="tiered1").tiers.count(), 2) @patch("stripe.Plan.auto_paging_iter", create=True) def test_sync_plans_update(self, PlanAutoPagerMock): @@ -1248,7 +1284,10 @@ def test_sync_plans_update(self, PlanAutoPagerMock): "metadata": {}, "name": "The Pro Plan", "statement_descriptor": "ALTMAN", - "trial_period_days": 3 + "trial_period_days": 3, + "billing_scheme": "per_unit", + "tiers_mode": None, + "tiers": None }, { "id": "simple1", @@ -1262,15 +1301,47 @@ def test_sync_plans_update(self, PlanAutoPagerMock): "metadata": {}, "name": "The Simple Plan", "statement_descriptor": "ALTMAN", - "trial_period_days": 3 + "trial_period_days": 3, + "billing_scheme": "per_unit", + "tiers_mode": None, + "tiers": None + }, + { + "id": "tiered1", + "object": "plan", + "amount": None, + "created": 1448121054, + "currency": "usd", + "interval": "month", + "interval_count": 1, + "livemode": False, + "metadata": {}, + "name": "The Simple Plan", + "statement_descriptor": "ALTMAN", + "trial_period_days": 3, + "billing_scheme": "tiered", + "tiers_mode": "test", + "tiers": [ + { + "amount": None, + "flat_amount": 14900, + "up_to": 100 + }, + { + "amount": 100, + "flat_amount": None, + "up_to": None + } + ], }, ] plans.sync_plans() - self.assertTrue(Plan.objects.all().count(), 2) + self.assertEqual(Plan.objects.all().count(), len(PlanAutoPagerMock.return_value)) self.assertEqual(Plan.objects.get(stripe_id="simple1").amount, decimal.Decimal("9.99")) PlanAutoPagerMock.return_value[1].update({"amount": 499}) plans.sync_plans() self.assertEqual(Plan.objects.get(stripe_id="simple1").amount, decimal.Decimal("4.99")) + self.assertEqual(Plan.objects.get(stripe_id="tiered1").tiers.count(), 2) def test_sync_plan(self): """ @@ -1295,7 +1366,10 @@ def test_sync_plan(self): "metadata": {}, "name": "Gold Plan", "statement_descriptor": "ALTMAN", - "trial_period_days": 3 + "trial_period_days": 3, + "billing_scheme": "per_unit", + "tiers_mode": None, + "tiers": None } plans.sync_plan(plan) self.assertTrue(Plan.objects.all().count(), 1) diff --git a/pinax/stripe/tests/test_commands.py b/pinax/stripe/tests/test_commands.py index 51f520e02..f76d98109 100644 --- a/pinax/stripe/tests/test_commands.py +++ b/pinax/stripe/tests/test_commands.py @@ -43,7 +43,10 @@ def test_plans_create(self, PlanAutoPagerMock): "statement_descriptor": None, "trial_period_days": None, "name": "Pro", - "metadata": {} + "metadata": {}, + "billing_scheme": "per_unit", + "tiers_mode": None, + "tiers": None }] management.call_command("sync_plans") self.assertEqual(Plan.objects.count(), 1) diff --git a/pinax/stripe/tests/test_managers.py b/pinax/stripe/tests/test_managers.py index 5f1b787e7..757e136b0 100644 --- a/pinax/stripe/tests/test_managers.py +++ b/pinax/stripe/tests/test_managers.py @@ -5,7 +5,7 @@ from django.test import TestCase from django.utils import timezone -from ..models import Charge, Customer, Plan, Subscription +from ..models import Charge, Customer, Plan, Subscription, Tier class CustomerManagerTest(TestCase): @@ -193,3 +193,64 @@ def test_paid_totals_for_dec(self): totals = Charge.objects.paid_totals_for(2013, 12) self.assertEqual(totals["total_amount"], None) self.assertEqual(totals["total_refunded"], None) + + +class TieredPricingManagerTests(TestCase): + def setUp(self): + self.plan = Plan.objects.create( + stripe_id="plan", amount=0, interval="monthly", interval_count=1, billing_scheme=Plan.BILLING_SCHEME_TIERED + ) + Tier.objects.create(plan=self.plan, up_to=5, amount=5, flat_amount=10) + Tier.objects.create(plan=self.plan, up_to=10, amount=4, flat_amount=20) + Tier.objects.create(plan=self.plan, up_to=15, amount=3, flat_amount=30) + Tier.objects.create(plan=self.plan, up_to=20, amount=2, flat_amount=40) + Tier.objects.create(plan=self.plan, up_to=None, amount=1, flat_amount=50) + + def test_calculate_final_cost_with_volume_tiers_mode(self): + test_cases = [ + (1, 5), + (5, 25), + (6, 24), + (20, 40), + (25, 25), + ] + self.plan.tiers.all().update(flat_amount=0) + for quantity, expected in test_cases: + cost = Tier.pricing.calculate_final_cost(self.plan, quantity, Tier.pricing.TIERS_MODE_VOLUME) + self.assertEqual(cost, expected) + + def test_calculate_final_cost_with_graduated_tiers_mode(self): + test_cases = [ + (1, 5), + (5, 25), + (6, 29), + (20, 70), + (25, 75), + ] + self.plan.tiers.all().update(flat_amount=0) + for quantity, expected in test_cases: + cost = Tier.pricing.calculate_final_cost(self.plan, quantity, Tier.pricing.TIERS_MODE_GRADUATED) + self.assertEqual(cost, expected) + + def test_calculate_final_cost_with_volume_tiers_and_flat_fees(self): + test_cases = [ + (12, 66) + ] + for quantity, expected in test_cases: + cost = Tier.pricing.calculate_final_cost(self.plan, quantity, Tier.pricing.TIERS_MODE_VOLUME) + self.assertEqual(cost, expected) + + def test_calculate_final_cost_with_graduated_tiers_and_flat_fees(self): + test_cases = [ + (12, 111) + ] + for quantity, expected in test_cases: + cost = Tier.pricing.calculate_final_cost(self.plan, quantity, Tier.pricing.TIERS_MODE_GRADUATED) + self.assertEqual(cost, expected) + + def test_calculate_final_cost_with_invalid_tier(self): + try: + Tier.pricing.calculate_final_cost(self.plan, 1, "invalid") + self.fail("Excepted an exception from calculate_total_amount") + except: + pass diff --git a/pinax/stripe/tests/test_models.py b/pinax/stripe/tests/test_models.py index ea9564eb9..046859b51 100644 --- a/pinax/stripe/tests/test_models.py +++ b/pinax/stripe/tests/test_models.py @@ -26,6 +26,7 @@ InvoiceItem, Plan, Subscription, + Tier, Transfer, UserAccount ) @@ -74,6 +75,28 @@ def test_plan_stripe_plan_with_account(self, RetrieveMock): self.assertTrue(RetrieveMock.call_args_list, [ call("plan", stripe_account="acct_A")]) + def test_plan_calculate_total_amount_per_unit_billing_scheme(self): + quantity = 10 + p = Plan(amount=decimal.Decimal("5"), stripe_id="plan", billing_scheme=Plan.BILLING_SCHEME_PER_UNIT) + self.assertEqual(p.calculate_total_amount(quantity), decimal.Decimal("50")) + + @patch("pinax.stripe.models.Tier.pricing") + def test_plan_calculate_total_amount_tiered_billing_scheme(self, TierPricingMock): + quantity = 10 + p = Plan(amount=0, stripe_id="plan", billing_scheme=Plan.BILLING_SCHEME_TIERED) + p.calculate_total_amount(quantity) + TierPricingMock.calculate_final_cost.assert_called_with(p, quantity, p.tiers_mode) + + @patch("pinax.stripe.models.Tier.pricing") + def test_plan_calculate_total_amount_raises_exception_for_invalid_billing_scheme(self, TierPricingMock): + quantity = 10 + p = Plan(amount=0, stripe_id="plan", billing_scheme="unknown") + try: + p.calculate_total_amount(quantity) + self.fail("Excepted an exception from calculate_total_amount") + except: + pass + def test_plan_per_account(self): Plan.objects.create(stripe_id="plan", amount=decimal.Decimal("100"), interval="monthly", interval_count=1) account = Account.objects.create(stripe_id="acct_A") @@ -302,6 +325,12 @@ def test_blank_with_null(self): if f.null: self.assertTrue(f.blank, msg="%s.%s should be blank=True" % (klass.__name__, f.name)) + def test_tier_calculate_cost(self): + quantity = 12 + p = Plan.objects.create(stripe_id="plan", amount=0, interval="monthly", interval_count=1) + t = Tier(plan=p, amount=4, flat_amount=20) + self.assertEqual(t.calculate_cost(quantity), 68) + class StripeObjectTests(TestCase): diff --git a/pinax/stripe/tests/test_utils.py b/pinax/stripe/tests/test_utils.py index fd21f3a47..f44b5c9e2 100644 --- a/pinax/stripe/tests/test_utils.py +++ b/pinax/stripe/tests/test_utils.py @@ -59,6 +59,11 @@ def test_convert_amount_for_db_none_currency(self): actual = convert_amount_for_db(999, currency=None) self.assertEqual(expected, actual) + def test_convert_amount_for_db_none_amount(self): + expected = decimal.Decimal("0.00") + actual = convert_amount_for_db(None) + self.assertEquals(expected, actual) + class ConvertAmountForApiTests(TestCase): diff --git a/pinax/stripe/utils.py b/pinax/stripe/utils.py index 648681f38..44c9883e0 100644 --- a/pinax/stripe/utils.py +++ b/pinax/stripe/utils.py @@ -33,6 +33,8 @@ def convert_tstamp(response, field_name=None): def convert_amount_for_db(amount, currency="usd"): if currency is None: # @@@ not sure if this is right; find out what we should do when API returns null for currency currency = "usd" + if amount is None: # @@@ not sure if this is right; find out what we should do when API returns null for amount + amount = 0 return (amount / decimal.Decimal("100")) if currency.lower() not in ZERO_DECIMAL_CURRENCIES else decimal.Decimal(amount) diff --git a/setup.py b/setup.py index 31fc32bfc..60df7eca0 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ tests_require = [ "mock", - "pytest", + "pytest!=4.2.0,>=3.6", "pytest-django", ]