Skip to content

Commit b2b3d92

Browse files
committed
Parallelize computation
1 parent 62c551c commit b2b3d92

4 files changed

Lines changed: 128 additions & 101 deletions

File tree

scijava-ops-flim/pom.xml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
<artifactId>scijava-ops-flim</artifactId>
1313

14-
<name>SciJava Ops Flim</name>
14+
<name>SciJava Ops FLIM</name>
1515
<description>A Fluoresence lifetime analysis library for SciJava Ops.</description>
1616
<url>https://github.com/scijava/scijava</url>
1717
<inceptionYear>2024</inceptionYear>
@@ -120,6 +120,12 @@
120120
</dependency>
121121

122122
<!-- SciJava dependencies -->
123+
<dependency>
124+
<groupId>org.scijava</groupId>
125+
<artifactId>scijava-concurrent</artifactId>
126+
<version>${project.version}</version>
127+
<scope>compile</scope>
128+
</dependency>
123129
<dependency>
124130
<groupId>org.scijava</groupId>
125131
<artifactId>scijava-function</artifactId>

scijava-ops-flim/src/main/java/org/scijava/ops/flim/AbstractFitRAI.java

Lines changed: 52 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,18 @@ public abstract class AbstractFitRAI<I extends RealType<I>, K extends RealType<K
4646
@OpDependency(name = "filter.convolve")
4747
private Functions.Arity3<RandomAccessibleInterval<I>, RandomAccessibleInterval<K>, I, RandomAccessibleInterval<I>> convolveOp;
4848

49-
private RandomAccessibleInterval<K> kernel;
50-
51-
private RealMask roi;
52-
53-
private FitWorker<I> fitWorker;
54-
55-
private int lifetimeAxis;
56-
57-
private ParamEstimator<I> est;
58-
59-
private List<int[]> roiPos;
60-
61-
private FitParams<I> params;
62-
63-
private FitResults rslts;
64-
65-
public void assertConformity(final FitParams<I> in) {
49+
private void assertConformity( //
50+
final FitParams<I> in, //
51+
final RealMask roi, //
52+
final RandomAccessibleInterval<K> kernel //
53+
) {
6654
// requires a 3D image
6755
if (in.transMap.numDimensions() != 3) {
6856
throw new IllegalArgumentException(
6957
"Fitting requires 3-dimensional input");
7058
}
7159
// lifetime axis must be valid
72-
lifetimeAxis = in.ltAxis;
73-
if (lifetimeAxis < 0 || lifetimeAxis >= in.transMap.numDimensions()) {
60+
if (in.ltAxis < 0 || in.ltAxis >= in.transMap.numDimensions()) {
7461
throw new IllegalArgumentException("Lifetime axis must be 0, 1, or 2");
7562
}
7663

@@ -85,70 +72,69 @@ public void assertConformity(final FitParams<I> in) {
8572
}
8673
}
8774

88-
public void initialize(FitParams<I> in) {
89-
90-
// dimension doesn't really matter
91-
if (roi == null) {
92-
roi = Masks.allRealMask(0);
93-
}
94-
95-
// So that we bin the correct axis
96-
if (kernel != null) {
97-
kernel = Views.permute(kernel, 2, lifetimeAxis);
98-
}
99-
100-
params = in.copy();
101-
initParam();
102-
rslts = new FitResults();
103-
fitWorker = createWorker(params, rslts);
104-
initRslt();
105-
}
106-
10775
/**
10876
* @param params the {@link FitParams} used for fitting
10977
* @param mask a {@link RealMask} defining the areas to fit
11078
* @param kernel kernel used in an optional convolution preprocessing step
111-
* @param handler
79+
* @param handler a {@link FitWorker.FitEventHandler} allowing for callback
80+
* once computation has completed
11281
* @return the results of fitting
11382
*/
11483
@Override
115-
public FitResults apply(FitParams<I> params, @Nullable RealMask mask,
116-
@Nullable RandomAccessibleInterval<K> kernel,
117-
@Nullable FitWorker.FitEventHandler<I> handler)
118-
{
119-
this.kernel = kernel;
120-
this.roi = mask;
121-
assertConformity(params);
122-
initialize(params);
123-
fitWorker.fitBatch(roiPos, handler);
124-
return rslts;
125-
}
126-
127-
/**
128-
* Generates a worker for the actual fit.
129-
*
130-
* @return A {@link FitWorker}.
131-
*/
132-
public abstract FitWorker<I> createWorker(FitParams<I> params,
133-
FitResults results);
84+
public FitResults apply( //
85+
FitParams<I> params, //
86+
@Nullable RealMask mask, //
87+
@Nullable RandomAccessibleInterval<K> kernel, //
88+
@Nullable FitWorker.FitEventHandler<I> handler //
89+
) {
90+
assertConformity(params, mask, kernel);
91+
92+
// Assign reasonable defaults for nullable params
93+
if (mask == null) {
94+
mask = Masks.allRealMask(0);
95+
}
96+
if (kernel != null) {
97+
kernel = Views.permute(kernel, 2, params.ltAxis);
98+
}
13499

135-
private void initParam() {
100+
// -- Initialize -- //
101+
params = params.copy(); // TODO: Is this necessary
136102
// convolve the image if necessary
137103
if (kernel != null) {
138104
params.transMap = convolveOp.apply( //
139105
params.transMap, //
140106
kernel, //
141107
Util.getTypeFromInterval(params.transMap));
142108
}
109+
List<int[]> roiPos = getRoiPositions(mask, params.ltAxis, params.transMap);
143110

144-
roiPos = getRoiPositions(params.transMap);
145-
146-
est = new ParamEstimator<>(params, roiPos);
111+
ParamEstimator<I> est = new ParamEstimator<>(params, roiPos);
147112
est.estimateStartEnd();
148113
est.estimateIThreshold();
114+
FitResults rslts = new FitResults();
115+
FitWorker<I> fitWorker = createWorker(params, rslts);
116+
initRslt(params, fitWorker, est, rslts);
117+
118+
// -- Run -- //
119+
fitWorker.fitBatch(roiPos, handler);
120+
return rslts;
149121
}
150122

151-
private void initRslt() {
123+
/**
124+
* Generates a worker for the actual fit.
125+
*
126+
* @return A {@link FitWorker}.
127+
*/
128+
public abstract FitWorker<I> createWorker(FitParams<I> params,
129+
FitResults results);
130+
131+
private void initRslt( //
132+
FitParams<I> params, //
133+
FitWorker<I> fitWorker, //
134+
ParamEstimator<I> est, //
135+
FitResults rslts //
136+
) {
137+
int lifetimeAxis = params.ltAxis;
152138
// get dimensions and replace time axis with decay parameters
153139
long[] dimFit = new long[params.transMap.numDimensions()];
154140
params.transMap.dimensions(dimFit);
@@ -177,7 +163,9 @@ private void initRslt() {
177163
rslts.intensityMap = est.getIntensityMap();
178164
}
179165

180-
private List<int[]> getRoiPositions(RandomAccessibleInterval<I> trans) {
166+
private List<int[]> getRoiPositions(RealMask roi, int lifetimeAxis,
167+
RandomAccessibleInterval<I> trans)
168+
{
181169
final List<int[]> interested = new ArrayList<>();
182170
final IntervalView<I> xyPlane = Views.hyperSlice(trans, lifetimeAxis, 0);
183171
final Cursor<I> xyCursor = xyPlane.localizingCursor();

scijava-ops-flim/src/main/java/org/scijava/ops/flim/impl/AbstractSingleFitWorker.java

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@
2323
package org.scijava.ops.flim.impl;
2424

2525
import net.imglib2.type.numeric.RealType;
26+
import org.scijava.concurrent.Parallelization;
2627
import org.scijava.ops.flim.FitParams;
2728
import org.scijava.ops.flim.FitResults;
2829
import org.scijava.ops.flim.util.RAHelper;
2930

31+
import java.util.ArrayList;
3032
import java.util.List;
33+
import java.util.function.Consumer;
34+
import java.util.stream.Collectors;
35+
import java.util.stream.IntStream;
3136

3237
public abstract class AbstractSingleFitWorker<I extends RealType<I>> extends
3338
AbstractFitWorker<I>
@@ -117,48 +122,78 @@ protected void onThreadInit() {}
117122
@Override
118123
public void fitBatch(List<int[]> pos, FitEventHandler<I> handler) {
119124
final AbstractSingleFitWorker<I> thisWorker = this;
120-
// TODO: Re-implement parallel behavior
121-
122-
// thread-local reusable read/write buffers
123-
final FitParams<I> lParams;
124-
final FitResults lResults;
125-
final AbstractSingleFitWorker<I> fitWorker;
126-
// don't make copy in single thread mode
127-
if (!params.multithread || pos.size() == 1) {
128-
lParams = params;
129-
lResults = results;
130-
fitWorker = thisWorker;
131-
}
132-
else {
133-
lParams = params.copy();
134-
lResults = results.copy();
135-
// grab your own buffer
136-
lParams.param = lParams.trans = lResults.param = lResults.fitted =
137-
lResults.residuals = null;
138-
fitWorker = duplicate(lParams, lResults);
139-
}
140-
fitWorker.onThreadInit();
141125

142-
final RAHelper<I> helper = new RAHelper<>(params, results);
126+
Consumer<int[]> worker = (data) -> {
127+
int start = data[0];
128+
int size = data[1];
129+
if (!params.multithread) {
130+
// let the first fitting thread do all the work
131+
if (start != 0) {
132+
return;
133+
}
134+
size = pos.size();
135+
}
143136

144-
for (int[] xytPos : pos) {
145-
if (!helper.loadData(fitWorker.transBuffer, fitWorker.paramBuffer, params,
146-
xytPos)) lResults.retCode = FitResults.RET_INTENSITY_BELOW_THRESH;
137+
// thread-local reusable read/write buffers
138+
final FitParams<I> lParams;
139+
final FitResults lResults;
140+
final AbstractSingleFitWorker<I> fitWorker;
141+
// don't make copy in single thread mode
142+
if (!params.multithread || pos.size() == 1) {
143+
lParams = params;
144+
lResults = results;
145+
fitWorker = thisWorker;
146+
}
147147
else {
148-
fitWorker.fitSingle();
149-
150-
// invalidate fit if chisq is insane
151-
final float chisq = lResults.chisq;
152-
if (params.dropBad && lResults.retCode == FitResults.RET_OK &&
153-
(chisq < 0 || chisq > 1E5 || Float.isNaN(chisq))) lResults.retCode =
154-
FitResults.RET_BAD_FIT_CHISQ_OUT_OF_RANGE;
148+
lParams = params.copy();
149+
lResults = results.copy();
150+
// grab your own buffer
151+
lParams.param = lParams.trans = lResults.param = lResults.fitted =
152+
lResults.residuals = null;
153+
fitWorker = duplicate(lParams, lResults);
155154
}
155+
fitWorker.onThreadInit();
156+
157+
final RAHelper<I> helper = new RAHelper<>(params, results);
156158

157-
helper.commitRslts(lParams, lResults, xytPos);
159+
for (int i = start; i < start + size; i++) {
160+
final int[] xytPos = pos.get(i);
158161

159-
if (handler != null) handler.onSingleComplete(xytPos, params, results);
162+
if (!helper.loadData(fitWorker.transBuffer, fitWorker.paramBuffer,
163+
params, xytPos)) lResults.retCode =
164+
FitResults.RET_INTENSITY_BELOW_THRESH;
165+
else {
166+
fitWorker.fitSingle();
167+
168+
// invalidate fit if chisq is insane
169+
final float chisq = lResults.chisq;
170+
if (params.dropBad && lResults.retCode == FitResults.RET_OK &&
171+
(chisq < 0 || chisq > 1E5 || Float.isNaN(chisq))) lResults.retCode =
172+
FitResults.RET_BAD_FIT_CHISQ_OUT_OF_RANGE;
173+
}
174+
175+
helper.commitRslts(lParams, lResults, xytPos);
176+
177+
if (handler != null) handler.onSingleComplete(xytPos, params, results);
178+
}
179+
};
180+
181+
int n = Parallelization.getTaskExecutor().suggestNumberOfTasks();
182+
int s = pos.size() / n;
183+
int r = pos.size() % n;
184+
185+
List<int[]> list = new ArrayList<>(n);
186+
int start = 0;
187+
int size = s + 1;
188+
for (int i = 0; i < n; i++) {
189+
list.add(new int[] { start, size });
190+
start += size;
191+
if (i == r - 1) {
192+
size--;
193+
}
160194
}
161195

196+
Parallelization.getTaskExecutor().forEach(list, worker);
162197
if (handler != null) handler.onComplete(params, results);
163198
}
164199
}

scijava-ops-flim/src/test/java/org/scijava/ops/flim/FitTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ public void testBinning() {
136136
.input(param, roi, kernel) //
137137
.outType(FitResults.class) //
138138
.apply();
139-
// FitResults out = (FitResults) ops.run("flim.fitRLD", param, FlimOps.SQUARE_KERNEL_3, roi);
140139
System.out.println("RLD with binning finished in " + (System
141140
.currentTimeMillis() - ms) + " ms");
142141

@@ -167,7 +166,6 @@ public void testLMAFitImg() {
167166
public void testBayesFitImg() {
168167
// estimation using RLD
169168
param.getChisqMap = true;
170-
// param.multithread = false;
171169
long ms = System.currentTimeMillis();
172170
FitResults out = ops.binary("flim.fitBayes").input(param, roi).outType(
173171
FitResults.class).apply();

0 commit comments

Comments
 (0)