Skip to content

Commit

Permalink
refactor(transforms): extract detection and transform enable/disable …
Browse files Browse the repository at this point in the history
…logic
  • Loading branch information
PaulHax committed Sep 12, 2024
1 parent 3bbb260 commit ca03a2e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 98 deletions.
4 changes: 0 additions & 4 deletions src/nrtk_explorer/app/trame_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,3 @@ def on_state(**kwargs):
return callback

return decorator


def boolean_turned_true(old, new):
return old is False and new
142 changes: 57 additions & 85 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
"""

import logging
from typing import Dict
from typing import Dict, Callable

from trame.ui.quasar import QLayout
from trame.widgets import quasar
from trame.widgets import html
from trame.app import get_server, asynchronous
from trame_server import Server

import nrtk_explorer.library.transforms as trans
import nrtk_explorer.library.nrtk_transforms as nrtk_trans
Expand All @@ -28,7 +29,6 @@
SetStateAsync,
change_checker,
delete_state,
boolean_turned_true,
)
from nrtk_explorer.app.images.image_ids import (
dataset_id_to_image_id,
Expand All @@ -52,6 +52,45 @@
logger.setLevel(logging.INFO)


class ProcessingStep:
def __init__(
self,
server: Server,
feature_enabled_state_key: str,
gui_switch_key: str,
column_name: str,
enabled_callback: Callable,
):
self.state = server.state
self.feature_enabled_state_key = feature_enabled_state_key
self.gui_switch_key = gui_switch_key
self.enabled_callback = enabled_callback
self.column_name = column_name
self.state.change(self.gui_switch_key)(self.on_gui_switch)
self.update_feature_enabled_state()
self.state.change("visible_columns", self.gui_switch_key)(
self.update_feature_enabled_state
)
self.state.change(self.feature_enabled_state_key)(self.on_change_feature_enabled)

def on_gui_switch(self, **kwargs):
if self.state[self.gui_switch_key]:
self.state.visible_columns = list(set([*self.state.visible_columns, self.column_name]))
else:
self.state.visible_columns = [
col for col in self.state.visible_columns if col != self.column_name
]

def update_feature_enabled_state(self, **kwargs):
self.state[self.feature_enabled_state_key] = (
self.column_name in self.state.visible_columns and self.state[self.gui_switch_key]
)

def on_change_feature_enabled(self, **kwargs):
if self.state[self.feature_enabled_state_key]:
self.enabled_callback()


class TransformsApp(Applet):
def __init__(
self,
Expand Down Expand Up @@ -124,90 +163,23 @@ def delete_meta_state(old_ids, new_ids):
self.state.transforms = [k for k in self._transforms.keys()]
self.state.current_transform = self.state.transforms[0]

# Annotations enabled control via predictions_original_images_enabled ###
def update_prediction_original_images_enabled(**kwargs):
self.state.predictions_original_images_enabled = (
"original" in self.state.visible_columns and self.state.annotations_enabled_switch
)

update_prediction_original_images_enabled()
self.state.change("visible_columns", "annotations_enabled_switch")(
update_prediction_original_images_enabled
)

def on_change_predictions_original_images_enabled(**kwargs):
if self.state.predictions_original_images_enabled:
# run whole pipeline, so possibly compute transforms too, as transforms compute scores based on original images
self._start_update_images(self.visible_ids)

self.state.change("predictions_original_images_enabled")(
on_change_predictions_original_images_enabled
# On annotations enabled, run whole pipeline to possibly compute transforms. Why: Transforms compute scores based on original images
self.annotations_enable_control = ProcessingStep(
server,
feature_enabled_state_key="predictions_original_images_enabled",
gui_switch_key="annotations_enabled_switch",
column_name=ORIGINAL_COLUMNS[0],
enabled_callback=self._start_update_images,
)

def turn_on_original_columns(_=None, __=None):
if any(col not in ORIGINAL_COLUMNS for col in self.state.visible_columns):
self.state.visible_columns = list(
set([*self.state.visible_columns, *ORIGINAL_COLUMNS])
)

change_checker(self.state, "annotations_enabled_switch", boolean_turned_true)(
turn_on_original_columns
self.transform_enable_control = ProcessingStep(
server,
feature_enabled_state_key="transform_enabled",
gui_switch_key="transform_enabled_switch",
column_name=TRANSFORM_COLUMNS[0],
enabled_callback=self.schedule_transformed_images,
)

def update_original_related_visible_columns(**kwargs):
if self.state.predictions_original_images_enabled:
turn_on_original_columns()
else:
if "original" in self.state.visible_columns:
self.state.visible_columns = [
col for col in self.state.visible_columns if col != "original"
]

update_original_related_visible_columns()
self.state.change("predictions_original_images_enabled")(
update_original_related_visible_columns
)
# end annotations enabled control ###

# Transform enabled control via transform_enabled ###
def update_transform_enabled(**kwargs):
self.state.transform_enabled = (
"transformed" in self.state.visible_columns and self.state.transform_enabled_switch
)

update_transform_enabled()
self.state.change("visible_columns", "transform_enabled_switch")(update_transform_enabled)

def transform_became_enabled(old, new):
return old is False and new

change_checker(self.state, "transform_enabled", transform_became_enabled)(
self.schedule_transformed_images
)

def turn_on_transform_columns(_=None, __=None):
if any(col not in TRANSFORM_COLUMNS for col in self.state.visible_columns):
self.state.visible_columns = list(
set([*self.state.visible_columns, *TRANSFORM_COLUMNS])
)

change_checker(self.state, "transform_enabled_switch", boolean_turned_true)(
turn_on_transform_columns
)

def update_transform_related_visible_columns(**kwargs):
if self.state.transform_enabled:
turn_on_transform_columns()
else:
if "transformed" in self.state.visible_columns:
self.state.visible_columns = [
col for col in self.state.visible_columns if col != "transformed"
]

update_transform_related_visible_columns()
self.state.change("transform_enabled")(update_transform_related_visible_columns)
# End transform enabled control ###

self.server.controller.add("on_server_ready")(self.on_server_ready)
self.server.controller.apply_transform.add(self.schedule_transformed_images)
self._on_hover_fn = None
Expand Down Expand Up @@ -352,17 +324,17 @@ async def _update_images(self, dataset_ids):
async with SetStateAsync(self.state):
await self.update_transformed_images(dataset_ids)

def _start_update_images(self, priority_ids):
def _start_update_images(self):
if hasattr(self, "_update_task"):
self._update_task.cancel()
self._update_task = asynchronous.create_task(self._update_images(priority_ids))
self._update_task = asynchronous.create_task(self._update_images(self.visible_ids))

def _updating_images(self):
return hasattr(self, "_update_task") and not self._update_task.done()

def on_scroll(self, visible_ids):
self.visible_ids = visible_ids
self._start_update_images(self.visible_ids)
self._start_update_images()

def on_image_hovered(self, id):
self.state.hovered_id = id
Expand Down
16 changes: 11 additions & 5 deletions src/nrtk_explorer/app/ui/image_list.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from trame.widgets import html, quasar, client
from trame.app import get_server
from nrtk_explorer.app.trame_utils import change_checker
from nrtk_explorer.widgets.nrtk_explorer import ImageDetection
from nrtk_explorer.app.images.image_ids import get_image_state_keys

Expand Down Expand Up @@ -44,18 +45,23 @@
state.visible_columns = [col["name"] for col in COLUMNS]


def make_dependent_columns_handler(columns):
def make_dependent_columns_handler(state, columns):
toggle_column = columns[0]
dependent_columns = columns[1:]

def column_toggler(**kwargs):
def column_toggler(old_columns, new_columns):
dependant_columns_visible = any(col in state.visible_columns for col in dependent_columns)
if toggle_column not in state.visible_columns and dependant_columns_visible:
state.visible_columns = [
col for col in state.visible_columns if col not in dependent_columns
]
return

return column_toggler
toggle_column_turned_on = toggle_column in new_columns and toggle_column not in old_columns
if toggle_column_turned_on:
state.visible_columns = list(set([*state.visible_columns, *dependent_columns]))

change_checker(state, "visible_columns")(column_toggler)


ORIGINAL_COLUMNS = [
Expand All @@ -64,7 +70,7 @@ def column_toggler(**kwargs):
]


state.change("visible_columns")(make_dependent_columns_handler(ORIGINAL_COLUMNS))
make_dependent_columns_handler(state, ORIGINAL_COLUMNS)


TRANSFORM_COLUMNS = [
Expand All @@ -73,7 +79,7 @@ def column_toggler(**kwargs):
"original_detection_to_transformed_detection_score",
]

state.change("visible_columns")(make_dependent_columns_handler(TRANSFORM_COLUMNS))
make_dependent_columns_handler(state, TRANSFORM_COLUMNS)


state.client_only("image_list_ids")
Expand Down
4 changes: 0 additions & 4 deletions src/nrtk_explorer/library/embeddings_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
IMAGE_MODEL_RESOLUTION = (224, 224)


def prepare_for_model(img):
"""Prepare image for model input"""


# Create a dataset for images
class ImagesDataset(Dataset):
def __init__(self, images):
Expand Down

0 comments on commit ca03a2e

Please sign in to comment.