Commit 6876b9d8 authored by Chris Coughlin's avatar Chris Coughlin

Refactored ROI probabilities, added support for same in Passive Aggressive ROI finders

parent e4ae6536
......@@ -9,7 +9,6 @@ import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.ElasticBandPrior;
import org.apache.mahout.classifier.sgd.PriorFunction;
......@@ -19,14 +18,13 @@ import org.apache.mahout.math.Vector;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.stream.DoubleStream;
/**
* AdaptiveSGDROIFinder - machine learning ROI detector that locates regions of interest by maintaining a pool of SGD
* classifiers.
*/
@Slf4j
public class AdaptiveSGDROIFinder implements MLROIFinder, ROIProbability {
public class AdaptiveSGDROIFinder extends MLROIConfFinder {
/**
* Number of categories of ROI e.g. 2 for is/isn't a region of interest.
*/
......@@ -51,11 +49,6 @@ public class AdaptiveSGDROIFinder implements MLROIFinder, ROIProbability {
* The number of SGD learners to use. Defaults to 20.
*/
private int poolSize = 20;
/**
* A confidence threshold between 0 and 1 to be met for labelling a sample as containing ROI. The higher the
* threshold the fewer ROI will be reported. Defaults to 0 i.e. no confidence thresholding is performed.
*/
private double confThr = 0;
/**
* Creates a new AdaptiveSGDROIFinder with the specified number of categories and the specified
......@@ -109,33 +102,6 @@ public class AdaptiveSGDROIFinder implements MLROIFinder, ROIProbability {
}
}
/**
* Predicts the label for a sample. Prediction is the category index with the maximum value.
* @param data sample to predict
* @return integer index with the maximum value
*/
@Override
public double predict(double[] data) {
double[] labelProbs = predict_proba(data);
double maxConf = DoubleStream.of(labelProbs).max().getAsDouble();
int idx = ArrayUtils.indexOf(labelProbs, maxConf);
if (idx > 0 && maxConf < confThr) {
// Only report ROI if the model's confidence > threshold
return 0;
}
return idx;
}
/**
* Predicts the label for a sample. Prediction is the category index with the maximum value.
* @param data sample to predict
* @return integer index with the maximum value
*/
@Override
public double predict(Dataset data) {
return predict(data.getData());
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
......@@ -339,25 +305,4 @@ public class AdaptiveSGDROIFinder implements MLROIFinder, ROIProbability {
log.error("Error encountered reading model: " + ioe.getMessage());
}
}
/**
* Returns the confidence threshold for labelling ROI.
* @return confidence threshold between 0 and 1
*/
public double getConfidenceThreshold() {
return confThr;
}
/**
* Sets the confidence threshold for labelling ROI samples - the model must be (confThr * 100)% confident a
* sample contains ROI for it to be labelled as such.
* @param confThr new confidence threshold
* @throws IllegalArgumentException if the confidence threshold is not in the range 0-1 inclusive.
*/
public void setConfidenceThreshold(double confThr) throws IllegalArgumentException {
if (confThr < 0 || confThr > 1) {
throw new IllegalArgumentException("Confidence threshold must be between 0 and 1 inclusive.");
}
this.confThr = confThr;
}
}
/*
* Copyright (c) 2016 Emphysic LLC. All rights reserved.
*/
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import org.apache.commons.lang3.ArrayUtils;
import java.util.stream.DoubleStream;
/**
* A Region of Interest (ROI) finder based on machine learning that provides both probabilities of its
* classifications and the ability to define a confidence threshold.
* Created by ccoughlin on 9/21/2016.
*/
public abstract class MLROIConfFinder implements MLROIFinder, ROIProbability {
/**
* A confidence threshold between 0 and 1 to be met for labelling a sample as containing ROI. The higher the
* threshold the fewer ROI will be reported. Defaults to 0 i.e. no confidence thresholding is performed.
*/
protected double confThr = 0;
/**
* Returns the confidence threshold for labelling ROI.
* @return confidence threshold between 0 and 1
*/
public double getConfidenceThreshold() {
return confThr;
}
/**
* Sets the confidence threshold for labelling ROI samples - the model must be (confThr * 100)% confident a
* sample contains ROI for it to be labelled as such.
* @param confThr new confidence threshold
* @throws IllegalArgumentException if the confidence threshold is not in the range 0-1 inclusive.
*/
public void setConfidenceThreshold(double confThr) throws IllegalArgumentException {
if (confThr < 0 || confThr > 1) {
throw new IllegalArgumentException("Confidence threshold must be between 0 and 1 inclusive.");
}
this.confThr = confThr;
}
/**
* Predicts the label for a sample. Prediction is the category index with the maximum value.
* @param data sample to predict
* @return integer index with the maximum value
*/
@Override
public double predict(double[] data) {
double[] labelProbs = predict_proba(data);
double maxConf = DoubleStream.of(labelProbs).max().getAsDouble();
int idx = ArrayUtils.indexOf(labelProbs, maxConf);
if (idx > 0 && maxConf < confThr) {
// Only report ROI if the model's confidence > threshold
idx = 0;
}
return idx;
}
/**
* Predicts the label for a sample. Prediction is the category index with the maximum value.
* @param data sample to predict
* @return integer index with the maximum value
*/
@Override
public double predict(Dataset data) {
return predict(data.getData());
}
}
......@@ -22,7 +22,7 @@ import java.io.IOException;
* "Online Passive-Aggressive Algorithms" K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR 7 (2006).
*/
@Slf4j
public class PassiveAggressiveROIFinder implements MLROIFinder {
public class PassiveAggressiveROIFinder extends MLROIConfFinder {
/**
* Number of categories to learn (default 2)
*/
......@@ -75,36 +75,6 @@ public class PassiveAggressiveROIFinder implements MLROIFinder {
}
}
/**
* Predicts the label for a sample.
* @param data sample to predict
* @return integer index with the maximum value
*/
@Override
public double predict(double[] data) {
Vector scores = model.classify(new DenseVector(data));
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
double cat0score = 1 - scores.zSum();
double pred;
if (cat0score > scores.maxValue()) {
pred = negativeClass();
} else {
// Mahout prediction vectors start with prob. of category 1
pred = scores.maxValueIndex() + 1;
}
return pred;
}
/**
* Predicts the label for a sample. Prediction is the category index with the maximum value.
* @param data sample to predict
* @return integer index with the maximum value
*/
@Override
public double predict(Dataset data) {
return predict(data.getData());
}
/**
* Returns the numeric value of the positive class of a two-category model.
* @return positive class label
......@@ -233,4 +203,36 @@ public class PassiveAggressiveROIFinder implements MLROIFinder {
log.error("Error reading model: " + ioe.getMessage());
}
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
*
* @param data sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(double[] data) {
DenseVector d = new DenseVector(data);
Vector scores = model.classify(d);
double[] probabilities = new double[1 + scores.size()];
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
probabilities[0] = 1 - scores.zSum();
for (int i = 1; i < scores.size() + 1; i++) {
probabilities[i] = scores.get(i - 1);
}
return probabilities;
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
*
* @param dataset sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(Dataset dataset) {
return predict_proba(dataset.getData());
}
}
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import org.apache.commons.lang3.ArrayUtils;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.util.stream.DoubleStream;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* PassiveAggressiveROIFinderTest - unit tests for PassiveAggressiveROIFinder.
......@@ -64,6 +68,36 @@ public class PassiveAggressiveROIFinderTest {
}
}
@Test
public void confidenceThreshold() throws Exception {
train();
// Verify throwing exception for bad confidence levels
double[] badThresh = {-1, 1.01, -0.1, 75};
for (double bad : badThresh) {
try {
model.setConfidenceThreshold(bad);
Assert.fail("Expected exception to be thrown");
} catch (Exception e) {
assertTrue(e instanceof IllegalArgumentException);
}
}
// Verify reporting ROI iff confidence threshold is met
double[] confThresh = {0, 0.25, 0.5, 0.75, 1};
for (double t : confThresh) {
model.setConfidenceThreshold(t);
assertEquals(t, model.getConfidenceThreshold(), 0.05 * t);
for (double[] aX : X) {
double[] probs = model.predict_proba(aX);
double maxConf = DoubleStream.of(probs).max().getAsDouble();
int idx = ArrayUtils.indexOf(probs, maxConf);
if (idx > 0 && maxConf < t) {
idx = 0;
}
assertEquals(idx == model.positiveClass(), model.isROI(aX));
}
}
}
@Test
public void serialize() throws Exception {
File out = File.createTempFile("tmp_pa", "dat");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment