diff --git a/.env.sample b/.env.sample index bf53fdac0..72cfca777 100644 --- a/.env.sample +++ b/.env.sample @@ -34,5 +34,8 @@ TWILIO_PHONE_NUMBER="TWILIO_PHONE_NUMBER" FLUTTERWAVE_SECRET="" PAYSTACK_SECRET="" +STRIPE_SECRET_KEY="" +STRIPE_WEBHOOK_SECRET="" + MAILJET_API_KEY='MAIL JET API KEY' MAILJET_API_SECRET='SECRET KEY' diff --git a/alembic/versions/9a4e3d412f8e_updated_billing_model_ensuing_that_plan_.py b/alembic/versions/9a4e3d412f8e_updated_billing_model_ensuing_that_plan_.py new file mode 100644 index 000000000..d2b84df60 --- /dev/null +++ b/alembic/versions/9a4e3d412f8e_updated_billing_model_ensuing_that_plan_.py @@ -0,0 +1,50 @@ +"""updated billing model ensuing that plan name is unique + +Revision ID: 9a4e3d412f8e +Revises: af8459ffc616 +Create Date: 2024-08-11 21:16:54.902038 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9a4e3d412f8e' +down_revision: Union[str, None] = 'af8459ffc616' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user_subscriptions', + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('plan_id', sa.String(), nullable=False), + sa.Column('organisation_id', sa.String(), nullable=False), + sa.Column('active', sa.Boolean(), nullable=True), + sa.Column('start_date', sa.String(), nullable=False), + sa.Column('end_date', sa.String(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['plan_id'], ['billing_plans.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_subscriptions_id'), 'user_subscriptions', ['id'], unique=False) + op.create_unique_constraint(None, 'billing_plans', ['name']) + op.add_column('user_organisation_roles', sa.Column('status', sa.String(length=20), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('user_organisation_roles', 'status') + op.drop_constraint(None, 'billing_plans', type_='unique') + op.drop_index(op.f('ix_user_subscriptions_id'), table_name='user_subscriptions') + op.drop_table('user_subscriptions') + # ### end Alembic commands ### diff --git a/api/v1/models/billing_plan.py b/api/v1/models/billing_plan.py index 784296197..a59b96c09 100644 --- a/api/v1/models/billing_plan.py +++ b/api/v1/models/billing_plan.py @@ -1,5 +1,5 @@ # app/models/billing_plan.py -from sqlalchemy import Column, String, ARRAY, ForeignKey, Numeric +from sqlalchemy import Column, String, ARRAY, ForeignKey, Numeric, Boolean from sqlalchemy.orm import relationship from api.v1.models.base_model import BaseTableModel @@ -10,7 +10,7 @@ class BillingPlan(BaseTableModel): organisation_id = Column( String, ForeignKey("organisations.id", ondelete="CASCADE"), nullable=False ) - name = Column(String, nullable=False) + name = Column(String, nullable=False, unique=True) price = Column(Numeric, nullable=False) currency = Column(String, nullable=False) duration = Column(String, nullable=False) @@ -18,3 +18,19 @@ class BillingPlan(BaseTableModel): features = Column(ARRAY(String), nullable=False) organisation = relationship("Organisation", back_populates="billing_plans") + user_subscriptions = relationship("UserSubscription", back_populates="billing_plan", cascade="all, delete-orphan") + + +class UserSubscription(BaseTableModel): + __tablename__ = "user_subscriptions" + + user_id = Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + plan_id = Column(String, ForeignKey("billing_plans.id", ondelete="CASCADE"), nullable=False) + organisation_id = Column(String, ForeignKey("organisations.id", ondelete="CASCADE"), nullable=False) + active = Column(Boolean, default=True) + start_date = Column(String, nullable=False) + end_date = Column(String, nullable=True) + + user = relationship("User", back_populates="subscriptions") + billing_plan = relationship("BillingPlan", back_populates="user_subscriptions") + organisation = relationship("Organisation", back_populates="user_subscriptions") diff --git a/api/v1/models/organisation.py b/api/v1/models/organisation.py index ee44cd9b7..baa3d2cb7 100644 --- a/api/v1/models/organisation.py +++ b/api/v1/models/organisation.py @@ -27,15 +27,7 @@ class Organisation(BaseTableModel): products = relationship("Product", back_populates="organisation", cascade="all, delete-orphan") contact_us = relationship("ContactUs", back_populates="organisation", cascade="all, delete-orphan") - billing_plans = relationship( - "BillingPlan", back_populates="organisation", cascade="all, delete-orphan" - ) - invitations = relationship( - "Invitation", back_populates="organisation", cascade="all, delete-orphan" - ) - products = relationship( - "Product", back_populates="organisation", cascade="all, delete-orphan" - ) + user_subscriptions = relationship("UserSubscription", back_populates="organisation", cascade="all, delete-orphan") sales = relationship('Sales', back_populates='organisation', cascade='all, delete-orphan') def __str__(self): diff --git a/api/v1/models/permissions/user_org_role.py b/api/v1/models/permissions/user_org_role.py index 32d84eaf4..52248728c 100644 --- a/api/v1/models/permissions/user_org_role.py +++ b/api/v1/models/permissions/user_org_role.py @@ -6,5 +6,6 @@ Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True), Column("organisation_id", String, ForeignKey("organisations.id", ondelete="CASCADE"), primary_key=True), Column('role_id', String, ForeignKey('roles.id', ondelete='CASCADE'), nullable=True), - Column('is_owner', Boolean, server_default='false') + Column('is_owner', Boolean, server_default='false'), + Column('status', String(20), nullable=False, default="active") ) \ No newline at end of file diff --git a/api/v1/models/user.py b/api/v1/models/user.py index d1d5f74c0..13d1ca9cd 100644 --- a/api/v1/models/user.py +++ b/api/v1/models/user.py @@ -86,6 +86,10 @@ class User(BaseTableModel): ) product_comments = relationship("ProductComment", back_populates="user", cascade="all, delete-orphan") + subscriptions = relationship( + "UserSubscription", back_populates="user", cascade="all, delete-orphan" + ) + def to_dict(self): obj_dict = super().to_dict() obj_dict.pop("password") diff --git a/api/v1/routes/__init__.py b/api/v1/routes/__init__.py index 5de9042b6..05671c8f5 100644 --- a/api/v1/routes/__init__.py +++ b/api/v1/routes/__init__.py @@ -43,6 +43,7 @@ from api.v1.routes.privacy import privacies from api.v1.routes.settings import settings from api.v1.routes.terms_and_conditions import terms_and_conditions +from api.v1.routes.stripe import subscription_ api_version_one = APIRouter(prefix="/api/v1") @@ -89,3 +90,4 @@ api_version_one.include_router(team) api_version_one.include_router(terms_and_conditions) api_version_one.include_router(product_comment) +api_version_one.include_router(subscription_) \ No newline at end of file diff --git a/api/v1/routes/stripe.py b/api/v1/routes/stripe.py new file mode 100644 index 000000000..a9b597f66 --- /dev/null +++ b/api/v1/routes/stripe.py @@ -0,0 +1,85 @@ +from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy.orm import Session +import stripe +from api.v1.services.stripe_payment import stripe_payment_request, update_user_plan, fetch_all_organisations_with_users_and_plans +import json +from api.v1.schemas.stripe import PlanUpgradeRequest +from typing import List +from api.db.database import get_db +import os +from api.utils.success_response import success_response +from api.v1.models.user import User +from api.v1.services.user import user_service +from dotenv import load_dotenv, find_dotenv + +load_dotenv(find_dotenv()) + +stripe.api_key = os.getenv('STRIPE_SECRET_KEY') +endpoint_secret = os.getenv('STRIPE_WEBHOOK_SECRET') + +subscription_ = APIRouter(prefix="/payment", tags=["subscribe-plan"]) + +@subscription_.post("/stripe/upgrade-plan") +def stripe_payment( + plan_upgrade_request: PlanUpgradeRequest, + request: Request, + db: Session = Depends(get_db), + current_user: User = Depends(user_service.get_current_user) + ): + return stripe_payment_request(db, plan_upgrade_request.user_id, request, plan_upgrade_request.plan_name) + +@subscription_.get("/stripe/success") +def success_upgrade(): + return {"message" : "Payment successful"} + +@subscription_.get("/stripe/cancel") +def cancel_upgrade(): + return {"message" : "Payment canceled"} + +@subscription_.post("/webhook") +async def webhook_received( + request: Request, + db: Session = Depends(get_db) + ): + + payload = await request.body() + event = None + + try: + event = stripe.Event.construct_from(json.loads(payload), stripe.api_key) + except ValueError as e: + print("Invalid payload") + raise HTTPException(status_code=400, detail="Invalid payload") + except stripe.error.SignatureVerificationError as e: + print("Invalid signature") + raise HTTPException(status_code=400, detail="Invalid signature") + + if event["type"] == "checkout.session.completed": + payment = event["data"]["object"] + response_details = { + "amount": payment["amount_total"], + "currency": payment["currency"], + "user_id": payment["metadata"]["user_id"], + "user_email": payment["customer_details"]["email"], + "user_name": payment["customer_details"]["name"], + "order_id": payment["id"] + } + # Save to DB + # Send email in background task + await update_user_plan(db, payment["metadata"]["user_id"], payment["metadata"]["plan_name"]) + return {"message": response_details} + + +@subscription_.get("/organisations/users/plans") +async def get_organisations_with_users_and_plans(db: Session = Depends(get_db), current_user: User = Depends(user_service.get_current_super_admin)): + try: + data = fetch_all_organisations_with_users_and_plans(db) + if not data: + return {"status_code": 404, "success": False, "message": "No data found"} + return success_response( + status_code=status.HTTP_302_FOUND, + message='billing details successfully retrieved', + data=data, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/api/v1/schemas/stripe.py b/api/v1/schemas/stripe.py new file mode 100644 index 000000000..1c445fa69 --- /dev/null +++ b/api/v1/schemas/stripe.py @@ -0,0 +1,30 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field, validator + + +class PaymentInfo(BaseModel): + card_number: str = Field(..., min_length=16, max_length=16) + exp_month: int + exp_year: int + cvc: str = Field(..., min_length=3, max_length=4) + + @validator('card_number') + def card_number_validator(cls, v): + if not v.isdigit() or len(v) != 16: + raise ValueError('Card number must be 16 digits') + return v + + @validator('cvc') + def cvc_validator(cls, v): + if not v.isdigit() or not (3 <= len(v) <= 4): + raise ValueError('CVC must be 3 or 4 digits') + return v + + +class PlanUpgradeRequest(BaseModel): + user_id: str + plan_name: str + payment_info: Optional[PaymentInfo] = None + + diff --git a/api/v1/services/stripe_payment.py b/api/v1/services/stripe_payment.py new file mode 100644 index 000000000..4c138ff90 --- /dev/null +++ b/api/v1/services/stripe_payment.py @@ -0,0 +1,194 @@ +from sqlalchemy.orm import Session +from api.v1.models.user import User +from api.v1.models.billing_plan import BillingPlan, UserSubscription +from api.v1.models.organisation import Organisation +import stripe +from sqlalchemy import select, join +from fastapi.encoders import jsonable_encoder +from api.utils.success_response import success_response +import os +from fastapi import HTTPException, status, Request +from datetime import datetime, timedelta + +stripe.api_key = os.getenv('STRIPE_SECRET_KEY') + +def get_plan_by_name(db: Session, plan_name: str): + return db.query(BillingPlan).filter(BillingPlan.name == plan_name).first() + +def stripe_payment_request(db: Session, user_id: str, request: Request, plan_name: str): + + base_url = request.base_url + + success_url = f"{base_url}api/v1/payment/stripe/success" + cancel_url = f"{base_url}api/v1/payment/stripe/cancel" + + user = db.query(User).filter(User.id == user_id).first() + + if not user: + raise HTTPException(status_code=404, detail="User not found") + + plan = get_plan_by_name(db, plan_name) + + if not plan: + raise HTTPException(status_code=404, detail="Plan not found") + + if plan.name != "Free": + try: + # Create a checkout session + checkout_session = stripe.checkout.Session.create( + payment_method_types=['card'], + line_items=[{ + 'price_data': { + 'currency': plan.currency, + 'product_data': { + 'name': plan.name, + }, + 'unit_amount': int(plan.price * 100), # Convert to the smallest unit + }, + 'quantity': 1, + }], + mode='payment', + customer_email=user.email, # Automatically fill in the user's email in the checkout + success_url=success_url, + cancel_url=cancel_url, + metadata={ + 'user_id': user_id, + 'plan_name': plan_name, + }, + ) + + if checkout_session: + data = { + "cancel_url": checkout_session["cancel_url"], + "success_url": checkout_session["success_url"], + "customer_details": checkout_session["customer_details"], + "customer_email": checkout_session["customer_email"], + "created_at": checkout_session["created"], + "expires_at": checkout_session["expires_at"], + "metadata": checkout_session["metadata"], + "payment_method_types": checkout_session["payment_method_types"], + "checkout_url": checkout_session["url"], + "amount_total": checkout_session["amount_total"] + } + + return success_response( + status_code=status.HTTP_201_CREATED, + message='payment in progress', + data=data, + ) + + except stripe.error.StripeError as e: + # Handle Stripe error + raise HTTPException(status_code=500, detail=f"Payment failed: {str(e)}") + + else: + raise HTTPException(status_code=400, detail="No payment is required for the Free plan") + + +def convert_duration_to_timedelta(duration: str) -> timedelta: + if duration == "monthly": + return timedelta(days=30) # Approximate month length + elif duration == "yearly": + return timedelta(days=365) # Approximate year length + else: + raise ValueError("Invalid duration") + + +async def update_user_plan(db: Session, user_id: str, plan_name: str): + # Fetch the user by ID + user = db.query(User).filter(User.id == user_id).first() + + # Fetch the plan by name + plan = get_plan_by_name(db, plan_name) + + # Check if the user exists + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Check if the plan exists + if not plan: + raise HTTPException(status_code=404, detail="Plan not found") + + # Convert duration from string to timedelta + try: + duration = convert_duration_to_timedelta(plan.duration) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Fetch the organisation ID from the plan + organisation_id = plan.organisation_id + + # Update the user's subscription in the database + user_subscription = db.query(UserSubscription).filter( + UserSubscription.user_id == user_id, + UserSubscription.organisation_id == organisation_id + ).first() + + if user_subscription: + user_subscription.plan_id = plan.id + user_subscription.start_date = datetime.utcnow() + user_subscription.end_date = datetime.utcnow() + duration + else: + user_subscription = UserSubscription( + user_id=user_id, + plan_id=plan.id, + organisation_id=organisation_id, + start_date=datetime.utcnow(), + end_date=datetime.utcnow() + duration + ) + db.add(user_subscription) + + # Commit the transaction + db.commit() + db.refresh(user_subscription) # Refresh the session to get the updated data + + # Return the updated or newly created subscription + return user_subscription + + +def fetch_all_organisations_with_users_and_plans(db: Session): + # Perform a join to retrieve the relevant data + stmt = ( + select( + Organisation.id, + Organisation.name, + User.id.label("user_id"), + (User.first_name + " " + User.last_name).label("user_name"), + BillingPlan.name.label("plan_name"), + BillingPlan.price, + BillingPlan.currency, + BillingPlan.duration, + UserSubscription.start_date, + UserSubscription.end_date + ) + .join(UserSubscription, Organisation.id == UserSubscription.organisation_id) + .join(User, User.id == UserSubscription.user_id) + .join(BillingPlan, BillingPlan.id == UserSubscription.plan_id) + ) + + result = db.execute(stmt).all() + + # Organize the data by organizations, users, and their plans + organizations_data = {} + for row in result: + org_id = row.id + if org_id not in organizations_data: + organizations_data[org_id] = { + "organisation_name": row.name, + "users": [] + } + + user_info = { + "user_id": row.user_id, + "user_name": row.user_name, + "plan_name": row.plan_name, + "price": row.price, + "currency": row.currency, + "duration": row.duration, + "start_date": row.start_date, + "end_date": row.end_date + } + + organizations_data[org_id]["users"].append(user_info) + + return list(organizations_data.values()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c03d6bba9..48a67aac4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -110,3 +110,4 @@ watchfiles==0.22.0 webencodings==0.5.1 websockets==12.0 yarl==1.9.4 +stripe==10.7.0 diff --git a/t b/t new file mode 100644 index 000000000..46ba6e06f --- /dev/null +++ b/t @@ -0,0 +1,8 @@ + bugfix/billing_test + dev + feat/accept_invite + feat/expose_test_cases + feat/mail_service + feat/request_password_reset + feat/roles_permissions +* feat/stripe_checkout diff --git a/tests/v1/billing_plan/test_stripe.py b/tests/v1/billing_plan/test_stripe.py new file mode 100644 index 000000000..590b65643 --- /dev/null +++ b/tests/v1/billing_plan/test_stripe.py @@ -0,0 +1,81 @@ +import pytest +from unittest.mock import MagicMock, patch +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session +from api.v1.models.user import User +from api.v1.models.billing_plan import UserSubscription, BillingPlan +from main import app +from api.v1.services.user import user_service +from api.db.database import get_db +from datetime import datetime, timezone, timedelta +from uuid_extensions import uuid7 +from api.v1.services.stripe_payment import update_user_plan, fetch_all_organisations_with_users_and_plans + +client = TestClient(app) + +# Mock Data +email = "test@gmail.com" +user_id = "user_123" +plan_id = "plan_123" +org_id = "org_123" +start_date = datetime.utcnow() +end_date = start_date + timedelta(days=30) + +mock_user = User(id=user_id, email=email, first_name="Mike", last_name="Zeus", is_superadmin=True) +mock_plan = BillingPlan(id=plan_id, name="Premium", price=29.99, currency="USD", duration="monthly", organisation_id=org_id) +mock_subscription = UserSubscription(user_id=user_id, plan_id=plan_id, organisation_id=org_id, start_date=start_date, end_date=end_date) + +@pytest.fixture +def mock_db_session(): + session = MagicMock(spec=Session) + session.query().filter().first.side_effect = lambda: { + User: mock_user, + BillingPlan: mock_plan, + UserSubscription: mock_subscription + }[session.query.call_args[0][0]] + return session + +@pytest.fixture +def mock_subscribe_user_to_plan(): + with patch("api.v1.services.stripe_payment.update_user_plan") as mock_service: + yield mock_service + +@pytest.fixture +def mock_user_service(): + """Fixture to create a mock user service.""" + with patch("api.v1.services.user.user_service", autospec=True) as mock_service: + yield mock_service + +def create_mock_user(mock_user_service, mock_db_session): + """Create a mock user in the mock database session.""" + mock_user = User( + id=user_id, + email="testuser@gmail.com", + password=user_service.hash_password("Testpassword@123"), + first_name="Test", + last_name="User", + is_active=True, + is_superadmin=True, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + mock_db_session.query.return_value.filter.return_value.first.return_value = mock_user + return mock_user + +@pytest.fixture +def mock_fetch_all_organisations_with_users_and_plans(): + with patch("api.v1.services.stripe_payment.fetch_all_organisations_with_users_and_plans") as mock_service: + yield mock_service + +@pytest.mark.asyncio +async def test_subscribe_user_to_plan(mock_db_session, mock_subscribe_user_to_plan): + # Mock the behavior of the service function + mock_subscribe_user_to_plan.return_value = mock_subscription + + # Call the actual service function + response = await update_user_plan(mock_db_session, user_id=user_id, plan_name="Premium") + + # Assertions + assert response.user_id == user_id + assert response.plan_id == plan_id + assert response.organisation_id == org_id \ No newline at end of file