Skip to content

Commit

Permalink
feat: axis utils more methods for permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
bogovicj committed May 9, 2024
1 parent fd1679a commit 5b01482
Showing 1 changed file with 184 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,27 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.janelia.saalfeldlab.n5.universe.metadata.N5Metadata;
import org.janelia.saalfeldlab.n5.universe.metadata.N5SpatialDatasetMetadata;
import org.janelia.saalfeldlab.n5.universe.metadata.SpatialMetadata;
import org.janelia.saalfeldlab.n5.universe.metadata.SpatialModifiable;

import net.imglib2.RandomAccessibleInterval;
import net.imglib2.realtransform.AffineGet;
import net.imglib2.realtransform.AffineTransform;
import net.imglib2.realtransform.AffineTransform3D;
import net.imglib2.transform.integer.MixedTransform;
import net.imglib2.util.Pair;
import net.imglib2.util.ValuePair;
import net.imglib2.view.IntervalView;
import net.imglib2.view.MixedTransformView;
import net.imglib2.view.Views;
Expand All @@ -34,6 +46,7 @@ public class AxisUtils {
public static String SPACE_UNIT = "um";
public static String TIME_UNIT = "s";


/**
* Finds and returns a permutation p such that source[p[i]] equals target[i]
*
Expand Down Expand Up @@ -70,6 +83,53 @@ public static <T> List<T> permute(final List<T> in, final int[] p) {
return out;
}

/**
* Permutes an array in place.
*
* @param <T>
* the type
* @param in
* the array
* @param p
* the permutation
*/
public static <T> void permute(final T[] in, final T[] dest, final int[] p) {

final ArrayList<T> tmp = new ArrayList<T>(in.length);
for (int i = 0; i < in.length; i++)
tmp.add(in[i]);

for (int i = 0; i < p.length; i++)
dest[i] = tmp.get(p[i]);
}

public static long[] permute(final long[] in, final int[] p) {

final long[] out = new long[p.length];
for (int i = 0; i < p.length; i++)
out[i] = in[p[i]];

return out;
}

public static int[] permute(final int[] in, final int[] p) {

final int[] out = new int[p.length];
for (int i = 0; i < p.length; i++)
out[i] = in[p[i]];

return out;
}

public static double[] permute(final double[] in, final int[] p) {

final double[] out = new double[p.length];
for (int i = 0; i < p.length; i++)
out[i] = in[p[i]];

return out;
}

public static Axis[] buildAxes( final String... labels )
{
return Arrays.stream(labels).map( x -> {
Expand All @@ -86,7 +146,10 @@ public static Axis[] buildAxes( final String... labels )
* @return the permutation
*/
public static <A extends AxisMetadata > int[] findImagePlusPermutation( final AxisMetadata axisMetadata ) {
return findImagePlusPermutation( axisMetadata.getAxisLabels());

// TODO should use axis types, not just labels.
// and should consider what to do if an unknown label exists
return findImagePlusPermutation(axisMetadata.getAxisLabels());
}

/**
Expand All @@ -95,7 +158,7 @@ public static <A extends AxisMetadata > int[] findImagePlusPermutation( final Ax
* @param axisLabels the axis labels
* @return the permutation array
*/
public static int[] findImagePlusPermutation( final String[] axisLabels ) {
public static int[] findImagePlusPermutation(final String[] axisLabels) {

final int[] p = new int[ 5 ];
p[0] = indexOf( axisLabels, "x" );
Expand All @@ -106,31 +169,82 @@ public static int[] findImagePlusPermutation( final String[] axisLabels ) {
return p;
}

public static int[] findImagePlusSpatialPermutation(final int[] p) {

final OptionalInt minOpt = Arrays.stream(p).min();
if (minOpt.isPresent()) {
final int min = minOpt.getAsInt();
return Arrays.stream(p).map(x -> x - min).toArray();
} else
return p;
}

/**
* Converters an array of integers to a normalized array of integers such
* that the smallest integer is mapped to 0, the second smallest to 1 ...
* and the largest is mapped to N-1, where N is the number of unique
* integers in the array.
*
* @param p
* indexes
* @return normalized indexes
*/
public static int[] normalizeIndexes(final int[] indexes) {

final TreeSet<Integer> set = new TreeSet<Integer>();
for (final int i : indexes)
set.add(i);

// can't get index from a tree set, use this sad, not scalable workaround for now
final int[] sortedUniqueIndexes = new int[set.size()];
final Iterator<Integer> it = set.iterator();
int i = 0;
while ( it.hasNext())
sortedUniqueIndexes[i++] = it.next();

final int[] out = new int[indexes.length];
for (i = 0; i < out.length; i++ )
out[i] = Arrays.binarySearch(sortedUniqueIndexes, indexes[i]);

return out;
}

/**
* Replaces "-1"s in the input permutation array
* with the largest value.
*
* @param p the permutation
*/
public static void fillPermutation( final int[] p ) {
public static void fillPermutation(final int[] p) {
int j = Arrays.stream(p).max().getAsInt() + 1;
for (int i = 0; i < p.length; i++)
if (p[i] < 0)
p[i] = j++;
}

public static boolean isIdentityPermutation( final int[] p )
{
public static AffineGet axisPermutationTransform(final int[] p) {

final int N = p.length;
final int[] normalP = normalizeIndexes(p);
final double[] affineParams = new double[N * (N + 1)];
for (int i = 0; i < normalP.length; i++)
affineParams[normalP[i] + (N + 1) * i] = 1.0;

return new AffineTransform(affineParams);
}

public static boolean isIdentityPermutation( final int[] p ) {

for( int i = 0; i < p.length; i++ )
if( p[i] != i )
return false;

return true;
}

public static <T> RandomAccessibleInterval<T> permuteForImagePlus(
public static <T, M extends AxisMetadata & N5Metadata> RandomAccessibleInterval<T> permuteForImagePlus(
final RandomAccessibleInterval<T> img,
final AxisMetadata meta ) {
final M meta) {

final int[] p = findImagePlusPermutation( meta );
fillPermutation( p );
Expand All @@ -147,10 +261,55 @@ public static <T> RandomAccessibleInterval<T> permuteForImagePlus(
return permute(imgTmp, invertPermutation(p));
}

public static <T> RandomAccessibleInterval<T> reverseDimensions(final RandomAccessibleInterval<T> img) {
public static <M extends AxisMetadata & N5Metadata> M permuteForImagePlus(int[] spatialPermutation, final M meta) {

if (isIdentityPermutation(spatialPermutation))
return meta;

if (meta instanceof SpatialMetadata && meta instanceof SpatialModifiable) {

final AffineTransform3D tform = ((SpatialMetadata)meta).spatialTransform3d().copy();
final AffineTransform3D tformInv = ((SpatialMetadata)meta).spatialTransform3d().inverse().copy();

final AffineGet permTform = AxisUtils.axisPermutationTransform(spatialPermutation);
tform.concatenate(permTform).preConcatenate(permTform.inverse()); // exchange rows and
tform.concatenate(tformInv);

final M out = (M)(((SpatialModifiable)meta).modifySpatialTransform(meta.getPath(), tform));
return out;
}

return meta;
}

public static <T, M extends N5Metadata, A extends AxisMetadata & N5Metadata> Pair<RandomAccessibleInterval<T>, M> permuteImageAndMetadataForImagePlus(
final RandomAccessibleInterval<T> img, final M meta) {

if (meta != null && meta instanceof AxisMetadata) {

final int[] p = AxisUtils.findImagePlusPermutation((AxisMetadata)meta);
AxisUtils.fillPermutation(p);

RandomAccessibleInterval<T> imgTmp = img;
while (imgTmp.numDimensions() < 5)
imgTmp = Views.addDimension(imgTmp, 0, 0);

if (AxisUtils.isIdentityPermutation(p))
return new ValuePair<>(imgTmp, meta);

// do the permutation
final RandomAccessibleInterval<T> imgOut = permute(imgTmp, invertPermutation(p));
final int[] spatialPermutation = new int[]{p[0], p[1], p[3]};
@SuppressWarnings("unchecked")
final M permutedMeta = (M)permuteForImagePlus(spatialPermutation, (A)meta);

return new ValuePair<>(imgOut, permutedMeta);
}

// final int[] p = IntStream.range(0, img.numDimensions()).toArray();
// ArrayUtils.reverse(p);
return new ValuePair<>(img, meta);
}

public static <T> RandomAccessibleInterval<T> reverseDimensions(final RandomAccessibleInterval<T> img) {

final int nd = img.numDimensions();
final int[] p = IntStream.iterate(nd - 1, x -> x - 1).limit(nd).toArray();
Expand All @@ -175,22 +334,21 @@ private static final <T> int indexOf(final T[] arr, final T tgt) {
* @param p the permutation
* @return the permuted source
*/
public static final < T > IntervalView< T > permute( final RandomAccessibleInterval< T > source, final int[] p )
{
public static final <T> IntervalView<T> permute(final RandomAccessibleInterval<T> source, final int[] p) {

final int n = source.numDimensions();

final long[] min = new long[ n ];
final long[] max = new long[ n ];
for ( int i = 0; i < n; ++i )
{
min[ p[ i ] ] = source.min( i );
max[ p[ i ] ] = source.max( i );
final long[] min = new long[n];
final long[] max = new long[n];
for (int i = 0; i < n; ++i) {
min[p[i]] = source.min(i);
max[p[i]] = source.max(i);
}

final MixedTransform t = new MixedTransform( n, n );
t.setComponentMapping( p );
final MixedTransform t = new MixedTransform(n, n);
t.setComponentMapping(p);

final IntervalView<T> out = Views.interval( new MixedTransformView< T >( source, t ), min, max );
final IntervalView<T> out = Views.interval(new MixedTransformView<T>(source, t), min, max);
return out;
}

Expand All @@ -203,6 +361,11 @@ public static int[] invertPermutation( final int[] p )
return inv;
}

public static int[] indexes(final Axis[] axes, final Predicate<Axis> predicate) {

return IntStream.range(0, axes.length).filter(i -> predicate.test(axes[i])).toArray();
}

public static Axis[] defaultAxes( final int N ) {

return IntStream.range(0, N).mapToObj(i -> {
Expand Down

0 comments on commit 5b01482

Please sign in to comment.