Skip to content

Commit

Permalink
Docs/refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jveski committed Dec 16, 2023
1 parent 4063d25 commit 7cda050
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 89 deletions.
96 changes: 7 additions & 89 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,27 +259,31 @@ func newStripeCheckoutHandler(env *conf.Env, kc *keycloak.Keycloak, pc *stripeut

// No active payment - sign them up!
checkoutParams := &stripe.CheckoutSessionParams{
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
CustomerEmail: &user.Email,
SuccessURL: stripe.String(env.SelfURL + "/profile?i=" + etag),
CancelURL: stripe.String(env.SelfURL + "/profile"),
}
checkoutParams.Context = r.Context()

// Calculate specific pricing based on the member's profile
priceID := r.URL.Query().Get("price")
checkoutParams.Mode = stripe.String(string(stripe.CheckoutSessionModeSubscription))
checkoutParams.LineItems = calculateLineItems(user, priceID, pc)
checkoutParams.Discounts = calculateDiscount(user, priceID, pc)
if checkoutParams.Discounts == nil {
// Stripe API doesn't allow Discounts and AllowPromotionCodes to be set
checkoutParams.AllowPromotionCodes = stripe.Bool(true)
}
checkoutParams.LineItems = calculateLineItems(user, priceID, pc)

checkoutParams.SubscriptionData = &stripe.CheckoutSessionSubscriptionDataParams{
Metadata: map[string]string{"etag": etag},
BillingCycleAnchor: calculateBillingCycleAnchor(user),
BillingCycleAnchor: calculateBillingCycleAnchor(user), // This enables migration from paypal
}
if checkoutParams.SubscriptionData.BillingCycleAnchor != nil {
// In this case, the member is already paid up - don't make them pay for the currenet period again
checkoutParams.SubscriptionData.ProrationBehavior = stripe.String("none")
}

s, err := session.New(checkoutParams)
if err != nil {
renderSystemError(w, "error while creating session: %s", err)
Expand Down Expand Up @@ -379,7 +383,6 @@ func newStripeWebhookHandler(env *conf.Env, kc *keycloak.Keycloak, pc *stripeuti
}
}

// TODO: Handle user 404s to avoid constantly failing webhooks?
err = kc.UpdateUserStripeInfo(r.Context(), customer, sub)
if err != nil {
log.Printf("error while updating Keycloak for Stripe subscription webhook event: %s", err)
Expand All @@ -405,91 +408,6 @@ func onlyLeadership(next http.HandlerFunc) http.HandlerFunc {
}
}

func calculateLineItems(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionLineItemParams {
// Migrate existing paypal users at their current rate
if priceID == "paypal" {
interval := "month"
if user.LastPaypalTransactionPrice > 50 {
interval = "year"
}

cents := user.LastPaypalTransactionPrice * 100
productID := pc.GetPrices()[0].ProductID // all prices reference the same product
return []*stripe.CheckoutSessionLineItemParams{{
Quantity: stripe.Int64(1),
PriceData: &stripe.CheckoutSessionLineItemPriceDataParams{
Currency: stripe.String("usd"),
Product: &productID,
UnitAmountDecimal: &cents,
Recurring: &stripe.CheckoutSessionLineItemPriceDataRecurringParams{
Interval: &interval,
},
},
}}
}

return []*stripe.CheckoutSessionLineItemParams{{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
}}
}

func calculateDiscount(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionDiscountParams {
if user.DiscountType == "" || priceID == "" {
return nil
}
for _, price := range pc.GetPrices() {
if price.ID == priceID && price.CouponIDs != nil && price.CouponIDs[user.DiscountType] != "" {
return []*stripe.CheckoutSessionDiscountParams{{
Coupon: stripe.String(price.CouponIDs[user.DiscountType]),
}}
}
}
return nil
}

func calculateDiscounts(user *keycloak.User, prices []*stripeutil.Price) []*stripeutil.Price {
if user.DiscountType == "" {
return prices
}
out := make([]*stripeutil.Price, len(prices))
for i, price := range prices {
amountOff := price.CouponAmountsOff[user.DiscountType]
out[i] = &stripeutil.Price{
ID: price.ID,
ProductID: price.ProductID,
Annual: price.Annual,
Price: price.Price - (float64(amountOff) / 100),
CouponIDs: price.CouponIDs,
CouponAmountsOff: price.CouponAmountsOff,
}
}
return out
}

func calculateBillingCycleAnchor(user *keycloak.User) *int64 {
if user.LastPaypalTransactionPrice == 0 {
return nil
}

var end time.Time
if user.LastPaypalTransactionPrice > 41 {
// Annual
end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 365)
} else {
// Monthly
end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 30)
}

// Stripe will throw an error if the cycle anchor is before the current time
if time.Until(end) < time.Minute {
return nil
}

ts := end.Unix()
return &ts
}

// getUserID allows the oauth2proxy header to be overridden for testing.
func getUserID(r *http.Request) string {
user := r.Header.Get("X-Forwarded-Preferred-Username")
Expand Down
95 changes: 95 additions & 0 deletions pricing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package main

import (
"time"

"github.com/stripe/stripe-go/v75"

"github.com/TheLab-ms/profile/internal/keycloak"
"github.com/TheLab-ms/profile/internal/stripeutil"
)

func calculateLineItems(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionLineItemParams {
// Migrate existing paypal users at their current rate
if priceID == "paypal" {
interval := "month"
if user.LastPaypalTransactionPrice > 50 {
interval = "year"
}

cents := user.LastPaypalTransactionPrice * 100
productID := pc.GetPrices()[0].ProductID // all prices reference the same product
return []*stripe.CheckoutSessionLineItemParams{{
Quantity: stripe.Int64(1),
PriceData: &stripe.CheckoutSessionLineItemPriceDataParams{
Currency: stripe.String("usd"),
Product: &productID,
UnitAmountDecimal: &cents,
Recurring: &stripe.CheckoutSessionLineItemPriceDataRecurringParams{
Interval: &interval,
},
},
}}
}

return []*stripe.CheckoutSessionLineItemParams{{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
}}
}

func calculateDiscount(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionDiscountParams {
if user.DiscountType == "" || priceID == "" {
return nil
}
for _, price := range pc.GetPrices() {
if price.ID == priceID && price.CouponIDs != nil && price.CouponIDs[user.DiscountType] != "" {
return []*stripe.CheckoutSessionDiscountParams{{
Coupon: stripe.String(price.CouponIDs[user.DiscountType]),
}}
}
}
return nil
}

func calculateDiscounts(user *keycloak.User, prices []*stripeutil.Price) []*stripeutil.Price {
if user.DiscountType == "" {
return prices
}
out := make([]*stripeutil.Price, len(prices))
for i, price := range prices {
amountOff := price.CouponAmountsOff[user.DiscountType]
out[i] = &stripeutil.Price{
ID: price.ID,
ProductID: price.ProductID,
Annual: price.Annual,
Price: price.Price - (float64(amountOff) / 100),
CouponIDs: price.CouponIDs,
CouponAmountsOff: price.CouponAmountsOff,
}
}
return out
}

func calculateBillingCycleAnchor(user *keycloak.User) *int64 {
if user.LastPaypalTransactionPrice == 0 {
return nil
}

var end time.Time
if user.LastPaypalTransactionPrice > 41 {
// Annual
end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 365)
} else {
// Monthly
end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 30)
}

// Stripe will throw an error if the cycle anchor is before the current time
if time.Until(end) < time.Minute {
return nil
}

ts := end.Unix()
return &ts
}

0 comments on commit 7cda050

Please sign in to comment.