Commit e4ae6536 authored by Chris Coughlin's avatar Chris Coughlin

Initial implementation of probabilities and confidences

parent 9cb5fb58
......@@ -9,6 +9,7 @@ import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.ElasticBandPrior;
import org.apache.mahout.classifier.sgd.PriorFunction;
......@@ -18,13 +19,14 @@ import org.apache.mahout.math.Vector;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.stream.DoubleStream;
/**
* AdaptiveSGDROIFinder - machine learning ROI detector that locates regions of interest by maintaining a pool of SGD
* classifiers.
*/
@Slf4j
public class AdaptiveSGDROIFinder implements MLROIFinder {
public class AdaptiveSGDROIFinder implements MLROIFinder, ROIProbability {
/**
* Number of categories of ROI e.g. 2 for is/isn't a region of interest.
*/
......@@ -49,6 +51,11 @@ public class AdaptiveSGDROIFinder implements MLROIFinder {
* The number of SGD learners to use. Defaults to 20.
*/
private int poolSize = 20;
/**
* A confidence threshold between 0 and 1 to be met for labelling a sample as containing ROI. The higher the
* threshold the fewer ROI will be reported. Defaults to 0 i.e. no confidence thresholding is performed.
*/
private double confThr = 0;
/**
* Creates a new AdaptiveSGDROIFinder with the specified number of categories and the specified
......@@ -109,20 +116,14 @@ public class AdaptiveSGDROIFinder implements MLROIFinder {
*/
@Override
public double predict(double[] data) {
DenseVector d = new DenseVector(data);
Vector scores = model.getBest().getPayload().getLearner().classify(d);
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
double cat0score = 1 - scores.zSum();
double pred;
// TODO: think about implementing a confidence level for Mahout classifiers, only label positive if
// exceeds threshold e.g. cat0score > scores.maxValue() || scores.maxValue() < 0.75
if (cat0score > scores.maxValue()) {
pred = negativeClass();
} else {
// Mahout prediction vectors start with prob. of category 1
pred = scores.maxValueIndex() + 1;
double[] labelProbs = predict_proba(data);
double maxConf = DoubleStream.of(labelProbs).max().getAsDouble();
int idx = ArrayUtils.indexOf(labelProbs, maxConf);
if (idx > 0 && maxConf < confThr) {
// Only report ROI if the model's confidence > threshold
return 0;
}
return pred;
return idx;
}
/**
......@@ -135,6 +136,36 @@ public class AdaptiveSGDROIFinder implements MLROIFinder {
return predict(data.getData());
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
* @param data sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(double[] data) {
DenseVector d = new DenseVector(data);
Vector scores = model.getBest().getPayload().getLearner().classify(d);
double[] probabilities = new double[1 + scores.size()];
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
probabilities[0] = 1 - scores.zSum();
for (int i = 1; i < scores.size() + 1; i++) {
probabilities[i] = scores.get(i - 1);
}
return probabilities;
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
* @param dataset sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(Dataset dataset) {
return predict_proba(dataset.getData());
}
/**
* Predict if a sample contains a region of interest.
* @param data raw data to examine
......@@ -308,4 +339,25 @@ public class AdaptiveSGDROIFinder implements MLROIFinder {
log.error("Error encountered reading model: " + ioe.getMessage());
}
}
/**
* Returns the confidence threshold for labelling ROI.
* @return confidence threshold between 0 and 1
*/
public double getConfidenceThreshold() {
return confThr;
}
/**
* Sets the confidence threshold for labelling ROI samples - the model must be (confThr * 100)% confident a
* sample contains ROI for it to be labelled as such.
* @param confThr new confidence threshold
* @throws IllegalArgumentException if the confidence threshold is not in the range 0-1 inclusive.
*/
public void setConfidenceThreshold(double confThr) throws IllegalArgumentException {
if (confThr < 0 || confThr > 1) {
throw new IllegalArgumentException("Confidence threshold must be between 0 and 1 inclusive.");
}
this.confThr = confThr;
}
}
/*
* Copyright (c) 2016 Emphysic LLC. All rights reserved.
*/
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
/**
* ROIProbability - the implementing class can return probabilities of a given sample being in a given class.
* Created by ccoughlin on 9/21/2016.
*/
public interface ROIProbability {
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
* @param data sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
double[] predict_proba(double[] data);
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
* @param dataset sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
double[] predict_proba(Dataset dataset);
}
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.mahout.classifier.sgd.UniformPrior;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.util.stream.DoubleStream;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* AdaptiveSGDROIFinderTest - tests the AdaptiveSGDROIFinder.
......@@ -78,6 +82,36 @@ public class AdaptiveSGDROIFinderTest {
}
}
@Test
public void confidenceThreshold() throws Exception {
train();
// Verify throwing exception for bad confidence levels
double[] badThresh = {-1, 1.01, -0.1, 75};
for (double bad : badThresh) {
try {
model.setConfidenceThreshold(bad);
Assert.fail("Expected exception to be thrown");
} catch (Exception e) {
assertTrue(e instanceof IllegalArgumentException);
}
}
// Verify reporting ROI iff confidence threshold is met
double[] confThresh = {0, 0.25, 0.5, 0.75, 1};
for (double t : confThresh) {
model.setConfidenceThreshold(t);
assertEquals(t, model.getConfidenceThreshold(), 0.05 * t);
for (double[] aX : X) {
double[] probs = model.predict_proba(aX);
double maxConf = DoubleStream.of(probs).max().getAsDouble();
int idx = ArrayUtils.indexOf(probs, maxConf);
if (idx > 0 && maxConf < t) {
idx = 0;
}
assertEquals(idx == model.positiveClass(), model.isROI(aX));
}
}
}
@Test
public void serialize() throws Exception {
File out = File.createTempFile("tmp", "dat");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment