Commit 465e9069 authored by Chris Coughlin's avatar Chris Coughlin

ROIFinder serialization bugfixes

parent fae0f2f8
......@@ -296,14 +296,14 @@ public class AdaptiveSGDROIFinder implements MLROIFinder {
public void read(Kryo kryo, Input input) {
try {
AdaptiveSGDROIFinder read = kryo.readObject(input, AdaptiveSGDROIFinder.class);
setNumCategories(read.getNumCategories());
setNumFeatures(read.getNumFeatures());
setPriorFunction(read.getPriorFunction());
setThreadCount(read.getThreadCount());
setPoolSize(read.getPoolSize());
AdaptiveLogisticRegression model = new AdaptiveLogisticRegression();
model.readFields(new DataInputStream(input));
setModel(model);
setNumFeatures(model.numFeatures());
setNumCategories(model.getNumCategories());
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
......
......@@ -223,12 +223,12 @@ public class PassiveAggressiveROIFinder implements MLROIFinder {
public void read(Kryo kryo, Input input) {
try {
PassiveAggressiveROIFinder paf = kryo.readObject(input, PassiveAggressiveROIFinder.class);
setNumCategories(paf.getNumCategories());
setLearningRate(paf.getLearningRate());
setNumCategories(paf.getNumFeatures());
PassiveAggressive model = new PassiveAggressive(getNumCategories(), getNumFeatures());
model.readFields(new DataInputStream(input));
setModel(model);
setNumFeatures(model.numFeatures());
setNumCategories(model.numCategories());
} catch (IOException ioe) {
log.error("Error writing model: " + ioe.getMessage());
}
......
......@@ -30,7 +30,6 @@ public interface ROIFinder extends KryoSerializable {
*/
boolean isROI(Dataset dataset);
/**
* Saves a model to disk
* @param outFile destination file
......@@ -54,4 +53,20 @@ public interface ROIFinder extends KryoSerializable {
read(kryo, input);
input.close();
}
/**
* Loads a Region of Interest finder from disk.
* @param inFile input file
* @param clz class of ROIFinder
* @return new ROIFinder
* @throws InstantiationException error instantiating the ROIFinder (abstract, interface, etc.)
* @throws IllegalAccessException constructor isn't accessible
* @throws IOException if an I/O error occurs reading the input file
*/
static ROIFinder fromFile(File inFile, Class<? extends ROIFinder> clz)
throws InstantiationException, IllegalAccessException, IOException {
ROIFinder r = clz.newInstance();
r.load(inFile);
return r;
}
}
......@@ -85,9 +85,28 @@ public class AdaptiveSGDROIFinderTest {
model.save(out);
AdaptiveSGDROIFinder read = new AdaptiveSGDROIFinder();
read.load(out);
assertModelsEqual(model, read);
AdaptiveSGDROIFinder read2 = (AdaptiveSGDROIFinder) ROIFinder.fromFile(out, AdaptiveSGDROIFinder.class);
assertModelsEqual(model, read2);
}
/**
* Ensure that two AdaptiveSGD models are equivalent i.e. make the same predictions and have the same basic
* settings.
* @param expected expected model
* @param actual actual model
*/
private void assertModelsEqual(AdaptiveSGDROIFinder expected, AdaptiveSGDROIFinder actual) {
for (int i = 0; i < X.length; i++) {
assertEquals(model.predict(X[i]), read.predict(X[i]), 0.05 * y[i]);
assertEquals(expected.predict(X[i]), actual.predict(X[i]), 0.05 * y[i]);
}
assertEquals(expected.getNumCategories(), actual.getNumCategories());
assertEquals(expected.getNumFeatures(), actual.getNumFeatures());
assertEquals(expected.getPoolSize(), actual.getPoolSize());
assertEquals(expected.getThreadCount(), actual.getThreadCount());
assertEquals(expected.positiveClass(), actual.positiveClass(), expected.positiveClass() * 0.05);
assertEquals(expected.negativeClass(), actual.negativeClass(), expected.negativeClass() * 0.05);
}
}
\ No newline at end of file
......@@ -98,13 +98,25 @@ public class ExternalROIFinderTest {
ExternalROIFinder reread = new ExternalROIFinder();
reread.load(tmpFile);
assertEquals(erf.getTimeout(), reread.getTimeout());
assertEquals(erf.getTimeoutUnits(), reread.getTimeoutUnits());
ExternalProcess proc = reread.getProcessRunner();
assertEquals(cmd, proc.getCmd());
// ExternalProcess prepends command to list of arguments
assertEquals(cargs, proc.getArgs().subList(1, proc.getArgs().size()));
assertEquals(env, proc.getEnv());
assertModelsEqual(erf, reread);
ExternalROIFinder read2 = (ExternalROIFinder) ROIFinder.fromFile(tmpFile, ExternalROIFinder.class);
assertModelsEqual(erf, read2);
}
/**
* Verify that two external process configurations are equivalent.
* @param expected expected configuration
* @param actual actual configuration
*/
private void assertModelsEqual(ExternalROIFinder expected, ExternalROIFinder actual) {
assertEquals(expected.getTimeout(), actual.getTimeout());
assertEquals(expected.getTimeoutUnits(), actual.getTimeoutUnits());
ExternalProcess expectedProc = expected.getProcessRunner();
ExternalProcess actualProc = actual.getProcessRunner();
assertEquals(expectedProc.getCmd(), actualProc.getCmd());
assertEquals(expectedProc.getArgs(), actualProc.getArgs());
assertEquals(expectedProc.getEnv(), actualProc.getEnv());
}
}
\ No newline at end of file
......@@ -71,8 +71,26 @@ public class PassiveAggressiveROIFinderTest {
model.save(out);
PassiveAggressiveROIFinder read = new PassiveAggressiveROIFinder();
read.load(out);
assertModelsEqual(model, read);
PassiveAggressiveROIFinder read2 = (PassiveAggressiveROIFinder) ROIFinder.fromFile(out, PassiveAggressiveROIFinder.class);
assertModelsEqual(model, read2);
}
/**
* Verify that two PassiveAggressive models are "equal" i.e. make the same predictions and have the same basic
* settings.
* @param expected expected model
* @param actual actual model
*/
public void assertModelsEqual(PassiveAggressiveROIFinder expected, PassiveAggressiveROIFinder actual) {
for (int i = 0; i < X.length; i++) {
assertEquals(model.predict(X[i]), read.predict(X[i]), 0.05 * y[i]);
assertEquals(expected.predict(X[i]), actual.predict(X[i]), 0.05 * y[i]);
}
assertEquals(expected.getLearningRate(), actual.getLearningRate(), 0.05 * expected.getLearningRate());
assertEquals(expected.getNumCategories(), actual.getNumCategories());
assertEquals(expected.getNumFeatures(), actual.getNumFeatures());
assertEquals(expected.positiveClass(), actual.positiveClass(), expected.positiveClass() * 0.05);
assertEquals(expected.negativeClass(), actual.negativeClass(), expected.negativeClass() * 0.05);
}
}
\ No newline at end of file
......@@ -11,7 +11,8 @@ import org.junit.Test;
import java.io.File;
import java.util.Random;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class SGDROIFinderTest {
private double[][] X;
......@@ -48,11 +49,24 @@ public class SGDROIFinderTest {
sgdFlawFinder.save(out);
SGDROIFinder read = new SGDROIFinder();
read.load(out);
assertNotNull(read.getLearner());
assertNotNull(read.getLinearModel());
int rndIdx = random.nextInt(y.length);
double[] rndData = X[rndIdx];
assertEquals(y[rndIdx] > 0, read.isROI(rndData));
assertModelsEqual(sgdFlawFinder, read);
SGDROIFinder read2 = (SGDROIFinder)ROIFinder.fromFile(out, SGDROIFinder.class);
assertModelsEqual(sgdFlawFinder, read2);
}
/**
* Asserts that two SGD models are "equal" i.e. make the same predictions and have the same basic parameters.
* @param expected expected model
* @param actual actual model
*/
private void assertModelsEqual(SGDROIFinder expected, SGDROIFinder actual) {
assertEquals(expected.getNumFeatures(), actual.getNumFeatures());
assertEquals(expected.positiveClass(), actual.positiveClass(), expected.positiveClass() * 0.05);
assertEquals(expected.negativeClass(), actual.negativeClass(), expected.negativeClass() * 0.05);
for (double[] aX : X) {
assertEquals(expected.isROI(aX), actual.isROI(aX));
}
}
@Test
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment