diff --git a/pom.xml b/pom.xml index dda02c4..673c99f 100755 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ sc.fiji pom-indago - 2.2.6 + 2.2.7 @@ -40,11 +40,17 @@ com.indago tr2d + + + com.miglayout + miglayout + + net.imglib2 imglib2-labkit - 0.1.7 + 0.1.10 hr.irb.fastRandomForest @@ -55,7 +61,11 @@ miglayout - true + + + com.miglayout + miglayout + swing diff --git a/src/main/java/com/indago/tr2d/plugins/seg/LabkitPanel.java b/src/main/java/com/indago/tr2d/plugins/seg/LabkitPanel.java index 71f8741..b1d7bc0 100644 --- a/src/main/java/com/indago/tr2d/plugins/seg/LabkitPanel.java +++ b/src/main/java/com/indago/tr2d/plugins/seg/LabkitPanel.java @@ -3,7 +3,6 @@ import com.indago.tr2d.ui.model.Tr2dModel; import net.imglib2.RandomAccessibleInterval; -import net.imglib2.labkit.SegmentationComponent; import net.imglib2.type.numeric.integer.IntType; import org.scijava.Context; import org.scijava.log.Logger; @@ -21,16 +20,14 @@ public class LabkitPanel { public LabkitPanel(Context context, Tr2dModel model, Logger log) { this.log = log; - boolean isTimeSeries = true; - segmentation = createSegmentationComponent(context, model, isTimeSeries); + segmentation = createSegmentationComponent(context, model); } private SegmentationComponent createSegmentationComponent(Context context, - Tr2dModel model, boolean isTimeSeries) + Tr2dModel model) { try { - return new SegmentationComponent(context, null, model.getRawData(), - isTimeSeries); + return new SegmentationComponent(context, model.getRawData()); } catch (NoClassDefFoundError e) { return null; @@ -48,9 +45,8 @@ public JPanel getPanel() { } private void calculateOutputs() { - outputs = isUsable() && segmentation.isTrained() ? Collections - .singletonList(segmentation.getSegmentation(new IntType())) : Collections - .emptyList(); + outputs = isUsable() ? segmentation.getSegmentations(new IntType()) + : Collections.emptyList(); } public boolean isUsable() { diff --git a/src/main/java/com/indago/tr2d/plugins/seg/MySegmentationItem.java b/src/main/java/com/indago/tr2d/plugins/seg/MySegmentationItem.java new file mode 100644 index 0000000..3e9ed54 --- /dev/null +++ b/src/main/java/com/indago/tr2d/plugins/seg/MySegmentationItem.java @@ -0,0 +1,84 @@ + +package com.indago.tr2d.plugins.seg; + +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.cell.CellImgFactory; +import net.imglib2.labkit.labeling.Labeling; +import net.imglib2.labkit.models.SegmentationModel; +import net.imglib2.labkit.models.DefaultHolder; +import net.imglib2.labkit.models.Holder; +import net.imglib2.labkit.models.SegmentationItem; +import net.imglib2.labkit.segmentation.Segmenter; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.IntegerType; +import net.imglib2.type.numeric.real.DoubleType; +import net.imglib2.util.ConstantUtils; +import net.imglib2.view.Views; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.DoubleToIntFunction; + +public class MySegmentationItem extends SegmentationItem { + + private final Labeling labeling; + + private final SegmentationModel model; + + private final Holder> thresholds = new DefaultHolder<>( + Collections.singletonList(0.5)); + + public MySegmentationItem(SegmentationModel model, Segmenter segmenter) { + super(model, segmenter); + this.model = model; + labeling = new Labeling(Arrays.asList("background", "foreground"), model + .image()); + } + + public & NativeType> RandomAccessibleInterval + apply(T type) + { + RandomAccessibleInterval image = model.image(); + RandomAccessibleInterval labels = new CellImgFactory<>(type).create( + image); + RandomAccessibleInterval dummy = ConstantUtils + .constantRandomAccessibleInterval(new DoubleType(), image.numDimensions(), + image); + RandomAccessibleInterval probability = new CellImgFactory<>( + new DoubleType()).create(image); + RandomAccessibleInterval probabilities = Views.stack(dummy, + probability); + segmenter().predict(image, probabilities); + DoubleToIntFunction thresholdFunction = thresholdFunction(); + Views.interval(Views.pair(probability, labels), labels).forEach(pair -> pair + .getB().setInteger(thresholdFunction.applyAsInt(pair.getA().get()))); + return labels; + } + + public DoubleToIntFunction thresholdFunction() { + double[] thresholds = thresholds().get().stream().mapToDouble(x -> x) + .toArray(); + Arrays.sort(thresholds); + return (p) -> { + int result = 0; + for (double threshold : thresholds) + if (p < threshold) break; + else result++; + return result; + }; + } + + public Labeling labeling() { + return labeling; + } + + public void train() { + segmenter().train(Collections.singletonList(model.image()), Collections + .singletonList(labeling)); + } + + public Holder> thresholds() { + return thresholds; + } +} diff --git a/src/main/java/com/indago/tr2d/plugins/seg/PredictionLayer.java b/src/main/java/com/indago/tr2d/plugins/seg/PredictionLayer.java new file mode 100644 index 0000000..c8b4cc8 --- /dev/null +++ b/src/main/java/com/indago/tr2d/plugins/seg/PredictionLayer.java @@ -0,0 +1,144 @@ + +package com.indago.tr2d.plugins.seg; + +import bdv.util.volatiles.SharedQueue; +import bdv.util.volatiles.VolatileViews; +import net.imglib2.RandomAccessible; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.converter.Converter; +import net.imglib2.converter.Converters; +import net.imglib2.labkit.bdv.BdvLayer; +import net.imglib2.labkit.bdv.BdvShowable; +import net.imglib2.labkit.models.Holder; +import net.imglib2.labkit.models.SegmentationItem; +import net.imglib2.labkit.models.SegmentationResultsModel; +import net.imglib2.labkit.segmentation.Segmenter; +import net.imglib2.labkit.utils.Notifier; +import net.imglib2.labkit.utils.RandomAccessibleContainer; +import net.imglib2.realtransform.AffineTransform3D; +import net.imglib2.type.numeric.ARGBType; +import net.imglib2.type.numeric.NumericType; +import net.imglib2.type.volatiles.VolatileARGBType; +import net.imglib2.type.volatiles.VolatileFloatType; +import net.imglib2.util.ConstantUtils; +import net.imglib2.view.Views; + +import java.util.Collections; +import java.util.Set; +import java.util.WeakHashMap; +import java.util.function.DoubleToIntFunction; +import java.util.function.IntUnaryOperator; +import java.util.stream.IntStream; + +public class PredictionLayer implements BdvLayer { + + private final Holder model; + private final RandomAccessibleContainer segmentationContainer; + private final SharedQueue queue = new SharedQueue(Runtime.getRuntime() + .availableProcessors()); + private Notifier listeners = new Notifier<>(); + private RandomAccessibleInterval> view; + private AffineTransform3D transformation; + private Set alreadyRegistered = Collections.newSetFromMap( + new WeakHashMap<>()); + + public PredictionLayer(Holder model) { + this.model = model; + SegmentationResultsModel selected = model.get().results(); + this.segmentationContainer = new RandomAccessibleContainer<>( + getEmptyPrediction(selected)); + this.transformation = selected.transformation(); + this.view = Views.interval(segmentationContainer, selected.interval()); + model.notifier().add(ignore -> classifierChanged()); + registerListener(model.get()); + } + + private void registerListener(MySegmentationItem segmenter) { + if (alreadyRegistered.contains(segmenter)) return; + alreadyRegistered.add(segmenter); + segmenter.segmenter().listeners().add(this::onTrainingFinished); + segmenter.thresholds().notifier().add(ignore -> onTrainingFinished(segmenter + .segmenter())); + } + + private void onTrainingFinished(Segmenter segmenter) { + if (model.get().segmenter() == segmenter) classifierChanged(); + } + + private RandomAccessible getEmptyPrediction( + SegmentationResultsModel selected) + { + return ConstantUtils.constantRandomAccessible(new VolatileARGBType(0), + selected.interval().numDimensions()); + } + + private void classifierChanged() { + MySegmentationItem segmentationItem = model.get(); + registerListener(segmentationItem); + SegmentationResultsModel selected = segmentationItem.results(); + RandomAccessible source = selected.hasResults() ? Views + .extendValue(coloredVolatileView(segmentationItem), new VolatileARGBType( + 0)) : getEmptyPrediction(selected); + segmentationContainer.setSource(source); + listeners.forEach(Runnable::run); + } + + private RandomAccessibleInterval coloredVolatileView( + MySegmentationItem segmentationItem) + { + SegmentationResultsModel selected = segmentationItem.results(); + DoubleToIntFunction thresholdFunction = segmentationItem + .thresholdFunction(); + ARGBType[] colors = setupColors(segmentationItem.thresholds().get().size(), + selected.colors().get(0), selected.colors().get(1)); + final Converter conv = (input, + output) -> { + final boolean isValid = input.isValid(); + output.setValid(isValid); + if (isValid) output.set(colors[thresholdFunction.applyAsInt(input.get() + .get())].get()); + }; + + RandomAccessibleInterval source = Views.hyperSlice( + VolatileViews.wrapAsVolatile(selected.prediction(), queue), 3, 1); + return Converters.convert(source, conv, new VolatileARGBType()); + } + + private ARGBType[] setupColors(int size, ARGBType background, + ARGBType foreground) + { + return IntStream.rangeClosed(0, size).mapToObj(i -> blend((double) i / + (double) size, background, foreground)).toArray(ARGBType[]::new); + } + + private ARGBType blend(double alpha, ARGBType background, + ARGBType foreground) + { + int r = blend(alpha, background, foreground, ARGBType::red); + int g = blend(alpha, background, foreground, ARGBType::green); + int b = blend(alpha, background, foreground, ARGBType::blue); + return new ARGBType(ARGBType.rgba(r, g, b, 255)); + } + + private int blend(double alpha, ARGBType background, ARGBType foreground, + IntUnaryOperator channel) + { + return (int) (alpha * channel.applyAsInt(foreground.get()) + (1 - alpha) * + channel.applyAsInt(background.get())); + } + + @Override + public BdvShowable image() { + return BdvShowable.wrap(view, transformation); + } + + @Override + public Notifier listeners() { + return listeners; + } + + @Override + public String title() { + return "Segmentation"; + } +} diff --git a/src/main/java/com/indago/tr2d/plugins/seg/SegmentationComponent.java b/src/main/java/com/indago/tr2d/plugins/seg/SegmentationComponent.java new file mode 100644 index 0000000..9d758de --- /dev/null +++ b/src/main/java/com/indago/tr2d/plugins/seg/SegmentationComponent.java @@ -0,0 +1,110 @@ + +package com.indago.tr2d.plugins.seg; + +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.labkit.DefaultExtensible; +import net.imglib2.labkit.Extensible; +import net.imglib2.labkit.BasicLabelingComponent; +import net.imglib2.labkit.actions.SelectClassifier; +import net.imglib2.labkit.inputimage.DefaultInputImage; +import net.imglib2.labkit.models.ColoredLabelsModel; +import net.imglib2.labkit.panel.GuiUtils; +import net.imglib2.labkit.panel.LabelPanel; +import net.imglib2.labkit.panel.SegmenterPanel; +import net.imglib2.labkit.segmentation.TrainClassifier; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.IntegerType; +import net.imglib2.type.numeric.NumericType; +import net.miginfocom.swing.MigLayout; +import org.scijava.Context; + +import javax.swing.*; +import java.util.List; + +public class SegmentationComponent implements AutoCloseable { + + private final JSplitPane panel; + + private final JFrame dialogBoxOwner = null; + + private BasicLabelingComponent labelingComponent; + + private final Context context; + + private SegmentationModel segmentationModel; + + public SegmentationComponent(Context context, + RandomAccessibleInterval> image) + { + this.context = context; + segmentationModel = new SegmentationModel(initInputImage(image, true), + context); + labelingComponent = new BasicLabelingComponent(dialogBoxOwner, + segmentationModel.imageLabelingModel()); + labelingComponent.addBdvLayer(new PredictionLayer(segmentationModel + .selectedSegmenter())); + initActions(); + JPanel leftPanel = initLeftPanel(); + this.panel = initPanel(leftPanel, labelingComponent.getComponent()); + } + + private static DefaultInputImage initInputImage( + RandomAccessibleInterval> image, + boolean isTimeSeries) + { + DefaultInputImage defaultInputImage = new DefaultInputImage(image); + defaultInputImage.setTimeSeries(isTimeSeries); + return defaultInputImage; + } + + private void initActions() { + Extensible extensible = new DefaultExtensible(context, dialogBoxOwner, + labelingComponent); + new TrainClassifier(extensible, segmentationModel); + new SelectClassifier(extensible, segmentationModel.selectedSegmenter()); + } + + private JPanel initLeftPanel() { + JPanel panel = new JPanel(); + panel.setLayout(new MigLayout("", "[grow]", "[][grow][grow][]")); + ActionMap actions = getActions(); + panel.add(GuiUtils.createCheckboxGroupedPanel(actions.get("Image"), GuiUtils + .createDimensionsInfo(segmentationModel.image())), "grow, wrap"); + panel.add(GuiUtils.createCheckboxGroupedPanel(actions.get("Labeling"), + new LabelPanel(dialogBoxOwner, new ColoredLabelsModel(segmentationModel + .imageLabelingModel()), true).getComponent()), "grow, wrap"); + panel.add(GuiUtils.createCheckboxGroupedPanel(actions.get("Segmentation"), + new SegmenterPanel(segmentationModel, actions).getComponent()), + "grow, wrap"); + panel.add(new ThresholdButton(segmentationModel).getComponent(), "grow"); + return panel; + } + + private JSplitPane initPanel(JComponent left, JComponent right) { + JSplitPane panel = new JSplitPane(); + panel.setSize(100, 100); + panel.setOneTouchExpandable(true); + panel.setLeftComponent(left); + panel.setRightComponent(right); + return panel; + } + + public JComponent getComponent() { + return panel; + } + + private ActionMap getActions() { + return labelingComponent.getActions(); + } + + public & NativeType> + List> getSegmentations(T type) + { + return segmentationModel.getSegmentations(type); + } + + @Override + public void close() { + labelingComponent.close(); + } +} diff --git a/src/main/java/com/indago/tr2d/plugins/seg/SegmentationModel.java b/src/main/java/com/indago/tr2d/plugins/seg/SegmentationModel.java new file mode 100644 index 0000000..72f95c3 --- /dev/null +++ b/src/main/java/com/indago/tr2d/plugins/seg/SegmentationModel.java @@ -0,0 +1,130 @@ + +package com.indago.tr2d.plugins.seg; + +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.cell.CellGrid; +import net.imglib2.labkit.color.ColorMap; +import net.imglib2.labkit.inputimage.InputImage; +import net.imglib2.labkit.labeling.Labeling; +import net.imglib2.labkit.models.DefaultHolder; +import net.imglib2.labkit.models.Holder; +import net.imglib2.labkit.models.ImageLabelingModel; +import net.imglib2.labkit.models.SegmenterListModel; +import net.imglib2.labkit.segmentation.Segmenter; +import net.imglib2.labkit.segmentation.weka.TimeSeriesSegmenter; +import net.imglib2.labkit.segmentation.weka.TrainableSegmentationSegmenter; +import net.imglib2.labkit.utils.LabkitUtils; +import net.imglib2.realtransform.AffineTransform3D; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.IntegerType; +import net.imglib2.type.numeric.NumericType; +import org.scijava.Context; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Serves as a model for PredictionLayer and TrainClassifierAction + */ +public class SegmentationModel implements + net.imglib2.labkit.models.SegmentationModel, SegmenterListModel +{ + + private final ImageLabelingModel imageLabelingModel; + private final Holder selectedSegmenter; + private final InputImage inputImage; + private List segmenters = new ArrayList<>(); + private final RandomAccessibleInterval> compatibleImage; + private final CellGrid grid; + + private Context context; + + public SegmentationModel(InputImage image, Context context) { + this.context = context; + this.inputImage = image; + this.compatibleImage = image.imageForSegmentation(); + this.grid = LabkitUtils.suggestGrid(compatibleImage, image.isTimeSeries()); + MySegmentationItem segmentationItem = addSegmenter(); + this.selectedSegmenter = new DefaultHolder<>(segmentationItem); + this.selectedSegmenter.notifier().add(this::selectedSegmenterChanged); + this.imageLabelingModel = new ImageLabelingModel(image.showable(), + segmentationItem.labeling(), true); + } + + private void selectedSegmenterChanged(MySegmentationItem segmentationItem) { + imageLabelingModel.labeling().set(segmentationItem.labeling()); + } + + @Override + public Labeling labeling() { + return imageLabelingModel.labeling().get(); + } + + @Override + public RandomAccessibleInterval image() { + return compatibleImage; + } + + @Override + public CellGrid grid() { + return grid; + } + + @Override + public List segmenters() { + return segmenters; + } + + @Override + public Holder selectedSegmenter() { + return selectedSegmenter; + } + + @Override + public ColorMap colorMap() { + return imageLabelingModel.colorMapProvider().colorMap(); + } + + @Override + public AffineTransform3D labelTransformation() { + return imageLabelingModel.labelTransformation(); + } + + @Override + public MySegmentationItem addSegmenter() { + MySegmentationItem segmentationItem = new MySegmentationItem(this, + initClassifier()); + this.segmenters.add(segmentationItem); + return segmentationItem; + } + + private Segmenter initClassifier() { + TrainableSegmentationSegmenter classifier1 = + new TrainableSegmentationSegmenter(context, inputImage); + return inputImage.isTimeSeries() ? new TimeSeriesSegmenter(classifier1) + : classifier1; + } + + @Override + public void trainSegmenter() { + selectedSegmenter().get().train(); + } + + public & NativeType> + List> getSegmentations(T type) + { + Stream trainedSegmenters = getTrainedSegmenters(); + return trainedSegmenters.map(segmenter -> segmenter.apply(type)).collect( + Collectors.toList()); + } + + private Stream getTrainedSegmenters() { + return segmenters().stream().filter(x -> x.segmenter().isTrained()); + } + + public ImageLabelingModel imageLabelingModel() { + return imageLabelingModel; + } +} diff --git a/src/main/java/com/indago/tr2d/plugins/seg/ThresholdButton.java b/src/main/java/com/indago/tr2d/plugins/seg/ThresholdButton.java new file mode 100644 index 0000000..44a599c --- /dev/null +++ b/src/main/java/com/indago/tr2d/plugins/seg/ThresholdButton.java @@ -0,0 +1,89 @@ + +package com.indago.tr2d.plugins.seg; + +import net.imglib2.img.array.ArrayImgs; +import net.imglib2.labkit.inputimage.DefaultInputImage; +import net.imglib2.labkit.panel.SegmenterPanel; +import org.scijava.Context; + +import javax.swing.*; +import java.awt.*; +import java.util.Arrays; +import java.util.List; +import java.util.StringJoiner; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class ThresholdButton { + + private final SegmentationModel segmentationModel; + + private final JButton button; + + public ThresholdButton(SegmentationModel segmentationModel) { + this.segmentationModel = segmentationModel; + segmentationModel.selectedSegmenter().notifier().add( + ignore -> updateThresholds()); + button = new JButton("Thresholds ..."); + button.addActionListener(a -> { + String text = JOptionPane.showInputDialog(null, "Enter thresholds", + "Thresholds ...", JOptionPane.PLAIN_MESSAGE); + if (text != null) try { + MySegmentationItem segmentationItem = this.segmentationModel + .selectedSegmenter().get(); + segmentationItem.thresholds().set(new ListOfDoubleFormatter() + .stringToValue(text)); + } + catch (NumberFormatException ignore) {} + updateThresholds(); + }); + } + + public JComponent getComponent() { + return button; + } + + private void updateThresholds() { + List doubles = segmentationModel.selectedSegmenter().get() + .thresholds().get(); + button.setText("Thresholds: " + new ListOfDoubleFormatter().valueToString( + doubles)); + } + + private static class ListOfDoubleFormatter extends + JFormattedTextField.AbstractFormatter + { + + @Override + public List stringToValue(String text) { + return Stream.of(text.split(";")).map(Double::new).collect(Collectors + .toList()); + } + + @Override + public String valueToString(Object value) { + if (value == null) return ""; + @SuppressWarnings("unchecked") + List list = (List) value; + StringJoiner joiner = new StringJoiner("; "); + list.stream().map(Object::toString).forEach(joiner::add); + return joiner.toString(); + } + } + + public static void main(String... args) { + SegmentationModel segmentationModel = new SegmentationModel( + new DefaultInputImage(ArrayImgs.unsignedBytes(100, 100, 100)), + new Context()); + JFrame frame = new JFrame(); + frame.add(new SegmenterPanel(segmentationModel, new ActionMap()) + .getComponent()); + frame.add(new ThresholdButton(segmentationModel).getComponent(), + BorderLayout.PAGE_END); + frame.setSize(500, 500); + frame.setVisible(true); + segmentationModel.selectedSegmenter().notifier().add(System.out::println); + segmentationModel.selectedSegmenter().get().thresholds().notifier().add( + x -> System.out.println(Arrays.toString(x.toArray()))); + } +} diff --git a/src/test/java/com/indago/tr2d/plugins/seg/SegmentationComponentDemo.java b/src/test/java/com/indago/tr2d/plugins/seg/SegmentationComponentDemo.java new file mode 100644 index 0000000..b4d9f60 --- /dev/null +++ b/src/test/java/com/indago/tr2d/plugins/seg/SegmentationComponentDemo.java @@ -0,0 +1,61 @@ + +package com.indago.tr2d.plugins.seg; + +import ij.ImagePlus; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.img.Img; +import net.imglib2.img.display.imagej.ImageJFunctions; +import net.imglib2.type.numeric.NumericType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.view.Views; +import net.miginfocom.swing.MigLayout; +import org.scijava.Context; +import org.scijava.ui.behaviour.util.RunnableAction; + +import javax.swing.*; +import java.awt.*; + +public class SegmentationComponentDemo { + + private final SegmentationComponent segmenter; + + public static void main(String... args) { + new SegmentationComponentDemo(); + } + + private SegmentationComponentDemo() { + JFrame frame = setupFrame(); + Img> image = ImageJFunctions.wrap(new ImagePlus( + "/home/arzt/Documents/Notes/Tr2d/ProjectFiles/raw.tif")); + Context context = new Context(); + segmenter = new SegmentationComponent(context, image); + frame.add(segmenter.getComponent()); + frame.add(getBottomPanel(), BorderLayout.PAGE_END); + frame.setVisible(true); + } + + private JPanel getBottomPanel() { + JButton segmentation = new JButton(new RunnableAction("Show Result", + this::showSegmentation)); + JPanel panel = new JPanel(); + panel.setLayout(new MigLayout()); + panel.add(segmentation); + return panel; + } + + private void showSegmentation() { + for (RandomAccessibleInterval segmentation : segmenter + .getSegmentations(new UnsignedByteType())) + { + Views.iterable(segmentation).forEach(x -> x.mul(50)); + ImageJFunctions.show(segmentation); + } + } + + private static JFrame setupFrame() { + JFrame frame = new JFrame(); + frame.setSize(500, 500); + frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + return frame; + } +}