Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/search work by tag #99

Merged
merged 11 commits into from
Aug 3, 2022
21 changes: 10 additions & 11 deletions cruds/works/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
from fastapi import HTTPException
from sqlalchemy import desc, func
from cruds.assets import delete_asset_by_id
from cruds.url_infos import create_url_info, delete_url_info
from db import models
Expand Down Expand Up @@ -76,18 +77,16 @@ def set_work(db: Session, title: str, description: str, user_id: str,

return work

def get_work_by_id(db: Session, work_id: str) -> Work:
work_orm = db.query(models.Work).get(work_id)
if work_orm == None:
return None
return Work.from_orm(work_orm)


def get_works_by_limit(db: Session, limit: int, visibility: models.Visibility, oldest_id: str, auth: bool = False) -> List[Work]:
works_orm = db.query(models.Work).order_by(models.Work.created_at).filter(
models.Work.visibility != models.Visibility.draft)
def get_works_by_limit(db: Session, limit: int, visibility: models.Visibility, oldest_id: str, tags: str, auth: bool = False) -> List[Work]:
works_orm = db.query(models.Work).order_by(desc(models.Work.created_at)).filter(models.Work.visibility != models.Visibility.draft)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

タグで絞り込んで作品数が減った後にソートした方が処理が早い気がする

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sqlalchemyはall()やfirst()などが付くまではSQLを構成するだけで実行はされないため全て記述した順番通りに実行されるわけではないと記憶しております.
また,そもそもPostgreSQLではソートを行ったあとで絞り込みを行うという記述は実行することが出来ないため,正常に動作していることからも問題ないと考えています.
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

そっか
ほんまやわw

if tags:
tag_list = tags.split(',')
works_orm = works_orm.filter(models.Tagging.tag_id.in_(tag_list)).filter(models.Tagging.work_id == models.Work.id)
works_orm = works_orm.group_by(models.Work.id).having(func.count(models.Work.id) == len(tag_list))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

これって何してるん?
いまいち分かってなくて、、、

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tag_list = tags.split(',')

','で区切ったtag_idを配列にして,

works_orm = works_orm.filter(models.Tagging.tag_id.in_(tag_list)).filter(models.Tagging.work_id == models.Work.id)

Taggingテーブルからそのタグが付いたデータを抽出して,そのデータのwork_idと一致する作品情報をWorkテーブルから持ってきて,

works_orm = works_orm.group_by(models.Work.id).having(func.count(models.Work.id) == len(tag_list))

重複してる作品をGROUPで一纏めにして,重複したデータ数がタグの数と一致したもので更に絞り込んでます.
1つの作品と検索してるタグそれぞれのTaggingデータがあればちょうどタグの数の分だけ重複するから最後のhavingがついてます.

SQLっぽく書くと,

SELECT works.* FROM works, tagging
WHERE tagging.tag_id in ("tagid1", ..., "tagid2") AND tagging.work_id = works.id
GROUP BY works.id HAVING count(works.id) = 4; -- 4はタグの数

って感じです.

Copy link
Member

@Simo-C3 Simo-C3 Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works_orm = works_orm.group_by(models.Work.id).having(func.count(models.Work.id) == len(tag_list))
重複してる作品をGROUPで一纏めにして,重複したデータ数がタグの数と一致したもので更に絞り込んでます.
1つの作品と検索してるタグそれぞれのTaggingデータがあればちょうどタグの数の分だけ重複するから最後のhavingがついてます.

なるほど、理解

if oldest_id:
limit_work = db.query(models.Work).filter(models.Work.id == oldest_id).first()
if limit_work is None:
raise HTTPException(status_code=400, detail='this oldest_id is invalid')
limit_created_at = limit_work.created_at
works_orm = works_orm.filter(models.Work.created_at > limit_created_at)
if not auth:
Expand Down Expand Up @@ -249,4 +248,4 @@ def get_works_by_user_id(db: Session, user_id: str, at_me: bool = False, auth: b
works_orm = works_orm.filter(models.Work.visibility == 'public').all()

works = list(map(Work.from_orm, works_orm))
return works
return works
6 changes: 3 additions & 3 deletions routers/works/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from schemas.common import DeleteStatus
from schemas.work import PostWork, Work
from schemas.work import PostWork, SearchOption, Work
from schemas.user import User
from fastapi import APIRouter
from db import get_db,models
Expand All @@ -18,9 +18,9 @@ async def post_work(payload: PostWork, db: Session = Depends(get_db), user: User
return work

@work_router.get('', response_model=List[Work])
async def get_works(limit: int = 30, visibility: models.Visibility = None, oldest_id: str = None, db: Session = Depends(get_db), user: User = Depends(GetCurrentUser(auto_error=False))):
async def get_works(limit: int = 30, visibility: models.Visibility = None, oldest_id: str = None, tags: str = None, db: Session = Depends(get_db), user: User = Depends(GetCurrentUser(auto_error=False))):
auth = user is not None
works = get_works_by_limit(db, limit, visibility, oldest_id, auth=auth)
works = get_works_by_limit(db, limit, visibility, oldest_id, tags, auth=auth)
return works

@work_router.get('/{work_id}', response_model=Work)
Expand Down
3 changes: 3 additions & 0 deletions schemas/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ class Work(BaseModel):

class Config:
orm_mode = True

class SearchOption(BaseModel):
tags: list[str]
21 changes: 18 additions & 3 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlalchemy_utils.view import refresh_materialized_view
from cruds.works import set_work

from db.models import User
from db.models import User, Visibility
from schemas.url_info import BaseUrlInfo
from schemas.user import User as UserSchema, Token as TokenSchema, TokenResponse as TokenResponseSchema
from schemas.work import Work as WorkSchema
Expand Down Expand Up @@ -170,7 +170,7 @@ def work_for_test(
session_for_test: Session = session_for_test,
title: str = 'WorkTitleForTest',
description: str = 'this work is test',
visibility: str = 'public',
visibility: str = Visibility.public,
exist_thumbnail: bool = False,
asset_types: List[str] = ['image'],
urls: List[BaseUrlInfo] = [],
Expand Down Expand Up @@ -244,4 +244,19 @@ def tag_for_test(
Create test tag
"""
c = create_tag(session_for_test, name, color)
return c
return c

@pytest.fixture
def tag_factory_for_test(
session_for_test: Session,
) -> Callable[[str, str], TagSchema]:
def tag_for_test(
name: str = "test_tag",
color: str = "#FFFFFF"
) -> TagSchema:
"""
Create test tag
"""
c = create_tag(session_for_test, name, color)
return c
return tag_for_test
118 changes: 115 additions & 3 deletions tests/test_work.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import pytest
from .fixtures import client, use_test_db_fixture, session_for_test, user_factory_for_test, user_token_factory_for_test, asset_factory_for_test, user_for_test, tag_for_test, work_factory_for_test
from .fixtures import client, use_test_db_fixture, session_for_test, user_factory_for_test, user_token_factory_for_test, asset_factory_for_test, user_for_test, tag_for_test, work_factory_for_test, tag_factory_for_test
from db.models import Visibility

@pytest.mark.usefixtures('use_test_db_fixture')
class TestWork:
Expand Down Expand Up @@ -360,7 +361,7 @@ def test_post_work_about_url(use_test_db_fixture, user_token_factory_for_test, a
# 複数のWorkを取得する
# """

# def test_get_works_pagenation(use_test_db_fixture):
# def test_get_works_pagination(use_test_db_fixture):
# """
# Work取得のページネーションを確認する
# """
Expand Down Expand Up @@ -428,4 +429,115 @@ def test_get_not_exist_user_works(use_test_db_fixture, user_token_factory_for_te
"Authorization": f"Bearer { token.access_token }"
})

assert res.status_code == 404, '作品の取得に失敗する'
assert res.status_code == 404, '作品の取得に失敗する'

def test_search_works_by_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test):
"""
タグから作品を絞り込んで検索する
"""
token = user_token_factory_for_test()

test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work2 = work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get(f'/api/v1/works?tags={test_tag1.id}', headers={
'Authorization': f'Bearer { token.access_token }'
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 2
assert res_json[0].get('title') == work3.title
assert res_json[1].get('title') == work1.title

res = client.get(f'/api/v1/works?tags={test_tag5.id}', headers={
'Authorization': f'Bearer { token.access_token }'
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 3
assert res_json[0].get('title') == work3.title
assert res_json[1].get('title') == work2.title
assert res_json[2].get('title') == work1.title

def test_search_works_by_some_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test):
"""
複数のタグから作品を絞り込んで検索する
"""
token = user_token_factory_for_test()

test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work3 = work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get(f'/api/v1/works?tags={test_tag1.id},{test_tag4.id}', headers={
'Authorization': f'Bearer { token.access_token }'
}
)

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 1
assert res_json[0].get('title') == work3.title

def test_search_works_by_tag_without_auth(use_test_db_fixture, tag_factory_for_test, work_factory_for_test):
"""
認証無しでタグから作品を絞り込んで検索する
"""
test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work1 = work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get(f'/api/v1/works?tags={test_tag1.id}')

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 1
assert res_json[0].get('title') == work1.title

def test_search_works_by_strict_tag(use_test_db_fixture, user_token_factory_for_test, tag_factory_for_test, work_factory_for_test):
"""
存在しない条件でタグから作品を絞り込んで検索する
"""
token = user_token_factory_for_test()

test_tag1 = tag_factory_for_test(name='testtag1', color='#ff3030')
test_tag2 = tag_factory_for_test(name='testtag2', color='#30ff30')
test_tag3 = tag_factory_for_test(name='testtag3', color='#3030ff')
test_tag4 = tag_factory_for_test(name='testtag4', color='#44e099')
test_tag5 = tag_factory_for_test(name='testtag5', color='#30dda0')

work_factory_for_test(title='testwork1', visibility=Visibility.public, tags_id=[test_tag1.id, test_tag2.id, test_tag5.id])
work_factory_for_test(title='testwork2', visibility=Visibility.public, tags_id=[test_tag2.id, test_tag3.id, test_tag5.id])
work_factory_for_test(title='testwork3', visibility=Visibility.private, tags_id=[test_tag1.id, test_tag4.id, test_tag5.id])

res = client.get(f'/api/v1/works?tags={test_tag1.id},{test_tag2.id},{test_tag4.id}', headers={
'Authorization': f'Bearer { token.access_token }'
})

assert res.status_code == 200
res_json = res.json()
assert len(res_json) == 0
assert res_json == []