Commit 140fe7ce authored by Chris Coughlin's avatar Chris Coughlin

Initial implementation of Gradient Machine ROI finder (experimental)

parent 878cac86
......@@ -116,34 +116,9 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
}
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
* @param data sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(double[] data) {
DenseVector d = new DenseVector(data);
Vector scores = model.getBest().getPayload().getLearner().classify(d);
double[] probabilities = new double[1 + scores.size()];
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
probabilities[0] = 1 - scores.zSum();
for (int i = 0; i < scores.size(); i++) {
probabilities[i + 1] = scores.get(i);
}
return probabilities;
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
* @param dataset sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(Dataset dataset) {
return predict_proba(dataset.getData());
public Vector classify(DenseVector d) {
return model.getBest().getPayload().getLearner().classify(d);
}
/**
......
......@@ -201,7 +201,7 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
model.write(new DataOutputStream(output));
}
} catch (IOException ioe) {
log.error("Error writing model: " + ioe.getMessage());
log.error("Error writing model: {}", ioe.getMessage());
}
}
......@@ -222,35 +222,8 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
}
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
*
* @param data sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(double[] data) {
DenseVector d = new DenseVector(data);
Vector scores = model.classify(d);
double[] probabilities = new double[1 + scores.size()];
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
probabilities[0] = 1 - scores.zSum();
for (int i = 0; i < scores.size(); i++) {
probabilities[i + 1] = scores.get(i);
}
return probabilities;
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
*
* @param dataset sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
@Override
public double[] predict_proba(Dataset dataset) {
return predict_proba(dataset.getData());
public Vector classify(DenseVector d) {
return model.classify(d);
}
}
......@@ -19,6 +19,8 @@
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
/**
* ROIProbability - the implementing class can return probabilities of a given sample being in a given class.
......@@ -29,16 +31,32 @@ public interface ROIProbability {
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
*
* @param data sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
double[] predict_proba(double[] data);
default double[] predict_proba(double[] data) {
DenseVector d = new DenseVector(data);
Vector scores = classify(d);
double[] probabilities = new double[1 + scores.size()];
// Mahout doesn't return the probability of the first class (class 0) - need to calculate
probabilities[0] = 1 - scores.zSum();
for (int i = 0; i < scores.size(); i++) {
probabilities[i + 1] = scores.get(i);
}
return probabilities;
}
/**
* Returns the probability of the sample being in each of the classes recognized by the current model.
* Element 0 is the probability of the "negative" class i.e. no ROI.
*
* @param dataset sample data to classify
* @return array of probabilities between 0 and 1 of the sample being in the specified class.
*/
double[] predict_proba(Dataset dataset);
default double[] predict_proba(Dataset dataset) {
return predict_proba(dataset.getData());
}
Vector classify(DenseVector d);
}
/*
* com.emphysic.myriad.core.experimental.roi.GradMachineROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.emphysic.myriad.core.experimental.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import com.emphysic.myriad.core.data.roi.MLROIConfFinder;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import lombok.extern.slf4j.Slf4j;
import org.apache.mahout.classifier.sgd.GradientMachine;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Random;
/**
* Region Of Interest (ROI) finder based on Mahout's GradientMachine, an online gradient machine with one hidden
* sigmoid layer that minimizes the hinge loss.
* Created by ccoughlin on 11/2/2016.
*/
@Slf4j
public class GradMachineROIFinder extends MLROIConfFinder {
/**
* The ROI detection model
*/
private GradientMachine model;
/**
* Number of features in the feature space
*/
private int numFeatures;
/**
* Number of classes (defaults to 2, i.e. ROI and notROI)
*/
private int numCategories = 2;
/**
* Number of nodes in the hidden layer (default 100)
*/
private int numHidden = 100;
/**
* Learning rate (default 0.1)
*/
private double learningRate = 0.1;
/**
* Sparsity - a positive number between 0-1 that controls the sparsity of the hidden layer (default 0.1)
*/
private double sparsity = 0.1;
/**
* Regularization parameter - controls the size of the weight vector (default 0.1)
*/
private double regularization = 0.1;
/**
* Default no-arg constructor for serialization
*/
public GradMachineROIFinder() {}
/**
* Constructor.
* @param numCats number of categories in the data
* @param numHidden number of nodes in the hidden layer
* @param learningRate learning rate (0-1)
* @param sparsity sparsity of hidden layer (0-1)
* @param regularization reqularization of weight vector
*/
public GradMachineROIFinder(int numCats, int numHidden, double learningRate, double sparsity, double regularization) {
this.numCategories = numCats;
this.numHidden = numHidden;
this.learningRate = learningRate;
this.sparsity = sparsity;
this.regularization = regularization;
}
/**
* Constructor.
* @param numCats number of categories in the data
*/
public GradMachineROIFinder(int numCats) {
this.numCategories = numCats;
}
/**
* Trains the flaw finder on new data. Initializes the model if required and sets the number of features.
*
* @param X N examples with M features per example
* @param y N labels for the N examples in X
* @throws Exception if an error occurs
*/
@Override
public void train(double[][] X, int[] y) throws Exception {
if (model == null) {
numFeatures = X[0].length;
initModel();
}
for (int i = 0; i < y.length; i++) {
model.train(y[i], new DenseVector(X[i]));
}
}
/**
* Initializes the model.
*/
private void initModel() {
log.info("Initializing model");
model = new GradientMachine(numFeatures, numHidden, numCategories);
model.initWeights(new Random());
model.learningRate(learningRate).sparsity(sparsity).regularization(regularization);
}
/**
* The numeric label assigned to positive samples i.e. samples that contain ROI
*
* @return the value assigned to positive samples
*/
@Override
public double positiveClass() {
return 1;
}
/**
* The numeric label assigned to negative samples i.e. samples that do not contain ROI
*
* @return the value assigned to negative samples
*/
@Override
public double negativeClass() {
return 0;
}
/**
* Examine an array of data and report whether it appears to contain a region of interest (ROI)
*
* @param data raw data to examine
* @return true if data appears to contain an ROI, false otherwise
*/
@Override
public boolean isROI(double[] data) {
return false;
}
/**
* Examine a dataset and return whether or not it seems to contain a region of interest (ROI)
*
* @param dataset data to examine
* @return true if data appears to contain an ROI, false otherwise
*/
@Override
public boolean isROI(Dataset dataset) {
return false;
}
@Override
public void write(Kryo kryo, Output output) {
try {
kryo.writeObject(output, new Integer(numFeatures));
kryo.writeObject(output, new Integer(numCategories));
kryo.writeObject(output, new Integer(numHidden));
kryo.writeObject(output, new Double(learningRate));
kryo.writeObject(output, new Double(sparsity));
kryo.writeObject(output, new Double(regularization));
kryo.writeObject(output, new Double(getConfidenceThreshold()));
model.write(new DataOutputStream(output));
} catch (IOException ioe) {
log.error("Unable to serialize: {}", ioe);
}
}
@Override
public void read(Kryo kryo, Input input) {
try {
numFeatures = kryo.readObject(input, Integer.class);
numCategories = kryo.readObject(input, Integer.class);
numHidden = kryo.readObject(input, Integer.class);
learningRate = kryo.readObject(input, Double.class);
sparsity = kryo.readObject(input, Double.class);
regularization = kryo.readObject(input, Double.class);
confThr = kryo.readObject(input, Double.class);
initModel();
model.readFields(new DataInputStream(input));
} catch (IOException ioe) {
log.error("Unable to deserialize: {}", ioe);
}
}
public Vector classify(DenseVector d) {
return model.classify(d);
}
public GradientMachine getModel() {
return model;
}
public void setModel(GradientMachine model) {
this.model = model;
}
public int getNumFeatures() {
return numFeatures;
}
public int getNumCategories() {
return numCategories;
}
public int getNumHidden() {
return numHidden;
}
public double getLearningRate() {
return learningRate;
}
public double getSparsity() {
return sparsity;
}
public double getRegularization() {
return regularization;
}
}
/*
* com.emphysic.myriad.core.experimental.roi.GradMachineROIFinderTest
*
* Copyright (c) 2016 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.emphysic.myriad.core.experimental.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import com.emphysic.myriad.core.data.roi.ROIFinder;
import org.apache.commons.lang.ArrayUtils;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.util.Random;
/**
* Tests the experimental GradMachineROIFinder implementation.
* Created by ccoughlin on 11/8/2016.
*/
public class GradMachineROIFinderTest {
private double[][] X = {{-1.0, -1.0}, {-2.0, -1.0}, {1.0, 1.0}, {2.0, 1.0}};
private int[] y = {0, 0, 1, 1};
private GradMachineROIFinder model;
@Before
public void setUp() throws Exception {
model = new GradMachineROIFinder(2);
model.train(X, y);
}
@Test
public void predict() throws Exception {
int[] acceptablePredictions = {(int) model.positiveClass(), (int) model.negativeClass()};
for (double[] aX : X) {
int prediction = (int) model.predict(aX);
Assert.assertTrue(ArrayUtils.contains(acceptablePredictions, prediction));
}
}
@Test
public void predict1() throws Exception {
int[] acceptablePredictions = {(int) model.positiveClass(), (int) model.negativeClass()};
for (double[] aX : X) {
Dataset d = new Dataset(aX, aX.length, 1);
int prediction = (int) model.predict(d);
Assert.assertTrue(ArrayUtils.contains(acceptablePredictions, prediction));
}
}
@Test
public void serialize() throws Exception {
File out = File.createTempFile("tmp_gm", "dat");
Random random = new Random();
model.setConfidenceThreshold(random.nextDouble());
model.save(out);
GradMachineROIFinder read = new GradMachineROIFinder();
read.load(out);
assertModelsEqual(model, read);
GradMachineROIFinder read2 = (GradMachineROIFinder) ROIFinder.fromFile(out, GradMachineROIFinder.class);
assertModelsEqual(model, read2);
}
/**
* Convenience method to verify two GradMachine ROI Finders are "equal" i.e. equivalent fields.
* @param expected expected model
* @param actual actual model
*/
public void assertModelsEqual(GradMachineROIFinder expected, GradMachineROIFinder actual) {
Assert.assertEquals(expected.getNumCategories(), actual.getNumCategories());
Assert.assertEquals(expected.getLearningRate(), actual.getLearningRate(), expected.getLearningRate() * 0.05);
Assert.assertEquals(expected.getNumFeatures(), actual.getNumFeatures());
Assert.assertEquals(expected.getNumHidden(), actual.getNumHidden());
Assert.assertEquals(expected.getConfidenceThreshold(), actual.getConfidenceThreshold(),
expected.getConfidenceThreshold() * 0.05);
Assert.assertEquals(expected.getRegularization(), actual.getRegularization(), expected.getRegularization() * 0.05);
Assert.assertEquals(expected.getSparsity(), actual.getSparsity(), expected.getSparsity() * 0.05);
for (int i = 0; i < X.length; i++) {
Assert.assertEquals(expected.predict(X[i]), actual.predict(X[i]), 0.05 * y[i]);
}
}
}
\ No newline at end of file
......@@ -45,8 +45,8 @@ public class CrossValidationTest {
@Before
public void setUp() throws Exception {
URL testJpgURL = Thread.currentThread().getContextClassLoader().getResource("data/ml/enchilada.txt");
File samplesFile = new File(testJpgURL.getPath());
URL dataURL = Thread.currentThread().getContextClassLoader().getResource("data/ml/enchilada.txt");
File samplesFile = new File(dataURL.getPath());
BufferedReader br = new BufferedReader(new FileReader(samplesFile));
String line;
int numFeatures = 225;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment