Skip to content

Commit

Permalink
Feature/sckan 284 (#262)
Browse files Browse the repository at this point in the history
* SCKAN-284 chore: Update tests

* SCKAN-284 feat: Replace placeholder data for Origin widget

* SCKAN-284 fix: Update journey calculation to work with implicit from_entities
  • Loading branch information
afonsobspinto authored Apr 11, 2024
1 parent 384e2d3 commit b551772
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 55 deletions.
8 changes: 4 additions & 4 deletions backend/composer/services/graph_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def create_paths_from_origin(origin, vias, destinations, current_path, destinati
# This checks if the last node in the current path is one of the nodes that can lead to the current via.
# In other words, it checks if there is a valid connection
# from the last node in the current path to the current via.
if current_path[-1][0] in list(
a.name for a in current_via.from_entities.all()) or not current_via.from_entities.exists():
if (current_path[-1][0] in list(a.name for a in current_via.from_entities.all())
or (not current_via.from_entities.exists() and current_path[-1][1] == via_layer - 1)):
for entity in current_via.anatomical_entities.all():
# Build new sub-paths including the current via entity
new_sub_path = current_path + [(entity.name, via_layer)]
# Recursively call to build paths from the next vias
new_paths.extend(
create_paths_from_origin(origin, vias[idx + 1:], destinations, new_sub_path, destination_layer))
new_paths.extend(create_paths_from_origin(origin, vias[idx + 1:], destinations,
new_sub_path, destination_layer))

# Check for direct connections to destinations from the current via
for dest in destinations:
Expand Down
6 changes: 4 additions & 2 deletions backend/tests/models/test_vias.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.test import TestCase
from composer.models import ConnectivityStatement, Via, Sentence, AnatomicalEntity
from composer.models import ConnectivityStatement, Via, Sentence, AnatomicalEntity, AnatomicalEntityMeta


class ViaModelTestCase(TestCase):

Expand Down Expand Up @@ -44,7 +45,8 @@ def test_via_deletion_updates_order(self):

def test_via_order_change_clears_from_entities(self):
statement, initial_vias = self.create_initial_state()
anatomical_entity = AnatomicalEntity.objects.create(name="Test Entity")
anatomical_entity_meta = AnatomicalEntityMeta.objects.create(name="Test Entity")
anatomical_entity = AnatomicalEntity.objects.create(simple_entity=anatomical_entity_meta)

for via in initial_vias:
via.from_entities.add(anatomical_entity)
Expand Down
153 changes: 105 additions & 48 deletions backend/tests/test_journey.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from django.db import connection
from django.test import TestCase, override_settings

from composer.models import Sentence, ConnectivityStatement, AnatomicalEntity, Via, Destination
from composer.models import Sentence, ConnectivityStatement, AnatomicalEntity, AnatomicalEntityMeta, Via, Destination
from composer.services.graph_service import generate_paths, consolidate_paths


@override_settings(DEBUG=True)
class JourneyTestCase(TestCase):

def setUp(self):
self.created_entities = {}

def create_or_get_anatomical_entity(self, name):
if name not in self.created_entities:
meta, _ = AnatomicalEntityMeta.objects.get_or_create(name=name, ontology_uri=name)
entity, _ = AnatomicalEntity.objects.get_or_create(simple_entity=meta)
self.created_entities[name] = entity
return self.created_entities[name]

def test_journey_simple_graph_with_jump(self):
# Test setup
sentence = Sentence.objects.create()
cs = ConnectivityStatement.objects.create(sentence=sentence)

origin1 = AnatomicalEntity.objects.create(name='Oa')
origin2 = AnatomicalEntity.objects.create(name='Ob')
via1 = AnatomicalEntity.objects.create(name='V1a')
destination1 = AnatomicalEntity.objects.create(name='Da')
origin1 = self.create_or_get_anatomical_entity("Oa")
origin2 = self.create_or_get_anatomical_entity("Ob")
via1 = self.create_or_get_anatomical_entity('V1a')
destination1 = self.create_or_get_anatomical_entity('Da')

cs.origins.add(origin1, origin2)

Expand All @@ -42,10 +52,7 @@ def test_journey_simple_graph_with_jump(self):
[('Ob', 0), ('Da', 2)]
]

initial_query_count = len(connection.queries)
all_paths = generate_paths(origins, vias, destinations)
new_query_count = len(connection.queries) - initial_query_count
self.assertTrue(new_query_count == 0)

all_paths.sort()
expected_paths.sort()
Expand All @@ -70,9 +77,9 @@ def test_journey_simple_direct_graph(self):
cs = ConnectivityStatement.objects.create(sentence=sentence)

# Create Anatomical Entities
origin1 = AnatomicalEntity.objects.create(name='Oa')
origin2 = AnatomicalEntity.objects.create(name='Ob')
destination1 = AnatomicalEntity.objects.create(name='Da')
origin1 = self.create_or_get_anatomical_entity("Oa")
origin2 = self.create_or_get_anatomical_entity("Ob")
destination1 = self.create_or_get_anatomical_entity('Da')

# Add origins
cs.origins.add(origin1, origin2)
Expand Down Expand Up @@ -115,10 +122,10 @@ def test_journey_simple_graph_no_jumps(self):
cs = ConnectivityStatement.objects.create(sentence=sentence)

# Create Anatomical Entities
origin1 = AnatomicalEntity.objects.create(name='Oa')
origin2 = AnatomicalEntity.objects.create(name='Ob')
via1 = AnatomicalEntity.objects.create(name='V1a')
destination1 = AnatomicalEntity.objects.create(name='Da')
origin1 = self.create_or_get_anatomical_entity("Oa")
origin2 = self.create_or_get_anatomical_entity("Ob")
via1 = self.create_or_get_anatomical_entity('V1a')
destination1 = self.create_or_get_anatomical_entity('Da')

# Add origins
cs.origins.add(origin1, origin2)
Expand Down Expand Up @@ -170,11 +177,11 @@ def test_journey_multiple_vias_no_jumps(self):
cs = ConnectivityStatement.objects.create(sentence=sentence)

# Create Anatomical Entities
origin1 = AnatomicalEntity.objects.create(name='Oa')
origin2 = AnatomicalEntity.objects.create(name='Ob')
via1 = AnatomicalEntity.objects.create(name='V1a')
via2 = AnatomicalEntity.objects.create(name='V1b')
destination1 = AnatomicalEntity.objects.create(name='Da')
origin1 = self.create_or_get_anatomical_entity("Oa")
origin2 = self.create_or_get_anatomical_entity("Ob")
via1 = self.create_or_get_anatomical_entity('V1a')
via2 = self.create_or_get_anatomical_entity('V1b')
destination1 = self.create_or_get_anatomical_entity('Da')

# Add origins
cs.origins.add(origin1, origin2)
Expand Down Expand Up @@ -232,15 +239,14 @@ def test_journey_complex_graph(self):
sentence = Sentence.objects.create()
cs = ConnectivityStatement.objects.create(sentence=sentence)

# Create Anatomical Entities
origin_a = AnatomicalEntity.objects.create(name='Oa')
origin_b = AnatomicalEntity.objects.create(name='Ob')
via1_a = AnatomicalEntity.objects.create(name='V1a')
via2_a = AnatomicalEntity.objects.create(name='V2a')
via2_b = AnatomicalEntity.objects.create(name='V2b')
via3_a = AnatomicalEntity.objects.create(name='V3a')
via4_a = AnatomicalEntity.objects.create(name='V4a')
destination_a = AnatomicalEntity.objects.create(name='Da')
origin_a = self.create_or_get_anatomical_entity("Oa")
origin_b = self.create_or_get_anatomical_entity("Ob")
via1_a = self.create_or_get_anatomical_entity('V1a')
via2_a = self.create_or_get_anatomical_entity('V2a')
via2_b = self.create_or_get_anatomical_entity('V2b')
via3_a = self.create_or_get_anatomical_entity('V3a')
via4_a = self.create_or_get_anatomical_entity('V4a')
destination_a = self.create_or_get_anatomical_entity('Da')

# Add origins
cs.origins.add(origin_a, origin_b)
Expand Down Expand Up @@ -304,18 +310,17 @@ def test_journey_complex_graph_2(self):
sentence = Sentence.objects.create()
cs = ConnectivityStatement.objects.create(sentence=sentence)

# Create Anatomical Entities
origin_a = AnatomicalEntity.objects.create(name='Oa')
origin_b = AnatomicalEntity.objects.create(name='Ob')
via1_a = AnatomicalEntity.objects.create(name='V1a')
via2_a = AnatomicalEntity.objects.create(name='V2a')
via2_b = AnatomicalEntity.objects.create(name='V2b')
via3_a = AnatomicalEntity.objects.create(name='V3a')
via4_a = AnatomicalEntity.objects.create(name='V4a')
via5_a = AnatomicalEntity.objects.create(name='V5a')
via5_b = AnatomicalEntity.objects.create(name='V5b')
via6_a = AnatomicalEntity.objects.create(name='V6a')
destination_a = AnatomicalEntity.objects.create(name='Da')
origin_a = self.create_or_get_anatomical_entity("Oa")
origin_b = self.create_or_get_anatomical_entity("Ob")
via1_a = self.create_or_get_anatomical_entity('V1a')
via2_a = self.create_or_get_anatomical_entity('V2a')
via2_b = self.create_or_get_anatomical_entity('V2b')
via3_a = self.create_or_get_anatomical_entity('V3a')
via4_a = self.create_or_get_anatomical_entity('V4a')
via5_a = self.create_or_get_anatomical_entity('V5a')
via5_b = self.create_or_get_anatomical_entity('V5b')
via6_a = self.create_or_get_anatomical_entity('V6a')
destination_a = self.create_or_get_anatomical_entity('Da')

# Add origins
cs.origins.add(origin_a, origin_b)
Expand Down Expand Up @@ -405,9 +410,9 @@ def test_journey_cycles(self):
cs = ConnectivityStatement.objects.create(sentence=sentence)

# Create Anatomical Entities
origin1 = AnatomicalEntity.objects.create(name='Oa')
origin2 = AnatomicalEntity.objects.create(name='Ob')
destination1 = AnatomicalEntity.objects.create(name='Da')
origin1 = self.create_or_get_anatomical_entity("Oa")
origin2 = self.create_or_get_anatomical_entity("Ob")
destination1 = self.create_or_get_anatomical_entity('Da')

# Add origins
cs.origins.add(origin1, origin2)
Expand Down Expand Up @@ -459,10 +464,10 @@ def test_journey_nonconsecutive_vias(self):
sentence = Sentence.objects.create()
cs = ConnectivityStatement.objects.create(sentence=sentence)

origin1 = AnatomicalEntity.objects.create(name='Oa')
via1 = AnatomicalEntity.objects.create(name='V1a')
via2 = AnatomicalEntity.objects.create(name='V2a')
destination1 = AnatomicalEntity.objects.create(name='Da')
origin1 = self.create_or_get_anatomical_entity("Oa")
via1 = self.create_or_get_anatomical_entity('V1a')
via2 = self.create_or_get_anatomical_entity("V2a")
destination1 = self.create_or_get_anatomical_entity('Da')

cs.origins.add(origin1)

Expand Down Expand Up @@ -508,3 +513,55 @@ def test_journey_nonconsecutive_vias(self):
expected_journey.sort()
self.assertTrue(journey_paths == expected_journey,
f"Expected journey {expected_journey}, but found {journey_paths}")

def test_journey_implicit_from_entities(self):
# Test setup
sentence = Sentence.objects.create()
cs = ConnectivityStatement.objects.create(sentence=sentence)

origin1 = self.create_or_get_anatomical_entity("Myenteric")
via1 = self.create_or_get_anatomical_entity('Longitudinal')
via2 = self.create_or_get_anatomical_entity("Serosa")
via3 = self.create_or_get_anatomical_entity("lumbar")
destination1 = self.create_or_get_anatomical_entity('inferior')

cs.origins.add(origin1)

via_a = Via.objects.create(connectivity_statement=cs)
via_a.anatomical_entities.add(via1)

via_b = Via.objects.create(connectivity_statement=cs)
via_b.anatomical_entities.add(via2)

via_c = Via.objects.create(connectivity_statement=cs)
via_c.anatomical_entities.add(via3)

destination = Destination.objects.create(connectivity_statement=cs)
destination.anatomical_entities.add(destination1)

# Prefetch related data
origins = list(cs.origins.all())
vias = list(
Via.objects.filter(connectivity_statement=cs).prefetch_related('anatomical_entities', 'from_entities'))
destinations = list(
Destination.objects.filter(connectivity_statement=cs).prefetch_related('anatomical_entities',
'from_entities'))

expected_paths = [
[('Myenteric', 0), ('Longitudinal', 1), ('Serosa', 2), ('lumbar', 3), ('inferior', 4)],
]

all_paths = generate_paths(origins, vias, destinations)

all_paths.sort()
expected_paths.sort()
self.assertTrue(all_paths == expected_paths, f"Expected paths {expected_paths}, but found {all_paths}")

journey_paths = consolidate_paths(all_paths)
expected_journey = [
[('Myenteric', 0), ('Longitudinal', 1), ('Serosa', 2), ('lumbar', 3), ('inferior', 4)],
]
journey_paths.sort()
expected_journey.sort()
self.assertTrue(journey_paths == expected_journey,
f"Expected journey {expected_journey}, but found {journey_paths}")
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export const OriginNodeWidget: React.FC<OriginNodeProps> = ({
lineHeight: "1.25rem",
}}
>
Intermediolateral nucleus of eleventh thoracic segment
{model.name}
</Typography>
<Typography
sx={{
Expand Down

0 comments on commit b551772

Please sign in to comment.