Commit 552a6723 by Chris Coughlin

The triumphal return of the SVMROIFinder, now with serialization!

parent 1d65fd37
/*
* com.emphysic.myriad.core.data.roi.MLROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -25,7 +25,7 @@ import com.emphysic.myriad.core.data.io.Dataset;
*/
public interface MLROIFinder extends ROIFinder {
/**
* Trains the flaw finder on new data
* Trains the Region Of Interest finder on new data
* @param X N examples with M features per example
* @param y N labels for the N examples in X
* @throws Exception if an error occurs
......
/*
* com.emphysic.myriad.core.data.roi.SVMROIFinder
*
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import lombok.extern.slf4j.Slf4j;
import smile.classification.SVM;
import smile.math.Math;
import smile.math.kernel.LinearKernel;
import smile.math.kernel.MercerKernel;
import java.io.*;
import java.util.Arrays;
import java.util.Map;
/**
* SVMROIFinder - a Region Of Interest (ROI) finder based on an online Support Vector Machine (SVM).
* Created by ccoughlin on 5/28/17.
*/
@Slf4j
public class SVMROIFinder implements MLROIFinder {
private static final long serialVersionUID = 1L; // try never to change - indicates backwards compatibility is broken
private static final int VERSION = 1; // current implementation version
private SVM<double[]> model;
/**
* Number of features in the model
*/
private int numFeatures = 0;
/**
* Label of positive (i.e. has ROI) class
*/
private int posClass = 0;
/**
* Label of negative (i.e. does not have ROI) class
*/
private int negClass = 0;
/**
* Soft margin penalty parameters for positive, negative, and all samples respectively. Defined as a number between
* 0 and 1, where the larger the number the harder the model works to avoid mislabelling samples.
*
* Note that these penalties can only be set during initialization, and will default to 0 if not otherwise configured.
*/
private double Cp;
private double Cn;
private double C;
/**
* Constructor.
* @param kernel kernel function
* @param Cp soft margin penalty parameter for positive instances (0-1)
* @param Cn soft margin penalty parameter for negative instances (0-1)
*/
public SVMROIFinder(MercerKernel<double[]> kernel, double Cp, double Cn) {
model = new SVM<>(kernel, Cp, Cn);
this.Cp = Cp;
this.Cn = Cn;
}
/**
* Constructor.
* @param kernel kernel function
* @param C soft margin penalty parameter (0-1). The larger the penalty the more the model will try to avoid
* misclassification.
*/
public SVMROIFinder(MercerKernel<double[]> kernel, double C) {
model = new SVM<>(kernel, C);
this.C = C;
this.Cp = C;
this.Cn = C;
}
/**
* Constructor for a linear kernel SVM.
* @param Cp soft margin penalty parameter for positive instances (0-1)
* @param Cn soft margin penalty parameter for negative instances (0-1)
*/
public SVMROIFinder(double Cp, double Cn) {
this(new LinearKernel(), Cp, Cn);
}
/**
* Constructor for a linear kernel SVM.
* @param C soft margin penalty parameter (0-1). The larger the penalty the more the model will try to avoid
* misclassification.
*/
public SVMROIFinder(double C) {
this(new LinearKernel(), C);
}
/**
* Default constructor. Creates a linear kernel SVM with a soft margin penalty parameter of 0.1.
*/
public SVMROIFinder() {
this(new LinearKernel(), 0.1);
}
/**
* Constructor
* @param model SVM model
*/
public SVMROIFinder(SVM model) {
this.model = model;
}
/**
* Trains the ROI finder on new data. If the positive and negative labels have not been set and two unique values
* are found in y, the positive class is assumed to be the smaller of the two and the negative as the larger.
* @param X N examples with M features per example
* @param y N labels for the N examples in X
*/
@Override
public void train(double[][] X, int[] y) {
int features = X[0].length;
if (numFeatures == 0) {
numFeatures = features;
} else if (numFeatures != features) {
log.error("Wrong number of features in training set got ", features, " expected ", numFeatures);
return;
}
if (posClass == negClass) {
// Try to grab positive and negative labels
int[] categories = Math.unique(y);
if (categories.length == 2) {
Arrays.sort(categories);
posClass = categories[0];
negClass = categories[1];
log.info("Set positive and negative class labels to ", posClass, " and ", negClass, " respectively");
}
}
model.learn(X, y);
model.finish();
}
@Override
public double predict(double[] data) {
return model.predict(data);
}
@Override
public double predict(Dataset data) {
return predict(data.getData());
}
/**
* Sets the positive (i.e. has ROI) label. If not set, will attempt to discover during training by finding the
* unique labels in the training set and taking the smallest value.
* @param posClass label of positive class
*/
public void positiveClass(int posClass) {
this.posClass = posClass;
}
@Override
public double positiveClass() {
return posClass;
}
/**
* Sets the negative (i.e. does not have ROI) label. If not set, will attempt to discover during training by finding the
* unique labels in the training set and taking the second smallest value.
* @param negClass label of negative class.
*/
public void negativeClass(int negClass) {
this.negClass = negClass;
}
@Override
public double negativeClass() {
return negClass;
}
@Override
public long getSerializationVersion() {
return serialVersionUID;
}
@Override
public int getVersion() {
return VERSION;
}
@Override
public Map<String, Object> getObjectMap() {
Map<String, Object> map = MLROIFinder.super.getObjectMap();
map.put("features", numFeatures);
map.put("positive_class", posClass);
map.put("negative_class", negClass);
map.put("cp", Cp);
map.put("cn", Cn);
map.put("c", C);
if (model != null) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
ObjectOutputStream out = new ObjectOutputStream(baos);
out.writeObject(model);
out.flush();
byte[] forestBytes = baos.toByteArray();
map.put("model", forestBytes);
} catch (IOException ex) {
log.error("Error encountered serializing SVM model: ", ex);
} finally {
try {
baos.close();
} catch (IOException ex) {
log.info("Caught an IOException attempting to close ByteArrayOutputStream, ignoring");
}
}
}
return map;
}
@Override
public void initCurrentVersion(Map<String, Object> objectMap) {
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
posClass = (int) objectMap.getOrDefault("positive_class", posClass);
negClass = (int) objectMap.getOrDefault("negative_class", negClass);
C = (double) objectMap.getOrDefault("c", C);
Cp = (double) objectMap.getOrDefault("cp", Cp);
Cn = (double) objectMap.getOrDefault("cn", Cn);
if (objectMap.containsKey("model")) {
ObjectInputStream in = null;
try {
ByteArrayInputStream bis = new ByteArrayInputStream((byte[]) objectMap.get("model"));
in = new ObjectInputStream(bis);
model = (SVM<double[]>) in.readObject();
} catch (IOException | ClassNotFoundException e) {
log.error("Encountered an error deserializing the SVM model: ", e);
} finally {
try {
if (in != null) {
in.close();
}
} catch (IOException ex) {
log.info("Caught an IOException attempting to close ObjectInputStream, ignoring");
}
}
}
}
@Override
public boolean isROI(double[] data) {
return (int)predict(data) != (int)negativeClass();
}
@Override
public boolean isROI(Dataset dataset) {
return isROI(dataset.getData());
}
/**
* Returns the number of features
* @return number of features in the model's feature space
*/
public int getNumFeatures() {
return numFeatures;
}
/**
* Retrieves the soft penalty parameter for positive samples.
* @return soft penalty parameter, or 0 if not set.
*/
public double getCp() {
return Cp;
}
/**
* Retrieves the soft penalty parameter for negative samples.
* @return soft penalty parameter, or 0 if not set.
*/
public double getCn() {
return Cn;
}
/**
* Retrieves the soft penalty parameter.
* @return soft penalty parameter, or 0 if not set.
*/
public double getC() {
return C;
}
}
/*
* com.emphysic.myriad.core.data.roi.SVMROIFinderTest
*
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import static org.junit.Assert.assertEquals;
/**
* SVMROIFinderTest - tests the SVMROIFinder class.
* Created by ccoughlin on 5/13/17.
*/
public class SVMROIFinderTest {
private double[][] X = {{-1.0, -1.0}, {-2.0, -1.0}, {1.0, 1.0}, {2.0, 1.0}};
private int[] y = {0, 0, 1, 1};
private SVMROIFinder model;
@Before
public void setUp() throws Exception {
model = new SVMROIFinder(0.123);
model.train(X, y);
}
@Test
public void predict() throws Exception {
for (int i = 0; i < X.length; i++) {
double prediction = model.predict(X[i]);
double expected = y[i];
assertEquals(expected, prediction, 0.05 * expected);
}
}
@Test
public void predict1() throws Exception {
for (int i=0; i<X.length; i++) {
Dataset d = new Dataset(X[i], X[i].length, 1);
double prediction = model.predict(d);
assertEquals(y[i], prediction, 0.05 * y[i]);
}
}
@Test
public void isROI() throws Exception {
for (int i = 0; i < X.length; i++) {
boolean prediction = model.isROI(X[i]);
assertEquals(y[i] == (int)model.positiveClass(), prediction);
}
}
@Test
public void isROI1() throws Exception {
for (int i = 0; i < X.length; i++) {
Dataset d = new Dataset(X[i], X[i].length, 1);
boolean prediction = model.isROI(d);
assertEquals(y[i] == (int)model.positiveClass(), prediction);
}
}
@Test
public void serialize() throws Exception {
File out = File.createTempFile("tmp_svm", "dat");
model.save(out);
SVMROIFinder read = new SVMROIFinder();
read.load(out);
assertModelsEqual(model, read);
SVMROIFinder read2 = (SVMROIFinder) ROIFinder.fromFile(out, SVMROIFinder.class);
assertModelsEqual(model, read2);
}
/**
* Verify that two SVM models are "equal" i.e. make the same predictions and have the same basic
* settings.
* @param expected expected model
* @param actual actual model
*/
public void assertModelsEqual(SVMROIFinder expected, SVMROIFinder actual) {
assertEquals(expected.getNumFeatures(), actual.getNumFeatures());
assertEquals(expected.positiveClass(), actual.positiveClass(), expected.positiveClass() * 0.05);
assertEquals(expected.negativeClass(), actual.negativeClass(), expected.negativeClass() * 0.05);
assertEquals(expected.getC(), actual.getC(), expected.getC() * 0.05);
assertEquals(expected.getCp(), actual.getCp(), expected.getCp() * 0.05);
assertEquals(expected.getCn(), actual.getCn(), expected.getCn() * 0.05);
for (int i = 0; i < X.length; i++) {
assertEquals(expected.predict(X[i]), actual.predict(X[i]), 0.05 * y[i]);
}
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment