diff --git a/cruds/works/__init__.py b/cruds/works/__init__.py index 77aca5c..6d5e896 100644 --- a/cruds/works/__init__.py +++ b/cruds/works/__init__.py @@ -1,5 +1,6 @@ from typing import List from fastapi import HTTPException +from sqlalchemy import desc, func from cruds.assets import delete_asset_by_id from cruds.url_infos import create_url_info, delete_url_info from db import models @@ -240,12 +241,13 @@ def get_works_by_user_id(db: Session, user_id: str, at_me: bool = False, auth: b return works def search_work_by_option(db: Session, limit: int, oldest_id: str, tags: list[str], auth: bool = False) -> list[Work]: - works_orm = db.query(models.Work).filter(models.Tagging.work_id == models.Work.id).filter(models.Tagging.tag_id == models.Tag.id).filter(models.Tag.id.in_(tags)).group_by(models.Work.id) + works_orm = db.query(models.Work).filter(models.Tagging.work_id == models.Work.id).filter(models.Tagging.tag_id == models.Tag.id).filter(models.Tag.id.in_(tags)) + works_orm = works_orm.group_by(models.Work.id).having(func.count(models.Work.id) == len(tags)) if auth: - works_orm = works_orm.filter(models.Work.visibility == models.Visibility.public) - else: works_orm = works_orm.filter(models.Work.visibility != models.Visibility.draft) + else: + works_orm = works_orm.filter(models.Work.visibility == models.Visibility.public) if oldest_id: oldest_work = db.query(models.Work).filter(models.Work.id == oldest_id).first() @@ -253,7 +255,7 @@ def search_work_by_option(db: Session, limit: int, oldest_id: str, tags: list[st raise HTTPException(status_code=400, detail='this oldest_id is invalid') works_orm = works_orm.filter(models.Work.created_at < oldest_work.created_at) - works_orm = works_orm.limit(limit).all() + works_orm = works_orm.order_by(desc(models.Work.created_at)).limit(limit).all() works = list(map(Work.from_orm, works_orm)) diff --git a/routers/works/__init__.py b/routers/works/__init__.py index 4156d7e..9835189 100644 --- a/routers/works/__init__.py +++ b/routers/works/__init__.py @@ -23,6 +23,12 @@ async def get_works(limit: int = 30, oldest_id: str = None, db: Session = Depend works = get_works_by_limit(db, limit, oldest_id, auth=auth) return works +@work_router.get('/search', response_model=list[Work]) +async def search_work(payload: SearchOption, limit: int = 30, oldest_id: str = None, user: User = Depends(GetCurrentUser(auto_error=False)), db: Session = Depends(get_db)): + auth = user is not None + works = search_work_by_option(db, limit, oldest_id, payload.tags, auth) + return works + @work_router.get('/{work_id}', response_model=Work) async def get_work(work_id: str, db: Session = Depends(get_db), user: User = Depends(GetCurrentUser(auto_error=False))): auth = user is not None @@ -40,9 +46,3 @@ async def put_work(work_id: str, payload: PostWork, db: Session = Depends(get_db async def delete_work(work_id: str, db: Session = Depends(get_db)): result = delete_work_by_id(db, work_id) return result - -@work_router.get('/search', response_model=list[Work]) -async def search_work(payload: SearchOption, limit: int = 30, oldest_id: str = None, user: User = Depends(GetCurrentUser(auto_error=False)), db: Session = Depends(get_db)): - auth = user is not None - works = search_work_by_option(db, limit, oldest_id, payload.tags, auth) - return works diff --git a/tests/fixtures.py b/tests/fixtures.py index 0ad650a..890c8fd 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,31 +1,26 @@ -from fastapi.datastructures import UploadFile +import os +import json import sqlalchemy -import pytest -from sqlalchemy.orm import sessionmaker import sqlalchemy_utils +import pytest +from datetime import timedelta +from typing import Callable, List, Optional from fastapi.testclient import TestClient -from pytest import fixture +from fastapi.datastructures import UploadFile +from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.session import Session -from sqlalchemy_utils.view import refresh_materialized_view -from cruds.works import set_work - -from db.models import User -from schemas.url_info import BaseUrlInfo -from schemas.user import User as UserSchema, Token as TokenSchema, TokenResponse as TokenResponseSchema -from schemas.work import Work as WorkSchema -from schemas.asset import Asset as AssetSchema -from schemas.tag import GetTag as TagSchema -from db import Base, get_db from main import app -import os -from datetime import timedelta -from cruds.users import auth -# from cruds.works import create_work +from db import Base, get_db +from db.models import User, Visibility from cruds.assets import create_asset -from typing import Callable, List, Optional from cruds.tags.tag import create_tag - -import json +from cruds.users import auth +from cruds.works import set_work +from schemas.asset import Asset as AssetSchema +from schemas.tag import GetTag as TagSchema +from schemas.url_info import BaseUrlInfo +from schemas.user import User as UserSchema, TokenResponse as TokenResponseSchema +from schemas.work import Work as WorkSchema DATABASE = 'postgresql' USER = os.environ.get('POSTGRES_USER') @@ -51,7 +46,6 @@ def use_test_db_fixture(): get_db関数をテストDBで上書きする """ if not sqlalchemy_utils.database_exists(DATABASE_URL): - print('[INFO] CREATE DATABASE') sqlalchemy_utils.create_database(DATABASE_URL) # Reset test tables @@ -170,7 +164,7 @@ def work_for_test( session_for_test: Session = session_for_test, title: str = 'WorkTitleForTest', description: str = 'this work is test', - visibility: str = 'public', + visibility: str = Visibility.public, exist_thumbnail: bool = False, asset_types: List[str] = ['image'], urls: List[BaseUrlInfo] = [], @@ -244,4 +238,19 @@ def tag_for_test( Create test tag """ c = create_tag(session_for_test, name, color) - return c \ No newline at end of file + return c + +@pytest.fixture +def tag_factory_for_test( + session_for_test: Session, +) -> Callable[[str, str], TagSchema]: + def tag_for_test( + name: str = "test_tag", + color: str = "#FFFFFF" + ) -> TagSchema: + """ + Create test tag + """ + c = create_tag(session_for_test, name, color) + return c + return tag_for_test diff --git a/tests/test_work.py b/tests/test_work.py index 4e36157..09190c0 100644 --- a/tests/test_work.py +++ b/tests/test_work.py @@ -1,6 +1,8 @@ import json import pytest -from .fixtures import client, use_test_db_fixture, session_for_test, user_factory_for_test, user_token_factory_for_test, asset_factory_for_test, user_for_test, tag_for_test, work_factory_for_test + +from db.models import Visibility +from .fixtures import client, use_test_db_fixture, session_for_test, user_factory_for_test, user_token_factory_for_test, asset_factory_for_test, user_for_test, tag_for_test, work_factory_for_test, tag_factory_for_test @pytest.mark.usefixtures('use_test_db_fixture') class TestWork: @@ -360,7 +362,7 @@ def test_post_work_about_url(use_test_db_fixture, user_token_factory_for_test, a # 複数のWorkを取得する # """ - # def test_get_works_pagenation(use_test_db_fixture): + # def test_get_works_pagination(use_test_db_fixture): # """ # Work取得のページネーションを確認する # """ @@ -428,4 +430,140 @@ def test_get_not_exist_user_works(use_test_db_fixture, user_token_factory_for_te "Authorization": f"Bearer { token.access_token }" }) - assert res.status_code == 404, '作品の取得に失敗する' \ No newline at end of file + assert res.status_code == 404, '作品の取得に失敗する' + + def test_search_works_by_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test): + """ + タグから作品を絞り込んで検索する + """ + token = user_token_factory_for_test() + + test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030') + test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30') + test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff') + test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099') + test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0') + + work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id]) + work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id]) + work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id]) + + res = client.get('/api/v1/works/search', headers={ + 'Authorization': f'Bearer { token.access_token }' + }, json={ + 'tags': [ + test_tag1.id + ] + } + ) + + assert res.status_code == 200 + res_json = res.json() + assert len(res_json) == 2 + assert res_json[0].get('title') == work3.title + assert res_json[1].get('title') == work1.title + + res = client.get('/api/v1/works/search', headers={ + 'Authorization': f'Bearer { token.access_token }' + }, json={ + 'tags': [ + test_tag5.id + ] + } + ) + + assert res.status_code == 200 + res_json = res.json() + assert len(res_json) == 3 + assert res_json[0].get('title') == work3.title + assert res_json[1].get('title') == work2.title + assert res_json[2].get('title') == work1.title + + def test_search_works_by_some_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test): + """ + 複数のタグから作品を絞り込んで検索する + """ + token = user_token_factory_for_test() + + test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030') + test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30') + test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff') + test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099') + test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0') + + work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id]) + work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id]) + work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id]) + + res = client.get('/api/v1/works/search', headers={ + 'Authorization': f'Bearer { token.access_token }' + }, json={ + 'tags': [ + test_tag1.id, + test_tag4.id + ] + } + ) + + assert res.status_code == 200 + res_json = res.json() + assert len(res_json) == 1 + assert res_json[0].get('title') == work3.title + + def test_search_works_by_tag_without_auth(use_test_db_fixture, tag_factory_for_test, work_factory_for_test): + """ + 認証無しでタグから作品を絞り込んで検索する + """ + test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030') + test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30') + test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff') + test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099') + test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0') + + work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id]) + work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id]) + work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id]) + + res = client.get('/api/v1/works/search', json={ + 'tags': [ + test_tag1.id + ] + } + ) + + assert res.status_code == 200 + res_json = res.json() + assert len(res_json) == 1 + assert res_json[0].get('title') == work1.title + + def test_search_works_by_strict_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test): + """ + 存在しない条件でタグから作品を絞り込んで検索する + """ + token = user_token_factory_for_test() + + test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030') + test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30') + test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff') + test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099') + test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0') + + work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id]) + work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id]) + work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id]) + + res = client.get('/api/v1/works/search', headers={ + 'Authorization': f'Bearer { token.access_token }' + }, json={ + 'tags': [ + test_tag1.id, + test_tag2.id, + test_tag4.id + ] + } + ) + + assert res.status_code == 200 + res_json = res.json() + assert len(res_json) == 0 + assert res_json == []