Commit cc4bff96 authored by Chris Coughlin's avatar Chris Coughlin

Towards a saner model serialization

parent 7e3af6fe
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright (c) 2016 Emphysic LLC.
~ C:/Users/chris/IdeaProjects/myriad/core/pom.xml
~
~ Copyright (c) 2017 Emphysic LLC.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
......@@ -21,7 +23,7 @@
<parent>
<artifactId>myriad</artifactId>
<groupId>com.emphysic</groupId>
<version>1.0-SNAPSHOT</version>
<version>2.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<properties>
......
/*
* com.emphysic.myriad.core.data.roi.AdaptiveSGDROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -29,9 +29,9 @@ import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
/**
* AdaptiveSGDROIFinder - machine learning ROI detector that locates regions of interest by maintaining a pool of SGD
......@@ -39,6 +39,9 @@ import java.io.IOException;
*/
@Slf4j
public class AdaptiveSGDROIFinder extends MLROIConfFinder {
private static final long serialVersionUID = 1L; // try never to change - indicates backwards compatibility is broken
private static final int VERSION = 1; // current implementation version
/**
* Number of categories of ROI e.g. 2 for is/isn't a region of interest.
*/
......@@ -141,6 +144,61 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
return isROI(dataset.getData());
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
map.put("categories", numCategories);
map.put("features", numFeatures);
map.put("priorfunction", priorFunction);
map.put("confidence_threshold", confThr);
map.put("thread_count", threadCount);
map.put("pool_size", poolSize);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream daos = new DataOutputStream(baos);
if (model != null) {
try {
model.write(daos);
daos.flush();
map.put("model_fields", baos.toByteArray());
} catch (IOException ioe) {
log.error("Error encountered writing model: " + ioe.getMessage());
}
}
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
numCategories = (int)objectMap.getOrDefault("categories", numCategories);
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
confThr = (double) objectMap.getOrDefault("confidence_threshold", confThr);
threadCount = (int) objectMap.getOrDefault("thread_count", threadCount);
poolSize = (int) objectMap.getOrDefault("pool_size", poolSize);
if (objectMap.containsKey("priorfunction")) {
priorFunction = (PriorFunction) objectMap.get("priorfunction");
}
if (objectMap.containsKey("model_fields")) {
byte[] arr = (byte[]) objectMap.get("model_fields");
ByteArrayInputStream bais = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bais);
model = new AdaptiveLogisticRegression();
try {
model.readFields(dis);
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
}
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
/**
* Return the number of categories for this model. Typically 2 for does / does not contain a region of interest.
* @return number of categories
......@@ -266,8 +324,7 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
this.poolSize = poolSize;
}
@Override
public void write(Kryo kryo, Output output) {
public void legacyWrite(Kryo kryo, Output output) {
try {
kryo.writeObject(output, this);
kryo.writeObject(output, new Double(getConfidenceThreshold()));
......@@ -279,8 +336,7 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
}
}
@Override
public void read(Kryo kryo, Input input) {
public void legacyRead(Kryo kryo, Input input) {
try {
AdaptiveSGDROIFinder read = kryo.readObject(input, AdaptiveSGDROIFinder.class);
setConfidenceThreshold(kryo.readObject(input, Double.class));
......
/*
* com.emphysic.myriad.core.data.roi.ExternalROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.io.Output;
import lombok.extern.slf4j.Slf4j;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
......@@ -36,6 +37,9 @@ import java.util.concurrent.TimeUnit;
*/
@Slf4j
public class ExternalROIFinder implements ROIFinder {
private static final long serialVersionUID = 1L; // try never to change - indicates backwards compatibility is broken
private static final int VERSION = 1; // current implementation version
/**
* Manages the external application process
*/
......@@ -154,6 +158,51 @@ public class ExternalROIFinder implements ROIFinder {
return isROI(dataset.getData());
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
List<String> procArgs = processRunner.getArgs();
map.put("args", procArgs);
Map<String, String> procEnv = processRunner.getEnv();
map.put("env", procEnv);
File workingFolder = processRunner.getWorkingFolder();
String procFolder = null;
if (workingFolder != null) {
procFolder = workingFolder.getAbsolutePath();
}
map.put("dir", procFolder);
map.put("timeout", timeout);
map.put("tunits", timeoutUnits);
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
File workingFolder = null;
String wdir = (String)objectMap.get("dir");
if (wdir != null) {
workingFolder = new File(wdir);
}
List<String> commandAndArgs = (List<String>) objectMap.get("args");
processRunner = new ExternalProcess(
commandAndArgs.get(0),
commandAndArgs.subList(1, commandAndArgs.size()),
(Map<String, String>) objectMap.get("env"),
workingFolder
);
timeout = (long) objectMap.getOrDefault("timeout", timeout);
timeoutUnits = (TimeUnit) objectMap.getOrDefault("tunits", timeoutUnits);
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
/**
* Convenience method for building an external ROI finder
*
......@@ -201,16 +250,13 @@ public class ExternalROIFinder implements ROIFinder {
return new ExternalROIFinder(rp);
}
@Override
public void write(Kryo kryo, Output output) {
public void legacyWrite(Kryo kryo, Output output) {
kryo.writeObject(output, processRunner);
kryo.writeObject(output, new Long(timeout));
kryo.writeObject(output, timeoutUnits);
output.flush();
}
@Override
public void read(Kryo kryo, Input input) {
public void legacyRead(Kryo kryo, Input input) {
this.setProcessRunner(kryo.readObject(input, ExternalProcess.class));
this.setTimeout(kryo.readObject(input, Long.class), kryo.readObject(input, TimeUnit.class));
}
......
/*
* com.emphysic.myriad.core.data.roi.PassiveAggressiveROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -27,9 +27,9 @@ import org.apache.mahout.classifier.sgd.PassiveAggressive;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
/**
* PassiveAggressiveROIFinder - machine learning ROI detector that uses the Passive-Aggressive algorithm outlined in
......@@ -37,6 +37,9 @@ import java.io.IOException;
*/
@Slf4j
public class PassiveAggressiveROIFinder extends MLROIConfFinder {
private static final long serialVersionUID = 1L; // try never to change - indicates backwards compatibility is broken
private static final int VERSION = 1; // current implementation version
/**
* Number of categories to learn (default 2)
*/
......@@ -127,6 +130,56 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
return isROI(dataset.getData());
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
map.put("categories", numCategories);
map.put("features", numFeatures);
map.put("learning_rate", learningRate);
map.put("confidence_threshold", confThr);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream daos = new DataOutputStream(baos);
if (model != null) {
try {
model.write(daos);
daos.flush();
map.put("model_fields", baos.toByteArray());
} catch (IOException ioe) {
log.error("Error encountered writing model: " + ioe.getMessage());
}
}
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
numCategories = (int)objectMap.getOrDefault("categories", numCategories);
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
confThr = (double) objectMap.getOrDefault("confidence_threshold", confThr);
learningRate = (double) objectMap.getOrDefault("learning_rate", learningRate);
if (objectMap.containsKey("model_fields")) {
byte[] arr = (byte[]) objectMap.get("model_fields");
ByteArrayInputStream bais = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bais);
model = new PassiveAggressive(numCategories, numFeatures);
model.learningRate(learningRate);
try {
model.readFields(dis);
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
}
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
/**
* Returns the number of labels known by the current model.
* @return number of categories
......@@ -191,8 +244,7 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
this.learningRate = learningRate;
}
@Override
public void write(Kryo kryo, Output output) {
public void legacyWrite(Kryo kryo, Output output) {
try {
kryo.writeObject(output, this);
Double ct = new Double(getConfidenceThreshold());
......@@ -205,8 +257,7 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
}
}
@Override
public void read(Kryo kryo, Input input) {
public void legacyRead(Kryo kryo, Input input) {
try {
PassiveAggressiveROIFinder paf = kryo.readObject(input, PassiveAggressiveROIFinder.class);
Double ct = kryo.readObject(input, Double.class);
......
/*
* com.emphysic.myriad.core.data.roi.RESTROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -54,6 +54,9 @@ import java.util.Map;
*/
@Slf4j
public class RESTROIFinder implements ROIFinder, AutoCloseable {
private static final long serialVersionUID = 1L; // try never to change - indicates backwards compatibility is broken
private static final int VERSION = 1; // current implementation version
/**
* Server isROI API endpoint e.g. http://127.0.0.1:8080/api/isroi
*/
......@@ -153,7 +156,45 @@ public class RESTROIFinder implements ROIFinder, AutoCloseable {
}
@Override
public void write(Kryo kryo, Output output) {
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
map.put("url", url);
map.put("header", header);
map.put("timeout", connectionTimeout);
map.put("socket_timeout", socketTimeout);
map.put("save_credentials", saveCredentials);
if (saveCredentials) {
map.put("username", username);
map.put("password", password);
}
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
url = (String) objectMap.getOrDefault("url", url);
header = (Map<String, String>) objectMap.getOrDefault("header", header);
connectionTimeout = (long) objectMap.getOrDefault("timeout", connectionTimeout);
socketTimeout = (long) objectMap.getOrDefault("socket_timeout", socketTimeout);
if (objectMap.containsKey("save_credentials")) {
saveCredentials = (boolean) objectMap.get("save_credentials");
if (saveCredentials) {
username = (String) objectMap.getOrDefault("username", username);
password = (String) objectMap.getOrDefault("password", password);
}
}
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
public void legacyWrite(Kryo kryo, Output output) {
kryo.writeObject(output, getUrl());
kryo.writeObject(output, getConnectionTimeout());
kryo.writeObject(output, getSocketTimeout());
......@@ -164,8 +205,7 @@ public class RESTROIFinder implements ROIFinder, AutoCloseable {
}
}
@Override
public void read(Kryo kryo, Input input) {
public void legacyRead(Kryo kryo, Input input) {
setUrl(kryo.readObject(input, String.class));
setConnectionTimeout(kryo.readObject(input, long.class));
setSocketTimeout(kryo.readObject(input, long.class));
......
/*
* com.emphysic.myriad.core.data.roi.ROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -28,6 +28,7 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Map;
public interface ROIFinder extends KryoSerializable {
/**
......@@ -44,6 +45,37 @@ public interface ROIFinder extends KryoSerializable {
*/
boolean isROI(Dataset dataset);
/**
* Creates a map of the important fields for the instance, suitable for serialization.
* @return "object map" of the instance
*/
Map<String, Object> getObjectMap();
/**
* Initializes an instance with an object map
* @param objectMap Map of field:value pairs
*/
void init(Map<String, Object> objectMap);
/**
* Reads a Kryo-serialized object.
* @param kryo Kryo instance
* @param input Input
*/
default void read(Kryo kryo, Input input) {
init((Map<String, Object>) kryo.readClassAndObject(input));
}
/**
* Serializes the current instance as a Kryo object.
* @param kryo Kryo instance
* @param output Output
*/
default void write(Kryo kryo, Output output) {
kryo.writeClassAndObject(output, getObjectMap());
output.flush();
}
/**
* Saves a model to disk
* @param outFile destination file
......@@ -52,7 +84,6 @@ public interface ROIFinder extends KryoSerializable {
default void save(File outFile) throws IOException {
Kryo kryo = new Kryo();
Output output = new Output(new FileOutputStream(outFile));
kryo.writeObject(output, this.getClass().toString());
write(kryo, output);
output.close();
}
......@@ -65,12 +96,7 @@ public interface ROIFinder extends KryoSerializable {
default void load(File inFile) throws IOException {
Kryo kryo = new Kryo();
Input input = new Input(new FileInputStream(inFile));
String clz = kryo.readObject(input, String.class);
if (!this.getClass().toString().equals(clz)) {
throw new IOException("Unrecognized class " + clz);
}
read(kryo, input);
input.close();
}
/**
......
/*
* com.emphysic.myriad.core.data.roi.SGDROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -27,7 +27,9 @@ import org.sgdtk.*;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* SGDROIFinder - machine learning region of interest detection based on the Stochastic Gradient Descent (SGD)
......@@ -35,6 +37,9 @@ import java.util.List;
*/
@Slf4j
public class SGDROIFinder implements MLROIFinder {
private static final long serialVersionUID = 1L; // try never to change - indicates backwards compatibility is broken
private static final int VERSION = 1; // current implementation version
/**
* Trains the SGD model
*/
......@@ -121,6 +126,31 @@ public class SGDROIFinder implements MLROIFinder {
return isROI(dataset.getData());
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
map.put("features", numFeatures);
map.put("learner", learner);
map.put("model", getLinearModel());
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
learner = (Learner) objectMap.getOrDefault("learner", learner);
linearModel = (Model) objectMap.getOrDefault("model", linearModel);
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
/**
* Returns the prediction of the specified sample.
*
......@@ -213,16 +243,12 @@ public class SGDROIFinder implements MLROIFinder {
return -1;
}
@Override
public void write(Kryo kryo, Output output) {
//TODO: think about Learner serialization - need it or not?
//TODO: why doesn't output.write(numFeatures) work?
public void legacyWrite(Kryo kryo, Output output) {
kryo.writeObject(output, new Integer(numFeatures));
kryo.writeObject(output, linearModel);
}
@Override
public void read(Kryo kryo, Input input) {
public void legacyRead(Kryo kryo, Input input) {
this.numFeatures = kryo.readObject(input, Integer.class);
setLinearModel(kryo.readObject(input, LinearModel.class));
}
......
/*
* com.emphysic.myriad.core.experimental.roi.GradMachineROIFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -28,9 +28,9 @@ import org.apache.mahout.classifier.sgd.GradientMachine;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
......@@ -40,6 +40,9 @@ import java.util.Random;
*/
@Slf4j
public class GradMachineROIFinder extends MLROIConfFinder {
private static final long serialVersionUID = 1L; // try never to change - indicates backwards compatibility is broken
private static final int VERSION = 1; // current implementation version
/**
* The ROI detection model
*/
......@@ -169,7 +172,59 @@ public class GradMachineROIFinder extends MLROIConfFinder {
}
@Override
public void write(Kryo kryo, Output output) {
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
map.put("categories", numCategories);
map.put("features", numFeatures);
map.put("hidden", numHidden);
map.put("learning_rate", learningRate);
map.put("regularization", regularization);
map.put("confidence_threshold", confThr);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream daos = new DataOutputStream(baos);
if (model != null) {
try {
model.write(daos);
daos.flush();
map.put("model_fields", baos.toByteArray());
} catch (IOException ioe) {
log.error("Error encountered writing model: " + ioe.getMessage());
}
}
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
numCategories = (int)objectMap.getOrDefault("categories", numCategories);
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
numHidden = (int) objectMap.getOrDefault("hidden", numHidden);
learningRate = (double) objectMap.getOrDefault("learning_rate", learningRate);
regularization = (double) objectMap.getOrDefault("regularization", regularization);
confThr = (double) objectMap.getOrDefault("confidence_threshold", confThr);
if (objectMap.containsKey("model_fields")) {
byte[] arr = (byte[]) objectMap.get("model_fields");
ByteArrayInputStream bais = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bais);
model = new GradientMachine(numFeatures, numHidden, numCategories);
try {
model.readFields(dis);
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
}
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
public void legacyWrite(Kryo kryo, Output output) {
try {
kryo.writeObject(output, new Integer(numFeatures));
kryo.writeObject(output, new Integer(numCategories));
......@@ -184,8 +239,7 @@ public class GradMachineROIFinder extends MLROIConfFinder {
}
}
@Override
public void read(Kryo kryo, Input input) {
public void legacyRead(Kryo kryo, Input input) {
try {
numFeatures = kryo.readObject(input, Integer.class);
numCategories = kryo.readObject(input, Integer.class);
......
/*
* com.emphysic.myriad.core.data.roi.AdaptiveSGDROIFinderTest
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,6 +19,7 @@
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.UniformPrior