Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tentative fix for handling time-lapses. #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,16 @@ public Sequence predict(Sequence inputSequence) {
@Override
public Sequence predict(Sequence inputSequence, int time) {
Sequence normalized = normalize(inputSequence, time);
inputSequence.setImage(0, 0, normalized.getImage(0, 0));
return super.predict(normalized, time);
// Because the normalized sequence has only 1 time-point, we have to process its first time-point, the number 0.
return super.predict(normalized, 0);
}

/**
* Normalizes the time-point of the specified sequence, and returns the results in a <b>1 time-point</b> sequence.
* @param input the sequence to normalize.
* @param time the time-point to normalize.
* @return a new sequence, with only one time-point.
*/
private Sequence normalize(Sequence input, int time) {
double[] values = new double[input.getWidth() * input.getHeight()];
for (int x = 0; x < input.getWidth(); x++) {
Expand All @@ -115,8 +121,7 @@ private Sequence normalize(Sequence input, int time) {
valuesOut[i] = (float) ((values[i] - minVal) * factor);
}
Sequence res = new Sequence(resImg);
res.setDataXY(time, 0, 0, valuesOut);
res.setDataXY(0, 0, 0, valuesOut);
return res;
}

}
39 changes: 29 additions & 10 deletions src/main/java/plugins/stardist/StarDist2DPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import plugins.adufour.ezplug.EzGroup;
import plugins.adufour.ezplug.EzLabel;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzStoppable;
import plugins.adufour.ezplug.EzVarBoolean;
import plugins.adufour.ezplug.EzVarDouble;
import plugins.adufour.ezplug.EzVarEnum;
Expand All @@ -31,7 +32,7 @@
import java.lang.reflect.InvocationTargetException;
import java.util.Map;

public class StarDist2DPlugin extends EzPlug
public class StarDist2DPlugin extends EzPlug implements EzStoppable
{

private final String msgTitle = "<html>" +
Expand All @@ -56,6 +57,8 @@ public class StarDist2DPlugin extends EzPlug
private EzVarBoolean normalizeInput;
private EzVarEnum<AvailableModels.Model2D> modelChoice;

private boolean wasCanceled;

@Override
protected void initialize()
{
Expand Down Expand Up @@ -108,19 +111,26 @@ public void clean()
{
// Nothing to do
}

@Override
public void stopExecution()
{
this.wasCanceled = true;
}

@Override
protected void execute()
{

// Load in a separate thread.
ThreadUtil.bgRun(() -> {
predict(input.getValue());
predict();
});
}

private void predict(Sequence sequence) {
Sequence inputSequence = input.getValue();
private void predict() {
wasCanceled = false;
Sequence sequence = input.getValue();
ModelPrediction prediction = new TensorFlowModelPrediction();
// if (roiPosition.equals(Opt.ROI_POSITION_AUTO))
// roiPositionActive = input.numDimensions() > 3 && !input.isRGBMerged() ? Opt.ROI_POSITION_HYPERSTACK : Opt.ROI_POSITION_STACK;
Expand Down Expand Up @@ -163,16 +173,23 @@ private void predict(Sequence sequence) {
// paramsNMS.put("verbose", verbose);

// final LinkedHashSet<AxisType> inputAxes = Utils.orderedAxesSet(input);
final boolean isTimelapse = inputSequence.getSizeT() > 1;
final boolean isTimelapse = sequence.getSizeT() > 1;

// TODO: option to normalize image/timelapse channel by channel or all channels jointly

if (isTimelapse) {
// TODO: option to normalize timelapse frame by frame (currently) or jointly
final long numFrames = inputSequence.getSizeT();
final long numFrames = sequence.getSizeT();
for (int t = 0; t < numFrames; t++) {
Sequence predictionResult = prediction.predict(inputSequence, t);
Candidates polygons = nms.run(predictionResult, t);
if (wasCanceled)
break;
getStatus().setMessage("Processing time-point " + t + " of " + numFrames);
Sequence predictionResult = prediction.predict(sequence, t);
/*
* Because the prediction results will only have 1 time-point, we need to run
* the NMS prediction on its first time-point, the number 0.
*/
Candidates polygons = nms.run(predictionResult, 0);
display(sequence, polygons, t);
getStatus().setCompletion((float)(1+t) / (float)numFrames);
}
Expand All @@ -181,7 +198,7 @@ private void predict(Sequence sequence) {
// - joint normalization of all frames
// - requires more memory to store intermediate results (prob and dist) of all frames
// - allows showing prob and dist easily
Sequence predictionResult = prediction.predict(inputSequence);
Sequence predictionResult = prediction.predict(sequence);
Candidates polygons = nms.run(predictionResult, 0);
display(sequence, polygons, 0);
}
Expand All @@ -195,6 +212,7 @@ private void predict(Sequence sequence) {
e.printStackTrace();
}
}
getStatus().done();
}

private void display(Sequence sequence, Candidates polygons, int time) {
Expand Down Expand Up @@ -232,7 +250,8 @@ public static void main( final String[] args ) {
* Programmatically launch a plugin, as if the user had clicked its
* button.
*/
String imagePath = "samples/blobs.png";
// String imagePath = "samples/blobs.png";
String imagePath = "C:/Users/tinevez/Development/TrackMateWS/TrackMate-StarDist/samples/P31-crop.tif";
final Sequence sequence = Loader.loadSequence( imagePath, 0, true );
display(sequence);
PluginLauncher.start( PluginLoader.getPlugin( StarDist2DPlugin.class.getName() ) );
Expand Down