Skip to content

Commit

Permalink
[add] テストケースの追加とそれに伴った修正
Browse files Browse the repository at this point in the history
  • Loading branch information
PigeonsHouse committed Jul 5, 2022
1 parent 5841f38 commit f8d0107
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 37 deletions.
10 changes: 6 additions & 4 deletions cruds/works/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
from fastapi import HTTPException
from sqlalchemy import desc, func
from cruds.assets import delete_asset_by_id
from cruds.url_infos import create_url_info, delete_url_info
from db import models
Expand Down Expand Up @@ -240,20 +241,21 @@ def get_works_by_user_id(db: Session, user_id: str, at_me: bool = False, auth: b
return works

def search_work_by_option(db: Session, limit: int, oldest_id: str, tags: list[str], auth: bool = False) -> list[Work]:
works_orm = db.query(models.Work).filter(models.Tagging.work_id == models.Work.id).filter(models.Tagging.tag_id == models.Tag.id).filter(models.Tag.id.in_(tags)).group_by(models.Work.id)
works_orm = db.query(models.Work).filter(models.Tagging.work_id == models.Work.id).filter(models.Tagging.tag_id == models.Tag.id).filter(models.Tag.id.in_(tags))
works_orm = works_orm.group_by(models.Work.id).having(func.count(models.Work.id) == len(tags))

if auth:
works_orm = works_orm.filter(models.Work.visibility == models.Visibility.public)
else:
works_orm = works_orm.filter(models.Work.visibility != models.Visibility.draft)
else:
works_orm = works_orm.filter(models.Work.visibility == models.Visibility.public)

if oldest_id:
oldest_work = db.query(models.Work).filter(models.Work.id == oldest_id).first()
if oldest_work is None:
raise HTTPException(status_code=400, detail='this oldest_id is invalid')
works_orm = works_orm.filter(models.Work.created_at < oldest_work.created_at)

works_orm = works_orm.limit(limit).all()
works_orm = works_orm.order_by(desc(models.Work.created_at)).limit(limit).all()

works = list(map(Work.from_orm, works_orm))

Expand Down
12 changes: 6 additions & 6 deletions routers/works/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ async def get_works(limit: int = 30, oldest_id: str = None, db: Session = Depend
works = get_works_by_limit(db, limit, oldest_id, auth=auth)
return works

@work_router.get('/search', response_model=list[Work])
async def search_work(payload: SearchOption, limit: int = 30, oldest_id: str = None, user: User = Depends(GetCurrentUser(auto_error=False)), db: Session = Depends(get_db)):
auth = user is not None
works = search_work_by_option(db, limit, oldest_id, payload.tags, auth)
return works

@work_router.get('/{work_id}', response_model=Work)
async def get_work(work_id: str, db: Session = Depends(get_db), user: User = Depends(GetCurrentUser(auto_error=False))):
auth = user is not None
Expand All @@ -40,9 +46,3 @@ async def put_work(work_id: str, payload: PostWork, db: Session = Depends(get_db
async def delete_work(work_id: str, db: Session = Depends(get_db)):
result = delete_work_by_id(db, work_id)
return result

@work_router.get('/search', response_model=list[Work])
async def search_work(payload: SearchOption, limit: int = 30, oldest_id: str = None, user: User = Depends(GetCurrentUser(auto_error=False)), db: Session = Depends(get_db)):
auth = user is not None
works = search_work_by_option(db, limit, oldest_id, payload.tags, auth)
return works
57 changes: 33 additions & 24 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
from fastapi.datastructures import UploadFile
import os
import json
import sqlalchemy
import pytest
from sqlalchemy.orm import sessionmaker
import sqlalchemy_utils
import pytest
from datetime import timedelta
from typing import Callable, List, Optional
from fastapi.testclient import TestClient
from pytest import fixture
from fastapi.datastructures import UploadFile
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy_utils.view import refresh_materialized_view
from cruds.works import set_work

from db.models import User
from schemas.url_info import BaseUrlInfo
from schemas.user import User as UserSchema, Token as TokenSchema, TokenResponse as TokenResponseSchema
from schemas.work import Work as WorkSchema
from schemas.asset import Asset as AssetSchema
from schemas.tag import GetTag as TagSchema
from db import Base, get_db
from main import app
import os
from datetime import timedelta
from cruds.users import auth
# from cruds.works import create_work
from db import Base, get_db
from db.models import User, Visibility
from cruds.assets import create_asset
from typing import Callable, List, Optional
from cruds.tags.tag import create_tag

import json
from cruds.users import auth
from cruds.works import set_work
from schemas.asset import Asset as AssetSchema
from schemas.tag import GetTag as TagSchema
from schemas.url_info import BaseUrlInfo
from schemas.user import User as UserSchema, TokenResponse as TokenResponseSchema
from schemas.work import Work as WorkSchema

DATABASE = 'postgresql'
USER = os.environ.get('POSTGRES_USER')
Expand All @@ -51,7 +46,6 @@ def use_test_db_fixture():
get_db関数をテストDBで上書きする
"""
if not sqlalchemy_utils.database_exists(DATABASE_URL):
print('[INFO] CREATE DATABASE')
sqlalchemy_utils.create_database(DATABASE_URL)

# Reset test tables
Expand Down Expand Up @@ -170,7 +164,7 @@ def work_for_test(
session_for_test: Session = session_for_test,
title: str = 'WorkTitleForTest',
description: str = 'this work is test',
visibility: str = 'public',
visibility: str = Visibility.public,
exist_thumbnail: bool = False,
asset_types: List[str] = ['image'],
urls: List[BaseUrlInfo] = [],
Expand Down Expand Up @@ -244,4 +238,19 @@ def tag_for_test(
Create test tag
"""
c = create_tag(session_for_test, name, color)
return c
return c

@pytest.fixture
def tag_factory_for_test(
session_for_test: Session,
) -> Callable[[str, str], TagSchema]:
def tag_for_test(
name: str = "test_tag",
color: str = "#FFFFFF"
) -> TagSchema:
"""
Create test tag
"""
c = create_tag(session_for_test, name, color)
return c
return tag_for_test
144 changes: 141 additions & 3 deletions tests/test_work.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import pytest
from .fixtures import client, use_test_db_fixture, session_for_test, user_factory_for_test, user_token_factory_for_test, asset_factory_for_test, user_for_test, tag_for_test, work_factory_for_test

from db.models import Visibility
from .fixtures import client, use_test_db_fixture, session_for_test, user_factory_for_test, user_token_factory_for_test, asset_factory_for_test, user_for_test, tag_for_test, work_factory_for_test, tag_factory_for_test

@pytest.mark.usefixtures('use_test_db_fixture')
class TestWork:
Expand Down Expand Up @@ -360,7 +362,7 @@ def test_post_work_about_url(use_test_db_fixture, user_token_factory_for_test, a
# 複数のWorkを取得する
# """

# def test_get_works_pagenation(use_test_db_fixture):
# def test_get_works_pagination(use_test_db_fixture):
# """
# Work取得のページネーションを確認する
# """
Expand Down Expand Up @@ -428,4 +430,140 @@ def test_get_not_exist_user_works(use_test_db_fixture, user_token_factory_for_te
"Authorization": f"Bearer { token.access_token }"
})

assert res.status_code == 404, '作品の取得に失敗する'
assert res.status_code == 404, '作品の取得に失敗する'

def test_search_works_by_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test):
"""
タグから作品を絞り込んで検索する
"""
token = user_token_factory_for_test()

test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get('/api/v1/works/search', headers={
'Authorization': f'Bearer { token.access_token }'
}, json={
'tags': [
test_tag1.id
]
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 2
assert res_json[0].get('title') == work3.title
assert res_json[1].get('title') == work1.title

res = client.get('/api/v1/works/search', headers={
'Authorization': f'Bearer { token.access_token }'
}, json={
'tags': [
test_tag5.id
]
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 3
assert res_json[0].get('title') == work3.title
assert res_json[1].get('title') == work2.title
assert res_json[2].get('title') == work1.title

def test_search_works_by_some_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test):
"""
複数のタグから作品を絞り込んで検索する
"""
token = user_token_factory_for_test()

test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get('/api/v1/works/search', headers={
'Authorization': f'Bearer { token.access_token }'
}, json={
'tags': [
test_tag1.id,
test_tag4.id
]
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 1
assert res_json[0].get('title') == work3.title

def test_search_works_by_tag_without_auth(use_test_db_fixture, tag_factory_for_test, work_factory_for_test):
"""
認証無しでタグから作品を絞り込んで検索する
"""
test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get('/api/v1/works/search', json={
'tags': [
test_tag1.id
]
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 1
assert res_json[0].get('title') == work1.title

def test_search_works_by_strict_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test):
"""
存在しない条件でタグから作品を絞り込んで検索する
"""
token = user_token_factory_for_test()

test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get('/api/v1/works/search', headers={
'Authorization': f'Bearer { token.access_token }'
}, json={
'tags': [
test_tag1.id,
test_tag2.id,
test_tag4.id
]
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 0
assert res_json == []

0 comments on commit f8d0107

Please sign in to comment.