Commit 17ba3ef0 authored by Chris Coughlin's avatar Chris Coughlin

Implemented SMOTE balancing in training data

parent e7e3cfd3
/* /*
* com.emphysic.myriad.core.ml.CrossValidation * com.emphysic.myriad.core.ml.CrossValidation
* *
* Copyright (c) 2017 Emphysic LLC. * Copyright (c) 2018 Emphysic LLC.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -274,16 +274,20 @@ public abstract class CrossValidation { ...@@ -274,16 +274,20 @@ public abstract class CrossValidation {
* *
* @param numRounds number of rounds of testing * @param numRounds number of rounds of testing
* @param models models to test * @param models models to test
* @param upSample if true, upsample training data to try to balance minority / majority classes
* @return the accuracy measurements for each model for each of the rounds of testing * @return the accuracy measurements for each model for each of the rounds of testing
* @throws Exception if an error occurs training the models * @throws Exception if an error occurs training the models
*/ */
public Map<MLROIFinder, List<Double>> evalModels(int numRounds, MLROIFinder[] models) throws Exception { public Map<MLROIFinder, List<Double>> evalModels(int numRounds, MLROIFinder[] models, boolean upSample) throws Exception {
Map<MLROIFinder, List<Double>> modelAccuracies = new HashMap<>(); Map<MLROIFinder, List<Double>> modelAccuracies = new HashMap<>();
for (MLROIFinder model : models) { for (MLROIFinder model : models) {
modelAccuracies.put(model, new ArrayList<>()); modelAccuracies.put(model, new ArrayList<>());
} }
for (int round = 0; round < numRounds; round++) { for (int round = 0; round < numRounds; round++) {
TrainTestSubsets data = getTrainTestSubset(); TrainTestSubsets data = getTrainTestSubset();
if (upSample) {
data.training = balanceUp(data.training);
}
for (MLROIFinder model : models) { for (MLROIFinder model : models) {
model.train(data.training.samples, data.training.labels); model.train(data.training.samples, data.training.labels);
int[] truths = data.testing.labels; int[] truths = data.testing.labels;
...@@ -297,16 +301,31 @@ public abstract class CrossValidation { ...@@ -297,16 +301,31 @@ public abstract class CrossValidation {
return modelAccuracies; return modelAccuracies;
} }
/**
* Evaluates the specified models for accuracy i.e. the proportion of true results
* (both true positives and true negatives) in the overall results.
*
* @param numRounds number of rounds of testing
* @param models models to test
* @return the accuracy measurements for each model for each of the rounds of testing
* @throws Exception if an error occurs training the models
*/
public Map<MLROIFinder, List<Double>> evalModels(int numRounds, MLROIFinder[] models) throws Exception {
return evalModels(numRounds, models, false);
}
/** /**
* Evaluates the specified models for mean accuracy (ratio of true results / total results). * Evaluates the specified models for mean accuracy (ratio of true results / total results).
* *
* @param numRounds number of rounds of testing * @param numRounds number of rounds of testing
* @param models models to test * @param models models to test
* @param upSample if true, upsample training data to try to balance minority / majority classes
* @return average(mean) accuracy for each model * @return average(mean) accuracy for each model
* @throws Exception if an error occurs * @throws Exception if an error occurs
*/ */
public Map<MLROIFinder, Double> eval(int numRounds, MLROIFinder[] models) throws Exception { public Map<MLROIFinder, Double> eval(int numRounds, MLROIFinder[] models, boolean upSample) throws Exception {
Map<MLROIFinder, List<Double>> accuracies = evalModels(numRounds, models); Map<MLROIFinder, List<Double>> accuracies = evalModels(numRounds, models, upSample);
Map<MLROIFinder, Double> meanAccuracies = new HashMap<>(); Map<MLROIFinder, Double> meanAccuracies = new HashMap<>();
for (MLROIFinder model : accuracies.keySet()) { for (MLROIFinder model : accuracies.keySet()) {
double mean = 0; double mean = 0;
...@@ -318,16 +337,29 @@ public abstract class CrossValidation { ...@@ -318,16 +337,29 @@ public abstract class CrossValidation {
return meanAccuracies; return meanAccuracies;
} }
/**
* Evaluates the specified models for mean accuracy (ratio of true results / total results).
*
* @param numRounds number of rounds of testing
* @param models models to test
* @return average(mean) accuracy for each model
* @throws Exception if an error occurs
*/
public Map<MLROIFinder, Double> eval(int numRounds, MLROIFinder[] models) throws Exception {
return eval(numRounds, models, false);
}
/** /**
* Evaluates the specified models and returns the most accurate (ratio of true results / total results) model. * Evaluates the specified models and returns the most accurate (ratio of true results / total results) model.
* *
* @param numRounds number of rounds of testing * @param numRounds number of rounds of testing
* @param models models to test * @param models models to test
* @param upSample if true, upsample training data to try to balance minority / majority classes
* @return the most accurate model found and its mean accuracy score. * @return the most accurate model found and its mean accuracy score.
* @throws Exception if an error occurs * @throws Exception if an error occurs
*/ */
public Map.Entry<MLROIFinder, Double> findBestModel(int numRounds, MLROIFinder[] models) throws Exception { public Map.Entry<MLROIFinder, Double> findBestModel(int numRounds, MLROIFinder[] models, boolean upSample) throws Exception {
Map<MLROIFinder, Double> meanModelAccuracies = eval(numRounds, models); Map<MLROIFinder, Double> meanModelAccuracies = eval(numRounds, models, upSample);
MLROIFinder bestModel = null; MLROIFinder bestModel = null;
Double bestAccuracy = Double.MIN_VALUE; Double bestAccuracy = Double.MIN_VALUE;
for (MLROIFinder model : meanModelAccuracies.keySet()) { for (MLROIFinder model : meanModelAccuracies.keySet()) {
...@@ -339,4 +371,16 @@ public abstract class CrossValidation { ...@@ -339,4 +371,16 @@ public abstract class CrossValidation {
} }
return new AbstractMap.SimpleEntry<>(bestModel, bestAccuracy); return new AbstractMap.SimpleEntry<>(bestModel, bestAccuracy);
} }
/**
* Evaluates the specified models and returns the most accurate (ratio of true results / total results) model.
*
* @param numRounds number of rounds of testing
* @param models models to test
* @return the most accurate model found and its mean accuracy score.
* @throws Exception if an error occurs
*/
public Map.Entry<MLROIFinder, Double> findBestModel(int numRounds, MLROIFinder[] models) throws Exception {
return findBestModel(numRounds, models, false);
}
} }
/* /*
* com.emphysic.myriad.network.DataIngestorPool * com.emphysic.myriad.network.DataIngestorPool
* *
* Copyright (c) 2016 Emphysic LLC. * Copyright (c) 2018 Emphysic LLC.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -129,6 +129,6 @@ public class DataIngestorPool extends LinkedWorkerPool { ...@@ -129,6 +129,6 @@ public class DataIngestorPool extends LinkedWorkerPool {
} catch (Exception e) { } catch (Exception e) {
log.info(e.getMessage()); log.info(e.getMessage());
} }
ingestorDemo.shutdown(); ingestorDemo.terminate();
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment