Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add getting notifications endpoint #76

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/v1/models/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Notification(AbstractBaseModel):
status: Mapped[str] = mapped_column(
SQLAlchemyEnum(NotificationStatus), server_default="unread"
)

user = relationship("User", back_populates="notifications")

def __str__(self):
Expand Down
2 changes: 2 additions & 0 deletions api/v1/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from api.v1.routes.user import users
from api.v1.routes.post import posts
from api.v1.routes.post_comment import comments
from api.v1.routes.notification import notifications

# version 1 routes

Expand All @@ -12,3 +13,4 @@
version_one.include_router(users)
version_one.include_router(posts)
version_one.include_router(comments)
version_one.include_router(notifications)
32 changes: 32 additions & 0 deletions api/v1/routes/notification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from api.v1.models.user import User
from api.v1.services.user import user_service
from api.v1.services.notification import notification_service
from api.v1.utils.dependencies import get_db
from api.v1.responses.success_response import success_response
from typing import List

notifications = APIRouter(prefix="/notifications", tags=["notification"])


@notifications.get("/sse")
async def sse_endpoint(user: User = Depends(user_service.get_current_user)):
return StreamingResponse(
notification_service.event_generator(user.id), media_type="text/event_stream"
)


@notifications.get("")
async def user_notifications(
user: User = Depends(user_service.get_current_user), db: Session = Depends(get_db)
):

notifications: List = notification_service.notifications(user=user, db=db)

return success_response(
status_code=200,
message="Notifications returned successfully",
data=notifications,
)
10 changes: 6 additions & 4 deletions api/v1/routes/post.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, status, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, status, WebSocket, WebSocketDisconnect, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List

Expand Down Expand Up @@ -96,15 +96,16 @@ async def websocket_post_endpoint(websocket: WebSocket):
@posts.patch("/{id}/like", status_code=status.HTTP_200_OK)
async def like_post(
id: str,
background_task: BackgroundTasks = BackgroundTasks(),
db: Session = Depends(get_db),
user: User = Depends(user_service.get_current_user),
):

liked_post = post_service.like_post(db=db, user=user, post_id=id)
liked_post = post_service.like_post(db=db, user=user, post_id=id, background_task=background_task)

return success_response(
status_code=status.HTTP_200_OK,
message="Post updated successfully",
message="Post liked successfully",
)


Expand All @@ -128,11 +129,12 @@ async def get_likes(
async def repost(
id: str,
schema: RepostCreate,
background_task: BackgroundTasks = BackgroundTasks(),
db: Session = Depends(get_db),
user: User = Depends(user_service.get_current_user),
):

repost = post_service.repost(db=db, post_id=id, schema=schema, user=user)
repost = post_service.repost(db=db, post_id=id, schema=schema, user=user, background_task=background_task)

manager.broadcast(repost)

Expand Down
5 changes: 3 additions & 2 deletions api/v1/routes/post_comment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends, status, BackgroundTasks
from sqlalchemy.orm import Session
from api.v1.schemas.post_comment import CreateCommentSchema, CommentResponse
from api.v1.schemas.user import UserResponse
Expand Down Expand Up @@ -34,12 +34,13 @@ async def get_comments(
async def create_comment(
post_id: str,
comment: CreateCommentSchema,
background_task: BackgroundTasks = BackgroundTasks(),
db: Session = Depends(get_db),
user: User = Depends(user_service.get_current_user),
):

new_comment: CommentResponse = comment_service.create(
db=db, user=user, post_id=post_id, schema=comment
db=db, user=user, post_id=post_id, schema=comment, background_task=background_task
)

return success_response(
Expand Down
9 changes: 6 additions & 3 deletions api/v1/routes/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Query, status
from fastapi import APIRouter, Depends, Query, status, BackgroundTasks
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from api.v1.models.user import User
Expand Down Expand Up @@ -67,11 +67,13 @@ async def get_users(search: str = "", db: Session = Depends(get_db)):
@users.patch("/{followee_id}/follow", summary="Follow a particular user")
async def follow(
followee_id: str,
background_task: BackgroundTasks = BackgroundTasks(),
user: User = Depends(user_service.get_current_user),
db: Session = Depends(get_db),
):

user_service.follow_user(db=db, user=user, user_id=followee_id)
user_service.follow_user(db=db, user=user, user_id=followee_id, background_task=background_task)

return success_response(
status_code=200,
message="User followed successfully",
Expand All @@ -81,11 +83,12 @@ async def follow(
@users.delete("/{followee_id}/unfollow", summary="Unfollow the user with the id")
async def unfollow(
followee_id: str,
background_task: BackgroundTasks = BackgroundTasks(),
user: User = Depends(user_service.get_current_user),
db: Session = Depends(get_db),
):

user_service.unfollow_user(db=db, user_id=followee_id, user=user)
user_service.unfollow_user(db=db, user_id=followee_id, user=user, background_task=background_task)

return success_response(
status_code=200,
Expand Down
30 changes: 30 additions & 0 deletions api/v1/services/notification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from fastapi import BackgroundTasks
from sqlalchemy.orm import Session
from typing import Dict
import asyncio
from api.v1.models.user import User
from api.v1.models.notification import Notification


class NotificationService:
def __init__(self):

self.user_event_queues: Dict[str, asyncio.Queue] = {}

async def event_generator(self, user_id: str):
if user_id not in self.user_event_queues:
self.user_event_queues[user_id] = asyncio.Queue()

while True:
event = await self.user_event_queues[user_id].get()
yield f"data: {event}"

def notifications(self, user: User, db: Session):
notifications = (
db.query(Notification).filter(Notification.user_id == user.id).all()
)

return notifications


notification_service = NotificationService()
39 changes: 34 additions & 5 deletions api/v1/services/post.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import HTTPException, status
from fastapi import HTTPException, status, BackgroundTasks
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session, joinedload
from api.v1.models.post import Post, Like
Expand All @@ -14,7 +14,8 @@
)
from api.v1.schemas.user import UserResponse
from api.v1.services.user import user_service

from api.v1.models.notification import Notification
from api.v1.services.notification import notification_service

class PostService:
def get_post(self, db: Session, user: User, post_id: str):
Expand Down Expand Up @@ -103,11 +104,11 @@ def update(self, db: Session, user: User, post_id: str, schema: UpdatePostSchema
return jsonable_encoder(post)


def like_post(self, db: Session, user: User, post_id: str):
def like_post(self, db: Session, user: User, post_id: str, background_task: BackgroundTasks):

# get the post
post = (
db.query(Post).filter(Post.id == post_id, Post.user_id == user.id).first()
db.query(Post).filter(Post.id == post_id).first()
)

if not post:
Expand All @@ -125,12 +126,32 @@ def like_post(self, db: Session, user: User, post_id: str):
if like:
db.delete(like)
db.commit()

# notification for unliking a post
notification = Notification(user_id=post.user_id, message=f"{user.username} recently unliked your post")

db.add(notification)
db.commit()

background_task.add_task(notification_service.user_event_queues[notification.user_id].put, notification.message)

else:
like = Like(user_id=user.id, post_id=post_id)
like.liked = True
db.add(like)
db.commit()

# add notification for like

notification = Notification(user_id=post.user_id, message=f"{user.username} recently liked your post")

db.add(notification)
db.commit()

# background task for sse notification

background_task.add_task(notification_service.user_event_queues[notification.user_id].put, notification.message)


def get_likes(self, db: Session, post_id: str, user: User):

Expand Down Expand Up @@ -162,7 +183,7 @@ def get_likes(self, db: Session, post_id: str, user: User):
return likes_response


def repost(self, db: Session, post_id: str, user: User, schema: RepostCreate):
def repost(self, db: Session, post_id: str, user: User, schema: RepostCreate, background_task: BackgroundTasks):

original_post = self.get_post(db=db, user=user, post_id=post_id)

Expand All @@ -189,6 +210,14 @@ def repost(self, db: Session, post_id: str, user: User, schema: RepostCreate):
new_post_response["user"] = jsonable_encoder(new_post_owner)
new_post_response["post"] = original_post_response

# repost notification
notification = Notification(user_id=original_post.user_id,essage=f"{user.username} shared your post")
db.add(notification)
db.commit()

# background task for notificatiom
background_task.add_task(notification_service.user_event_queues[notification.user_id].put, notification.message)

return RepostResponse(**new_post_response)


Expand Down
18 changes: 15 additions & 3 deletions api/v1/services/post_comment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import HTTPException, status
from fastapi import HTTPException, status, BackgroundTasks
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from api.v1.schemas.post_comment import (
Expand All @@ -10,7 +10,9 @@
from api.v1.models.post_comment import PostComment
from api.v1.models.post import Post
from api.v1.models.user import User
from api.v1.models.notification import Notification
from api.v1.services.user import user_service
from api.v1.services.notification import notification_service


class CommentService:
Expand All @@ -27,7 +29,7 @@ class CommentService:

# class methods
def create(
self, db: Session, user: User, post_id: str, schema: CreateCommentSchema
self, db: Session, user: User, post_id: str, schema: CreateCommentSchema, background_task: BackgroundTasks
):
schema_dict = schema.model_dump()

Expand All @@ -39,7 +41,7 @@ def create(

# get the post
post = (
db.query(Post).filter(Post.user_id == user.id, Post.id == post_id).first()
db.query(Post).filter(Post.id == post_id).first()
)

if not post:
Expand All @@ -58,6 +60,16 @@ def create(
encoded = jsonable_encoder(comment)
encoded["user"] = response_user

# Comment Notification
notification = Notification(user_id=post.user_id, message=f"{user.username} commented on your post")

db.add(notification)
db.commit()

# add background task to send notifcation
background_task.add_task(notification_service.user_event_queues[notification.user_id].put, notification.message)


return CommentResponse(**encoded)

def update(
Expand Down
12 changes: 9 additions & 3 deletions api/v1/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from api.v1.utils.dependencies import get_db

load_dotenv()
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, status, BackgroundTasks
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, text
Expand All @@ -22,6 +22,7 @@
from api.v1.schemas.user import UserCreate, UserUpdateSchema, UserResponse
from api.v1.utils.storage import upload
from api.v1.models.notification import Notification
from api.v1.services.notification import notification_service

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
hash_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
Expand Down Expand Up @@ -350,7 +351,7 @@ def fetch_all(self, db: Session, search: str = ""):

return jsonable_encoder(users, exclude={"password"})

def follow_user(self, db: Session, user_id: str, user: User):
def follow_user(self, db: Session, user_id: str, user: User, background_task: BackgroundTasks):

followee = db.query(User).filter(User.id == user_id).first()
if not followee:
Expand All @@ -369,7 +370,9 @@ def follow_user(self, db: Session, user_id: str, user: User):
db.add(notification)
db.commit()

def unfollow_user(self, db: Session, user_id: str, user: User):
background_task.add_task(notification_service.user_event_queues[notification.user_id].put, notification.message)

def unfollow_user(self, db: Session, user_id: str, user: User, background_task: BackgroundTasks):
user_to_unfollow = db.query(User).filter(User.id == user_id).first()

if not user_to_unfollow:
Expand All @@ -389,6 +392,9 @@ def unfollow_user(self, db: Session, user_id: str, user: User):
db.add(notification)
db.commit()

background_task.add_task(notification_service.user_event_queues[notification.user_id].put, notification.message)


def followers(self, db: Session, user: User):

followers = [
Expand Down
4 changes: 3 additions & 1 deletion api/v1/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import sys


sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
)

from fastapi import HTTPException, status
import pytest
Expand Down
Loading
Loading