Commit ca69d3ac authored by Chris Coughlin's avatar Chris Coughlin

Minor tweak to isROI methodology

parent 9fb4d3f7
......@@ -115,8 +115,8 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
double[] probabilities = new double[1 + scores.size()];
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
probabilities[0] = 1 - scores.zSum();
for (int i = 1; i < scores.size() + 1; i++) {
probabilities[i] = scores.get(i - 1);
for (int i = 0; i < scores.size(); i++) {
probabilities[i + 1] = scores.get(i);
}
return probabilities;
}
......@@ -139,7 +139,7 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
*/
@Override
public boolean isROI(double[] data) {
return (int)predict(data) == (int)positiveClass();
return (int)predict(data) != (int)negativeClass();
}
/**
......
......@@ -5,6 +5,7 @@
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import java.util.stream.DoubleStream;
......@@ -14,6 +15,7 @@ import java.util.stream.DoubleStream;
* classifications and the ability to define a confidence threshold.
* Created by ccoughlin on 9/21/2016.
*/
@Slf4j
public abstract class MLROIConfFinder implements MLROIFinder, ROIProbability {
/**
* A confidence threshold between 0 and 1 to be met for labelling a sample as containing ROI. The higher the
......@@ -52,9 +54,16 @@ public abstract class MLROIConfFinder implements MLROIFinder, ROIProbability {
double[] labelProbs = predict_proba(data);
double maxConf = DoubleStream.of(labelProbs).max().getAsDouble();
int idx = ArrayUtils.indexOf(labelProbs, maxConf);
if (idx > 0 && maxConf < confThr) {
// Only report ROI if the model's confidence > threshold
idx = 0;
if (idx > 0) {
StringBuilder sb = new StringBuilder("Predicted ROI class " + idx + " with probability " + maxConf);
if (maxConf < confThr) {
// Only report ROI if the model's confidence > threshold
sb.append(": below " + confThr + ", setting to " + negativeClass());
return negativeClass();
} else {
sb.append(": above " + confThr + ", no changes made");
}
log.info(sb.toString());
}
return idx;
}
......
......@@ -100,7 +100,7 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
*/
@Override
public boolean isROI(double[] data) {
return (int)predict(data) == (int)positiveClass();
return (int)predict(data) != (int)negativeClass();
}
/**
......@@ -222,8 +222,8 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
double[] probabilities = new double[1 + scores.size()];
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
probabilities[0] = 1 - scores.zSum();
for (int i = 1; i < scores.size() + 1; i++) {
probabilities[i] = scores.get(i - 1);
for (int i = 0; i < scores.size(); i++) {
probabilities[i + 1] = scores.get(i);
}
return probabilities;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment