Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle plan transition logic for "Upgrade" and "Downgrade" on Payment #974

Merged
merged 6 commits into from
Aug 24, 2024
Merged
2 changes: 2 additions & 0 deletions api/v1/models/billing_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# app/models/billing_plan.py
from sqlalchemy import Column, String, ARRAY, ForeignKey, Numeric, Boolean
from sqlalchemy.orm import relationship
from sqlalchemy import DateTime
from api.v1.models.base_model import BaseTableModel


Expand Down Expand Up @@ -34,3 +35,4 @@ class UserSubscription(BaseTableModel):
user = relationship("User", back_populates="subscriptions")
billing_plan = relationship("BillingPlan", back_populates="user_subscriptions")
organisation = relationship("Organisation", back_populates="user_subscriptions")
billing_cycle = Column(DateTime, nullable=True)
39 changes: 25 additions & 14 deletions api/v1/routes/stripe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi import APIRouter, Depends, HTTPException, Request, status, Query
from sqlalchemy.orm import Session
import stripe
from api.v1.services.stripe_payment import stripe_payment_request, \
Expand Down Expand Up @@ -32,10 +32,10 @@ def stripe_payment(
db: Session = Depends(get_db),
current_user: User = Depends(user_service.get_current_user)
):
return stripe_payment_request(db, plan_upgrade_request.user_id, request, plan_upgrade_request.plan_name)
return stripe_payment_request(db, plan_upgrade_request.user_id, request, plan_upgrade_request.plan_id)

@subscription_.get("/stripe/success")
def success_upgrade(session_id: str):
def success_upgrade(session_id: str= Query(...)):
return success_response(
status_code=status.HTTP_200_OK,
message="Payment intent initiated. Please verify the payment using the session ID.",
Expand All @@ -53,17 +53,17 @@ async def verify_payment(session_id: str, db: Session = Depends(get_db)):
if session.payment_status == "paid":
# If payment was successful, update the user's plan
user_id = session.metadata["user_id"]
plan_name = session.metadata["plan_name"]
print(user_id, plan_name)
await update_user_plan(db, user_id, plan_name)

return { "status": "SUCCESS" }

# return success_response(
# status_code=status.HTTP_200_OK,
# message="Payment successful and plan updated.",
# data={"session_id": session_id, "payment_status": session.payment_status}
# )
plan_id = session.metadata["plan_id"]
print(user_id, plan_id)
await update_user_plan(db, user_id, plan_id)
#TODO Remember to uncomme
# return { "status": "SUCCESS" }

return success_response(
status_code=status.HTTP_200_OK,
message="Payment successful and plan updated.",
data={"status": "SUCCESS", "session_id": session_id, "payment_status": session.payment_status}
)
else:
return fail_response(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -77,6 +77,17 @@ async def verify_payment(session_id: str, db: Session = Depends(get_db)):
raise HTTPException(status_code=500, detail=str(e))


@subscription_.post("/stripe/change-plan")
def change_plan(
plan_upgrade_request: PlanUpgradeRequest,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(user_service.get_current_user)
):
is_downgrade = plan_upgrade_request.is_downgrade
return update_user_plan(db, plan_upgrade_request.user_id, plan_upgrade_request.plan_id, is_downgrade=is_downgrade)


@subscription_.get("/stripe/cancel")
def cancel_upgrade():

Expand Down
5 changes: 3 additions & 2 deletions api/v1/schemas/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def cvc_validator(cls, v):

class PlanUpgradeRequest(BaseModel):
user_id: str
plan_name: str
payment_info: Optional[PaymentInfo] = None
plan_id: str
is_downgrade: bool
#payment_info: Optional[PaymentInfo] = None


2 changes: 1 addition & 1 deletion api/v1/services/billing_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create(self, db: Session, request: CreateBillingPlanSchema):

# Adjust the price if the duration is 'yearly'
if request.duration == "yearly":
request.price = request.price * 12 * 0.8 # Apply yearly discount
request.price = request.price * 12 * 0.8 # Apply yearly discount of 20%

# Create a BillingPlan instance using the modified request
plan = BillingPlan(**request.dict())
Expand Down
182 changes: 115 additions & 67 deletions api/v1/services/stripe_payment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,78 @@
from api.v1.models.user import User
from api.v1.models.billing_plan import BillingPlan, UserSubscription
from api.v1.models.organisation import Organisation
from api.v1.models.payment import Payment
import stripe
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import select, join
from fastapi.encoders import jsonable_encoder
from api.utils.success_response import success_response, fail_response
import os
from sqlalchemy import cast, DateTime
from fastapi import HTTPException, status, Request
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta

stripe.api_key = os.getenv('STRIPE_SECRET_KEY')



def get_plan_by_id(db: Session, plan_id: str):
return db.query(BillingPlan).filter(BillingPlan.id == plan_id).first()


def convert_duration_to_timedelta(duration: str) -> timedelta:
if duration == "monthly":
return timedelta(days=30) # Approximate month length
elif duration == "yearly":
return timedelta(days=365) # Approximate year length
else:
raise ValueError("Invalid duration")

def is_eligible_for_plan(db: Session, user_id: str, plan_id: str):
# Fetch the user's current subscription
user_subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()

# If the user has no subscription, they are eligible for the plan
if not user_subscription:
return True

# Check if the user's current subscription has ended
if user_subscription.end_date < datetime.utcnow():
return True

# If the user is trying to upgrade or downgrade, they are eligible
if user_subscription.plan_id != plan_id:
return True

# If none of the above conditions are met, the user is not eligible
return False


def calculate_prorated_amount(db: Session, user_id: str, plan_id: str):
# Fetch the user's current subscription
user_subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()

# Fetch the plan the user is trying to upgrade or downgrade to
plan = get_plan_by_id(db, plan_id)

# Calculate the number of days remaining in the current subscription
days_remaining = (user_subscription.end_date - datetime.utcnow()).days

# Calculate the total number of days in the current subscription
total_days = (user_subscription.end_date - user_subscription.start_date).days

# Calculate the prorated amount
prorated_amount = (plan.price / total_days) * days_remaining

return prorated_amount


def get_all_plans(db: Session):
"""
Retrieve all billing plan details.
Expand All @@ -28,28 +87,77 @@ def get_all_plans(db: Session):
raise HTTPException(status_code=500, detail="An error occurred while fetching billing plans")


def get_plan_by_name(db: Session, plan_name: str):
return db.query(BillingPlan).filter(BillingPlan.name == plan_name).first()
async def update_user_plan(db: Session, user_id: str, plan_id: str, is_downgrade: bool = False):
user = db.query(User).filter(User.id == user_id).first()
plan = get_plan_by_id(db, plan_id)

try:
duration = convert_duration_to_timedelta(plan.duration)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

user_subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()

if user_subscription:
old_plan = user_subscription.billing_plan
old_duration = convert_duration_to_timedelta(user_subscription.billing_plan.duration)
days_remaining = (datetime.strptime(user_subscription.end_date, "%Y-%m-%d %H:%M:%S.%f") - datetime.utcnow()).days
total_days = (datetime.strptime(user_subscription.end_date, "%Y-%m-%d %H:%M:%S.%f") - datetime.strptime(user_subscription.start_date, "%Y-%m-%d %H:%M:%S.%f")).days

prorated_amount = 0 # Initialize prorated_amount to 0
if is_downgrade:
prorated_amount = (old_plan.price / total_days) * days_remaining
#TODO Refund or credit the user's account (implement based on payment logic)
else:
prorated_amount = (plan.price - prorated_amount)
#TODO Charge the user's payment method (implement based on payment logic)

def stripe_payment_request(db: Session, user_id: str, request: Request, plan_name: str):
user_subscription.plan_id = plan.id
user_subscription.start_date = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")
user_subscription.end_date = (datetime.utcnow() + duration).strftime("%Y-%m-%d %H:%M:%S.%f")
user_subscription.billing_cycle = datetime.utcnow() + duration

else:
user_subscription = UserSubscription(
user_id=user_id,
plan_id=plan.id,
organisation_id=plan.organisation_id,
start_date=datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"),
end_date=(datetime.utcnow() + duration).strftime("%Y-%m-%d %H:%M:%S.%f"),
billing_cycle=datetime.utcnow() + duration
)
db.add(user_subscription)

db.commit()
db.refresh(user_subscription)
return user_subscription

# base_url = request.base_url
# base_urls = str(request.url.scheme) + "://" + str(request.url.netloc)

def stripe_payment_request(db: Session, user_id: str, request: Request, plan_id: str):

# base_urls = request.base_url
# base_urls = str(request.url.scheme) + "://" + str(request.url.netloc)

base_urls = "https://anchor-python.teams.hng.tech/"
success_url = f"{base_urls}payment" + "/success?session_id={CHECKOUT_SESSION_ID}"
cancel_url = f"{base_urls}payment/pricing"

# success_url = f"{base_urls}api/v1/payment/stripe" + "/success?session_id={CHECKOUT_SESSION_ID}"
# cancel_url = f"{base_urls}api/v1/payment/stripe/cancel"


user = db.query(User).filter(User.id == user_id).first()

if not user:
return fail_response(status_code=404, message="User not found")

plan = get_plan_by_name(db, plan_name)
plan = get_plan_by_id(db, plan_id)

if not plan:
return fail_response(status_code=404, message="Plan not found")


if plan.name != "Free":
try:
Expand All @@ -72,7 +180,7 @@ def stripe_payment_request(db: Session, user_id: str, request: Request, plan_nam
cancel_url=cancel_url,
metadata={
'user_id': user_id,
'plan_name': plan_name,
'plan_id': plan.id,
},
)

Expand Down Expand Up @@ -104,66 +212,6 @@ def stripe_payment_request(db: Session, user_id: str, request: Request, plan_nam
return fail_response(status_code=400, message="No payment is required for the Free plan")


def convert_duration_to_timedelta(duration: str) -> timedelta:
if duration == "monthly":
return timedelta(days=30) # Approximate month length
elif duration == "yearly":
return timedelta(days=365) # Approximate year length
else:
raise ValueError("Invalid duration")


async def update_user_plan(db: Session, user_id: str, plan_name: str):
# Fetch the user by ID
user = db.query(User).filter(User.id == user_id).first()

# Fetch the plan by name
plan = get_plan_by_name(db, plan_name)

# Check if the user exists
if not user:
raise HTTPException(status_code=404, detail="User not found")

# Check if the plan exists
if not plan:
raise HTTPException(status_code=404, detail="Plan not found")

# Convert duration from string to timedelta
try:
duration = convert_duration_to_timedelta(plan.duration)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

# Fetch the organisation ID from the plan
organisation_id = plan.organisation_id

# Update the user's subscription in the database
user_subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.organisation_id == organisation_id
).first()

if user_subscription:
user_subscription.plan_id = plan.id
user_subscription.start_date = datetime.utcnow()
user_subscription.end_date = datetime.utcnow() + duration
else:
user_subscription = UserSubscription(
user_id=user_id,
plan_id=plan.id,
organisation_id=organisation_id,
start_date=datetime.utcnow(),
end_date=datetime.utcnow() + duration
)
db.add(user_subscription)

# Commit the transaction
db.commit()
db.refresh(user_subscription) # Refresh the session to get the updated data

# Return the updated or newly created subscription
return user_subscription


def fetch_all_organisations_with_users_and_plans(db: Session):
# Perform a join to retrieve the relevant data
Expand Down
13 changes: 0 additions & 13 deletions tests/v1/billing_plan/test_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,6 @@ def mock_fetch_all_organisations_with_users_and_plans():
with patch("api.v1.services.stripe_payment.fetch_all_organisations_with_users_and_plans") as mock_service:
yield mock_service

@pytest.mark.asyncio
async def test_subscribe_user_to_plan(mock_db_session, mock_subscribe_user_to_plan):
# Mock the behavior of the service function
mock_subscribe_user_to_plan.return_value = mock_subscription

# Call the actual service function
response = await update_user_plan(mock_db_session, user_id=user_id, plan_name="Premium")

# Assertions
assert response.user_id == user_id
assert response.plan_id == plan_id
assert response.organisation_id == org_id


@pytest.mark.usefixtures("mock_db_session", "mock_user_service")
def test_fetch_invalid_billing_plans(mock_user_service, mock_db_session):
Expand Down
Loading