Commit 60c53bd5 authored by Chris Coughlin's avatar Chris Coughlin

Cleaned up ROIFinder hierarchy

parent a12cf646
......@@ -29,8 +29,10 @@ import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.*;
import java.util.HashMap;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Map;
/**
......@@ -144,58 +146,49 @@ public class AdaptiveSGDROIFinder extends MLROIConfFinder {
return isROI(dataset.getData());
}
@Override
public long getSerializationVersion() {
return serialVersionUID;
}
@Override
public int getVersion() {
return VERSION;
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
Map<String, Object> map = AdaptiveSGDROIFinder.super.getObjectMap();
map.put("categories", numCategories);
map.put("features", numFeatures);
map.put("priorfunction", priorFunction);
map.put("confidence_threshold", confThr);
map.put("thread_count", threadCount);
map.put("pool_size", poolSize);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream daos = new DataOutputStream(baos);
if (model != null) {
try {
model.write(daos);
daos.flush();
map.put("model_fields", baos.toByteArray());
} catch (IOException ioe) {
log.error("Error encountered writing model: " + ioe.getMessage());
}
}
map.put("model_fields", writableWriteToBytes(model));
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
numCategories = (int)objectMap.getOrDefault("categories", numCategories);
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
confThr = (double) objectMap.getOrDefault("confidence_threshold", confThr);
threadCount = (int) objectMap.getOrDefault("thread_count", threadCount);
poolSize = (int) objectMap.getOrDefault("pool_size", poolSize);
if (objectMap.containsKey("priorfunction")) {
priorFunction = (PriorFunction) objectMap.get("priorfunction");
}
if (objectMap.containsKey("model_fields")) {
byte[] arr = (byte[]) objectMap.get("model_fields");
ByteArrayInputStream bais = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bais);
model = new AdaptiveLogisticRegression();
try {
model.readFields(dis);
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
public void initCurrentVersion(Map<String, Object> objectMap) {
numCategories = (int)objectMap.getOrDefault("categories", numCategories);
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
confThr = (double) objectMap.getOrDefault("confidence_threshold", confThr);
threadCount = (int) objectMap.getOrDefault("thread_count", threadCount);
poolSize = (int) objectMap.getOrDefault("pool_size", poolSize);
if (objectMap.containsKey("priorfunction")) {
priorFunction = (PriorFunction) objectMap.get("priorfunction");
}
if (objectMap.containsKey("model_fields")) {
byte[] arr = (byte[]) objectMap.get("model_fields");
ByteArrayInputStream bais = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bais);
model = new AdaptiveLogisticRegression();
try {
model.readFields(dis);
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
......
......@@ -26,7 +26,6 @@ import com.esotericsoftware.kryo.io.Output;
import lombok.extern.slf4j.Slf4j;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
......@@ -158,10 +157,19 @@ public class ExternalROIFinder implements ROIFinder {
return isROI(dataset.getData());
}
@Override
public long getSerializationVersion() {
return serialVersionUID;
}
@Override
public int getVersion() {
return VERSION;
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
Map<String, Object> map = ROIFinder.super.getObjectMap();
List<String> procArgs = processRunner.getArgs();
map.put("args", procArgs);
Map<String, String> procEnv = processRunner.getEnv();
......@@ -178,29 +186,21 @@ public class ExternalROIFinder implements ROIFinder {
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
File workingFolder = null;
String wdir = (String)objectMap.get("dir");
if (wdir != null) {
workingFolder = new File(wdir);
}
List<String> commandAndArgs = (List<String>) objectMap.get("args");
processRunner = new ExternalProcess(
commandAndArgs.get(0),
commandAndArgs.subList(1, commandAndArgs.size()),
(Map<String, String>) objectMap.get("env"),
workingFolder
);
timeout = (long) objectMap.getOrDefault("timeout", timeout);
timeoutUnits = (TimeUnit) objectMap.getOrDefault("tunits", timeoutUnits);
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
public void initCurrentVersion(Map<String, Object> objectMap) {
File workingFolder = null;
String wdir = (String)objectMap.get("dir");
if (wdir != null) {
workingFolder = new File(wdir);
}
List<String> commandAndArgs = (List<String>) objectMap.get("args");
processRunner = new ExternalProcess(
commandAndArgs.get(0),
commandAndArgs.subList(1, commandAndArgs.size()),
(Map<String, String>) objectMap.get("env"),
workingFolder
);
timeout = (long) objectMap.getOrDefault("timeout", timeout);
timeoutUnits = (TimeUnit) objectMap.getOrDefault("tunits", timeoutUnits);
}
/**
......
/*
* com.emphysic.myriad.core.data.roi.MLROIConfFinder
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -20,9 +20,14 @@ package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import lombok.extern.slf4j.Slf4j;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
/**
* A Region of Interest (ROI) finder based on machine learning that provides both probabilities of its
* classifications and the ability to define a confidence threshold.
......@@ -91,4 +96,24 @@ public abstract class MLROIConfFinder implements MLROIFinder, ROIProbability {
public double predict(Dataset data) {
return predict(data.getData());
}
/**
* Convenience method for writing a Mahout Writable to a byte array
* @param writable Mahout object that implements the Writable interface
* @return byte array
*/
public static byte[] writableWriteToBytes(Writable writable) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream daos = new DataOutputStream(baos);
if (writable != null) {
try {
writable.write(daos);
daos.flush();
return baos.toByteArray();
} catch (IOException ioe) {
log.error("Error encountered writing model: " + ioe.getMessage());
}
}
return null;
}
}
......@@ -27,8 +27,10 @@ import org.apache.mahout.classifier.sgd.PassiveAggressive;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.*;
import java.util.HashMap;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Map;
/**
......@@ -130,53 +132,44 @@ public class PassiveAggressiveROIFinder extends MLROIConfFinder {
return isROI(dataset.getData());
}
@Override
public long getSerializationVersion() {
return serialVersionUID;
}
@Override
public int getVersion() {
return VERSION;
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
Map<String, Object> map = PassiveAggressiveROIFinder.super.getObjectMap();
map.put("categories", numCategories);
map.put("features", numFeatures);
map.put("learning_rate", learningRate);
map.put("confidence_threshold", confThr);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream daos = new DataOutputStream(baos);
if (model != null) {
try {
model.write(daos);
daos.flush();
map.put("model_fields", baos.toByteArray());
} catch (IOException ioe) {
log.error("Error encountered writing model: " + ioe.getMessage());
}
}
map.put("model_fields", writableWriteToBytes(model));
return map;
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
numCategories = (int)objectMap.getOrDefault("categories", numCategories);
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
confThr = (double) objectMap.getOrDefault("confidence_threshold", confThr);
learningRate = (double) objectMap.getOrDefault("learning_rate", learningRate);
if (objectMap.containsKey("model_fields")) {
byte[] arr = (byte[]) objectMap.get("model_fields");
ByteArrayInputStream bais = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bais);
model = new PassiveAggressive(numCategories, numFeatures);
model.learningRate(learningRate);
try {
model.readFields(dis);
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
public void initCurrentVersion(Map<String, Object> objectMap) {
numCategories = (int)objectMap.getOrDefault("categories", numCategories);
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
confThr = (double) objectMap.getOrDefault("confidence_threshold", confThr);
learningRate = (double) objectMap.getOrDefault("learning_rate", learningRate);
if (objectMap.containsKey("model_fields")) {
byte[] arr = (byte[]) objectMap.get("model_fields");
ByteArrayInputStream bais = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bais);
model = new PassiveAggressive(numCategories, numFeatures);
model.learningRate(learningRate);
try {
model.readFields(dis);
} catch (IOException ioe) {
log.error("Error encountered reading model: " + ioe.getMessage());
}
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
......
......@@ -155,10 +155,19 @@ public class RESTROIFinder implements ROIFinder, AutoCloseable {
return isROI(dataset.getData());
}
@Override
public long getSerializationVersion() {
return serialVersionUID;
}
@Override
public int getVersion() {
return VERSION;
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
Map<String, Object> map = ROIFinder.super.getObjectMap();
map.put("url", url);
map.put("header", header);
map.put("timeout", connectionTimeout);
......@@ -172,25 +181,17 @@ public class RESTROIFinder implements ROIFinder, AutoCloseable {
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
url = (String) objectMap.getOrDefault("url", url);
header = (Map<String, String>) objectMap.getOrDefault("header", header);
connectionTimeout = (long) objectMap.getOrDefault("timeout", connectionTimeout);
socketTimeout = (long) objectMap.getOrDefault("socket_timeout", socketTimeout);
if (objectMap.containsKey("save_credentials")) {
saveCredentials = (boolean) objectMap.get("save_credentials");
if (saveCredentials) {
username = (String) objectMap.getOrDefault("username", username);
password = (String) objectMap.getOrDefault("password", password);
}
public void initCurrentVersion(Map<String, Object> objectMap) {
url = (String) objectMap.getOrDefault("url", url);
header = (Map<String, String>) objectMap.getOrDefault("header", header);
connectionTimeout = (long) objectMap.getOrDefault("timeout", connectionTimeout);
socketTimeout = (long) objectMap.getOrDefault("socket_timeout", socketTimeout);
if (objectMap.containsKey("save_credentials")) {
saveCredentials = (boolean) objectMap.get("save_credentials");
if (saveCredentials) {
username = (String) objectMap.getOrDefault("username", username);
password = (String) objectMap.getOrDefault("password", password);
}
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
}
......
......@@ -19,18 +19,12 @@
package com.emphysic.myriad.core.data.roi;
import com.emphysic.myriad.core.data.io.Dataset;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import com.emphysic.myriad.core.data.util.ObjectMap;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Map;
public interface ROIFinder extends KryoSerializable {
public interface ROIFinder extends ObjectMap {
/**
* Examine an array of data and report whether it appears to contain a region of interest (ROI)
* @param data raw data to examine
......@@ -45,60 +39,6 @@ public interface ROIFinder extends KryoSerializable {
*/
boolean isROI(Dataset dataset);
/**
* Creates a map of the important fields for the instance, suitable for serialization.
* @return "object map" of the instance
*/
Map<String, Object> getObjectMap();
/**
* Initializes an instance with an object map
* @param objectMap Map of field:value pairs
*/
void init(Map<String, Object> objectMap);
/**
* Reads a Kryo-serialized object.
* @param kryo Kryo instance
* @param input Input
*/
default void read(Kryo kryo, Input input) {
init((Map<String, Object>) kryo.readClassAndObject(input));
}
/**
* Serializes the current instance as a Kryo object.
* @param kryo Kryo instance
* @param output Output
*/
default void write(Kryo kryo, Output output) {
kryo.writeClassAndObject(output, getObjectMap());
output.flush();
}
/**
* Saves a model to disk
* @param outFile destination file
* @throws IOException if an I/O error occurs (file not found, insufficient permissions, etc.)
*/
default void save(File outFile) throws IOException {
Kryo kryo = new Kryo();
Output output = new Output(new FileOutputStream(outFile));
write(kryo, output);
output.close();
}
/**
* Reads a model from disk
* @param inFile source file
* @throws IOException if an I/O error occurs (file not found, insufficient permissions, etc.)
*/
default void load(File inFile) throws IOException {
Kryo kryo = new Kryo();
Input input = new Input(new FileInputStream(inFile));
read(kryo, input);
}
/**
* Loads a Region of Interest finder from disk.
* @param inFile input file
......
......@@ -27,7 +27,6 @@ import org.sgdtk.*;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
......@@ -126,10 +125,19 @@ public class SGDROIFinder implements MLROIFinder {
return isROI(dataset.getData());
}
@Override
public long getSerializationVersion() {
return serialVersionUID;
}
@Override
public int getVersion() {
return VERSION;
}
@Override
public Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", VERSION);
Map<String, Object> map = MLROIFinder.super.getObjectMap();
map.put("features", numFeatures);
map.put("learner", learner);
map.put("model", getLinearModel());
......@@ -137,18 +145,10 @@ public class SGDROIFinder implements MLROIFinder {
}
@Override
public void init(Map<String, Object> objectMap) {
int version = (Integer) objectMap.get("VERSION");
if (version == VERSION) {
// Current version
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
learner = (Learner) objectMap.getOrDefault("learner", learner);
linearModel = (Model) objectMap.getOrDefault("model", linearModel);
} else if (version < VERSION) {
// Previous version(s)
} else {
log.error("Can't deserialize version " + version + " of model, current version is " + VERSION);
}
public void initCurrentVersion(Map<String, Object> objectMap) {
numFeatures = (int) objectMap.getOrDefault("features", numFeatures);
learner = (Learner) objectMap.getOrDefault("learner", learner);
linearModel = (Model) objectMap.getOrDefault("model", linearModel);
}
/**
......
/*
* com.emphysic.myriad.core.data.util.ObjectMap
*
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.emphysic.myriad.core.data.util;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
* ObjectMap - Kryo serialization of an instance's object graph.
* Created by ccoughlin on 4/4/2017.
*/
public interface ObjectMap extends KryoSerializable {
/**
* Returns the current version of the serialization format. Increments when the serialization format changes, i.e.
* backwards compatibility is broken.
* @return serialization version
*/
long getSerializationVersion();
/**
* Returns the current class version. Class version increments when the class changes but doesn't necessarily
* break compatibility with serialization.
* @return class version
*/
int getVersion();
/**
* Creates a map of the important fields for the instance, suitable for serialization.
* @return "object map" of the instance
*/
default Map<String, Object> getObjectMap() {
HashMap<String, Object> map = new HashMap<>();
map.put("VERSION", getVersion());
map.put("SERVERSION", getSerializationVersion());
return map;
}
/**
* Initializes an instance with an object map.
* @param objectMap Map of field:value pairs
* @throws UnsupportedOperationException if serialization versions don't match
*/
default void init(Map<String, Object> objectMap) {
long serializationVersion = (long) objectMap.get("SERVERSION");
if (serializationVersion != getSerializationVersion()) {
throw new UnsupportedOperationException("Serialization formats do not match, expected "
+ getSerializationVersion() + " read " + serializationVersion);
}
int version = (Integer) objectMap.get("VERSION");
if (version == getVersion()) {
// Current version
initCurrentVersion(objectMap);
} else if (version < getVersion()) {
// Previous version(s)
initPreviousVersion(objectMap, version);
} else {
// Future version, not specified, etc.
initUnknownVersion(objectMap, version);
}
}
/**
* Initializes an instance with a current-version object graph.
* @param objectMap object graph for initialization
*/
void initCurrentVersion(Map<String, Object> objectMap);
/**
* Handles requests to initialize an instance with a previous version of an object graph. Default action is no-op.
* @param objectMap object graph for initialization
* @param version version of object graph
*/
default void initPreviousVersion(Map<String, Object> objectMap, int version) {}
/**
* Handles requests to initialize an instance with an unknown (future, undefined, etc.) version of an object graph.
* Default action is to throw UnsupportedOperationException.
* @param objectMap object graph
* @param version version
* @throws UnsupportedOperationException by default
*/
default void initUnknownVersion(Map<String, Object> objectMap, int version) {
throw new UnsupportedOperationException("Version " + version + " not supported.");
}
/**
* Saves a model to disk
* @param outFile destination file
* @throws IOException if an I/O error occurs (file not found, insufficient permissions, etc.)
*/
default void save(File outFile) throws IOException {
Kryo kryo = new Kryo();
Output output = new Output(new FileOutputStream(outFile));
write(kryo, output);
output.close();
}
/**
* Reads a model from disk
* @param inFile source file
* @throws IOException if an I/O error occurs (file not found, insufficient permissions, etc.)
*/
default void load(File inFile) throws IOException {
Kryo kryo = new Kryo();
Input input = new Input(new FileInputStream(inFile));
read(kryo, input);
}
/**
* Reads a Kryo-serialized object graph.
* @param kryo Kryo instance
* @param input Input
*/
default void read(Kryo kryo, Input input) {
init((Map<String, Object>) kryo.readClassAndObject(input));
}
/**
* Serializes the current instance as a Kryo object.
* @param kryo Kryo instance
* @param output Output
*/
default void write(Kryo kryo, Output output) {
Map<String, Object> objectGraph = getObjectMap();
kryo.writeClassAndObject(output, objectGraph);
output.flush();
}
}
......@@ -28,8 +28,10 @@ import org.apache.mahout.classifier.sgd.GradientMachine;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import java.io.*;
import java.util.HashMap;
import java.io.ByteArrayInputStream;