diff --git a/backend/composer/admin.py b/backend/composer/admin.py index f43cc983..62348f49 100644 --- a/backend/composer/admin.py +++ b/backend/composer/admin.py @@ -231,7 +231,7 @@ class ConnectivityStatementAdmin( list_per_page = 10 # The name of one or more FSMFields on the model to transition fsm_field = ("state",) - readonly_fields = ("state",) + readonly_fields = ("state", "journey_path") autocomplete_fields = ("sentence", "origins") date_hierarchy = "modified_date" list_display = ( diff --git a/backend/composer/api/views.py b/backend/composer/api/views.py index b383ebe8..d2516399 100644 --- a/backend/composer/api/views.py +++ b/backend/composer/api/views.py @@ -252,20 +252,23 @@ class ModelRetrieveViewSet( # mixins.DestroyModelMixin, # mixins.ListModelMixin, viewsets.GenericViewSet, -): ... +): + ... class ModelCreateRetrieveViewSet( ModelRetrieveViewSet, mixins.CreateModelMixin, mixins.ListModelMixin, -): ... +): + ... class ModelNoDeleteViewSet( ModelCreateRetrieveViewSet, mixins.UpdateModelMixin, -): ... +): + ... class AnatomicalEntityViewSet(viewsets.ReadOnlyModelViewSet): @@ -506,7 +509,8 @@ def my(self, request, *args, **kwargs): msg = "User not logged in." raise ValidationError(msg, code="authorization") - profile, created = Profile.objects.get_or_create(user=self.request.user) + profile, created = Profile.objects.get_or_create( + user=self.request.user) return Response(self.get_serializer(profile).data) @@ -547,6 +551,7 @@ class StatementAlertViewSet(viewsets.ModelViewSet): serializer_class = StatementAlertSerializer permission_classes = [IsOwnerOfConnectivityStatementOrReadOnly] + @extend_schema( responses=OpenApiTypes.OBJECT, ) diff --git a/backend/composer/migrations/0066_connectivitystatement_journey_path.py b/backend/composer/migrations/0066_connectivitystatement_journey_path.py new file mode 100644 index 00000000..42b80b08 --- /dev/null +++ b/backend/composer/migrations/0066_connectivitystatement_journey_path.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.4 on 2024-12-18 08:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("composer", "0065_alter_statementalert_text"), + ] + + operations = [ + migrations.AddField( + model_name="connectivitystatement", + name="journey_path", + field=models.JSONField(blank=True, null=True), + ), + ] diff --git a/backend/composer/migrations/0067_auto_20241218_0928.py b/backend/composer/migrations/0067_auto_20241218_0928.py new file mode 100644 index 00000000..df4fa235 --- /dev/null +++ b/backend/composer/migrations/0067_auto_20241218_0928.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.4 on 2024-12-18 08:28 + +from django.db import migrations +from django.db import migrations +from composer.services.graph_service import compile_journey + + +def update_journey_fields(apps, schema_editor): + ConnectivityStatement = apps.get_model('composer', 'ConnectivityStatement') + for cs in ConnectivityStatement.objects.all(): + cs.journey_path = compile_journey(cs) + cs.save(update_fields=["journey_path"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ("composer", "0066_connectivitystatement_journey_path"), + ] + + operations = [ + migrations.RunPython(update_journey_fields), + ] diff --git a/backend/composer/models.py b/backend/composer/models.py index 0d780632..8aa68c2e 100644 --- a/backend/composer/models.py +++ b/backend/composer/models.py @@ -4,6 +4,7 @@ from django.db.models.expressions import F from django.forms.widgets import Input as InputWidget from django_fsm import FSMField, transition +from composer.services.graph_service import build_journey_description, build_journey_entities from composer.services.layers_service import update_from_entities_on_deletion from composer.services.state_services import ( @@ -21,7 +22,6 @@ ViaType, Projection, ) -from .services.graph_service import compile_journey from .utils import doi_uri, pmcid_uri, pmid_uri, create_reference_uri @@ -567,6 +567,7 @@ class ConnectivityStatement(models.Model): ) created_date = models.DateTimeField(auto_now_add=True, db_index=True) modified_date = models.DateTimeField(auto_now=True, db_index=True) + journey_path = models.JSONField(null=True, blank=True) def __str__(self): suffix = "" @@ -683,11 +684,10 @@ def get_previous_layer_entities(self, via_order): return set(self.via_set.get(order=via_order - 1).anatomical_entities.all()) def get_journey(self): - return compile_journey(self)['journey'] + return build_journey_description(self.journey_path, self) def get_entities_journey(self): - entities_journey = compile_journey(self)['entities'] - return entities_journey + return build_journey_entities(self.journey_path, self) def get_laterality_description(self): laterality_map = { @@ -716,7 +716,6 @@ def save(self, *args, **kwargs): self.reference_uri = create_reference_uri(self.pk) self.save(update_fields=["reference_uri"]) - def set_origins(self, origin_ids): self.origins.set(origin_ids, clear=True) self.save() diff --git a/backend/composer/services/graph_service.py b/backend/composer/services/graph_service.py index e2b36d13..537b44d4 100644 --- a/backend/composer/services/graph_service.py +++ b/backend/composer/services/graph_service.py @@ -88,9 +88,7 @@ def consolidate_paths(paths): paths = consolidated + [paths[i] for i in range(len(paths)) if i not in used_indices] - return paths, [[((node[1].replace(JOURNEY_DELIMITER, ' or '), node[2]) if ( - node[2] == 0 or path.index(node) == len(path) - 1) else ( - node[1].replace(JOURNEY_DELIMITER, ', '), node[2])) for node in path] for path in paths] + return paths def can_merge(path1, path2): @@ -144,7 +142,7 @@ def merge_paths(path1, path2): return merged_path -def compile_journey(connectivity_statement) -> dict: +def compile_journey(connectivity_statement) -> List[str]: """ Generates a string of descriptions of journey paths for a given connectivity statement. @@ -157,17 +155,55 @@ def compile_journey(connectivity_statement) -> dict: # Extract origins, vias, and destinations from the connectivity statement Via = apps.get_model('composer', 'Via') Destination = apps.get_model('composer', 'Destination') + Origin = apps.get_model('composer', 'AnatomicalEntity') - origins = list(connectivity_statement.origins.all()) - - vias = list(Via.objects.filter(connectivity_statement=connectivity_statement)) - destinations = list(Destination.objects.filter(connectivity_statement=connectivity_statement)) + vias = list(Via.objects.filter(connectivity_statement__id=connectivity_statement.id)) + destinations = list(Destination.objects.filter(connectivity_statement__id=connectivity_statement.id)) + origins = list(Origin.objects.filter( + origins_relations__id=connectivity_statement.id).distinct()) # Generate all paths and then consolidate them all_paths2 = generate_paths(origins, vias, destinations) - consolidated_paths, journey_paths = consolidate_paths(all_paths2) + consolidated_paths = consolidate_paths(all_paths2) + return consolidated_paths + + +def get_journey_path_from_consolidated_paths(consolidated_paths): + journey_paths = [[((node[1].replace(JOURNEY_DELIMITER, ' or '), node[2]) if ( + node[2] == 0 or path.index(node) == len(path) - 1) else ( + node[1].replace(JOURNEY_DELIMITER, ', '), node[2])) for node in path] for path in consolidated_paths] + return journey_paths + + +def build_journey_description(consolidated_paths, connectivity_statement): + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_paths) + Via = apps.get_model('composer', 'Via') + vias = list(Via.objects.filter( + connectivity_statement=connectivity_statement)) + + # Create sentences for each journey path + journey_descriptions = [] + for path in journey_paths: + origin_names = path[0][0] + destination_names = path[-1][0] + via_names = ' via '.join([node for node, layer in path if 0 < layer < len(vias) + 1]) + + if via_names: + sentence = f"from {origin_names} to {destination_names} via {via_names}" + else: + sentence = f"from {origin_names} to {destination_names}" + journey_descriptions.append(sentence) + return journey_descriptions + + +def build_journey_entities(consolidated_paths, connectivity_statement): entities = [] + Via = apps.get_model('composer', 'Via') + vias = list(Via.objects.filter( + connectivity_statement=connectivity_statement)) + for path in consolidated_paths: origin_splits = path[0][0].split(JOURNEY_DELIMITER) destination_splits = path[-1][0].split(JOURNEY_DELIMITER) @@ -181,22 +217,9 @@ def compile_journey(connectivity_statement) -> dict: 'vias': [{'label': node, 'id': node_id} for node_id, node, layer in path if 0 < layer < len(vias) + 1] } entities.append(entity) + return entities - # Create sentences for each journey path - journey_descriptions = [] - for path in journey_paths: - origin_names = path[0][0] - destination_names = path[-1][0] - via_names = ' via '.join([node for node, layer in path if 0 < layer < len(vias) + 1]) - - if via_names: - sentence = f"from {origin_names} to {destination_names} via {via_names}" - else: - sentence = f"from {origin_names} to {destination_names}" - - journey_descriptions.append(sentence) - return { - 'journey': journey_descriptions, - 'entities': entities - } \ No newline at end of file +def recompile_journey_path(instance): + instance.journey_path = compile_journey(instance) + instance.save(update_fields=["journey_path"]) diff --git a/backend/composer/signals.py b/backend/composer/signals.py index 2739ca99..aad972b1 100644 --- a/backend/composer/signals.py +++ b/backend/composer/signals.py @@ -19,6 +19,8 @@ Via, ) from .services.export_services import compute_metrics, ConnectivityStatementStateService +from .services.graph_service import recompile_journey_path + @receiver(post_save, sender=ExportBatch) @@ -101,6 +103,8 @@ def connectivity_statement_origins_changed(sender, instance, action, pk_set, **k pass except ValueError: pass + recompile_journey_path(instance) + # Call `update_from_entities_on_deletion` for each removed entity if action == "post_remove" and pk_set: @@ -118,6 +122,7 @@ def via_anatomical_entities_changed(sender, instance, action, pk_set, **kwargs): pass except ValueError: pass + recompile_journey_path(instance.connectivity_statement) # Call `update_from_entities_on_deletion` for each removed entity if action == "post_remove" and pk_set: @@ -135,6 +140,8 @@ def via_from_entities_changed(sender, instance, action, **kwargs): pass except ValueError: pass + recompile_journey_path(instance.connectivity_statement) + # Signals for Destination anatomical_entities @@ -147,6 +154,8 @@ def destination_anatomical_entities_changed(sender, instance, action, **kwargs): pass except ValueError: pass + recompile_journey_path(instance.connectivity_statement) + # Signals for Destination from_entities @@ -159,6 +168,8 @@ def destination_from_entities_changed(sender, instance, action, **kwargs): pass except ValueError: pass + recompile_journey_path(instance.connectivity_statement) + # Signals for Via model changes diff --git a/backend/tests/test_journey.py b/backend/tests/test_journey.py index 2ce6b9c2..1a0e72e3 100644 --- a/backend/tests/test_journey.py +++ b/backend/tests/test_journey.py @@ -2,7 +2,7 @@ from django.test import TestCase, override_settings from composer.models import Sentence, ConnectivityStatement, AnatomicalEntity, AnatomicalEntityMeta, Via, Destination -from composer.services.graph_service import generate_paths, consolidate_paths +from composer.services.graph_service import generate_paths, consolidate_paths, get_journey_path_from_consolidated_paths @override_settings(DEBUG=True) @@ -58,7 +58,9 @@ def test_journey_simple_graph_with_jump(self): expected_paths.sort() self.assertTrue(all_paths == expected_paths) - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) expected_journey = [ [('Oa', 0), ('V1a', 1), ('Da', 2)], [('Ob', 0), ('Da', 2)] @@ -113,7 +115,9 @@ def test_journey_simple_direct_graph(self): expected_paths.sort() self.assertTrue(all_paths == expected_paths) - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) expected_consolidated_path = [ [('1\\2', 'Oa\\Ob', 0), ('3', 'Da', 1)], ] @@ -174,7 +178,9 @@ def test_journey_simple_graph_no_jumps(self): expected_consolidated_path = [ [('1\\2', 'Oa\\Ob', 0), ('3', 'V1a', 1), ('4', 'Da', 2)], ] - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) self.assertTrue(journey_paths == expected_journey) self.assertTrue(consolidated_path == expected_consolidated_path) @@ -239,7 +245,9 @@ def test_journey_multiple_vias_no_jumps(self): [('1\\2', 'Oa\\Ob', 0), ('3\\4', 'V1a\\V1b', 1), ('5', 'Da', 2)] ] - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) self.assertTrue(journey_paths == expected_journey) self.assertTrue(consolidated_path == expected_consolidated_path) @@ -327,7 +335,9 @@ def test_journey_complex_graph(self): [('1\\2', 'Oa\\Ob', 0), ('3', 'V1a', 1), ('5', 'V2b', 2), ('8', 'Da', 5)], [('1', 'Oa', 0), ('6', 'V3a', 3), ('8', 'Da', 5)] ] - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) self.assertTrue(journey_paths == expected_journey) self.assertTrue(consolidated_path == expected_consolidated_path) @@ -428,7 +438,9 @@ def test_journey_complex_graph_2(self): [('1\\2', 'Oa\\Ob', 0), ('3', 'V1a', 1), ('6', 'V3a', 3), ('7', 'V4a', 4), ('9', 'V5b', 5), ('11', 'Da', 7)] ] - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) journey_paths.sort() expected_journey.sort() expected_consolidated_path.sort() @@ -493,7 +505,9 @@ def test_journey_cycles(self): [('1\\2', 'Oa\\Ob', 0), ('3', 'Da', 2)] ] - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) expected_journey.sort() expected_consolidated_path.sort() journey_paths.sort() @@ -546,7 +560,9 @@ def test_journey_nonconsecutive_vias(self): expected_paths.sort() self.assertTrue(all_paths == expected_paths, f"Expected paths {expected_paths}, but found {all_paths}") - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) expected_journey = [ [('Oa', 0), ('V1a', 3), ('V2a', 6), ('Da', 7)], ] @@ -603,7 +619,9 @@ def test_journey_implicit_from_entities(self): expected_paths.sort() self.assertTrue(all_paths == expected_paths, f"Expected paths {expected_paths}, but found {all_paths}") - consolidated_path, journey_paths = consolidate_paths(all_paths) + consolidated_path = consolidate_paths(all_paths) + journey_paths = get_journey_path_from_consolidated_paths( + consolidated_path) expected_journey = [ [('Myenteric', 0), ('Longitudinal', 1), ('Serosa', 2), ('lumbar', 3), ('inferior', 4)], ]