Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/search work by tag #99

Merged
merged 11 commits into from
Aug 3, 2022
3 changes: 0 additions & 3 deletions cruds/tags/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ def create_tag(db: Session, name: str, color_code: str) -> GetTag:
db.commit()
db.refresh(tag_orm)

print(tag_orm.name)
print(tag_orm.color)

tag = GetTag.from_orm(tag_orm)

return tag
Expand Down
1 change: 0 additions & 1 deletion cruds/url_infos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

def create_url_info(db: Session, url: str, url_type: str, work_id: str, user_id: str):
pattern = url_type_pattern.get(url_type, '')
print(pattern)
if not re.match(pattern, url):
raise HTTPException(status_code=400, detail='url pattern is invalid')
url_info_orm = models.UrlInfo(
Expand Down
65 changes: 40 additions & 25 deletions cruds/works/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
from fastapi import HTTPException
from sqlalchemy import desc, func
from cruds.assets import delete_asset_by_id
from cruds.url_infos import create_url_info, delete_url_info
from db import models
Expand All @@ -11,22 +12,21 @@

# TODO: CASCADEを導入する

def set_work(db: Session, title: str, description: str, user_id: str,
def set_work(db: Session, title: str, description: str, user_id: str,
visibility: str, thumbnail_asset_id: str,
assets_id: List[str], urls: List[BaseUrlInfo], tags_id: List[str]) -> Work:


if title == '':
raise HTTPException(status_code=400, detail="Title is empty")

# DB書き込み
md = markdown.Markdown(extensions=['tables'])
work_orm = models.Work(
title = title,
description = description,
description_html = md.convert(description),
user_id = user_id,
visibility = visibility,
title=title,
description=description,
description_html=md.convert(description),
user_id=user_id,
visibility=visibility,
)
db.add(work_orm)
db.commit()
Expand Down Expand Up @@ -58,8 +58,8 @@ def set_work(db: Session, title: str, description: str, user_id: str,
if thumbnail is None:
raise HTTPException(status_code=400, detail='This thumbnail asset id is invalid.')
thumbnail_orm = models.Thumbnail(
work_id = work_orm.id,
asset_id = thumbnail.id
work_id=work_orm.id,
asset_id=thumbnail.id
)
db.add(thumbnail_orm)
db.commit()
Expand All @@ -76,12 +76,6 @@ def set_work(db: Session, title: str, description: str, user_id: str,

return work

def get_work_by_id(db: Session, work_id: str) -> Work:
work_orm = db.query(models.Work).get(work_id)
if work_orm == None:
return None
return Work.from_orm(work_orm)

def get_works_by_limit(db: Session, limit: int, oldest_id: str, auth: bool = False) -> List[Work]:
works_orm = db.query(models.Work).order_by(models.Work.created_at).filter(models.Work.visibility != models.Visibility.draft)
if oldest_id:
Expand All @@ -104,10 +98,10 @@ def get_work_by_id(db: Session, work_id: str, auth: bool = False) -> Work:
raise HTTPException(status_code=403, detail="This work is a private work. You need to sign in.")
return work

def replace_work(db: Session, work_id: str, title: str, description: str, user_id: str,
visibility: str, thumbnail_asset_id: str, assets_id: List[str],
def replace_work(db: Session, work_id: str, title: str, description: str, user_id: str,
visibility: str, thumbnail_asset_id: str, assets_id: List[str],
urls: List[BaseUrlInfo], tags_id: List[str]) -> Work:

work_orm = db.query(models.Work).get(work_id)

# 自分のWorkでなければ弾く
Expand All @@ -134,7 +128,7 @@ def replace_work(db: Session, work_id: str, title: str, description: str, user_i
asset_orm.work_id = None
db.commit()
db.refresh(asset_orm)

# 使われなくなったassetの削除
old_asset_ids = [asset_orm.id for asset_orm in assets_orm]
old_thumbnail_orm = db.query(models.Thumbnail).filter(models.Thumbnail.work_id == work_id).first()
Expand All @@ -146,7 +140,7 @@ def replace_work(db: Session, work_id: str, title: str, description: str, user_i
delete_asset_ids = set(old_asset_ids) - set(new_asset_ids)
for delete_asset_id in delete_asset_ids:
delete_asset_by_id(db, delete_asset_id)

# assetのwork_idの更新
for asset_id in assets_id:
asset_orm = db.query(models.Asset).get(asset_id)
Expand All @@ -158,7 +152,7 @@ def replace_work(db: Session, work_id: str, title: str, description: str, user_i
urls_orm = db.query(models.UrlInfo).filter(models.UrlInfo.work_id == work_id).all()
for url_orm in urls_orm:
delete_url_info(db, url_orm.id)

# url_infoテーブルへのインスタンスの作成
for url in urls:
create_url_info(db, url.get('url'), url.get('url_type', 'other'), work_id, user_id)
Expand Down Expand Up @@ -187,8 +181,8 @@ def replace_work(db: Session, work_id: str, title: str, description: str, user_i
if thumbnail is None:
raise HTTPException(status_code=400, detail='This thumbnail asset id is invalid.')
new_thumbnail_orm = models.Thumbnail(
work_id = work_id,
asset_id = thumbnail_asset_id
work_id=work_id,
asset_id=thumbnail_asset_id
)
db.add(new_thumbnail_orm)
db.commit()
Expand Down Expand Up @@ -233,7 +227,7 @@ def get_works_by_user_id(db: Session, user_id: str, at_me: bool = False, auth: b
user_orm = db.query(models.User).get(user_id)
if user_orm is None:
raise HTTPException(status_code=404, detail='this user is not exist')

works_orm = db.query(models.Work).filter(models.Work.user_id == user_id)

if at_me:
Expand All @@ -244,4 +238,25 @@ def get_works_by_user_id(db: Session, user_id: str, at_me: bool = False, auth: b
works_orm = works_orm.filter(models.Work.visibility == 'public').all()

works = list(map(Work.from_orm, works_orm))
return works
return works

def search_work_by_option(db: Session, limit: int, oldest_id: str, tags: list[str], auth: bool = False) -> list[Work]:
works_orm = db.query(models.Work).filter(models.Tagging.work_id == models.Work.id).filter(models.Tagging.tag_id == models.Tag.id).filter(models.Tag.id.in_(tags))
works_orm = works_orm.group_by(models.Work.id).having(func.count(models.Work.id) == len(tags))
Copy link
Contributor

@rkun123 rkun123 Jul 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[COMMENT]
とりあえず良さそう
taggingsテーブルのインデックスが上手に使えてない気がしているので、WorkとTagが増えたときにめっちゃ遅くなりそうな気持ちはある。

image

toybox=# EXPLAIN ANALYSE SELECT works.* FROM works, taggings, tags
WHERE taggings.work_id = works.id AND taggings.tag_id = tags.id AND tags.id IN ('0d435e98-1b63-4a8f-9433-1cfa7840e024', '51216642-ae4e-48d7-80ee-741ec76b0fd5', 'af040833-7ab9-4fff-844a-c234041f3f71') GROUP BY works.id
HAVING count(works.id) = 3;
                                                                                       QUERY PLAN                                                                                        
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 GroupAggregate  (cost=24.97..25.05 rows=1 width=1850) (actual time=0.017..0.018 rows=0 loops=1)
   Group Key: works.id
   Filter: (count(works.id) = 3)
   ->  Sort  (cost=24.97..24.98 rows=4 width=1850) (actual time=0.016..0.017 rows=0 loops=1)
         Sort Key: works.id
         Sort Method: quicksort  Memory: 25kB
         ->  Nested Loop  (cost=11.00..24.93 rows=4 width=1850) (actual time=0.005..0.006 rows=0 loops=1)
               ->  Hash Join  (cost=10.86..21.76 rows=4 width=516) (actual time=0.004..0.005 rows=0 loops=1)
                     Hash Cond: ((taggings.tag_id)::text = (tags.id)::text)
                     ->  Seq Scan on taggings  (cost=0.00..10.70 rows=70 width=1032) (actual time=0.003..0.004 rows=0 loops=1)
                     ->  Hash  (cost=10.82..10.82 rows=3 width=516) (never executed)
                           ->  Seq Scan on tags  (cost=0.00..10.82 rows=3 width=516) (never executed)
                                 Filter: ((id)::text = ANY ('{0d435e98-1b63-4a8f-9433-1cfa7840e024,51216642-ae4e-48d7-80ee-741ec76b0fd5,af040833-7ab9-4fff-844a-c234041f3f71}'::text[]))
               ->  Index Scan using works_pkey on works  (cost=0.14..0.79 rows=1 width=1850) (never executed)
                     Index Cond: ((id)::text = (taggings.work_id)::text)
 Planning Time: 0.296 ms
 Execution Time: 0.112 ms
(17 rows)


if auth:
works_orm = works_orm.filter(models.Work.visibility != models.Visibility.draft)
else:
works_orm = works_orm.filter(models.Work.visibility == models.Visibility.public)

if oldest_id:
oldest_work = db.query(models.Work).filter(models.Work.id == oldest_id).first()
if oldest_work is None:
raise HTTPException(status_code=400, detail='this oldest_id is invalid')
works_orm = works_orm.filter(models.Work.created_at < oldest_work.created_at)

works_orm = works_orm.order_by(desc(models.Work.created_at)).limit(limit).all()

works = list(map(Work.from_orm, works_orm))

return works
18 changes: 12 additions & 6 deletions routers/works/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from schemas.common import DeleteStatus
from schemas.work import PostWork, Work
from schemas.work import PostWork, SearchOption, Work
from schemas.user import User
from fastapi import APIRouter
from db import get_db
from fastapi.params import Depends
from sqlalchemy.orm import Session
from cruds.users.auth import GetCurrentUser
from cruds.works import delete_work_by_id, get_work_by_id, get_works_by_limit, replace_work, set_work
from cruds.works import delete_work_by_id, get_work_by_id, get_works_by_limit, replace_work, search_work_by_option, set_work
from typing import List

work_router = APIRouter()

@work_router.post('', response_model=Work)
async def post_work(payload: PostWork, db: Session = Depends(get_db), user: User = Depends(GetCurrentUser())):
work = set_work(db, payload.title, payload.description, user.id,
work = set_work(db, payload.title, payload.description, user.id,
payload.visibility, payload.thumbnail_asset_id, payload.assets_id, payload.urls, payload.tags_id)
return work

Expand All @@ -23,6 +23,12 @@ async def get_works(limit: int = 30, oldest_id: str = None, db: Session = Depend
works = get_works_by_limit(db, limit, oldest_id, auth=auth)
return works

@work_router.get('/search', response_model=list[Work])
async def search_work(payload: SearchOption, limit: int = 30, oldest_id: str = None, user: User = Depends(GetCurrentUser(auto_error=False)), db: Session = Depends(get_db)):
auth = user is not None
works = search_work_by_option(db, limit, oldest_id, payload.tags, auth)
return works

PigeonsHouse marked this conversation as resolved.
Show resolved Hide resolved
@work_router.get('/{work_id}', response_model=Work)
async def get_work(work_id: str, db: Session = Depends(get_db), user: User = Depends(GetCurrentUser(auto_error=False))):
auth = user is not None
Expand All @@ -31,12 +37,12 @@ async def get_work(work_id: str, db: Session = Depends(get_db), user: User = Dep

@work_router.put('/{work_id}', response_model=Work)
async def put_work(work_id: str, payload: PostWork, db: Session = Depends(get_db), user: User = Depends(GetCurrentUser())):
work = replace_work(db, work_id, payload.title, payload.description, user.id,
payload.visibility, payload.thumbnail_asset_id, payload.assets_id, payload.urls,
work = replace_work(db, work_id, payload.title, payload.description, user.id,
payload.visibility, payload.thumbnail_asset_id, payload.assets_id, payload.urls,
payload.tags_id)
return work

@work_router.delete('/{work_id}', response_model=DeleteStatus, dependencies=[Depends(GetCurrentUser())])
async def delete_work(work_id: str, db: Session = Depends(get_db)):
result = delete_work_by_id(db, work_id)
return result
return result
3 changes: 3 additions & 0 deletions schemas/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ class Work(BaseModel):

class Config:
orm_mode = True

class SearchOption(BaseModel):
tags: list[str]
57 changes: 33 additions & 24 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
from fastapi.datastructures import UploadFile
import os
import json
import sqlalchemy
import pytest
from sqlalchemy.orm import sessionmaker
import sqlalchemy_utils
import pytest
from datetime import timedelta
from typing import Callable, List, Optional
from fastapi.testclient import TestClient
from pytest import fixture
from fastapi.datastructures import UploadFile
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy_utils.view import refresh_materialized_view
from cruds.works import set_work

from db.models import User
from schemas.url_info import BaseUrlInfo
from schemas.user import User as UserSchema, Token as TokenSchema, TokenResponse as TokenResponseSchema
from schemas.work import Work as WorkSchema
from schemas.asset import Asset as AssetSchema
from schemas.tag import GetTag as TagSchema
from db import Base, get_db
from main import app
import os
from datetime import timedelta
from cruds.users import auth
# from cruds.works import create_work
from db import Base, get_db
from db.models import User, Visibility
from cruds.assets import create_asset
from typing import Callable, List, Optional
from cruds.tags.tag import create_tag

import json
from cruds.users import auth
from cruds.works import set_work
from schemas.asset import Asset as AssetSchema
from schemas.tag import GetTag as TagSchema
from schemas.url_info import BaseUrlInfo
from schemas.user import User as UserSchema, TokenResponse as TokenResponseSchema
from schemas.work import Work as WorkSchema

DATABASE = 'postgresql'
USER = os.environ.get('POSTGRES_USER')
Expand All @@ -51,7 +46,6 @@ def use_test_db_fixture():
get_db関数をテストDBで上書きする
"""
if not sqlalchemy_utils.database_exists(DATABASE_URL):
print('[INFO] CREATE DATABASE')
sqlalchemy_utils.create_database(DATABASE_URL)

# Reset test tables
Expand Down Expand Up @@ -170,7 +164,7 @@ def work_for_test(
session_for_test: Session = session_for_test,
title: str = 'WorkTitleForTest',
description: str = 'this work is test',
visibility: str = 'public',
visibility: str = Visibility.public,
exist_thumbnail: bool = False,
asset_types: List[str] = ['image'],
urls: List[BaseUrlInfo] = [],
Expand Down Expand Up @@ -244,4 +238,19 @@ def tag_for_test(
Create test tag
"""
c = create_tag(session_for_test, name, color)
return c
return c

@pytest.fixture
def tag_factory_for_test(
session_for_test: Session,
) -> Callable[[str, str], TagSchema]:
def tag_for_test(
name: str = "test_tag",
color: str = "#FFFFFF"
) -> TagSchema:
"""
Create test tag
"""
c = create_tag(session_for_test, name, color)
return c
return tag_for_test
2 changes: 0 additions & 2 deletions tests/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ def test_post_asset_without_auth(use_test_db_fixture):
"asset_type": asset_type
})

print(res.request)

assert res.status_code == 403, 'Assetの投稿に失敗する'

def test_post_correct_image_asset(use_test_db_fixture, user_token_factory_for_test):
Expand Down
4 changes: 0 additions & 4 deletions tests/test_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_post_tag(use_test_db_fixture, user_token_factory_for_test):
assert res.status_code == 200, 'tagの作成に成功する'

res_json = res.json()
print(res_json)
assert res_json['name'] == name
assert res_json['color'] == color

Expand Down Expand Up @@ -81,7 +80,6 @@ def test_get_tag_by_tag_id(use_test_db_fixture, tag_for_test, user_token_factory

res_json = res.json()

print(res_json)
assert res_json['id'] == tag_id
assert res_json['name'] == name
assert res_json['color'] == color
Expand All @@ -106,7 +104,6 @@ def test_put_tag(use_test_db_fixture, tag_for_test, user_token_factory_for_test)
assert res.status_code == 200, 'タグの編集に成功する'

res_json = res.json()
print(res_json)
assert res_json['id'] == tag_id
assert res_json['name'] == name
assert res_json['color'] == color
Expand All @@ -127,5 +124,4 @@ def test_put_tag(use_test_db_fixture, tag_for_test, user_token_factory_for_test)
assert res.status_code == 200, 'タグの削除に成功する'

res_json = res.json()
print(res_json)
assert res_json == {'status': 'OK'}
Loading