Skip to content

Commit

Permalink
Add class for rating field type and add field model entry
Browse files Browse the repository at this point in the history
  • Loading branch information
Nginearing authored Nov 30, 2024
1 parent 09c6ebf commit 2612536
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 47 deletions.
16 changes: 16 additions & 0 deletions tagstudio/src/core/library/alchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __eq__(self, value) -> bool:
elif isinstance(value, (TagBoxField, DatetimeField)):
return False
raise NotImplementedError

# Add class for RatingField


class TagBoxField(BaseField):
Expand All @@ -101,6 +103,19 @@ def __eq__(self, value) -> bool:
return self.__key() == value.__key()
raise NotImplementedError

class RatingBoxField(BaseField):
__tablename__ = "rating_box_fields"

value: Mapped[int | None]

def __key(self):
return (self.type, self.value)

def __eq__(self, value) -> bool:
if isinstance(value, RatingBoxField):
return self.__key() == value.__key()
raise NotImplementedError


class DatetimeField(BaseField):
__tablename__ = "datetime_fields"
Expand Down Expand Up @@ -158,3 +173,4 @@ class _FieldID(Enum):
GUEST_ARTIST = DefaultField(id=28, name="Guest Artist", type=FieldTypeEnum.TEXT_LINE)
COMPOSER = DefaultField(id=29, name="Composer", type=FieldTypeEnum.TEXT_LINE)
COMMENTS = DefaultField(id=30, name="Comments", type=FieldTypeEnum.TEXT_LINE)
RATING = DefaultField(id=31, name="Rating", type=FieldTypeEnum.RATING_BOX)

Check failure on line 176 in tagstudio/src/core/library/alchemy/fields.py

View workflow job for this annotation

GitHub Actions / mypy

[mypy] tagstudio/src/core/library/alchemy/fields.py#L176

"type[FieldTypeEnum]" has no attribute "RATING_BOX" [attr-defined]
Raw output
/home/runner/work/TagStudioTesting/TagStudioTesting/tagstudio/src/core/library/alchemy/fields.py:176:54: error: "type[FieldTypeEnum]" has no attribute "RATING_BOX"  [attr-defined]
161 changes: 114 additions & 47 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import (
Session,
aliased,
contains_eager,
make_transient,
selectinload,
Expand All @@ -36,6 +37,7 @@
TS_FOLDER_NAME,
)
from ...enums import LibraryPrefs
from ...media_types import MediaCategories
from .db import make_tables
from .enums import FieldTypeEnum, FilterState, TagColor
from .fields import (
Expand Down Expand Up @@ -394,6 +396,13 @@ def has_path_entry(self, path: Path) -> bool:
with Session(self.engine) as session:
return session.query(exists().where(Entry.path == path)).scalar()

def get_paths(self, glob: str | None = None) -> list[str]:
with Session(self.engine) as session:
paths = session.scalars(select(Entry.path)).unique()

path_strings: list[str] = list(map(lambda x: x.as_posix(), paths))
return path_strings

def search_library(
self,
search: FilterState,
Expand All @@ -409,13 +418,18 @@ def search_library(
statement = select(Entry)

if search.tag:
SubtagAlias = aliased(Tag) # noqa: N806
statement = (
statement.join(Entry.tag_box_fields)
.join(TagBoxField.tags)
.outerjoin(Tag.aliases)
.outerjoin(SubtagAlias, Tag.subtags)
.where(
or_(
Tag.name.ilike(search.tag),
Tag.shorthand.ilike(search.tag),
TagAlias.name.ilike(search.tag),
SubtagAlias.name.ilike(search.tag),
)
)
)
Expand All @@ -437,7 +451,20 @@ def search_library(
)
)
elif search.path:
statement = statement.where(Entry.path.ilike(f"%{search.path}%"))
search_str = str(search.path).replace("*", "%")
statement = statement.where(Entry.path.ilike(search_str))
elif search.filetype:
statement = statement.where(Entry.suffix.ilike(f"{search.filetype}"))
elif search.mediatype:
extensions: set[str] = set[str]()
for media_cat in MediaCategories.ALL_CATEGORIES:
if search.mediatype == media_cat.name:

Check failure on line 461 in tagstudio/src/core/library/alchemy/library.py

View workflow job for this annotation

GitHub Actions / mypy

[mypy] tagstudio/src/core/library/alchemy/library.py#L461

"MediaCategory" has no attribute "name" [attr-defined]
Raw output
/home/runner/work/TagStudioTesting/TagStudioTesting/tagstudio/src/core/library/alchemy/library.py:461:44: error: "MediaCategory" has no attribute "name"  [attr-defined]
extensions = extensions | media_cat.extensions
break
# just need to map it to search db - suffixes do not have '.'
statement = statement.where(
Entry.suffix.in_(map(lambda x: x.replace(".", ""), extensions))
)

extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST)
Expand Down Expand Up @@ -700,7 +727,16 @@ def add_entry_field_type(
assert isinstance(value, list)
for tag in value:
field_model.tags.add(Tag(name=tag))

elif field.type == FieldTypeEnum.DATETIME:
field_model = DatetimeField(
type_key=field.key,
value=value,
)
elif field.type == FieldTypeEnum.RATING_BOX:

Check failure on line 735 in tagstudio/src/core/library/alchemy/library.py

View workflow job for this annotation

GitHub Actions / mypy

[mypy] tagstudio/src/core/library/alchemy/library.py#L735

"type[FieldTypeEnum]" has no attribute "RATING_BOX" [attr-defined]
Raw output
/home/runner/work/TagStudioTesting/TagStudioTesting/tagstudio/src/core/library/alchemy/library.py:735:28: error: "type[FieldTypeEnum]" has no attribute "RATING_BOX"  [attr-defined]
field_model = RatingBoxField(

Check failure on line 736 in tagstudio/src/core/library/alchemy/library.py

View workflow job for this annotation

GitHub Actions / mypy

[mypy] tagstudio/src/core/library/alchemy/library.py#L736

Name "RatingBoxField" is not defined [name-defined]
Raw output
/home/runner/work/TagStudioTesting/TagStudioTesting/tagstudio/src/core/library/alchemy/library.py:736:27: error: Name "RatingBoxField" is not defined  [name-defined]
type_key=field.key,
value=value,
)
elif field.type == FieldTypeEnum.DATETIME:
field_model = DatetimeField(
type_key=field.key,
Expand Down Expand Up @@ -731,18 +767,23 @@ def add_entry_field_type(
)
return True

def add_tag(self, tag: Tag, subtag_ids: list[int] | None = None) -> Tag | None:
def add_tag(
self,
tag: Tag,
subtag_ids: set[int] | None = None,
alias_names: set[str] | None = None,
alias_ids: set[int] | None = None,
) -> Tag | None:
with Session(self.engine, expire_on_commit=False) as session:
try:
session.add(tag)
session.flush()

for subtag_id in subtag_ids or []:
subtag = TagSubtag(
parent_id=tag.id,
child_id=subtag_id,
)
session.add(subtag)
if subtag_ids is not None:
self.update_subtags(tag, subtag_ids, session)

if alias_ids is not None and alias_names is not None:
self.update_aliases(tag, alias_ids, alias_names, session)

session.commit()

Expand Down Expand Up @@ -826,75 +867,101 @@ def save_library_backup_to_disk(self) -> Path:

def get_tag(self, tag_id: int) -> Tag:
with Session(self.engine) as session:
tags_query = select(Tag).options(selectinload(Tag.subtags))
tags_query = select(Tag).options(selectinload(Tag.subtags), selectinload(Tag.aliases))
tag = session.scalar(tags_query.where(Tag.id == tag_id))

session.expunge(tag)
for subtag in tag.subtags:
session.expunge(subtag)

for alias in tag.aliases:
session.expunge(alias)

return tag

def get_alias(self, tag_id: int, alias_id: int) -> TagAlias:
with Session(self.engine) as session:
alias_query = select(TagAlias).where(TagAlias.id == alias_id, TagAlias.tag_id == tag_id)
alias = session.scalar(alias_query.where(TagAlias.id == alias_id))

return alias

def add_subtag(self, base_id: int, new_tag_id: int) -> bool:
if base_id == new_tag_id:
return False

# open session and save as parent tag
with Session(self.engine) as session:
tag = TagSubtag(
subtag = TagSubtag(
parent_id=base_id,
child_id=new_tag_id,
)

try:
session.add(tag)
session.add(subtag)
session.commit()
return True
except IntegrityError:
session.rollback()
logger.exception("IntegrityError")
return False

def update_tag(self, tag: Tag, subtag_ids: list[int]) -> None:
def remove_subtag(self, base_id: int, remove_tag_id: int) -> bool:
with Session(self.engine) as session:
p_id = base_id
r_id = remove_tag_id
remove = session.query(TagSubtag).filter_by(parent_id=p_id, child_id=r_id).one()
session.delete(remove)
session.commit()

return True

def update_tag(
self,
tag: Tag,
subtag_ids: set[int] | None = None,
alias_names: set[str] | None = None,
alias_ids: set[int] | None = None,
) -> None:
"""Edit a Tag in the Library."""
# TODO - maybe merge this with add_tag?
self.add_tag(tag, subtag_ids, alias_names, alias_ids)

if tag.shorthand:
tag.shorthand = slugify(tag.shorthand)
def update_aliases(self, tag, alias_ids, alias_names, session):
prev_aliases = session.scalars(select(TagAlias).where(TagAlias.tag_id == tag.id)).all()

if tag.aliases:
# TODO
...
for alias in prev_aliases:
if alias.id not in alias_ids or alias.name not in alias_names:
session.delete(alias)
else:
alias_ids.remove(alias.id)
alias_names.remove(alias.name)

# save the tag
with Session(self.engine) as session:
try:
# update the existing tag
session.add(tag)
session.flush()
for alias_name in alias_names:
alias = TagAlias(alias_name, tag.id)
session.add(alias)

# load all tag's subtag to know which to remove
prev_subtags = session.scalars(
select(TagSubtag).where(TagSubtag.parent_id == tag.id)
).all()
def update_subtags(self, tag, subtag_ids, session):
if tag.id in subtag_ids:
subtag_ids.remove(tag.id)

for subtag in prev_subtags:
if subtag.child_id not in subtag_ids:
session.delete(subtag)
else:
# no change, remove from list
subtag_ids.remove(subtag.child_id)
# load all tag's subtag to know which to remove
prev_subtags = session.scalars(select(TagSubtag).where(TagSubtag.parent_id == tag.id)).all()

# create remaining items
for subtag_id in subtag_ids:
# add new subtag
subtag = TagSubtag(
parent_id=tag.id,
child_id=subtag_id,
)
session.add(subtag)
for subtag in prev_subtags:
if subtag.child_id not in subtag_ids:
session.delete(subtag)
else:
# no change, remove from list
subtag_ids.remove(subtag.child_id)

session.commit()
except IntegrityError:
session.rollback()
logger.exception("IntegrityError")
# create remaining items
for subtag_id in subtag_ids:
# add new subtag
subtag = TagSubtag(
parent_id=tag.id,
child_id=subtag_id,
)
session.add(subtag)

def prefs(self, key: LibraryPrefs) -> Any:
# load given item from Preferences table
Expand Down

0 comments on commit 2612536

Please sign in to comment.