diff --git a/api/v1/routes/faq.py b/api/v1/routes/faq.py index fd9a34f66..ab6573180 100644 --- a/api/v1/routes/faq.py +++ b/api/v1/routes/faq.py @@ -21,23 +21,22 @@ async def get_all_faqs( keyword: Optional[str] = Query(None, min_length=1) ): """Endpoint to get all FAQs or search by keyword in both question and answer""" - + query_params = {} if keyword: - """Search by both question and answer fields""" query_params["question"] = keyword query_params["answer"] = keyword - faqs = faq_service.fetch_all(db=db, **query_params) + grouped_faqs = faq_service.fetch_all_grouped_by_category( + db=db, **query_params) return success_response( status_code=200, message="FAQs retrieved successfully", - data=jsonable_encoder(faqs), + data=jsonable_encoder(grouped_faqs), ) - @faq.post("", response_model=success_response, status_code=201) async def create_faq( schema: CreateFAQ, diff --git a/api/v1/services/faq.py b/api/v1/services/faq.py index 06f19ec38..161913246 100644 --- a/api/v1/services/faq.py +++ b/api/v1/services/faq.py @@ -19,6 +19,26 @@ def create(self, db: Session, schema: CreateFAQ): return new_faq + def fetch_all_grouped_by_category(self, db: Session, **query_params: Optional[Any]): + """Fetch all FAQs grouped by category""" + query = db.query(FAQ.category, FAQ.question, FAQ.answer) + + if query_params: + for column, value in query_params.items(): + if hasattr(FAQ, column) and value: + query = query.filter( + getattr(FAQ, column).ilike(f"%{value}%")) + faqs = query.order_by(FAQ.category).all() + + grouped_faqs = {} + for faq in faqs: + if faq.category not in grouped_faqs: + grouped_faqs[faq.category] = [] + grouped_faqs[faq.category].append( + {"question": faq.question, "answer": faq.answer}) + + return grouped_faqs + def fetch_all(self, db: Session, **query_params: Optional[Any]): """Fetch all FAQs with option to search using query parameters""" @@ -28,7 +48,8 @@ def fetch_all(self, db: Session, **query_params: Optional[Any]): if query_params: for column, value in query_params.items(): if hasattr(FAQ, column) and value: - query = query.filter(getattr(FAQ, column).ilike(f"%{value}%")) + query = query.filter( + getattr(FAQ, column).ilike(f"%{value}%")) return query.all() diff --git a/tests/v1/faq/get_all_faqs_test.py b/tests/v1/faq/get_all_faqs_test.py index 6f4df5704..2169947f1 100644 --- a/tests/v1/faq/get_all_faqs_test.py +++ b/tests/v1/faq/get_all_faqs_test.py @@ -1,11 +1,10 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session from api.db.database import get_db -from api.v1.models.faq import FAQ from api.v1.services.faq import faq_service from main import app @@ -15,6 +14,7 @@ def mock_db_session(): db_session = MagicMock(spec=Session) return db_session + @pytest.fixture def client(mock_db_session): app.dependency_overrides[get_db] = lambda: mock_db_session @@ -23,19 +23,24 @@ def client(mock_db_session): app.dependency_overrides = {} -def test_get_all_faqs(mock_db_session, client): - """Test to verify the pagination response for FAQs.""" - # Mock data - mock_faq_data = [ - FAQ(id=1, question="Question 1", answer="Answer 1"), - FAQ(id=2, question="Question 2", answer="Answer 2"), - FAQ(id=3, question="Question 3", answer="Answer 3"), - ] +def test_get_all_faqs_grouped_by_category(mock_db_session, client): + """Test to verify the response for FAQs grouped by category.""" + + mock_faq_data_grouped = { + "General": [ + {"question": "What is FastAPI?", + "answer": "FastAPI is a modern web framework for Python."}, + {"question": "What is SQLAlchemy?", + "answer": "SQLAlchemy is a SQL toolkit and ORM for Python."} + ], + "Billing": [ + {"question": "How do I update my billing information?", + "answer": "You can update your billing information in the account settings."} + ] + } - app.dependency_overrides[faq_service.fetch_all] = mock_faq_data + with patch.object(faq_service, 'fetch_all_grouped_by_category', return_value=mock_faq_data_grouped): - # Perform the GET request - response = client.get('/api/v1/faqs') + response = client.get('/api/v1/faqs') - # Verify the response assert response.status_code == 200