-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add individual SegmentationComponent
- Loading branch information
Showing
8 changed files
with
636 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
src/main/java/com/indago/tr2d/plugins/seg/MySegmentationItem.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
|
||
package com.indago.tr2d.plugins.seg; | ||
|
||
import net.imglib2.RandomAccessibleInterval; | ||
import net.imglib2.img.cell.CellImgFactory; | ||
import net.imglib2.labkit.labeling.Labeling; | ||
import net.imglib2.labkit.models.SegmentationModel; | ||
import net.imglib2.labkit.models.DefaultHolder; | ||
import net.imglib2.labkit.models.Holder; | ||
import net.imglib2.labkit.models.SegmentationItem; | ||
import net.imglib2.labkit.segmentation.Segmenter; | ||
import net.imglib2.type.NativeType; | ||
import net.imglib2.type.numeric.IntegerType; | ||
import net.imglib2.type.numeric.real.DoubleType; | ||
import net.imglib2.util.ConstantUtils; | ||
import net.imglib2.view.Views; | ||
|
||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.function.DoubleToIntFunction; | ||
|
||
public class MySegmentationItem extends SegmentationItem { | ||
|
||
private final Labeling labeling; | ||
|
||
private final SegmentationModel model; | ||
|
||
private final Holder<List<Double>> thresholds = new DefaultHolder<>( | ||
Collections.singletonList(0.5)); | ||
|
||
public MySegmentationItem(SegmentationModel model, Segmenter segmenter) { | ||
super(model, segmenter); | ||
this.model = model; | ||
labeling = new Labeling(Arrays.asList("background", "foreground"), model | ||
.image()); | ||
} | ||
|
||
public <T extends IntegerType<T> & NativeType<T>> RandomAccessibleInterval<T> | ||
apply(T type) | ||
{ | ||
RandomAccessibleInterval<?> image = model.image(); | ||
RandomAccessibleInterval<T> labels = new CellImgFactory<>(type).create( | ||
image); | ||
RandomAccessibleInterval<DoubleType> dummy = ConstantUtils | ||
.constantRandomAccessibleInterval(new DoubleType(), image.numDimensions(), | ||
image); | ||
RandomAccessibleInterval<DoubleType> probability = new CellImgFactory<>( | ||
new DoubleType()).create(image); | ||
RandomAccessibleInterval<DoubleType> probabilities = Views.stack(dummy, | ||
probability); | ||
segmenter().predict(image, probabilities); | ||
DoubleToIntFunction thresholdFunction = thresholdFunction(); | ||
Views.interval(Views.pair(probability, labels), labels).forEach(pair -> pair | ||
.getB().setInteger(thresholdFunction.applyAsInt(pair.getA().get()))); | ||
return labels; | ||
} | ||
|
||
public DoubleToIntFunction thresholdFunction() { | ||
double[] thresholds = thresholds().get().stream().mapToDouble(x -> x) | ||
.toArray(); | ||
Arrays.sort(thresholds); | ||
return (p) -> { | ||
int result = 0; | ||
for (double threshold : thresholds) | ||
if (p < threshold) break; | ||
else result++; | ||
return result; | ||
}; | ||
} | ||
|
||
public Labeling labeling() { | ||
return labeling; | ||
} | ||
|
||
public void train() { | ||
segmenter().train(Collections.singletonList(model.image()), Collections | ||
.singletonList(labeling)); | ||
} | ||
|
||
public Holder<List<Double>> thresholds() { | ||
return thresholds; | ||
} | ||
} |
144 changes: 144 additions & 0 deletions
144
src/main/java/com/indago/tr2d/plugins/seg/PredictionLayer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
|
||
package com.indago.tr2d.plugins.seg; | ||
|
||
import bdv.util.volatiles.SharedQueue; | ||
import bdv.util.volatiles.VolatileViews; | ||
import net.imglib2.RandomAccessible; | ||
import net.imglib2.RandomAccessibleInterval; | ||
import net.imglib2.converter.Converter; | ||
import net.imglib2.converter.Converters; | ||
import net.imglib2.labkit.bdv.BdvLayer; | ||
import net.imglib2.labkit.bdv.BdvShowable; | ||
import net.imglib2.labkit.models.Holder; | ||
import net.imglib2.labkit.models.SegmentationItem; | ||
import net.imglib2.labkit.models.SegmentationResultsModel; | ||
import net.imglib2.labkit.segmentation.Segmenter; | ||
import net.imglib2.labkit.utils.Notifier; | ||
import net.imglib2.labkit.utils.RandomAccessibleContainer; | ||
import net.imglib2.realtransform.AffineTransform3D; | ||
import net.imglib2.type.numeric.ARGBType; | ||
import net.imglib2.type.numeric.NumericType; | ||
import net.imglib2.type.volatiles.VolatileARGBType; | ||
import net.imglib2.type.volatiles.VolatileFloatType; | ||
import net.imglib2.util.ConstantUtils; | ||
import net.imglib2.view.Views; | ||
|
||
import java.util.Collections; | ||
import java.util.Set; | ||
import java.util.WeakHashMap; | ||
import java.util.function.DoubleToIntFunction; | ||
import java.util.function.IntUnaryOperator; | ||
import java.util.stream.IntStream; | ||
|
||
public class PredictionLayer implements BdvLayer { | ||
|
||
private final Holder<? extends MySegmentationItem> model; | ||
private final RandomAccessibleContainer<VolatileARGBType> segmentationContainer; | ||
private final SharedQueue queue = new SharedQueue(Runtime.getRuntime() | ||
.availableProcessors()); | ||
private Notifier<Runnable> listeners = new Notifier<>(); | ||
private RandomAccessibleInterval<? extends NumericType<?>> view; | ||
private AffineTransform3D transformation; | ||
private Set<MySegmentationItem> alreadyRegistered = Collections.newSetFromMap( | ||
new WeakHashMap<>()); | ||
|
||
public PredictionLayer(Holder<? extends MySegmentationItem> model) { | ||
this.model = model; | ||
SegmentationResultsModel selected = model.get().results(); | ||
this.segmentationContainer = new RandomAccessibleContainer<>( | ||
getEmptyPrediction(selected)); | ||
this.transformation = selected.transformation(); | ||
this.view = Views.interval(segmentationContainer, selected.interval()); | ||
model.notifier().add(ignore -> classifierChanged()); | ||
registerListener(model.get()); | ||
} | ||
|
||
private void registerListener(MySegmentationItem segmenter) { | ||
if (alreadyRegistered.contains(segmenter)) return; | ||
alreadyRegistered.add(segmenter); | ||
segmenter.segmenter().listeners().add(this::onTrainingFinished); | ||
segmenter.thresholds().notifier().add(ignore -> onTrainingFinished(segmenter | ||
.segmenter())); | ||
} | ||
|
||
private void onTrainingFinished(Segmenter segmenter) { | ||
if (model.get().segmenter() == segmenter) classifierChanged(); | ||
} | ||
|
||
private RandomAccessible<VolatileARGBType> getEmptyPrediction( | ||
SegmentationResultsModel selected) | ||
{ | ||
return ConstantUtils.constantRandomAccessible(new VolatileARGBType(0), | ||
selected.interval().numDimensions()); | ||
} | ||
|
||
private void classifierChanged() { | ||
MySegmentationItem segmentationItem = model.get(); | ||
registerListener(segmentationItem); | ||
SegmentationResultsModel selected = segmentationItem.results(); | ||
RandomAccessible<VolatileARGBType> source = selected.hasResults() ? Views | ||
.extendValue(coloredVolatileView(segmentationItem), new VolatileARGBType( | ||
0)) : getEmptyPrediction(selected); | ||
segmentationContainer.setSource(source); | ||
listeners.forEach(Runnable::run); | ||
} | ||
|
||
private RandomAccessibleInterval<VolatileARGBType> coloredVolatileView( | ||
MySegmentationItem segmentationItem) | ||
{ | ||
SegmentationResultsModel selected = segmentationItem.results(); | ||
DoubleToIntFunction thresholdFunction = segmentationItem | ||
.thresholdFunction(); | ||
ARGBType[] colors = setupColors(segmentationItem.thresholds().get().size(), | ||
selected.colors().get(0), selected.colors().get(1)); | ||
final Converter<VolatileFloatType, VolatileARGBType> conv = (input, | ||
output) -> { | ||
final boolean isValid = input.isValid(); | ||
output.setValid(isValid); | ||
if (isValid) output.set(colors[thresholdFunction.applyAsInt(input.get() | ||
.get())].get()); | ||
}; | ||
|
||
RandomAccessibleInterval<VolatileFloatType> source = Views.hyperSlice( | ||
VolatileViews.wrapAsVolatile(selected.prediction(), queue), 3, 1); | ||
return Converters.convert(source, conv, new VolatileARGBType()); | ||
} | ||
|
||
private ARGBType[] setupColors(int size, ARGBType background, | ||
ARGBType foreground) | ||
{ | ||
return IntStream.rangeClosed(0, size).mapToObj(i -> blend((double) i / | ||
(double) size, background, foreground)).toArray(ARGBType[]::new); | ||
} | ||
|
||
private ARGBType blend(double alpha, ARGBType background, | ||
ARGBType foreground) | ||
{ | ||
int r = blend(alpha, background, foreground, ARGBType::red); | ||
int g = blend(alpha, background, foreground, ARGBType::green); | ||
int b = blend(alpha, background, foreground, ARGBType::blue); | ||
return new ARGBType(ARGBType.rgba(r, g, b, 255)); | ||
} | ||
|
||
private int blend(double alpha, ARGBType background, ARGBType foreground, | ||
IntUnaryOperator channel) | ||
{ | ||
return (int) (alpha * channel.applyAsInt(foreground.get()) + (1 - alpha) * | ||
channel.applyAsInt(background.get())); | ||
} | ||
|
||
@Override | ||
public BdvShowable image() { | ||
return BdvShowable.wrap(view, transformation); | ||
} | ||
|
||
@Override | ||
public Notifier<Runnable> listeners() { | ||
return listeners; | ||
} | ||
|
||
@Override | ||
public String title() { | ||
return "Segmentation"; | ||
} | ||
} |
Oops, something went wrong.