Skip to content

Commit

Permalink
fixed stitchSectionsAndCreateGroupedTiles() and getOrBuildTile()
Browse files Browse the repository at this point in the history
  • Loading branch information
StephanPreibisch committed Aug 7, 2023
1 parent 88968d8 commit 9088b82
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.concurrent.ExecutionException;
import java.util.function.Function;

import com.google.common.graph.Graph;
import org.janelia.alignment.match.CanvasMatches;
import org.janelia.alignment.spec.ResolvedTileSpecCollection;
import org.janelia.alignment.spec.TileSpec;
Expand All @@ -24,6 +23,7 @@
import org.janelia.render.client.newsolver.solvers.Worker;
import org.janelia.render.client.solver.ConstantAffineModel2D;
import org.janelia.render.client.solver.DistributedSolveWorker;
import org.janelia.render.client.solver.Graph;
import org.janelia.render.client.solver.MinimalTileSpec;
import org.janelia.render.client.solver.SolveTools;
import org.janelia.render.client.solver.SerializableValuePair;
Expand Down Expand Up @@ -95,6 +95,9 @@ public class AffineAlignBlockWorker< M extends Model< M > & Affine2D< M >, S ext
// to filter matches
final MatchFilter matchFilter;

// if stitching first should be done
final boolean stitchFirst;

// for error computation (local)
final ArrayList< CanvasMatches > canvasMatches = new ArrayList<>();

Expand Down Expand Up @@ -141,6 +144,7 @@ public AffineAlignBlockWorker(
this.maxZRangeMatches = blockData.solveTypeParameters().maxZRangeMatches();

// used locally
this.stitchFirst = blockData.solveTypeParameters().minStitchingInliersSupplier() != null;
this.pairs = new ArrayList<>();
this.zToPairs = new HashMap<>();
}
Expand All @@ -159,7 +163,7 @@ public void run() throws IOException, ExecutionException, InterruptedException,
connectGroupedTiles( pairs, inputSolveItem );
this.solveItems = splitSolveItem( inputSolveItem, startId );

for ( final SolveItem< G, M, S > solveItem : solveItems )
for ( final AffineBlockDataWrapper< M, S, F > solveItem : solveItems )
{
if ( !assignRegularizationModel( solveItem ) )
throw new RuntimeException( "Couldn't regularize. Please check." );
Expand Down Expand Up @@ -288,23 +292,28 @@ private List<CanvasMatches> getCanvasMatches(final String sectionId, final int m
return matches;
}

private Tile<M> getOrBuildTile(final String id, final TileSpec tileSpec) {
protected Tile<M> getOrBuildTile( final String id, final TileSpec tileSpec )
{
final Tile<M> tile;
if (!inputSolveItem.idToTileMap().containsKey(id)) {
final Pair<Tile<M>, AffineModel2D> pair = SolveTools.buildTileFromSpec(inputSolveItem.blockSolveModelInstance(), SolveItem.samplesPerDimension, tileSpec);
if (!inputSolveItem.idToTileMap().containsKey(id))
{
final Pair<Tile<M>, AffineModel2D> pair =
SolveTools.buildTileFromSpec( inputSolveItem.blockData().solveTypeParameters().blockSolveModel().copy(), SolveItem.samplesPerDimension, tileSpec );
tile = pair.getA();
inputSolveItem.idToTileMap().put(id, tile);
inputSolveItem.tileToIdMap().put(tile, id);

inputSolveItem.idToPreviousModel().put(id, pair.getB());
final MinimalTileSpec minimalSpecWrapper = new MinimalTileSpecWrapper(tileSpec);
inputSolveItem.idToTileSpec().put(id, minimalSpecWrapper);
//inputSolveItem.blockData().idToTileSpec().put(id, minimalSpecWrapper); // this is now done ahead of time

if (minimalSpecWrapper.isRestart())
inputSolveItem.restarts().add((int) Math.round(minimalSpecWrapper.getZ()));
} else {
if ( tileSpec.hasLabel( "restart" ) )
inputSolveItem.restarts().add((int) Math.round(tileSpec.getZ()));
}
else
{
tile = inputSolveItem.idToTileMap().get(id);
}

return tile;
}

Expand Down Expand Up @@ -580,21 +589,23 @@ protected void connectGroupedTiles(
}

protected void stitchSectionsAndCreateGroupedTiles(
final SolveItem< G,M,S > solveItem,
final AffineBlockDataWrapper< M, S, F > solveItem,
final ArrayList< Pair< Pair< Tile< ? >, Tile< ? > >, List< PointMatch > > > pairs,
final HashMap< Integer, List< Integer > > zToPairs,
final Function< Integer, Integer > minStitchingInliersSupplier,
final boolean stitchFirst,
final int numThreads )
{
//final S model = solveItem.stitchingSolveModelInstance();

// combine tiles per layer that are be stitched first, but iterate over all z's
// (also those only consisting of single tiles, they are connected in z though)
final ArrayList< Integer > zList = new ArrayList<>( solveItem.zToTileId().keySet() );
final ArrayList< Integer > zList = new ArrayList<>( solveItem.blockData().zToTileId().keySet() );
Collections.sort( zList );

for ( final int z : zList )
{
LOG.info( "block " + solveItem.getId() + ": stitching z=" + z );
LOG.info( "block " + solveItem.blockData().getId() + ": stitching z=" + z );

final HashMap< String, Tile< S > > idTotile = new HashMap<>();
final HashMap< Tile< S >, String > tileToId = new HashMap<>();
Expand All @@ -609,7 +620,7 @@ protected void stitchSectionsAndCreateGroupedTiles(
final Pair< Pair< Tile< ? >, Tile< ? > >, List< PointMatch > > pair = pairs.get( index );

// stitching first only works when the stitching is reliable
if ( pair.getB().size() < minStitchingInliers )
if ( pair.getB().size() < minStitchingInliersSupplier.apply( z ) )
continue;

final String pId = solveItem.tileToIdMap().get( pair.getA().getA() );
Expand All @@ -629,7 +640,10 @@ protected void stitchSectionsAndCreateGroupedTiles(
{
//p = new Tile<>( model.copy() );
// since we do preAlign later this seems redundant. However, it makes sure the tiles are more or less at the right global coordinates
p = SolveTools.buildTile( solveItem.idToPreviousModel().get( pId ), solveItem.stitchingSolveModelInstance( z ).copy(), 100, 100, 3 );
p = SolveTools.buildTile(
solveItem.idToPreviousModel().get( pId ),
solveItem.blockData().solveTypeParameters().stitchingSolveModelInstance( z ).copy(),
100, 100, 3 );
idTotile.put( pId, p );
tileToId.put( p, pId );
}
Expand All @@ -641,7 +655,10 @@ protected void stitchSectionsAndCreateGroupedTiles(
if ( !idTotile.containsKey( qId ) )
{
//q = new Tile<>( model.copy() );
q = SolveTools.buildTile( solveItem.idToPreviousModel().get( qId ), solveItem.stitchingSolveModelInstance( z ).copy(), 100, 100, 3 );
q = SolveTools.buildTile(
solveItem.idToPreviousModel().get( qId ),
solveItem.blockData().solveTypeParameters().stitchingSolveModelInstance( z ).copy(),
100, 100, 3 );
idTotile.put( qId, q );
tileToId.put( q, qId );
}
Expand All @@ -657,31 +674,31 @@ protected void stitchSectionsAndCreateGroupedTiles(
}

// add all missing TileIds as unconnected Tiles
for ( final String tileId : solveItem.zToTileId().get( z ) )
for ( final String tileId : solveItem.blockData().zToTileId().get( z ) )
if ( !idTotile.containsKey( tileId ) )
{
LOG.info( "block " + solveItem.getId() + ": unconnected tileId " + tileId );
LOG.info( "block " + solveItem.blockData().getId() + ": unconnected tileId " + tileId );

final Tile< S > tile = new Tile< S >( solveItem.stitchingSolveModelInstance( z ).copy() );
final Tile< S > tile = new Tile< S >( solveItem.blockData().solveTypeParameters().stitchingSolveModelInstance( z ).copy() );
idTotile.put( tileId, tile );
tileToId.put( tile, tileId );
}

// Now identify connected graphs within all tiles
final ArrayList< Set< Tile< ? > > > sets = safelyIdentifyConnectedGraphs( new ArrayList<>(idTotile.values()) );

LOG.info( "block " + solveItem.getId() + ": stitching z=" + z + " #sets=" + sets.size() );
LOG.info( "block " + solveItem.blockData().getId() + ": stitching z=" + z + " #sets=" + sets.size() );

// solve each set (if size > 1)
int setCount = 0;
for ( final Set< Tile< ? > > set : sets )
{
LOG.info( "block " + solveItem.getId() + ": Set=" + setCount++ );
LOG.info( "block " + solveItem.blockData().getId() + ": Set=" + setCount++ );

//
// the grouped tile for this set of one layer
//
final Tile< M > groupedTile = new Tile<>( solveItem.blockSolveModelInstance() );
final Tile< M > groupedTile = new Tile<>( solveItem.blockData().solveTypeParameters().blockSolveModel().copy() );

if ( set.size() > 1 )
{
Expand All @@ -696,14 +713,17 @@ protected void stitchSectionsAndCreateGroupedTiles(
}
catch ( final NotEnoughDataPointsException | IllDefinedDataPointsException e )
{
LOG.info( "block " + solveItem.getId() + ": Could not solve prealign for z=" + z + ", cause: " + e );
LOG.info( "block " + solveItem.blockData().getId() + ": Could not solve prealign for z=" + z + ", cause: " + e );
e.printStackTrace();
}

// test if the graph has cycles, if yes we would need to do a solve
if ( !( ( set.iterator().next().getModel() instanceof TranslationModel2D || set.iterator().next().getModel() instanceof RigidModel2D) && !new Graph(new ArrayList<>(set ) ).isCyclic() ) )
if ( !( (
set.iterator().next().getModel() instanceof TranslationModel2D ||
set.iterator().next().getModel() instanceof RigidModel2D) &&
!new Graph( new ArrayList<>( set ) ).isCyclic() ) )
{
LOG.info( "block " + solveItem.getId() + ": Full solve required for stitching z=" + z );
LOG.info( "block " + solveItem.blockData().getId() + ": Full solve required for stitching z=" + z );

try
{
Expand All @@ -718,11 +738,11 @@ protected void stitchSectionsAndCreateGroupedTiles(
tileConfig.getFixedTiles(),
numThreads );

LOG.info( "block " + solveItem.getId() + ": Solve z=" + z + " avg=" + tileConfig.getError() + ", min=" + tileConfig.getMinError() + ", max=" + tileConfig.getMaxError() );
LOG.info( "block " + solveItem.blockData().getId() + ": Solve z=" + z + " avg=" + tileConfig.getError() + ", min=" + tileConfig.getMinError() + ", max=" + tileConfig.getMaxError() );
}
catch ( final Exception e )
{
LOG.info( "block " + solveItem.getId() + ": Could not solve stitiching for z=" + z + ", cause: " + e );
LOG.info( "block " + solveItem.blockData().getId() + ": Could not solve stitiching for z=" + z + ", cause: " + e );
e.printStackTrace();
}
}
Expand All @@ -741,8 +761,8 @@ protected void stitchSectionsAndCreateGroupedTiles(
solveItem.groupedTileToTiles().putIfAbsent( groupedTile, new ArrayList<>() );
solveItem.groupedTileToTiles().get( groupedTile ).add( solveItem.idToTileMap().get( tileId ) );

LOG.info( "block " + solveItem.getId() + ": TileId " + tileId + " Model= " + affine );
LOG.info( "block " + solveItem.getId() + ": TileId " + tileId + " prev Model=" + solveItem.idToPreviousModel().get( tileId ) );
LOG.info( "block " + solveItem.blockData().getId() + ": TileId " + tileId + " Model= " + affine );
LOG.info( "block " + solveItem.blockData().getId() + ": TileId " + tileId + " prev Model=" + solveItem.idToPreviousModel().get( tileId ) );
}

// Hack: show a section after alignment
Expand All @@ -758,7 +778,7 @@ protected void stitchSectionsAndCreateGroupedTiles(
}

new ImageJ();
final ImagePlus imp1 = VisualizeTools.render(models, solveItem.idToTileSpec(), 0.15 );
final ImagePlus imp1 = VisualizeTools.renderTS(models, solveItem.blockData().idToTileSpec(), 0.15 );
imp1.setTitle( "z=" + z );
}
catch ( final NoninvertibleModelException e )
Expand All @@ -778,7 +798,7 @@ protected void stitchSectionsAndCreateGroupedTiles(
solveItem.groupedTileToTiles().putIfAbsent( groupedTile, new ArrayList<>() );
solveItem.groupedTileToTiles().get( groupedTile ).add( solveItem.idToTileMap().get( tileId ) );

LOG.info( "block " + inputSolveItem.getId() + ": Single TileId " + tileId );
LOG.info( "block " + inputSolveItem.blockData().getId() + ": Single TileId " + tileId );
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ public class AffineBlockDataWrapper< M extends Model< M > & Affine2D< M >, S ext
// which z layers are restarts
final private HashSet< Integer > restarts = new HashSet<Integer>();

// contains the model as loaded from renderer (can go right now except for debugging)
final private HashMap<String, AffineModel2D> idToPreviousModel = new HashMap<>();

// matches for error computation
final List< Pair< Pair< String, String>, Matches > > matches = new ArrayList<>();

Expand All @@ -61,5 +64,6 @@ public AffineBlockDataWrapper( final BlockData< M, FIBSEMAlignmentParameters< M,
public HashMap< Tile< M >, Tile< M > > tileToGroupedTile() { return tileToGroupedTile; }
public HashMap< Tile< M >, List< Tile< M > > > groupedTileToTiles() { return groupedTileToTiles; }
public HashSet< Integer > restarts() { return restarts; }
public HashMap<String, AffineModel2D> idToPreviousModel() { return idToPreviousModel; }

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* @author spreibi
*
*/
class Graph
public class Graph
{
private int v; // No. of vertices
private LinkedList< Integer > adj[]; // Adjacency List Represntation
Expand Down Expand Up @@ -91,7 +91,7 @@ else if ( i != parent )
}

// Returns true if the graph contains a cycle, else false.
boolean isCyclic()
public boolean isCyclic()
{
// Mark all the vertices as not visited and not part of
// recursion stack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,15 +711,15 @@ public int deltaZ( final Tile<?> tile1, final Tile<?> tile2 )
return unAlignedTiles;
}

protected static AffineModel2D createAffine( final Affine2D< ? > model )
public static AffineModel2D createAffine( final Affine2D< ? > model )
{
final AffineModel2D m = new AffineModel2D();
m.set( model.createAffine() );

return m;
}

protected static List< PointMatch > duplicate( List< PointMatch > pms )
public static List< PointMatch > duplicate( List< PointMatch > pms )
{
final List< PointMatch > copy = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -372,6 +374,17 @@ public static < T > RandomAccessibleInterval< T > translateIfNecessary( final In
}
}

// TODO: rendering should take all models (including deformations, not only one affine)
public static ImagePlus renderTS( final HashMap<String, AffineModel2D> idToModels, final Map<String, TileSpec> idToTileSpec, final double scale ) throws NoninvertibleModelException
{
final HashMap<String, MinimalTileSpec> idToTileSpecMinimal = new HashMap<>();

for ( final Entry<String, TileSpec> e : idToTileSpec.entrySet() )
idToTileSpecMinimal.put(e.getKey(), new MinimalTileSpec( e.getValue() ) );

return render(idToModels, idToTileSpecMinimal, scale, Integer.MIN_VALUE, Integer.MAX_VALUE );
}

// TODO: rendering should take all models (including deformations, not only one affine)
public static ImagePlus render( final HashMap<String, AffineModel2D> idToModels, final HashMap<String, MinimalTileSpec> idToTileSpec, final double scale ) throws NoninvertibleModelException
{
Expand Down

0 comments on commit 9088b82

Please sign in to comment.