Skip to content

Commit

Permalink
Merge pull request #862 from MikeSoft007/feat/stripe_checkout
Browse files Browse the repository at this point in the history
feat: enhance error handling and response structure for billing plan management
  • Loading branch information
joboy-dev authored Aug 12, 2024
2 parents 37f63e3 + a77330b commit 8cc5e25
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 9 deletions.
30 changes: 30 additions & 0 deletions alembic/versions/b99c97c70536_bug_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""bug fix
Revision ID: b99c97c70536
Revises: 9a4e3d412f8e
Create Date: 2024-08-12 12:05:57.900484
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = 'b99c97c70536'
down_revision: Union[str, None] = '9a4e3d412f8e'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
14 changes: 12 additions & 2 deletions api/v1/routes/stripe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy.orm import Session
import stripe
from api.v1.services.stripe_payment import stripe_payment_request, update_user_plan, fetch_all_organisations_with_users_and_plans
from api.v1.services.stripe_payment import stripe_payment_request, \
update_user_plan, fetch_all_organisations_with_users_and_plans, get_all_plans
import json
from api.v1.schemas.stripe import PlanUpgradeRequest
from typing import List
Expand Down Expand Up @@ -36,6 +37,15 @@ def success_upgrade():
def cancel_upgrade():
return {"message" : "Payment canceled"}


@subscription_.get("/plans")
def get_plans(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(user_service.get_current_user)):
data = get_all_plans(db)
return data

@subscription_.post("/webhook")
async def webhook_received(
request: Request,
Expand Down Expand Up @@ -75,7 +85,7 @@ async def get_organisations_with_users_and_plans(db: Session = Depends(get_db),
try:
data = fetch_all_organisations_with_users_and_plans(db)
if not data:
return {"status_code": 404, "success": False, "message": "No data found"}
raise HTTPException(status_code=404, detail="No data found")
return success_response(
status_code=status.HTTP_302_FOUND,
message='billing details successfully retrieved',
Expand Down
14 changes: 14 additions & 0 deletions api/v1/schemas/plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,17 @@ class SubscriptionPlanResponse(CreateSubscriptionPlan):

class Config:
from_attributes = True


class BillingPlanSchema(BaseModel):
id: str
organisation_id: str
name: str
price: float
currency: str
duration: str
description: Optional[str] = None
features: List[str]

class Config:
orm_mode = True
35 changes: 29 additions & 6 deletions api/v1/services/billing_plan.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from api.v1.models.billing_plan import BillingPlan
from typing import Any, Optional
from api.core.base.services import Service
from api.v1.schemas.plans import CreateSubscriptionPlan
from api.utils.db_validators import check_model_existence
from fastapi import HTTPException
from fastapi import HTTPException, status


class BillingPlanService(Service):
Expand All @@ -14,13 +15,35 @@ def create(self, db: Session, request: CreateSubscriptionPlan):
"""
Create and return a new billing plan
"""

plan = BillingPlan(**request.dict())
db.add(plan)
db.commit()
db.refresh(plan)

try:
db.add(plan)
db.commit()
db.refresh(plan)
return plan

except IntegrityError as e:
db.rollback()
# Check if it's a foreign key violation error
if "foreign key constraint" in str(e.orig):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Organisation with id {request.organisation_id} not found."
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="A database integrity error occurred."
)

except SQLAlchemyError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="A database error occurred."
)

return plan

def delete(self, db: Session, id: str):
"""
Expand Down
17 changes: 17 additions & 0 deletions api/v1/services/stripe_payment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from api.v1.models.billing_plan import BillingPlan, UserSubscription
from api.v1.models.organisation import Organisation
import stripe
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import select, join
from fastapi.encoders import jsonable_encoder
from api.utils.success_response import success_response
Expand All @@ -12,9 +14,24 @@

stripe.api_key = os.getenv('STRIPE_SECRET_KEY')


def get_all_plans(db: Session):
"""
Retrieve all billing plan details.
"""
try:
data = db.query(BillingPlan).all()
if not data:
raise HTTPException(status_code=404, detail="No billing plans found")
return success_response(status_code=status.HTTP_302_FOUND, message="Plans successfully retrieved", data=data)
except SQLAlchemyError as e:
raise HTTPException(status_code=500, detail="An error occurred while fetching billing plans")


def get_plan_by_name(db: Session, plan_name: str):
return db.query(BillingPlan).filter(BillingPlan.name == plan_name).first()


def stripe_payment_request(db: Session, user_id: str, request: Request, plan_name: str):

base_url = request.base_url
Expand Down
17 changes: 16 additions & 1 deletion tests/v1/billing_plan/test_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from api.v1.models.user import User
from api.v1.models.billing_plan import UserSubscription, BillingPlan
from main import app
from fastapi import status
from api.v1.services.user import user_service
from api.db.database import get_db
from datetime import datetime, timezone, timedelta
Expand Down Expand Up @@ -78,4 +79,18 @@ async def test_subscribe_user_to_plan(mock_db_session, mock_subscribe_user_to_pl
# Assertions
assert response.user_id == user_id
assert response.plan_id == plan_id
assert response.organisation_id == org_id
assert response.organisation_id == org_id


@pytest.mark.usefixtures("mock_db_session", "mock_user_service")
def test_fetch_invalid_billing_plans(mock_user_service, mock_db_session):
"""Billing plan fetch test."""
mock_user = create_mock_user(mock_user_service, mock_db_session)
access_token = user_service.create_access_token(user_id=str(uuid7()))

response = client.get(
"/api/v1/payment/plans",
headers={"Authorization": f"Bearer {access_token}"},
)
print(response.json())
assert response.status_code == 404

0 comments on commit 8cc5e25

Please sign in to comment.