Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
aleks-iv committed Oct 30, 2024
1 parent 87a7a15 commit 2638155
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 29 deletions.
70 changes: 41 additions & 29 deletions ckanext/relationship/model/relationship.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from sqlalchemy import Column, Text, or_

from ckan import logic, model
Expand All @@ -22,59 +24,66 @@ class Relationship(Base):
def __repr__(self):
return (
"Relationship("
f"id={self.id!r},"
f"subject_id={self.subject_id!r},"
f"object_id={self.object_id!r},"
f"relation_type={self.relation_type!r},"
")"
f"id={self.id!r}, "
f"subject_id={self.subject_id!r}, "
f"object_id={self.object_id!r}, "
f"relation_type={self.relation_type!r})"
)

def as_dict(self):
id = self.id
subject_id = self.subject_id
object_id = self.object_id
relation_type = self.relation_type
return {
"id": id,
"subject_id": subject_id,
"object_id": object_id,
"relation_type": relation_type,
"id": self.id,
"subject_id": self.subject_id,
"object_id": self.object_id,
"relation_type": self.relation_type,
}

@classmethod
def by_object_id(cls, subject_id: str, object_id: str, relation_type: str):
subject_name = _entity_name_by_id(subject_id)
subject_identifiers = [subject_id]
if subject_name is not None:
subject_identifiers.append(subject_name)

object_name = _entity_name_by_id(object_id)
object_identifiers = [object_id]
if object_name is not None:
object_identifiers.append(object_name)

return (
model.Session.query(cls)
.filter(
or_(
cls.subject_id == subject_id,
cls.subject_id == subject_name,
),
cls.subject_id.in_(subject_identifiers),
cls.object_id.in_(object_identifiers),
cls.relation_type == relation_type,
)
.filter(or_(cls.object_id == object_id, cls.object_id == object_name))
.filter(cls.relation_type == relation_type)
.one_or_none()
)

@classmethod
def by_subject_id(
cls, subject_id: str, object_entity: str, object_type: str, relation_type: str
cls,
subject_id: str,
object_entity: str | None = None,
object_type: str | None = None,
relation_type: str | None = None,
):
subject_name = _entity_name_by_id(subject_id)
subject_identifiers = [subject_id]
if subject_name is not None:
subject_identifiers.append(subject_name)

q = model.Session.query(cls).filter(
or_(cls.subject_id == subject_id, cls.subject_id == subject_name),
cls.subject_id.in_(subject_identifiers),
)

if object_entity:
object_class = logic.model_name_to_class(model, object_entity)
q = q.filter(
q = q.join(
object_class,
or_(
object_class.id == cls.object_id,
object_class.name == cls.object_id,
cls.object_id == object_class.id,
cls.object_id == object_class.name,
),
)

Expand All @@ -84,23 +93,26 @@ def by_subject_id(
if relation_type:
q = q.filter(cls.relation_type == relation_type)

return q.distinct().all()
return q.all()


def _entity_name_by_id(entity_id: str):
"""Returns entity (package or organization or group) name by its id."""
def _entity_name_by_id(entity_id: str) -> str | None:
"""Returns the name of an entity (package or group) given its ID."""
if not entity_id:
return None

pkg = (
model.Session.query(model.Package)
.filter(model.Package.id == entity_id)
.one_or_none()
)
if pkg:
return pkg.name

group = (
model.Session.query(model.Group)
.filter(model.Group.id == entity_id)
.one_or_none()
)
if pkg:
if group:
return group.name
return None
84 changes: 84 additions & 0 deletions ckanext/relationship/tests/logic/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,90 @@ def test_relation_is_added_to_db(self):
assert relation_reverse.object_id == subject_id
assert relation_reverse.relation_type == relation_type

def test_creation_by_name(self):
"""We can create a relation by name instead of ID.
Should be revised in the future to avoid this kind of behavior.
"""
subject_dataset = factories.Dataset()
object_dataset = factories.Dataset()

subject_id = subject_dataset["name"]
object_id = object_dataset["name"]
relation_type = "related_to"

result = call_action(
"relationship_relation_create",
{"ignore_auth": True},
subject_id=subject_id,
object_id=object_id,
relation_type=relation_type,
)

assert result[0]["subject_id"] == subject_id
assert result[0]["object_id"] == object_id
assert result[0]["relation_type"] == relation_type

assert result[1]["subject_id"] == object_id
assert result[1]["object_id"] == subject_id
assert result[1]["relation_type"] == relation_type

def test_get_by_id_relation_created_by_name(self):
"""We can get a relation by ID if it was created by name."""
subject_dataset = factories.Dataset()
object_dataset = factories.Dataset()

subject_id = subject_dataset["id"]

subject_name = subject_dataset["name"]
object_name = object_dataset["name"]
relation_type = "related_to"

call_action(
"relationship_relation_create",
{"ignore_auth": True},
subject_id=subject_name,
object_id=object_name,
relation_type=relation_type,
)

result = call_action(
"relationship_relations_list",
{"ignore_auth": True},
subject_id=subject_id,
)

assert result[0]["subject_id"] == subject_name
assert result[0]["object_id"] == object_name
assert result[0]["relation_type"] == relation_type

def test_get_by_name_relation_created_by_id(self):
"""We cannot get a relation by name if it was created by ID."""
subject_dataset = factories.Dataset()
object_dataset = factories.Dataset()

subject_id = subject_dataset["id"]
object_id = object_dataset["id"]

subject_name = subject_dataset["name"]
relation_type = "related_to"

call_action(
"relationship_relation_create",
{"ignore_auth": True},
subject_id=subject_id,
object_id=object_id,
relation_type=relation_type,
)

result = call_action(
"relationship_relations_list",
{"ignore_auth": True},
subject_id=subject_name,
)

assert result == []

@pytest.mark.parametrize(
"relation_type",
[
Expand Down

0 comments on commit 2638155

Please sign in to comment.