Skip to content

Commit

Permalink
put StabilizingAffineModel2D back in, removed dynamiclambda
Browse files Browse the repository at this point in the history
  • Loading branch information
StephanPreibisch committed Aug 14, 2023
1 parent b6c0236 commit ff8262f
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.janelia.render.client.newsolver.solvers;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import org.janelia.render.client.newsolver.solvers.affine.AffineBlockDataWrapper;

import mpicbg.models.Affine2D;
import mpicbg.models.Model;
import mpicbg.models.Tile;

public class WorkerTools
{
public static class LayerDetails< M extends Model<M> & Affine2D< M >>
{
final public String tileId;
final public int tileCol, tileRow;
final public Tile< M > prevGroupedTile;

public LayerDetails( final String tileId, final int tileCol, final int tileRow, final Tile< M > prevGroupedTile )
{
this.tileId = tileId;
this.tileCol = tileCol;
this.tileRow = tileRow;
this.prevGroupedTile = prevGroupedTile;
}
}

public static < M extends Model<M> & Affine2D< M >> ArrayList< LayerDetails< M > > layerDetails(
final ArrayList< Integer > allZ,
final HashMap< Integer, List<Tile<M>> > zToGroupedTileList,
final AffineBlockDataWrapper< M, ?, ? > solveItem,
final int i )
{
final ArrayList< LayerDetails< M > > prevTiles = new ArrayList<>();

if ( i < 0 || i >= allZ.size() )
return prevTiles;

for ( final Tile< M > prevGroupedTile : zToGroupedTileList.get( allZ.get( i ) ) )
for ( final Tile< M > imageTile : solveItem.groupedTileToTiles().get( prevGroupedTile ) )
{
final String tileId = solveItem.tileToIdMap().get( imageTile );
final int tileCol = solveItem.blockData().idToTileSpec().get( tileId ).getLayout().getImageCol();//.getImageCol();
final int tileRow = solveItem.blockData().idToTileSpec().get( tileId ).getLayout().getImageRow();//

prevTiles.add( new LayerDetails<>(tileId, tileCol, tileRow, prevGroupedTile ) );//new ValuePair<>( new ValuePair<>( tileCol, tileId ), prevGroupedTile ) );
}

return prevTiles;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.janelia.render.client.newsolver.blocksolveparameters.FIBSEMAlignmentParameters;
import org.janelia.render.client.newsolver.blocksolveparameters.FIBSEMAlignmentParameters.PreAlign;
import org.janelia.render.client.newsolver.solvers.Worker;
import org.janelia.render.client.newsolver.solvers.WorkerTools;
import org.janelia.render.client.newsolver.solvers.WorkerTools.LayerDetails;
import org.janelia.render.client.solver.ConstantAffineModel2D;
import org.janelia.render.client.solver.DistributedSolveWorker;
import org.janelia.render.client.solver.Graph;
Expand Down Expand Up @@ -82,9 +84,9 @@ public class AffineAlignBlockWorker< M extends Model< M > & Affine2D< M >, S ext
// only for dynamic lambda stuff that we never used
//final Set<Integer> excludeFromRegularization;

final List<Double> blockOptimizerLambdasRigid, blockOptimizerLambdasTranslation;
final List<Integer> blockOptimizerIterations, blockMaxPlateauWidth;
final double blockMaxAllowedError;
//final List<Double> blockOptimizerLambdasRigid, blockOptimizerLambdasTranslation;
//final List<Integer> blockOptimizerIterations, blockMaxPlateauWidth;
//final double blockMaxAllowedError;

final AffineBlockDataWrapper< M, S, F > inputSolveItem;
private List< AffineBlockDataWrapper< M, S, F > > solveItems;
Expand Down Expand Up @@ -115,12 +117,6 @@ public AffineAlignBlockWorker(

this.inputSolveItem = new AffineBlockDataWrapper<>( blockData );

this.blockOptimizerLambdasRigid = blockData.solveTypeParameters().blockOptimizerLambdasRigid();
this.blockOptimizerLambdasTranslation = blockData.solveTypeParameters().blockOptimizerLambdasTranslation();
this.blockOptimizerIterations = blockData.solveTypeParameters().blockOptimizerIterations();
this.blockMaxPlateauWidth = blockData.solveTypeParameters().blockMaxPlateauWidth();
this.blockMaxAllowedError = blockData.solveTypeParameters().blockMaxAllowedError();

if ( blockData.solveTypeParameters().maxNumMatches() <= 0 )
this.matchFilter = new NoMatchFilter();
else
Expand Down Expand Up @@ -164,6 +160,7 @@ public void run() throws IOException, ExecutionException, InterruptedException,
{
if ( !assignRegularizationModel( solveItem ) )
throw new RuntimeException( "Couldn't regularize. Please check." );

solve( solveItem, zRadiusRestarts, numThreads );
}

Expand Down Expand Up @@ -315,8 +312,8 @@ protected Tile<M> getOrBuildTile( final String id, final TileSpec tileSpec )
inputSolveItem.idToPreviousModel().put(id, pair.getB());
//inputSolveItem.blockData().idToTileSpec().put(id, minimalSpecWrapper); // this is now done ahead of time

if ( tileSpec.hasLabel( "restart" ) )
inputSolveItem.restarts().add((int) Math.round(tileSpec.getZ()));
//if ( tileSpec.hasLabel( "restart" ) )
// inputSolveItem.restarts().add((int) Math.round(tileSpec.getZ()));
}
else
{
Expand All @@ -331,8 +328,13 @@ protected Tile<M> getOrBuildTile( final String id, final TileSpec tileSpec )
* (alternative: top left corner?)
*
* @param solveItem - the solve item
* @param samplesPerDimension - for creating fake matches using ConstantAffineModel2D or StabilizingAffineModel2D
* @param stabilizationRadius - the radius in z that is used for stabilization using StabilizingAffineModel2D
*/
protected boolean assignRegularizationModel( final AffineBlockDataWrapper< M, S, F > solveItem )
protected boolean assignRegularizationModel(
final AffineBlockDataWrapper< M, S, F > solveItem,
final int samplesPerDimension,
final int stabilizationRadius )
{
LOG.info( "Assigning regularization models." );

Expand All @@ -354,7 +356,9 @@ protected boolean assignRegularizationModel( final AffineBlockDataWrapper< M, S,
final ArrayList< Integer > allZ = new ArrayList<>( zToGroupedTileList.keySet() );
Collections.sort( allZ );

if (((InterpolatedAffineModel2D) zToGroupedTileList.get(allZ.get(0)).get(0).getModel()).getB() instanceof ConstantAffineModel2D)
final Model<?> model = ((InterpolatedAffineModel2D<?,?>) zToGroupedTileList.get(allZ.get(0)).get(0).getModel()).getB();

if ( ConstantAffineModel2D.class.isInstance( model ) )
{
//
// it is based on ConstantAffineModels, meaning we extract metadata and use that as regularizer
Expand Down Expand Up @@ -405,14 +409,14 @@ protected boolean assignRegularizationModel( final AffineBlockDataWrapper< M, S,
//LOG.info( "z=" + z + " stitching model: " + stitchingTransform );
//LOG.info( "z=" + z + " metaData model : " + metaDataTransform );

final double sampleWidth = (tileSpec.getWidth() - 1.0) / (SolveItem.samplesPerDimension - 1.0);
final double sampleHeight = (tileSpec.getHeight() - 1.0) / (SolveItem.samplesPerDimension - 1.0);
final double sampleWidth = (tileSpec.getWidth() - 1.0) / (samplesPerDimension - 1.0);
final double sampleHeight = (tileSpec.getHeight() - 1.0) / (samplesPerDimension - 1.0);

// ALTERNATIVELY: ONLY SELECT ONE OF THE TILES
for (int y = 0; y < SolveItem.samplesPerDimension; ++y)
for (int y = 0; y < samplesPerDimension; ++y)
{
final double sampleY = y * sampleHeight;
for (int x = 0; x < SolveItem.samplesPerDimension; ++x)
for (int x = 0; x < samplesPerDimension; ++x)
{
final double[] p = new double[] { x * sampleWidth, sampleY };
final double[] q = new double[] { x * sampleWidth, sampleY };
Expand All @@ -427,23 +431,23 @@ protected boolean assignRegularizationModel( final AffineBlockDataWrapper< M, S,
//final RigidModel2D regularizationModel = new RigidModel2D();
//final TranslationModel2D regularizationModel = new TranslationModel2D();
//final S regularizationModel = solveItem.stitchingSolveModelInstance();

final ConstantAffineModel2D cModel = (ConstantAffineModel2D)((InterpolatedAffineModel2D) groupedTile.getModel()).getB();
final Model< ? > regularizationModel = cModel.getModel();

try
{
regularizationModel.fit( matches );

double sumError = 0;

for ( final PointMatch pm : matches )
{
pm.getP1().apply( regularizationModel );

final double distance = Point.distance(pm.getP1(), pm.getP2() );
sumError += distance;

//LOG.info( "P1: " + Util.printCoordinates( pm.getP1().getW() ) + ", P2: " + Util.printCoordinates( pm.getP2().getW() ) + ", d=" + distance );
}
LOG.info( "Error=" + (sumError / matches.size()) );
Expand All @@ -456,6 +460,111 @@ protected boolean assignRegularizationModel( final AffineBlockDataWrapper< M, S,
}
return true;
}
else if ( StabilizingAffineModel2D.class.isInstance( model ) )
{
//
// it is based on StabilizingAffineModel2Ds, meaning each image wants to sit where its corresponding one in the above layer sits
//
for ( int i = 0; i < allZ.size(); ++i )
{
final int z = allZ.get( i );

// first get all tiles from adjacent layers and the associated grouped tile
final ArrayList< LayerDetails< M > > neighboringTiles = new ArrayList<>();

int from = i, to = i;

for ( int d = 1; d <= stabilizationRadius && i + d < allZ.size(); ++d )
{
//if ( solveItem.restarts().contains( allZ.get( i + d ) ) )
// break;
//else
neighboringTiles.addAll( WorkerTools.layerDetails( allZ, zToGroupedTileList, solveItem, i + d ) );

to = i + d;
}

// if this z section is a restart we only go down from here
// if ( !solveItem.restarts().contains( z ) )
{
for ( int d = 1; d <= stabilizationRadius && i - d >= 0; ++d )
{
// always connect up, even if it is a restart, then break afterwards
neighboringTiles.addAll( WorkerTools.layerDetails( allZ, zToGroupedTileList, solveItem, i - d ) );

from = i - d;

//if ( solveItem.restarts().contains( allZ.get( i - d ) ) )
// break;
}
}

final List< Tile< M > > groupedTiles = zToGroupedTileList.get( z );

//if ( solveItem.restarts().contains( z ) )
// LOG.info( "z=" + z + " is a RESTART" );

LOG.info( "z=" + z + " contains " + groupedTiles.size() + " grouped tiles (StabilizingAffineModel2D), connected from " + allZ.get( from ) + " to " + allZ.get( to ) );

// now go over all tiles of the current z
for ( final Tile< M > groupedTile : groupedTiles )
{
final List< Tile< M > > imageTiles = solveItem.groupedTileToTiles().get( groupedTile );

if ( groupedTiles.size() > 1 )
LOG.info( "z=" + z + " grouped tile [" + groupedTile + "] contains " + imageTiles.size() + " image tiles." );

// create pointmatches from the edges of each image in the grouped tile to the respective edges in the metadata
final List< Pair< List< PointMatch >, Tile< M > > > matchesList = new ArrayList<>();

for ( final Tile< M > imageTile : imageTiles )
{
final String tileId = solveItem.tileToIdMap().get( imageTile );
final TileSpec tileSpec = solveItem.blockData().idToTileSpec().get( tileId );

final int tileCol = tileSpec.getLayout().getImageCol();// tileSpec.getImageCol();
final int tileRow = tileSpec.getLayout().getImageRow();

// if ( tileCol != 0 )
// continue;

final ArrayList< LayerDetails< M > > neighbors = new ArrayList<>();

for ( final LayerDetails< M > neighboringTile : neighboringTiles )
if ( neighboringTile.tileCol == tileCol && neighboringTile.tileRow == tileRow )
neighbors.add( neighboringTile );

if ( neighbors.size() == 0 )
{
// this can happen when number of tiles per layer changes for example
LOG.info( "could not find corresponding tile for: " + tileId );
continue;
}

for ( final LayerDetails< M > neighbor : neighbors )
{
final AffineModel2D stitchingTransform = solveItem.idToStitchingModel().get( tileId );
final AffineModel2D stitchingTransformPrev = solveItem.idToStitchingModel().get( neighbor.tileId );

final List< PointMatch > matches = SolveTools.createFakeMatches(
tileSpec.getWidth(),
tileSpec.getHeight(),
stitchingTransform, // p
stitchingTransformPrev, // q
samplesPerDimension );

matchesList.add( new ValuePair<>( matches, neighbor.prevGroupedTile ) );
}
}

// in every iteration, update q with the current group tile transformation(s), the fit p to q for regularization
final StabilizingAffineModel2D cModel = (StabilizingAffineModel2D)((InterpolatedAffineModel2D) groupedTile.getModel()).getB();

cModel.setFitData( matchesList );
}
}
return true;
}
else
{
LOG.info( "Not using ConstantAffineModel2D for regularization. Nothing to do in assignRegularizationModel()." );
Expand Down Expand Up @@ -803,9 +912,9 @@ else if ( graphs.size() == 1 )
}

// add the restart lookup
for ( final int z : inputSolveItem.restarts() )
if ( z >= solveItem.blockData().minZ() && z <= solveItem.blockData().maxZ() )
solveItem.restarts().add( z );
//for ( final int z : inputSolveItem.restarts() )
// if ( z >= solveItem.blockData().minZ() && z <= solveItem.blockData().maxZ() )
// solveItem.restarts().add( z );

// used for global solve outside
for ( int z = solveItem.blockData().minZ(); z <= solveItem.blockData().maxZ(); ++z )
Expand Down Expand Up @@ -844,28 +953,42 @@ protected void solve(
final int numThreads
) throws InterruptedException, ExecutionException
{
final PreAlign preAlign = solveItem.blockData().solveTypeParameters().preAlign();

//final List<Double> blockOptimizerLambdasRigid, blockOptimizerLambdasTranslation;
//final List<Integer> blockOptimizerIterations, blockMaxPlateauWidth;
//final double blockMaxAllowedError;

final List<Double> blockOptimizerLambdasRigid = solveItem.blockData().solveTypeParameters().blockOptimizerLambdasRigid();
final List<Double> blockOptimizerLambdasTranslation = solveItem.blockData().solveTypeParameters().blockOptimizerLambdasTranslation();
final List<Double> blockOptimizerLambdasRegularization = solveItem.blockData().solveTypeParameters().blockOptimizerLambdasRegularization();
final List<Integer> blockOptimizerIterations = solveItem.blockData().solveTypeParameters().blockOptimizerIterations();
final List<Integer> blockMaxPlateauWidth = solveItem.blockData().solveTypeParameters().blockMaxPlateauWidth();
final double blockMaxAllowedError = blockData.solveTypeParameters().blockMaxAllowedError();

final TileConfiguration tileConfig = new TileConfiguration();

// new HashSet because all tiles link to their common group tile, which is therefore present more than once
tileConfig.addTiles( new HashSet<>( solveItem.tileToGroupedTile().values() ) );

LOG.info("block " + solveItem.blockData().getId() + ": run: optimizing {} tiles", solveItem.groupedTileToTiles().keySet().size() );

final HashMap< Tile< ? >, Double > tileToDynamicLambda = SolveTools.computeMetaDataLambdas( tileConfig.getTiles(), solveItem, zRadiusRestarts, excludeFromRegularization, dynamicLambdaFactor );
//final HashMap< Tile< ? >, Double > tileToDynamicLambda = SolveTools.computeMetaDataLambdas( tileConfig.getTiles(), solveItem, zRadiusRestarts, excludeFromRegularization, dynamicLambdaFactor );

if ( rigidPreAlign )
if ( preAlign == PreAlign.RIGID )
LOG.info( "block " + solveItem.blockData().getId() + ": prealigning with rigid" );
else
else if ( preAlign == PreAlign.TRANSLATION )
LOG.info( "block " + solveItem.blockData().getId() + ": prealigning with translation" );
else
LOG.info( "block " + solveItem.blockData().getId() + ": NO prealignment" );

for (final Tile< ? > tile : tileConfig.getTiles() )
for ( final Tile< ? > tile : tileConfig.getTiles() )
{
// TODO: the prealign should reflect wheter we use translation or rigid as baseline
((InterpolatedAffineModel2D)((InterpolatedAffineModel2D)((InterpolatedAffineModel2D) tile.getModel()).getA()).getA()).setLambda( 1.0 ); // rigid vs affine

if ( rigidPreAlign )
if ( preAlign == PreAlign.RIGID )
((InterpolatedAffineModel2D)((InterpolatedAffineModel2D) tile.getModel()).getA()).setLambda( 0.0 ); // translation vs (rigid vs affine)
else
else if ( preAlign == PreAlign.TRANSLATION )
((InterpolatedAffineModel2D)((InterpolatedAffineModel2D) tile.getModel()).getA()).setLambda( 1.0 ); // translation vs (rigid vs affine)

((InterpolatedAffineModel2D) tile.getModel()).setLambda( 0.0 ); // prealign without regularization
Expand All @@ -876,16 +999,19 @@ protected void solve(
double[] errors = SolveTools.computeErrors( tileConfig.getTiles() );
LOG.info( "errors: " + errors[ 0 ] + "/" + errors[ 1 ] + "/" + errors[ 2 ] );

final Map< Tile< ? >, Integer > tileToZ = new HashMap<>();

for ( final Tile< ? > tile : tileConfig.getTiles() )
tileToZ.put( tile, (int)Math.round( solveItem.idToTileSpec().get( solveItem.tileToIdMap().get( solveItem.groupedTileToTiles().get( tile ).get( 0 ) ) ).getZ() ) );

SolveTools.preAlignByLayerDistance( tileConfig, tileToZ );
//tileConfig.preAlign();

errors = SolveTools.computeErrors( tileConfig.getTiles() );
LOG.info( "errors: " + errors[ 0 ] + "/" + errors[ 1 ] + "/" + errors[ 2 ] );
if ( preAlign != PreAlign.NONE )
{
final Map< Tile< ? >, Integer > tileToZ = new HashMap<>();

for ( final Tile< ? > tile : tileConfig.getTiles() )
tileToZ.put( tile, (int)Math.round( solveItem.blockData().idToTileSpec().get( solveItem.tileToIdMap().get( solveItem.groupedTileToTiles().get( tile ).get( 0 ) ) ).getZ() ) );

SolveTools.preAlignByLayerDistance( tileConfig, tileToZ );
//tileConfig.preAlign();

errors = SolveTools.computeErrors( tileConfig.getTiles() );
LOG.info( "errors: " + errors[ 0 ] + "/" + errors[ 1 ] + "/" + errors[ 2 ] );
}
}
catch (final NotEnoughDataPointsException | IllDefinedDataPointsException e)
{
Expand Down Expand Up @@ -913,6 +1039,7 @@ protected void solve(

final double lambdaRigid = blockOptimizerLambdasRigid.get( s );
final double lambdaTranslation = blockOptimizerLambdasTranslation.get( s );
final double regularization = blockOptimizerLambdasRegularization.get( s );

for (final Tile< ? > tile : tileConfig.getTiles() )
{
Expand Down
Loading

0 comments on commit ff8262f

Please sign in to comment.