Skip to content

Commit

Permalink
Add individual SegmentationComponent
Browse files Browse the repository at this point in the history
  • Loading branch information
maarzt committed Jun 20, 2018
1 parent ca8ec54 commit 16e93f2
Show file tree
Hide file tree
Showing 8 changed files with 636 additions and 12 deletions.
16 changes: 13 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>sc.fiji</groupId>
<artifactId>pom-indago</artifactId>
<version>2.2.6</version>
<version>2.2.7</version>
</parent>

<properties>
Expand Down Expand Up @@ -40,11 +40,17 @@
<dependency>
<groupId>com.indago</groupId>
<artifactId>tr2d</artifactId>
<exclusions>
<exclusion>
<groupId>com.miglayout</groupId>
<artifactId>miglayout</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>net.imglib2</groupId>
<artifactId>imglib2-labkit</artifactId>
<version>0.1.7</version>
<version>0.1.10</version>
<exclusions>
<exclusion>
<groupId>hr.irb.fastRandomForest</groupId>
Expand All @@ -55,7 +61,11 @@
<artifactId>miglayout</artifactId>
</exclusion>
</exclusions>
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.miglayout</groupId>
<artifactId>miglayout</artifactId>
<classifier>swing</classifier>
</dependency>
</dependencies>

Expand Down
14 changes: 5 additions & 9 deletions src/main/java/com/indago/tr2d/plugins/seg/LabkitPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import com.indago.tr2d.ui.model.Tr2dModel;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.labkit.SegmentationComponent;
import net.imglib2.type.numeric.integer.IntType;
import org.scijava.Context;
import org.scijava.log.Logger;
Expand All @@ -21,16 +20,14 @@ public class LabkitPanel {

public LabkitPanel(Context context, Tr2dModel model, Logger log) {
this.log = log;
boolean isTimeSeries = true;
segmentation = createSegmentationComponent(context, model, isTimeSeries);
segmentation = createSegmentationComponent(context, model);
}

private SegmentationComponent createSegmentationComponent(Context context,
Tr2dModel model, boolean isTimeSeries)
Tr2dModel model)
{
try {
return new SegmentationComponent(context, null, model.getRawData(),
isTimeSeries);
return new SegmentationComponent(context, model.getRawData());
}
catch (NoClassDefFoundError e) {
return null;
Expand All @@ -48,9 +45,8 @@ public JPanel getPanel() {
}

private void calculateOutputs() {
outputs = isUsable() && segmentation.isTrained() ? Collections
.singletonList(segmentation.getSegmentation(new IntType())) : Collections
.emptyList();
outputs = isUsable() ? segmentation.getSegmentations(new IntType())
: Collections.emptyList();
}

public boolean isUsable() {
Expand Down
84 changes: 84 additions & 0 deletions src/main/java/com/indago/tr2d/plugins/seg/MySegmentationItem.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

package com.indago.tr2d.plugins.seg;

import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.cell.CellImgFactory;
import net.imglib2.labkit.labeling.Labeling;
import net.imglib2.labkit.models.SegmentationModel;
import net.imglib2.labkit.models.DefaultHolder;
import net.imglib2.labkit.models.Holder;
import net.imglib2.labkit.models.SegmentationItem;
import net.imglib2.labkit.segmentation.Segmenter;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.util.ConstantUtils;
import net.imglib2.view.Views;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.DoubleToIntFunction;

public class MySegmentationItem extends SegmentationItem {

private final Labeling labeling;

private final SegmentationModel model;

private final Holder<List<Double>> thresholds = new DefaultHolder<>(
Collections.singletonList(0.5));

public MySegmentationItem(SegmentationModel model, Segmenter segmenter) {
super(model, segmenter);
this.model = model;
labeling = new Labeling(Arrays.asList("background", "foreground"), model
.image());
}

public <T extends IntegerType<T> & NativeType<T>> RandomAccessibleInterval<T>
apply(T type)
{
RandomAccessibleInterval<?> image = model.image();
RandomAccessibleInterval<T> labels = new CellImgFactory<>(type).create(
image);
RandomAccessibleInterval<DoubleType> dummy = ConstantUtils
.constantRandomAccessibleInterval(new DoubleType(), image.numDimensions(),
image);
RandomAccessibleInterval<DoubleType> probability = new CellImgFactory<>(
new DoubleType()).create(image);
RandomAccessibleInterval<DoubleType> probabilities = Views.stack(dummy,
probability);
segmenter().predict(image, probabilities);
DoubleToIntFunction thresholdFunction = thresholdFunction();
Views.interval(Views.pair(probability, labels), labels).forEach(pair -> pair
.getB().setInteger(thresholdFunction.applyAsInt(pair.getA().get())));
return labels;
}

public DoubleToIntFunction thresholdFunction() {
double[] thresholds = thresholds().get().stream().mapToDouble(x -> x)
.toArray();
Arrays.sort(thresholds);
return (p) -> {
int result = 0;
for (double threshold : thresholds)
if (p < threshold) break;
else result++;
return result;
};
}

public Labeling labeling() {
return labeling;
}

public void train() {
segmenter().train(Collections.singletonList(model.image()), Collections
.singletonList(labeling));
}

public Holder<List<Double>> thresholds() {
return thresholds;
}
}
144 changes: 144 additions & 0 deletions src/main/java/com/indago/tr2d/plugins/seg/PredictionLayer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@

package com.indago.tr2d.plugins.seg;

import bdv.util.volatiles.SharedQueue;
import bdv.util.volatiles.VolatileViews;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.Converter;
import net.imglib2.converter.Converters;
import net.imglib2.labkit.bdv.BdvLayer;
import net.imglib2.labkit.bdv.BdvShowable;
import net.imglib2.labkit.models.Holder;
import net.imglib2.labkit.models.SegmentationItem;
import net.imglib2.labkit.models.SegmentationResultsModel;
import net.imglib2.labkit.segmentation.Segmenter;
import net.imglib2.labkit.utils.Notifier;
import net.imglib2.labkit.utils.RandomAccessibleContainer;
import net.imglib2.realtransform.AffineTransform3D;
import net.imglib2.type.numeric.ARGBType;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.volatiles.VolatileARGBType;
import net.imglib2.type.volatiles.VolatileFloatType;
import net.imglib2.util.ConstantUtils;
import net.imglib2.view.Views;

import java.util.Collections;
import java.util.Set;
import java.util.WeakHashMap;
import java.util.function.DoubleToIntFunction;
import java.util.function.IntUnaryOperator;
import java.util.stream.IntStream;

public class PredictionLayer implements BdvLayer {

private final Holder<? extends MySegmentationItem> model;
private final RandomAccessibleContainer<VolatileARGBType> segmentationContainer;
private final SharedQueue queue = new SharedQueue(Runtime.getRuntime()
.availableProcessors());
private Notifier<Runnable> listeners = new Notifier<>();
private RandomAccessibleInterval<? extends NumericType<?>> view;
private AffineTransform3D transformation;
private Set<MySegmentationItem> alreadyRegistered = Collections.newSetFromMap(
new WeakHashMap<>());

public PredictionLayer(Holder<? extends MySegmentationItem> model) {
this.model = model;
SegmentationResultsModel selected = model.get().results();
this.segmentationContainer = new RandomAccessibleContainer<>(
getEmptyPrediction(selected));
this.transformation = selected.transformation();
this.view = Views.interval(segmentationContainer, selected.interval());
model.notifier().add(ignore -> classifierChanged());
registerListener(model.get());
}

private void registerListener(MySegmentationItem segmenter) {
if (alreadyRegistered.contains(segmenter)) return;
alreadyRegistered.add(segmenter);
segmenter.segmenter().listeners().add(this::onTrainingFinished);
segmenter.thresholds().notifier().add(ignore -> onTrainingFinished(segmenter
.segmenter()));
}

private void onTrainingFinished(Segmenter segmenter) {
if (model.get().segmenter() == segmenter) classifierChanged();
}

private RandomAccessible<VolatileARGBType> getEmptyPrediction(
SegmentationResultsModel selected)
{
return ConstantUtils.constantRandomAccessible(new VolatileARGBType(0),
selected.interval().numDimensions());
}

private void classifierChanged() {
MySegmentationItem segmentationItem = model.get();
registerListener(segmentationItem);
SegmentationResultsModel selected = segmentationItem.results();
RandomAccessible<VolatileARGBType> source = selected.hasResults() ? Views
.extendValue(coloredVolatileView(segmentationItem), new VolatileARGBType(
0)) : getEmptyPrediction(selected);
segmentationContainer.setSource(source);
listeners.forEach(Runnable::run);
}

private RandomAccessibleInterval<VolatileARGBType> coloredVolatileView(
MySegmentationItem segmentationItem)
{
SegmentationResultsModel selected = segmentationItem.results();
DoubleToIntFunction thresholdFunction = segmentationItem
.thresholdFunction();
ARGBType[] colors = setupColors(segmentationItem.thresholds().get().size(),
selected.colors().get(0), selected.colors().get(1));
final Converter<VolatileFloatType, VolatileARGBType> conv = (input,
output) -> {
final boolean isValid = input.isValid();
output.setValid(isValid);
if (isValid) output.set(colors[thresholdFunction.applyAsInt(input.get()
.get())].get());
};

RandomAccessibleInterval<VolatileFloatType> source = Views.hyperSlice(
VolatileViews.wrapAsVolatile(selected.prediction(), queue), 3, 1);
return Converters.convert(source, conv, new VolatileARGBType());
}

private ARGBType[] setupColors(int size, ARGBType background,
ARGBType foreground)
{
return IntStream.rangeClosed(0, size).mapToObj(i -> blend((double) i /
(double) size, background, foreground)).toArray(ARGBType[]::new);
}

private ARGBType blend(double alpha, ARGBType background,
ARGBType foreground)
{
int r = blend(alpha, background, foreground, ARGBType::red);
int g = blend(alpha, background, foreground, ARGBType::green);
int b = blend(alpha, background, foreground, ARGBType::blue);
return new ARGBType(ARGBType.rgba(r, g, b, 255));
}

private int blend(double alpha, ARGBType background, ARGBType foreground,
IntUnaryOperator channel)
{
return (int) (alpha * channel.applyAsInt(foreground.get()) + (1 - alpha) *
channel.applyAsInt(background.get()));
}

@Override
public BdvShowable image() {
return BdvShowable.wrap(view, transformation);
}

@Override
public Notifier<Runnable> listeners() {
return listeners;
}

@Override
public String title() {
return "Segmentation";
}
}
Loading

0 comments on commit 16e93f2

Please sign in to comment.