Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: MFOV layer-prealign #184

Merged
merged 12 commits into from
Apr 27, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package org.janelia.render.client;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParametersDelegate;
import mpicbg.models.RigidModel2D;
import org.janelia.alignment.match.CanvasMatches;
import org.janelia.alignment.spec.ResolvedTileSpecCollection;
import org.janelia.alignment.spec.ResolvedTileSpecCollection.TransformApplicationMethod;
import org.janelia.alignment.spec.ResolvedTileSpecsWithMatchPairs;
import org.janelia.alignment.spec.TransformSpec;
import org.janelia.alignment.spec.stack.StackMetaData;
import org.janelia.render.client.newsolver.solvers.affine.MultiSemPreAligner;
import org.janelia.render.client.parameter.CommandLineParameters;
import org.janelia.render.client.parameter.MatchCollectionParameters;
import org.janelia.render.client.parameter.RenderWebServiceParameters;
import org.janelia.render.client.solver.SolveTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;

/**
* Coarsely aligns a multi-SEM stack by treating each mFOV layer as a single tile and using a rigid model.
*
* @author Michael Innerberger
*/
public class MultiSemPreAlignClient implements Serializable {

public static class Parameters extends CommandLineParameters {

@ParametersDelegate
public RenderWebServiceParameters renderWeb = new RenderWebServiceParameters();

@ParametersDelegate
public MatchCollectionParameters matches = new MatchCollectionParameters();

@Parameter(
names = "--stack",
description = "Name of source stack",
required = true)
public String stack;

@Parameter(
names = "--targetStack",
description = "Name of target stack",
required = true)
public String targetStack;

@Parameter(
names = "--completeTargetStack",
description = "Complete the target stack after pre-alignment",
arity = 0)
public boolean completeTargetStack = false;

@Parameter(
names = "--maxAllowedError",
description = "Max allowed error (default:10.0)"
)
public Double maxAllowedError = 10.0;

@Parameter(
names = "--maxIterations",
description = "Max iterations (default:1000)"
minnerbe marked this conversation as resolved.
Show resolved Hide resolved
)
public Integer maxIterations = 1000;

@Parameter(
names = "--maxPlateauWidth",
description = "Max plateau width (default:250)"
)
public Integer maxPlateauWidth = 250;

@Parameter(
names = "--maxNumMatches",
description = "Max number of matches between mFOV layers (default:1000)"
)
public Integer maxNumMatches = 1000;

@Parameter(
names = "--numThreads",
description = "Number of threads (default:8)"
)
public Integer numThreads = 8;
}

public static void main(final String[] args) {
final ClientRunner clientRunner = new ClientRunner(args) {
@Override
public void runClient(final String[] args) throws Exception {

final Parameters parameters = new Parameters();
parameters.parse(args);

LOG.info("runClient: entry, parameters={}", parameters);

final MultiSemPreAlignClient client = new MultiSemPreAlignClient(parameters);

client.process();
}
};
clientRunner.run();
}


private final Parameters parameters;
private final RenderDataClient dataClient;

public MultiSemPreAlignClient(final Parameters parameters) {
this.parameters = parameters;
this.dataClient = parameters.renderWeb.getDataClient();
}

private void setUpTargetStack() throws IOException {
final StackMetaData stackMetaData = dataClient.getStackMetaData(parameters.stack);
dataClient.setupDerivedStack(stackMetaData, parameters.targetStack);
}

private void completeTargetStack() throws IOException {
dataClient.setStackState(parameters.targetStack, StackMetaData.StackState.COMPLETE);
}

private void process() throws IOException, ExecutionException, InterruptedException {
setUpTargetStack();

final ResolvedTileSpecsWithMatchPairs tileSpecsWithMatchPairs =
dataClient.getResolvedTilesWithMatchPairs(parameters.stack,
null,
parameters.matches.matchCollection,
null,
null,
true);
minnerbe marked this conversation as resolved.
Show resolved Hide resolved

final ResolvedTileSpecCollection rtsc = tileSpecsWithMatchPairs.getResolvedTileSpecs();
final List<CanvasMatches> matches = tileSpecsWithMatchPairs.getMatchPairs();

final MultiSemPreAligner<RigidModel2D> preAligner = new MultiSemPreAligner<>(
new RigidModel2D(),
parameters.maxAllowedError,
parameters.maxIterations,
parameters.maxPlateauWidth,
parameters.numThreads,
parameters.maxNumMatches
);

final Map<String, RigidModel2D> tileIdToModel = preAligner.preAlign(rtsc, matches);

for (final Map.Entry<String, RigidModel2D> entry : tileIdToModel.entrySet()) {
final String tileId = entry.getKey();
final TransformSpec transformSpec = SolveTools.getTransformSpec(entry.getValue());
rtsc.addTransformSpecToTile(tileId, transformSpec, TransformApplicationMethod.PRE_CONCATENATE_LAST);
}

dataClient.saveResolvedTiles(rtsc, parameters.targetStack, null );
minnerbe marked this conversation as resolved.
Show resolved Hide resolved
if (parameters.completeTargetStack) {
completeTargetStack();
}
}

private static final Logger LOG = LoggerFactory.getLogger(MultiSemPreAlignClient.class);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.janelia.alignment.util.ScriptUtil;
import org.janelia.render.client.parameter.CommandLineParameters;
import org.janelia.render.client.parameter.RenderWebServiceParameters;
import org.janelia.render.client.solver.SolveTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -548,7 +549,7 @@ private void saveTargetStackTiles(final Map<String, Tile<TranslationModel2D>> id

if (tile != null) {
resolvedTiles.addTransformSpecToTile(tileId,
getTransformSpec(tile.getModel()),
SolveTools.getTransformSpec(tile.getModel()),
REPLACE_LAST);
}

Expand All @@ -567,13 +568,6 @@ private void saveTargetStackTiles(final Map<String, Tile<TranslationModel2D>> id
LOG.info("saveTargetStackTiles: exit");
}

private LeafTransformSpec getTransformSpec(final TranslationModel2D forModel) {
final double[] m = new double[6];
forModel.toArray(m);
final String data = String.valueOf(m[0]) + ' ' + m[1] + ' ' + m[2] + ' ' + m[3] + ' ' + m[4] + ' ' + m[5];
return new LeafTransformSpec(mpicbg.trakem2.transform.AffineModel2D.class.getName(), data);
}

private TileSpec getTileSpec(final String sectionId,
final String tileId)
throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package org.janelia.render.client.newsolver.solvers.affine;

import mpicbg.models.Affine2D;
import mpicbg.models.ErrorStatistic;
import mpicbg.models.Model;
import mpicbg.models.PointMatch;
import mpicbg.models.Tile;
import mpicbg.models.TileConfiguration;
import mpicbg.models.TileUtil;
import net.imglib2.util.Pair;
import net.imglib2.util.ValuePair;
import org.janelia.alignment.match.CanvasMatchResult;
import org.janelia.alignment.match.CanvasMatches;
import org.janelia.alignment.spec.ResolvedTileSpecCollection;
import org.janelia.alignment.spec.TileSpec;
import org.janelia.render.client.solver.SolveTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.DoubleSummaryStatistics;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
* Coarsely aligns a multi-SEM stack by treating each mFOV layer as a single tile.
*
* @param <M> the model type
*
* @author Michael Innerberger
*/
public class MultiSemPreAligner<M extends Model<M> & Affine2D<M>> implements Serializable {

// Note: this assumes a specific tileId format:
// <cut>_<mFov>_<sFov>_<day>_<time>.<z>.0
private static final Pattern TILE_ID_SEPARATOR = Pattern.compile("[_.]");

private final M model;
private final double maxAllowedError;
private final int maxIterations;
private final int maxPlateauWidth;
private final int numThreads;
private final int maxNumMatches;


public MultiSemPreAligner(
final M model,
final double maxAllowedError,
final int maxIterations,
final int maxPlateauWidth,
final int numThreads,
final int maxNumMatches
) {
this.model = model;
this.maxAllowedError = maxAllowedError;
this.maxIterations = maxIterations;
this.maxPlateauWidth = maxPlateauWidth;
this.numThreads = numThreads;
this.maxNumMatches = maxNumMatches;
}

public Map<String, M> preAlign(
final ResolvedTileSpecCollection rtsc,
final List<CanvasMatches> canvasMatches
) throws IOException, ExecutionException, InterruptedException {

// initialize and connect models for each mFOV layer
final Map<String, String> tileIdToMFovLayer = rtsc.getTileSpecs().stream()
.map(TileSpec::getTileId)
.collect(Collectors.toMap(
tileId -> tileId,
MultiSemPreAligner::extractMFovLayerId
));
final Map<String, Tile<M>> mFovLayerToTile = initializeAndConnectTiles(rtsc, tileIdToMFovLayer, canvasMatches);
minnerbe marked this conversation as resolved.
Show resolved Hide resolved

// optimize the models
final TileConfiguration tileConfig = new TileConfiguration();
tileConfig.addTiles(mFovLayerToTile.values());

final DoubleSummaryStatistics errorsBefore = SolveTools.computeErrors(tileConfig.getTiles());
LOG.info("mFOV layer-wise pre-align: errors after connecting, before optimization: {}", errorsBefore);

TileUtil.optimizeConcurrently(
new ErrorStatistic(maxPlateauWidth + 1),
maxAllowedError,
maxIterations,
maxPlateauWidth,
1.0f,
tileConfig,
tileConfig.getTiles(),
tileConfig.getFixedTiles(),
numThreads);

final DoubleSummaryStatistics errorsAfter = SolveTools.computeErrors(tileConfig.getTiles());
LOG.info("mFOV layer-wise pre-align: errors after optimization: {}", errorsAfter);

// distribute models to individual tiles
final Map<String, M> tileIdToModel = new HashMap<>();
for (final Map.Entry<String, String> tileIdAndMFovLayer : tileIdToMFovLayer.entrySet()) {
final String tileId = tileIdAndMFovLayer.getKey();
final String mFovLayer = tileIdAndMFovLayer.getValue();
final Tile<M> tile = mFovLayerToTile.get(mFovLayer);
tileIdToModel.put(tileId, tile.getModel());
}

return tileIdToModel;
}

private Map<String, Tile<M>> initializeAndConnectTiles(
final ResolvedTileSpecCollection rtsc,
final Map<String, String> tileIdToMFovLayer,
final List<CanvasMatches> canvasMatches
) {
// initialize models for each mFOV layer
final Map<String, Tile<M>> mFovLayerToTile = new HashMap<>();
for (final String mFovLayer : tileIdToMFovLayer.values()) {
mFovLayerToTile.put(mFovLayer, new Tile<>(model.copy()));
}

// accumulate matches for each pair of mFOV layers
final Map<Pair<String, String>, List<PointMatch>> pairsToMatches = new HashMap<>();
for (final CanvasMatches canvasMatch : canvasMatches) {
final String mFovLayerP = tileIdToMFovLayer.get(canvasMatch.getpId());
final String mFovLayerQ = tileIdToMFovLayer.get(canvasMatch.getqId());
final boolean tileIsMissing = mFovLayerToTile.get(mFovLayerP) == null || mFovLayerToTile.get(mFovLayerQ) == null;
if (tileIsMissing || mFovLayerP.equals(mFovLayerQ)) {
// only connect tiles from different layers and different mFOVs
continue;
}

final Pair<String, String> pair = new ValuePair<>(mFovLayerP, mFovLayerQ);
final List<PointMatch> layerMatches = pairsToMatches.computeIfAbsent(pair, k -> new ArrayList<>());

final TileSpec p = rtsc.getTileSpec(canvasMatch.getpId());
final TileSpec q = rtsc.getTileSpec(canvasMatch.getqId());
p.getLastTransform().getNewInstance();

final List<PointMatch> tileMatches = CanvasMatchResult.convertMatchesToPointMatchList(canvasMatch.getMatches());
final List<PointMatch> relativeMatches = SolveTools.createRelativePointMatches(tileMatches, p.getLastTransform().getNewInstance(), q.getLastTransform().getNewInstance());
layerMatches.addAll(relativeMatches);
}

// reduce the number of matches to a maximal number (choose randomly)
for (final Map.Entry<Pair<String, String>, List<PointMatch>> entry : pairsToMatches.entrySet()) {
final String mFoVLayerP = entry.getKey().getA();
final String mFoVLayerQ = entry.getKey().getB();
final Tile<M> p = mFovLayerToTile.get(mFoVLayerP);
final Tile<M> q = mFovLayerToTile.get(mFoVLayerQ);

final List<PointMatch> reducedMatches = getRandomElements(entry.getValue(), maxNumMatches);
p.connect(q, reducedMatches);

LOG.info("initializeAndConnectTiles: connected {} and {} with {} of {} matches",
mFoVLayerP, mFoVLayerQ, reducedMatches.size(), entry.getValue().size());
}

return mFovLayerToTile;
}

private <T> List<T> getRandomElements(final List<T> list, final int maxEntries) {
if (list.size() <= maxEntries) {
return list;
} else {
Collections.shuffle(list);
return list.subList(0, maxEntries);
}
}

private static String extractMFovLayerId(final String tileId) {
final String[] components = TILE_ID_SEPARATOR.split(tileId);
final String mFov = components[1];
final String layer = components[5];
return mFov + "_" + layer;
}

private static final Logger LOG = LoggerFactory.getLogger(MultiSemPreAligner.class);
}
Loading
Loading