Commit 9fb4d3f7 authored by Chris Coughlin's avatar Chris Coughlin

Workaround for storing confidence thresholds

parent 6876b9d8
......@@ -281,6 +281,7 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
public void write(Kryo kryo, Output output) {
try {
kryo.writeObject(output, this);
kryo.writeObject(output, new Double(getConfidenceThreshold()));
if (model != null) {
model.write(new DataOutputStream(output));
}
......@@ -293,6 +294,7 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
public void read(Kryo kryo, Input input) {
try {
AdaptiveSGDROIFinder read = kryo.readObject(input, AdaptiveSGDROIFinder.class);
setConfidenceThreshold(kryo.readObject(input, Double.class));
setPriorFunction(read.getPriorFunction());
setThreadCount(read.getThreadCount());
setPoolSize(read.getPoolSize());
......
......@@ -181,6 +181,8 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
public void write(Kryo kryo, Output output) {
try {
kryo.writeObject(output, this);
Double ct = new Double(getConfidenceThreshold());
kryo.writeObject(output, ct);
if (model != null) {
model.write(new DataOutputStream(output));
}
......@@ -193,6 +195,8 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
public void read(Kryo kryo, Input input) {
try {
PassiveAggressiveROIFinder paf = kryo.readObject(input, PassiveAggressiveROIFinder.class);
Double ct = kryo.readObject(input, Double.class);
setConfidenceThreshold(ct);
setLearningRate(paf.getLearningRate());
PassiveAggressive model = new PassiveAggressive(getNumCategories(), getNumFeatures());
model.readFields(new DataInputStream(input));
......
......@@ -8,6 +8,7 @@ import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.util.Random;
import java.util.stream.DoubleStream;
import static org.junit.Assert.assertEquals;
......@@ -116,6 +117,8 @@ public class AdaptiveSGDROIFinderTest {
public void serialize() throws Exception {
File out = File.createTempFile("tmp", "dat");
train();
Random random = new Random();
model.setConfidenceThreshold(random.nextDouble());
model.save(out);
AdaptiveSGDROIFinder read = new AdaptiveSGDROIFinder();
read.load(out);
......@@ -135,6 +138,8 @@ public class AdaptiveSGDROIFinderTest {
for (int i = 0; i < X.length; i++) {
assertEquals(expected.predict(X[i]), actual.predict(X[i]), 0.05 * y[i]);
}
assertEquals(expected.getConfidenceThreshold(), actual.getConfidenceThreshold(),
0.05 * expected.getConfidenceThreshold());
assertEquals(expected.getNumCategories(), actual.getNumCategories());
assertEquals(expected.getNumFeatures(), actual.getNumFeatures());
assertEquals(expected.getPoolSize(), actual.getPoolSize());
......
......@@ -7,6 +7,7 @@ import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.util.Random;
import java.util.stream.DoubleStream;
import static org.junit.Assert.assertEquals;
......@@ -102,6 +103,8 @@ public class PassiveAggressiveROIFinderTest {
public void serialize() throws Exception {
File out = File.createTempFile("tmp_pa", "dat");
train();
Random random = new Random();
model.setConfidenceThreshold(random.nextDouble());
model.save(out);
PassiveAggressiveROIFinder read = new PassiveAggressiveROIFinder();
read.load(out);
......@@ -121,6 +124,8 @@ public class PassiveAggressiveROIFinderTest {
for (int i = 0; i < X.length; i++) {
assertEquals(expected.predict(X[i]), actual.predict(X[i]), 0.05 * y[i]);
}
assertEquals(expected.getConfidenceThreshold(), actual.getConfidenceThreshold(),
0.05 * expected.getConfidenceThreshold());
assertEquals(expected.getLearningRate(), actual.getLearningRate(), 0.05 * expected.getLearningRate());
assertEquals(expected.getNumCategories(), actual.getNumCategories());
assertEquals(expected.getNumFeatures(), actual.getNumFeatures());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment