diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/AffineDistributedSolver.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/AffineDistributedSolver.java index b213764e5..1b35272c8 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/AffineDistributedSolver.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/AffineDistributedSolver.java @@ -3,22 +3,38 @@ import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.function.Function; +import org.janelia.alignment.spec.ResolvedTileSpecCollection.TransformApplicationMethod; import org.janelia.render.client.newsolver.blockfactories.ZBlockFactory; import org.janelia.render.client.newsolver.blocksolveparameters.FIBSEMAlignmentParameters; import org.janelia.render.client.newsolver.setup.AffineSolverSetup; import org.janelia.render.client.newsolver.setup.RenderSetup; import org.janelia.render.client.newsolver.solvers.Worker; +import org.janelia.render.client.newsolver.solvers.WorkerTools; import org.janelia.render.client.newsolver.solvers.affine.AffineAlignBlockWorker; +import org.janelia.render.client.solver.DistributedSolveDeSerialize; +import org.janelia.render.client.solver.DistributedSolveWorker; +import org.janelia.render.client.solver.MinimalTileSpec; +import org.janelia.render.client.solver.SolveItemData; +import org.janelia.render.client.solver.SolveTools; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import mpicbg.models.Affine2D; import mpicbg.models.AffineModel2D; import mpicbg.models.Model; import mpicbg.models.NoninvertibleModelException; +import mpicbg.spim.io.IOFunctions; public class AffineDistributedSolver { @@ -86,35 +102,89 @@ public static void main( final String[] args ) throws IOException final BlockCollection< ?, AffineModel2D, ?, ZBlockFactory > blockCollection = solverSetup.setupSolve( cmdLineSetup.blockModel(), cmdLineSetup.stitchingModel() ); - final ExecutorService taskExecutor = Executors.newFixedThreadPool( cmdLineSetup.threadsGlobal ); + // + // multi-threaded solve + // + LOG.info( "Multithreading with thread num=" + cmdLineSetup.threadsGlobal ); - taskExecutor.submit( () -> - blockCollection.allBlocks().parallelStream().forEach( block -> + final ArrayList< Callable< List< BlockData > > > workers = new ArrayList<>(); + + blockCollection.allBlocks().forEach( block -> + { + workers.add( () -> + { + final Worker worker = block.createWorker( + solverSetup.col.maxId() + 1, + cmdLineSetup.threadsWorker ); + + worker.run(); + + return new ArrayList<>( worker.getBlockDataList() ); + } ); + } ); + + final ArrayList< BlockData > allItems = new ArrayList<>(); + + try + { + final ExecutorService taskExecutor = Executors.newFixedThreadPool( cmdLineSetup.threadsGlobal ); + + taskExecutor.invokeAll( workers ).forEach( future -> { try { - Worker worker = block.solveTypeParameters().createWorker( block, solverSetup.col.maxId() + 1, cmdLineSetup.threadsWorker ); - worker.run(); + allItems.addAll( future.get() ); } - catch (IOException | ExecutionException | InterruptedException | NoninvertibleModelException e) + catch (InterruptedException | ExecutionException e) { + LOG.error( "Failed to compute alignments: " + e ); e.printStackTrace(); - System.exit( 1 ); + return; } - })); + } ); - taskExecutor.shutdown(); - /* - for ( final Worker worker : workers ) + taskExecutor.shutdown(); + } + catch (InterruptedException e1) { - ArrayList< ? extends BlockData< ?, AffineModel2D, ?, ? > > blockData = worker.getBlockDataList(); + LOG.error( "Failed to compute alignments: " + e1 ); + e1.printStackTrace(); + return; } // avoid duplicate id assigned while splitting solveitems in the workers // but do keep ids that are smaller or equal to the maxId of the initial solveset - final int maxId = SolveTools.fixIds( this.allItems, solverSetup.col.maxId() ); + final int maxId = WorkerTools.fixIds( allItems, solverSetup.col.maxId() ); + + LOG.info( "computed " + allItems.size() + " blocks, maxId=" + maxId); - System.out.println( workers.get( 0 ).getBlockDataList().size() ); + /* + // + // Saving the result + // + LOG.info( "Saving targetstack=" + cmdLineSetup.targetStack ); + + // + // save the re-aligned part + // + final HashSet< Double > zToSaveSet = new HashSet<>(); + + for ( final TileSpec ts : solve.idToTileSpecGlobal.values() ) + zToSaveSet.add( ts.getZ() ); + + List< Double > zToSave = new ArrayList<>( zToSaveSet ); + Collections.sort( zToSave ); + + LOG.info("Saving from " + zToSave.get( 0 ) + " to " + zToSave.get( zToSave.size() - 1 ) ); + + SolveTools.saveTargetStackTiles( parameters.stack, parameters.targetStack, runParams, solve.idToFinalModelGlobal, null, zToSave, TransformApplicationMethod.REPLACE_LAST ); + + if ( parameters.completeTargetStack ) + { + LOG.info( "Completing targetstack=" + parameters.targetStack ); + + SolveTools.completeStack( parameters.targetStack, runParams ); + } */ } @@ -212,4 +282,6 @@ protected < M extends Model< M > & Affine2D< M >, S extends Model< S > & Affine2 return workers; } + + private static final Logger LOG = LoggerFactory.getLogger(AffineDistributedSolver.class); } diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/Assembler.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/Assembler.java deleted file mode 100644 index 43ca89064..000000000 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/Assembler.java +++ /dev/null @@ -1,8 +0,0 @@ -package org.janelia.render.client.newsolver; - -import java.util.List; - -public abstract class Assembler -{ - public abstract void assemble( List< BlockData > blockDataList ); -} diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/BlockData.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/BlockData.java index 29c83d0ee..7f77547da 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/BlockData.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/BlockData.java @@ -13,6 +13,7 @@ import org.janelia.alignment.spec.TileSpec; import org.janelia.render.client.newsolver.blockfactories.BlockFactory; import org.janelia.render.client.newsolver.blocksolveparameters.BlockDataSolveParameters; +import org.janelia.render.client.newsolver.solvers.Worker; import mpicbg.models.CoordinateTransform; import mpicbg.models.Model; @@ -104,6 +105,12 @@ public ArrayList< Function< Double, Double > > createWeightFunctions() public void assignUpdatedId( final int id ) { this.id = id; } + public Worker< M, R, P, F > createWorker( final int startId, final int threadsWorker ) + { + return solveTypeParameters().createWorker( this , startId, threadsWorker ); + } + + /** * Fetches basic data for all TileSpecs * diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/Affine2DZBlockAssembler.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/Affine2DZBlockAssembler.java new file mode 100644 index 000000000..25c0d7dcf --- /dev/null +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/Affine2DZBlockAssembler.java @@ -0,0 +1,26 @@ +package org.janelia.render.client.newsolver.assembly; + +import java.util.List; + +import org.janelia.render.client.newsolver.BlockData; +import org.janelia.render.client.newsolver.blockfactories.ZBlockFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import mpicbg.models.AffineModel2D; + +public class Affine2DZBlockAssembler extends Assembler< AffineModel2D, ZBlockFactory > +{ + + public Affine2DZBlockAssembler( final List> blocks, final int startId) + { + super( blocks, startId ); + } + + @Override + public void assemble( ) + { + } + + private static final Logger LOG = LoggerFactory.getLogger(Affine2DZBlockAssembler.class); +} diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/Assembler.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/Assembler.java new file mode 100644 index 000000000..d1e40df63 --- /dev/null +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/Assembler.java @@ -0,0 +1,97 @@ +package org.janelia.render.client.newsolver.assembly; + +import java.util.HashSet; +import java.util.List; + +import org.janelia.render.client.newsolver.BlockData; +import org.janelia.render.client.newsolver.blockfactories.BlockFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import mpicbg.models.CoordinateTransform; + +public abstract class Assembler< R extends CoordinateTransform, F extends BlockFactory< F > > +{ + final List > blocks; + + int id; + + public Assembler( final List > blocks, final int startId ) + { + this.id = startId; + this.blocks = blocks; + } + + public AssemblyMaps< R > getAssembly() + { + AssemblyMaps< R > am; + + // the trivial case of a single block, would crash with the code below + if ( ( am = handleTrivialCase() ) != null ) + return am; + else + am = new AssemblyMaps< R >(); + + assemble(); + + return null; + } + + public abstract void assemble(); + + /** + * + * @return - the result of the trivial case if it was a single block + */ + protected AssemblyMaps< R > handleTrivialCase() + { + if ( blocks.size() == 1 ) + { + LOG.info( "Assembler: only a single block, no solve across blocks necessary." ); + + final AssemblyMaps< R > am = new AssemblyMaps< R >(); + + final BlockData< ?, R, ?, F > solveItem = blocks.get( 0 ); + + for ( int z = solveItem.minZ(); z <= solveItem.maxZ(); ++z ) + { + // there is no overlap with any other solveItem (should be beginning or end of the entire stack) + final HashSet< String > tileIds = solveItem.zToTileId().get( z ); + + // if there are none, we continue with the next + if ( tileIds.size() == 0 ) + continue; + + am.zToTileIdGlobal.putIfAbsent( z, new HashSet<>() ); + + for ( final String tileId : tileIds ) + { + am.zToTileIdGlobal.get( z ).add( tileId ); + am.idToTileSpecGlobal.put( tileId, solveItem.rtsc().getTileSpec( tileId ) ); + am.idToFinalModelGlobal.put( tileId, solveItem.idToNewModel().get( tileId ) ); + } + } + + return am; + } + else + { + return null; + } + } + + /* + public static ResolvedTileSpecCollection combineAllTileSpecs( final List< BlockData > allItems ) + { + final ResolvedTileSpecCollection rtsc = new ResolvedTileSpecCollection(); + + // TODO trautmane - improve this + // this should automatically get rid of duplicates due to the common tileId + for ( BlockData block : allItems ) + block.rtsc().getTileSpecs().forEach( tileSpec -> rtsc.addTileSpecToCollection( tileSpec ) ); + + return rtsc; + }*/ + + private static final Logger LOG = LoggerFactory.getLogger(Assembler.class); +} diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/AssemblyMaps.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/AssemblyMaps.java new file mode 100644 index 000000000..fa31fe8c5 --- /dev/null +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/assembly/AssemblyMaps.java @@ -0,0 +1,18 @@ +package org.janelia.render.client.newsolver.assembly; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +import org.janelia.alignment.spec.TileSpec; + +import mpicbg.models.CoordinateTransform; +import net.imglib2.util.Pair; + +public class AssemblyMaps< R extends CoordinateTransform > +{ + final public HashMap< String, R > idToFinalModelGlobal = new HashMap<>(); + final public HashMap< String, TileSpec > idToTileSpecGlobal = new HashMap<>(); + final public HashMap > zToTileIdGlobal = new HashMap<>(); + final public HashMap< String, List< Pair< String, Double > > > idToErrorMapGlobal = new HashMap<>(); +} diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/BlockFactory.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/BlockFactory.java index 5d3d95017..97ba21521 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/BlockFactory.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/BlockFactory.java @@ -11,12 +11,10 @@ import mpicbg.models.CoordinateTransform; import mpicbg.models.Model; -public abstract class BlockFactory< F extends BlockFactory< F > > implements Serializable +public interface BlockFactory< F extends BlockFactory< F > > extends Serializable { - private static final long serialVersionUID = 5919345114414922447L; - - public abstract , R extends CoordinateTransform, P extends BlockDataSolveParameters< M, R, P > > BlockCollection< M, R, P, F > defineBlockCollection( + public , R extends CoordinateTransform, P extends BlockDataSolveParameters< M, R, P > > BlockCollection< M, R, P, F > defineBlockCollection( final ParameterProvider< M, R, P > blockSolveParameterProvider ); - public abstract ArrayList< Function< Double, Double > > createWeightFunctions( final BlockData< ?, ?, ?, F > block ); + public ArrayList< Function< Double, Double > > createWeightFunctions( final BlockData< ?, ?, ?, F > block ); } diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/ZBlockFactory.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/ZBlockFactory.java index f4892c3dc..40baacf13 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/ZBlockFactory.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blockfactories/ZBlockFactory.java @@ -16,7 +16,7 @@ import mpicbg.models.CoordinateTransform; import mpicbg.models.Model; -public class ZBlockFactory extends BlockFactory< ZBlockFactory > implements Serializable +public class ZBlockFactory implements BlockFactory< ZBlockFactory >, Serializable { private static final long serialVersionUID = 4169473785487008894L; diff --git a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blocksolveparameters/BlockDataSolveParameters.java b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blocksolveparameters/BlockDataSolveParameters.java index 0d7289737..cd80d5d61 100644 --- a/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blocksolveparameters/BlockDataSolveParameters.java +++ b/render-ws-java-client/src/main/java/org/janelia/render/client/newsolver/blocksolveparameters/BlockDataSolveParameters.java @@ -15,7 +15,7 @@ * * @param - the result model type */ -public abstract class BlockDataSolveParameters< M extends Model< M >, R extends CoordinateTransform, B extends BlockDataSolveParameters< M, R, B > > implements Serializable +public abstract class BlockDataSolveParameters< M extends Model< M >, R extends CoordinateTransform, P extends BlockDataSolveParameters< M, R, P > > implements Serializable { private static final long serialVersionUID = -813404780882760053L; @@ -47,8 +47,8 @@ public BlockDataSolveParameters( public M blockSolveModel() { return blockSolveModel; } - public abstract < F extends BlockFactory< F > > Worker< M, R, B, F > createWorker( - final BlockData< M, R, B, F > blockData, + public abstract < F extends BlockFactory< F > > Worker< M, R, P, F > createWorker( + final BlockData< M, R, P, F > blockData, final int startId, final int threadsWorker ); }