From 5b0148261f4761b545d456fc4ce856446e608713 Mon Sep 17 00:00:00 2001 From: John Bogovic Date: Thu, 9 May 2024 10:24:31 -0400 Subject: [PATCH] feat: axis utils more methods for permutation --- .../n5/universe/metadata/axes/AxisUtils.java | 205 ++++++++++++++++-- 1 file changed, 184 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/janelia/saalfeldlab/n5/universe/metadata/axes/AxisUtils.java b/src/main/java/org/janelia/saalfeldlab/n5/universe/metadata/axes/AxisUtils.java index 1a8727a..78ec960 100644 --- a/src/main/java/org/janelia/saalfeldlab/n5/universe/metadata/axes/AxisUtils.java +++ b/src/main/java/org/janelia/saalfeldlab/n5/universe/metadata/axes/AxisUtils.java @@ -3,15 +3,27 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.Iterator; import java.util.List; +import java.util.OptionalInt; import java.util.Set; +import java.util.TreeSet; +import java.util.function.Predicate; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.janelia.saalfeldlab.n5.universe.metadata.N5Metadata; import org.janelia.saalfeldlab.n5.universe.metadata.N5SpatialDatasetMetadata; +import org.janelia.saalfeldlab.n5.universe.metadata.SpatialMetadata; +import org.janelia.saalfeldlab.n5.universe.metadata.SpatialModifiable; import net.imglib2.RandomAccessibleInterval; +import net.imglib2.realtransform.AffineGet; +import net.imglib2.realtransform.AffineTransform; +import net.imglib2.realtransform.AffineTransform3D; import net.imglib2.transform.integer.MixedTransform; +import net.imglib2.util.Pair; +import net.imglib2.util.ValuePair; import net.imglib2.view.IntervalView; import net.imglib2.view.MixedTransformView; import net.imglib2.view.Views; @@ -34,6 +46,7 @@ public class AxisUtils { public static String SPACE_UNIT = "um"; public static String TIME_UNIT = "s"; + /** * Finds and returns a permutation p such that source[p[i]] equals target[i] * @@ -70,6 +83,53 @@ public static List permute(final List in, final int[] p) { return out; } + /** + * Permutes an array in place. + * + * @param + * the type + * @param in + * the array + * @param p + * the permutation + */ + public static void permute(final T[] in, final T[] dest, final int[] p) { + + final ArrayList tmp = new ArrayList(in.length); + for (int i = 0; i < in.length; i++) + tmp.add(in[i]); + + for (int i = 0; i < p.length; i++) + dest[i] = tmp.get(p[i]); + } + + public static long[] permute(final long[] in, final int[] p) { + + final long[] out = new long[p.length]; + for (int i = 0; i < p.length; i++) + out[i] = in[p[i]]; + + return out; + } + + public static int[] permute(final int[] in, final int[] p) { + + final int[] out = new int[p.length]; + for (int i = 0; i < p.length; i++) + out[i] = in[p[i]]; + + return out; + } + + public static double[] permute(final double[] in, final int[] p) { + + final double[] out = new double[p.length]; + for (int i = 0; i < p.length; i++) + out[i] = in[p[i]]; + + return out; + } + public static Axis[] buildAxes( final String... labels ) { return Arrays.stream(labels).map( x -> { @@ -86,7 +146,10 @@ public static Axis[] buildAxes( final String... labels ) * @return the permutation */ public static int[] findImagePlusPermutation( final AxisMetadata axisMetadata ) { - return findImagePlusPermutation( axisMetadata.getAxisLabels()); + + // TODO should use axis types, not just labels. + // and should consider what to do if an unknown label exists + return findImagePlusPermutation(axisMetadata.getAxisLabels()); } /** @@ -95,7 +158,7 @@ public static int[] findImagePlusPermutation( final Ax * @param axisLabels the axis labels * @return the permutation array */ - public static int[] findImagePlusPermutation( final String[] axisLabels ) { + public static int[] findImagePlusPermutation(final String[] axisLabels) { final int[] p = new int[ 5 ]; p[0] = indexOf( axisLabels, "x" ); @@ -106,21 +169,72 @@ public static int[] findImagePlusPermutation( final String[] axisLabels ) { return p; } + public static int[] findImagePlusSpatialPermutation(final int[] p) { + + final OptionalInt minOpt = Arrays.stream(p).min(); + if (minOpt.isPresent()) { + final int min = minOpt.getAsInt(); + return Arrays.stream(p).map(x -> x - min).toArray(); + } else + return p; + } + + /** + * Converters an array of integers to a normalized array of integers such + * that the smallest integer is mapped to 0, the second smallest to 1 ... + * and the largest is mapped to N-1, where N is the number of unique + * integers in the array. + * + * @param p + * indexes + * @return normalized indexes + */ + public static int[] normalizeIndexes(final int[] indexes) { + + final TreeSet set = new TreeSet(); + for (final int i : indexes) + set.add(i); + + // can't get index from a tree set, use this sad, not scalable workaround for now + final int[] sortedUniqueIndexes = new int[set.size()]; + final Iterator it = set.iterator(); + int i = 0; + while ( it.hasNext()) + sortedUniqueIndexes[i++] = it.next(); + + final int[] out = new int[indexes.length]; + for (i = 0; i < out.length; i++ ) + out[i] = Arrays.binarySearch(sortedUniqueIndexes, indexes[i]); + + return out; + } + /** * Replaces "-1"s in the input permutation array * with the largest value. * * @param p the permutation */ - public static void fillPermutation( final int[] p ) { + public static void fillPermutation(final int[] p) { int j = Arrays.stream(p).max().getAsInt() + 1; for (int i = 0; i < p.length; i++) if (p[i] < 0) p[i] = j++; } - public static boolean isIdentityPermutation( final int[] p ) - { + public static AffineGet axisPermutationTransform(final int[] p) { + + final int N = p.length; + final int[] normalP = normalizeIndexes(p); + final double[] affineParams = new double[N * (N + 1)]; + for (int i = 0; i < normalP.length; i++) + affineParams[normalP[i] + (N + 1) * i] = 1.0; + + return new AffineTransform(affineParams); + } + + public static boolean isIdentityPermutation( final int[] p ) { + for( int i = 0; i < p.length; i++ ) if( p[i] != i ) return false; @@ -128,9 +242,9 @@ public static boolean isIdentityPermutation( final int[] p ) return true; } - public static RandomAccessibleInterval permuteForImagePlus( + public static RandomAccessibleInterval permuteForImagePlus( final RandomAccessibleInterval img, - final AxisMetadata meta ) { + final M meta) { final int[] p = findImagePlusPermutation( meta ); fillPermutation( p ); @@ -147,10 +261,55 @@ public static RandomAccessibleInterval permuteForImagePlus( return permute(imgTmp, invertPermutation(p)); } - public static RandomAccessibleInterval reverseDimensions(final RandomAccessibleInterval img) { + public static M permuteForImagePlus(int[] spatialPermutation, final M meta) { + + if (isIdentityPermutation(spatialPermutation)) + return meta; + + if (meta instanceof SpatialMetadata && meta instanceof SpatialModifiable) { + + final AffineTransform3D tform = ((SpatialMetadata)meta).spatialTransform3d().copy(); + final AffineTransform3D tformInv = ((SpatialMetadata)meta).spatialTransform3d().inverse().copy(); + + final AffineGet permTform = AxisUtils.axisPermutationTransform(spatialPermutation); + tform.concatenate(permTform).preConcatenate(permTform.inverse()); // exchange rows and + tform.concatenate(tformInv); + + final M out = (M)(((SpatialModifiable)meta).modifySpatialTransform(meta.getPath(), tform)); + return out; + } + + return meta; + } + + public static Pair, M> permuteImageAndMetadataForImagePlus( + final RandomAccessibleInterval img, final M meta) { + + if (meta != null && meta instanceof AxisMetadata) { + + final int[] p = AxisUtils.findImagePlusPermutation((AxisMetadata)meta); + AxisUtils.fillPermutation(p); + + RandomAccessibleInterval imgTmp = img; + while (imgTmp.numDimensions() < 5) + imgTmp = Views.addDimension(imgTmp, 0, 0); + + if (AxisUtils.isIdentityPermutation(p)) + return new ValuePair<>(imgTmp, meta); + + // do the permutation + final RandomAccessibleInterval imgOut = permute(imgTmp, invertPermutation(p)); + final int[] spatialPermutation = new int[]{p[0], p[1], p[3]}; + @SuppressWarnings("unchecked") + final M permutedMeta = (M)permuteForImagePlus(spatialPermutation, (A)meta); + + return new ValuePair<>(imgOut, permutedMeta); + } - // final int[] p = IntStream.range(0, img.numDimensions()).toArray(); - // ArrayUtils.reverse(p); + return new ValuePair<>(img, meta); + } + + public static RandomAccessibleInterval reverseDimensions(final RandomAccessibleInterval img) { final int nd = img.numDimensions(); final int[] p = IntStream.iterate(nd - 1, x -> x - 1).limit(nd).toArray(); @@ -175,22 +334,21 @@ private static final int indexOf(final T[] arr, final T tgt) { * @param p the permutation * @return the permuted source */ - public static final < T > IntervalView< T > permute( final RandomAccessibleInterval< T > source, final int[] p ) - { + public static final IntervalView permute(final RandomAccessibleInterval source, final int[] p) { + final int n = source.numDimensions(); - final long[] min = new long[ n ]; - final long[] max = new long[ n ]; - for ( int i = 0; i < n; ++i ) - { - min[ p[ i ] ] = source.min( i ); - max[ p[ i ] ] = source.max( i ); + final long[] min = new long[n]; + final long[] max = new long[n]; + for (int i = 0; i < n; ++i) { + min[p[i]] = source.min(i); + max[p[i]] = source.max(i); } - final MixedTransform t = new MixedTransform( n, n ); - t.setComponentMapping( p ); + final MixedTransform t = new MixedTransform(n, n); + t.setComponentMapping(p); - final IntervalView out = Views.interval( new MixedTransformView< T >( source, t ), min, max ); + final IntervalView out = Views.interval(new MixedTransformView(source, t), min, max); return out; } @@ -203,6 +361,11 @@ public static int[] invertPermutation( final int[] p ) return inv; } + public static int[] indexes(final Axis[] axes, final Predicate predicate) { + + return IntStream.range(0, axes.length).filter(i -> predicate.test(axes[i])).toArray(); + } + public static Axis[] defaultAxes( final int N ) { return IntStream.range(0, N).mapToObj(i -> {