Skip to content

Commit

Permalink
Merge branch 'newsolver' of github.com:saalfeldlab/render into newsolver
Browse files Browse the repository at this point in the history
  • Loading branch information
minnerbe committed Aug 16, 2023
2 parents 12cb4f6 + e2cb2da commit c9928ab
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,38 @@
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.Function;

import org.janelia.alignment.spec.ResolvedTileSpecCollection.TransformApplicationMethod;
import org.janelia.render.client.newsolver.blockfactories.ZBlockFactory;
import org.janelia.render.client.newsolver.blocksolveparameters.FIBSEMAlignmentParameters;
import org.janelia.render.client.newsolver.setup.AffineSolverSetup;
import org.janelia.render.client.newsolver.setup.RenderSetup;
import org.janelia.render.client.newsolver.solvers.Worker;
import org.janelia.render.client.newsolver.solvers.WorkerTools;
import org.janelia.render.client.newsolver.solvers.affine.AffineAlignBlockWorker;
import org.janelia.render.client.solver.DistributedSolveDeSerialize;
import org.janelia.render.client.solver.DistributedSolveWorker;
import org.janelia.render.client.solver.MinimalTileSpec;
import org.janelia.render.client.solver.SolveItemData;
import org.janelia.render.client.solver.SolveTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import mpicbg.models.Affine2D;
import mpicbg.models.AffineModel2D;
import mpicbg.models.Model;
import mpicbg.models.NoninvertibleModelException;
import mpicbg.spim.io.IOFunctions;

public class AffineDistributedSolver
{
Expand Down Expand Up @@ -86,35 +102,89 @@ public static void main( final String[] args ) throws IOException
final BlockCollection< ?, AffineModel2D, ?, ZBlockFactory > blockCollection =
solverSetup.setupSolve( cmdLineSetup.blockModel(), cmdLineSetup.stitchingModel() );

final ExecutorService taskExecutor = Executors.newFixedThreadPool( cmdLineSetup.threadsGlobal );
//
// multi-threaded solve
//
LOG.info( "Multithreading with thread num=" + cmdLineSetup.threadsGlobal );

taskExecutor.submit( () ->
blockCollection.allBlocks().parallelStream().forEach( block ->
final ArrayList< Callable< List< BlockData<?, AffineModel2D, ?, ZBlockFactory> > > > workers = new ArrayList<>();

blockCollection.allBlocks().forEach( block ->
{
workers.add( () ->
{
final Worker<?, AffineModel2D, ?, ZBlockFactory> worker = block.createWorker(
solverSetup.col.maxId() + 1,
cmdLineSetup.threadsWorker );

worker.run();

return new ArrayList<>( worker.getBlockDataList() );
} );
} );

final ArrayList< BlockData<?, AffineModel2D, ?, ZBlockFactory> > allItems = new ArrayList<>();

try
{
final ExecutorService taskExecutor = Executors.newFixedThreadPool( cmdLineSetup.threadsGlobal );

taskExecutor.invokeAll( workers ).forEach( future ->
{
try
{
Worker<?, AffineModel2D, ?, ZBlockFactory> worker = block.solveTypeParameters().createWorker( block, solverSetup.col.maxId() + 1, cmdLineSetup.threadsWorker );
worker.run();
allItems.addAll( future.get() );
}
catch (IOException | ExecutionException | InterruptedException | NoninvertibleModelException e)
catch (InterruptedException | ExecutionException e)
{
LOG.error( "Failed to compute alignments: " + e );
e.printStackTrace();
System.exit( 1 );
return;
}
}));
} );

taskExecutor.shutdown();
/*
for ( final Worker<?, ?, ZBlockFactory > worker : workers )
taskExecutor.shutdown();
}
catch (InterruptedException e1)
{
ArrayList< ? extends BlockData< ?, AffineModel2D, ?, ? > > blockData = worker.getBlockDataList();
LOG.error( "Failed to compute alignments: " + e1 );
e1.printStackTrace();
return;
}

// avoid duplicate id assigned while splitting solveitems in the workers
// but do keep ids that are smaller or equal to the maxId of the initial solveset
final int maxId = SolveTools.fixIds( this.allItems, solverSetup.col.maxId() );
final int maxId = WorkerTools.fixIds( allItems, solverSetup.col.maxId() );

LOG.info( "computed " + allItems.size() + " blocks, maxId=" + maxId);

System.out.println( workers.get( 0 ).getBlockDataList().size() );
/*
//
// Saving the result
//
LOG.info( "Saving targetstack=" + cmdLineSetup.targetStack );
//
// save the re-aligned part
//
final HashSet< Double > zToSaveSet = new HashSet<>();
for ( final TileSpec ts : solve.idToTileSpecGlobal.values() )
zToSaveSet.add( ts.getZ() );
List< Double > zToSave = new ArrayList<>( zToSaveSet );
Collections.sort( zToSave );
LOG.info("Saving from " + zToSave.get( 0 ) + " to " + zToSave.get( zToSave.size() - 1 ) );
SolveTools.saveTargetStackTiles( parameters.stack, parameters.targetStack, runParams, solve.idToFinalModelGlobal, null, zToSave, TransformApplicationMethod.REPLACE_LAST );
if ( parameters.completeTargetStack )
{
LOG.info( "Completing targetstack=" + parameters.targetStack );
SolveTools.completeStack( parameters.targetStack, runParams );
}
*/
}

Expand Down Expand Up @@ -212,4 +282,6 @@ protected < M extends Model< M > & Affine2D< M >, S extends Model< S > & Affine2

return workers;
}

private static final Logger LOG = LoggerFactory.getLogger(AffineDistributedSolver.class);
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.janelia.alignment.spec.TileSpec;
import org.janelia.render.client.newsolver.blockfactories.BlockFactory;
import org.janelia.render.client.newsolver.blocksolveparameters.BlockDataSolveParameters;
import org.janelia.render.client.newsolver.solvers.Worker;

import mpicbg.models.CoordinateTransform;
import mpicbg.models.Model;
Expand Down Expand Up @@ -104,6 +105,12 @@ public ArrayList< Function< Double, Double > > createWeightFunctions()

public void assignUpdatedId( final int id ) { this.id = id; }

public Worker< M, R, P, F > createWorker( final int startId, final int threadsWorker )
{
return solveTypeParameters().createWorker( this , startId, threadsWorker );
}


/**
* Fetches basic data for all TileSpecs
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.janelia.render.client.newsolver.assembly;

import java.util.List;

import org.janelia.render.client.newsolver.BlockData;
import org.janelia.render.client.newsolver.blockfactories.ZBlockFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import mpicbg.models.AffineModel2D;

public class Affine2DZBlockAssembler extends Assembler< AffineModel2D, ZBlockFactory >
{

public Affine2DZBlockAssembler( final List<BlockData<?, AffineModel2D, ?, ZBlockFactory>> blocks, final int startId)
{
super( blocks, startId );
}

@Override
public void assemble( )
{
}

private static final Logger LOG = LoggerFactory.getLogger(Affine2DZBlockAssembler.class);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package org.janelia.render.client.newsolver.assembly;

import java.util.HashSet;
import java.util.List;

import org.janelia.render.client.newsolver.BlockData;
import org.janelia.render.client.newsolver.blockfactories.BlockFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import mpicbg.models.CoordinateTransform;

public abstract class Assembler< R extends CoordinateTransform, F extends BlockFactory< F > >
{
final List<BlockData<?, R, ?, F> > blocks;

int id;

public Assembler( final List<BlockData<?, R, ?, F> > blocks, final int startId )
{
this.id = startId;
this.blocks = blocks;
}

public AssemblyMaps< R > getAssembly()
{
AssemblyMaps< R > am;

// the trivial case of a single block, would crash with the code below
if ( ( am = handleTrivialCase() ) != null )
return am;
else
am = new AssemblyMaps< R >();

assemble();

return null;
}

public abstract void assemble();

/**
*
* @return - the result of the trivial case if it was a single block
*/
protected AssemblyMaps< R > handleTrivialCase()
{
if ( blocks.size() == 1 )
{
LOG.info( "Assembler: only a single block, no solve across blocks necessary." );

final AssemblyMaps< R > am = new AssemblyMaps< R >();

final BlockData< ?, R, ?, F > solveItem = blocks.get( 0 );

for ( int z = solveItem.minZ(); z <= solveItem.maxZ(); ++z )
{
// there is no overlap with any other solveItem (should be beginning or end of the entire stack)
final HashSet< String > tileIds = solveItem.zToTileId().get( z );

// if there are none, we continue with the next
if ( tileIds.size() == 0 )
continue;

am.zToTileIdGlobal.putIfAbsent( z, new HashSet<>() );

for ( final String tileId : tileIds )
{
am.zToTileIdGlobal.get( z ).add( tileId );
am.idToTileSpecGlobal.put( tileId, solveItem.rtsc().getTileSpec( tileId ) );
am.idToFinalModelGlobal.put( tileId, solveItem.idToNewModel().get( tileId ) );
}
}

return am;
}
else
{
return null;
}
}

/*
public static ResolvedTileSpecCollection combineAllTileSpecs( final List< BlockData<?, ?, ?, ?> > allItems )
{
final ResolvedTileSpecCollection rtsc = new ResolvedTileSpecCollection();
// TODO trautmane - improve this
// this should automatically get rid of duplicates due to the common tileId
for ( BlockData<?, ?, ?, ?> block : allItems )
block.rtsc().getTileSpecs().forEach( tileSpec -> rtsc.addTileSpecToCollection( tileSpec ) );
return rtsc;
}*/

private static final Logger LOG = LoggerFactory.getLogger(Assembler.class);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.janelia.render.client.newsolver.assembly;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

import org.janelia.alignment.spec.TileSpec;

import mpicbg.models.CoordinateTransform;
import net.imglib2.util.Pair;

public class AssemblyMaps< R extends CoordinateTransform >
{
final public HashMap< String, R > idToFinalModelGlobal = new HashMap<>();
final public HashMap< String, TileSpec > idToTileSpecGlobal = new HashMap<>();
final public HashMap<Integer, HashSet<String> > zToTileIdGlobal = new HashMap<>();
final public HashMap< String, List< Pair< String, Double > > > idToErrorMapGlobal = new HashMap<>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
import mpicbg.models.CoordinateTransform;
import mpicbg.models.Model;

public abstract class BlockFactory< F extends BlockFactory< F > > implements Serializable
public interface BlockFactory< F extends BlockFactory< F > > extends Serializable
{
private static final long serialVersionUID = 5919345114414922447L;

public abstract <M extends Model< M >, R extends CoordinateTransform, P extends BlockDataSolveParameters< M, R, P > > BlockCollection< M, R, P, F > defineBlockCollection(
public <M extends Model< M >, R extends CoordinateTransform, P extends BlockDataSolveParameters< M, R, P > > BlockCollection< M, R, P, F > defineBlockCollection(
final ParameterProvider< M, R, P > blockSolveParameterProvider );

public abstract ArrayList< Function< Double, Double > > createWeightFunctions( final BlockData< ?, ?, ?, F > block );
public ArrayList< Function< Double, Double > > createWeightFunctions( final BlockData< ?, ?, ?, F > block );
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import mpicbg.models.CoordinateTransform;
import mpicbg.models.Model;

public class ZBlockFactory extends BlockFactory< ZBlockFactory > implements Serializable
public class ZBlockFactory implements BlockFactory< ZBlockFactory >, Serializable
{
private static final long serialVersionUID = 4169473785487008894L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*
* @param <M> - the result model type
*/
public abstract class BlockDataSolveParameters< M extends Model< M >, R extends CoordinateTransform, B extends BlockDataSolveParameters< M, R, B > > implements Serializable
public abstract class BlockDataSolveParameters< M extends Model< M >, R extends CoordinateTransform, P extends BlockDataSolveParameters< M, R, P > > implements Serializable
{
private static final long serialVersionUID = -813404780882760053L;

Expand Down Expand Up @@ -47,8 +47,8 @@ public BlockDataSolveParameters(

public M blockSolveModel() { return blockSolveModel; }

public abstract < F extends BlockFactory< F > > Worker< M, R, B, F > createWorker(
final BlockData< M, R, B, F > blockData,
public abstract < F extends BlockFactory< F > > Worker< M, R, P, F > createWorker(
final BlockData< M, R, P, F > blockData,
final int startId,
final int threadsWorker );
}

0 comments on commit c9928ab

Please sign in to comment.