diff --git a/main.go b/main.go index fc2aeb6..1222c87 100644 --- a/main.go +++ b/main.go @@ -259,27 +259,31 @@ func newStripeCheckoutHandler(env *conf.Env, kc *keycloak.Keycloak, pc *stripeut // No active payment - sign them up! checkoutParams := &stripe.CheckoutSessionParams{ + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), CustomerEmail: &user.Email, SuccessURL: stripe.String(env.SelfURL + "/profile?i=" + etag), CancelURL: stripe.String(env.SelfURL + "/profile"), } checkoutParams.Context = r.Context() + // Calculate specific pricing based on the member's profile priceID := r.URL.Query().Get("price") - checkoutParams.Mode = stripe.String(string(stripe.CheckoutSessionModeSubscription)) + checkoutParams.LineItems = calculateLineItems(user, priceID, pc) checkoutParams.Discounts = calculateDiscount(user, priceID, pc) if checkoutParams.Discounts == nil { // Stripe API doesn't allow Discounts and AllowPromotionCodes to be set checkoutParams.AllowPromotionCodes = stripe.Bool(true) } - checkoutParams.LineItems = calculateLineItems(user, priceID, pc) + checkoutParams.SubscriptionData = &stripe.CheckoutSessionSubscriptionDataParams{ Metadata: map[string]string{"etag": etag}, - BillingCycleAnchor: calculateBillingCycleAnchor(user), + BillingCycleAnchor: calculateBillingCycleAnchor(user), // This enables migration from paypal } if checkoutParams.SubscriptionData.BillingCycleAnchor != nil { + // In this case, the member is already paid up - don't make them pay for the currenet period again checkoutParams.SubscriptionData.ProrationBehavior = stripe.String("none") } + s, err := session.New(checkoutParams) if err != nil { renderSystemError(w, "error while creating session: %s", err) @@ -379,7 +383,6 @@ func newStripeWebhookHandler(env *conf.Env, kc *keycloak.Keycloak, pc *stripeuti } } - // TODO: Handle user 404s to avoid constantly failing webhooks? err = kc.UpdateUserStripeInfo(r.Context(), customer, sub) if err != nil { log.Printf("error while updating Keycloak for Stripe subscription webhook event: %s", err) @@ -405,91 +408,6 @@ func onlyLeadership(next http.HandlerFunc) http.HandlerFunc { } } -func calculateLineItems(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionLineItemParams { - // Migrate existing paypal users at their current rate - if priceID == "paypal" { - interval := "month" - if user.LastPaypalTransactionPrice > 50 { - interval = "year" - } - - cents := user.LastPaypalTransactionPrice * 100 - productID := pc.GetPrices()[0].ProductID // all prices reference the same product - return []*stripe.CheckoutSessionLineItemParams{{ - Quantity: stripe.Int64(1), - PriceData: &stripe.CheckoutSessionLineItemPriceDataParams{ - Currency: stripe.String("usd"), - Product: &productID, - UnitAmountDecimal: ¢s, - Recurring: &stripe.CheckoutSessionLineItemPriceDataRecurringParams{ - Interval: &interval, - }, - }, - }} - } - - return []*stripe.CheckoutSessionLineItemParams{{ - Price: stripe.String(priceID), - Quantity: stripe.Int64(1), - }} -} - -func calculateDiscount(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionDiscountParams { - if user.DiscountType == "" || priceID == "" { - return nil - } - for _, price := range pc.GetPrices() { - if price.ID == priceID && price.CouponIDs != nil && price.CouponIDs[user.DiscountType] != "" { - return []*stripe.CheckoutSessionDiscountParams{{ - Coupon: stripe.String(price.CouponIDs[user.DiscountType]), - }} - } - } - return nil -} - -func calculateDiscounts(user *keycloak.User, prices []*stripeutil.Price) []*stripeutil.Price { - if user.DiscountType == "" { - return prices - } - out := make([]*stripeutil.Price, len(prices)) - for i, price := range prices { - amountOff := price.CouponAmountsOff[user.DiscountType] - out[i] = &stripeutil.Price{ - ID: price.ID, - ProductID: price.ProductID, - Annual: price.Annual, - Price: price.Price - (float64(amountOff) / 100), - CouponIDs: price.CouponIDs, - CouponAmountsOff: price.CouponAmountsOff, - } - } - return out -} - -func calculateBillingCycleAnchor(user *keycloak.User) *int64 { - if user.LastPaypalTransactionPrice == 0 { - return nil - } - - var end time.Time - if user.LastPaypalTransactionPrice > 41 { - // Annual - end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 365) - } else { - // Monthly - end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 30) - } - - // Stripe will throw an error if the cycle anchor is before the current time - if time.Until(end) < time.Minute { - return nil - } - - ts := end.Unix() - return &ts -} - // getUserID allows the oauth2proxy header to be overridden for testing. func getUserID(r *http.Request) string { user := r.Header.Get("X-Forwarded-Preferred-Username") diff --git a/pricing.go b/pricing.go new file mode 100644 index 0000000..ba90f2c --- /dev/null +++ b/pricing.go @@ -0,0 +1,95 @@ +package main + +import ( + "time" + + "github.com/stripe/stripe-go/v75" + + "github.com/TheLab-ms/profile/internal/keycloak" + "github.com/TheLab-ms/profile/internal/stripeutil" +) + +func calculateLineItems(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionLineItemParams { + // Migrate existing paypal users at their current rate + if priceID == "paypal" { + interval := "month" + if user.LastPaypalTransactionPrice > 50 { + interval = "year" + } + + cents := user.LastPaypalTransactionPrice * 100 + productID := pc.GetPrices()[0].ProductID // all prices reference the same product + return []*stripe.CheckoutSessionLineItemParams{{ + Quantity: stripe.Int64(1), + PriceData: &stripe.CheckoutSessionLineItemPriceDataParams{ + Currency: stripe.String("usd"), + Product: &productID, + UnitAmountDecimal: ¢s, + Recurring: &stripe.CheckoutSessionLineItemPriceDataRecurringParams{ + Interval: &interval, + }, + }, + }} + } + + return []*stripe.CheckoutSessionLineItemParams{{ + Price: stripe.String(priceID), + Quantity: stripe.Int64(1), + }} +} + +func calculateDiscount(user *keycloak.User, priceID string, pc *stripeutil.PriceCache) []*stripe.CheckoutSessionDiscountParams { + if user.DiscountType == "" || priceID == "" { + return nil + } + for _, price := range pc.GetPrices() { + if price.ID == priceID && price.CouponIDs != nil && price.CouponIDs[user.DiscountType] != "" { + return []*stripe.CheckoutSessionDiscountParams{{ + Coupon: stripe.String(price.CouponIDs[user.DiscountType]), + }} + } + } + return nil +} + +func calculateDiscounts(user *keycloak.User, prices []*stripeutil.Price) []*stripeutil.Price { + if user.DiscountType == "" { + return prices + } + out := make([]*stripeutil.Price, len(prices)) + for i, price := range prices { + amountOff := price.CouponAmountsOff[user.DiscountType] + out[i] = &stripeutil.Price{ + ID: price.ID, + ProductID: price.ProductID, + Annual: price.Annual, + Price: price.Price - (float64(amountOff) / 100), + CouponIDs: price.CouponIDs, + CouponAmountsOff: price.CouponAmountsOff, + } + } + return out +} + +func calculateBillingCycleAnchor(user *keycloak.User) *int64 { + if user.LastPaypalTransactionPrice == 0 { + return nil + } + + var end time.Time + if user.LastPaypalTransactionPrice > 41 { + // Annual + end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 365) + } else { + // Monthly + end = user.LastPaypalTransactionTime.Add(time.Hour * 24 * 30) + } + + // Stripe will throw an error if the cycle anchor is before the current time + if time.Until(end) < time.Minute { + return nil + } + + ts := end.Unix() + return &ts +}