Commit 17ba3ef0 authored by Chris Coughlin's avatar Chris Coughlin

Implemented SMOTE balancing in training data

parent e7e3cfd3
/*
* com.emphysic.myriad.core.ml.CrossValidation
*
* Copyright (c) 2017 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -274,16 +274,20 @@ public abstract class CrossValidation {
*
* @param numRounds number of rounds of testing
* @param models models to test
* @param upSample if true, upsample training data to try to balance minority / majority classes
* @return the accuracy measurements for each model for each of the rounds of testing
* @throws Exception if an error occurs training the models
*/
public Map<MLROIFinder, List<Double>> evalModels(int numRounds, MLROIFinder[] models) throws Exception {
public Map<MLROIFinder, List<Double>> evalModels(int numRounds, MLROIFinder[] models, boolean upSample) throws Exception {
Map<MLROIFinder, List<Double>> modelAccuracies = new HashMap<>();
for (MLROIFinder model : models) {
modelAccuracies.put(model, new ArrayList<>());
}
for (int round = 0; round < numRounds; round++) {
TrainTestSubsets data = getTrainTestSubset();
if (upSample) {
data.training = balanceUp(data.training);
}
for (MLROIFinder model : models) {
model.train(data.training.samples, data.training.labels);
int[] truths = data.testing.labels;
......@@ -297,16 +301,31 @@ public abstract class CrossValidation {
return modelAccuracies;
}
/**
* Evaluates the specified models for accuracy i.e. the proportion of true results
* (both true positives and true negatives) in the overall results.
*
* @param numRounds number of rounds of testing
* @param models models to test
* @return the accuracy measurements for each model for each of the rounds of testing
* @throws Exception if an error occurs training the models
*/
public Map<MLROIFinder, List<Double>> evalModels(int numRounds, MLROIFinder[] models) throws Exception {
return evalModels(numRounds, models, false);
}
/**
* Evaluates the specified models for mean accuracy (ratio of true results / total results).
*
* @param numRounds number of rounds of testing
* @param models models to test
* @param upSample if true, upsample training data to try to balance minority / majority classes
* @return average(mean) accuracy for each model
* @throws Exception if an error occurs
*/
public Map<MLROIFinder, Double> eval(int numRounds, MLROIFinder[] models) throws Exception {
Map<MLROIFinder, List<Double>> accuracies = evalModels(numRounds, models);
public Map<MLROIFinder, Double> eval(int numRounds, MLROIFinder[] models, boolean upSample) throws Exception {
Map<MLROIFinder, List<Double>> accuracies = evalModels(numRounds, models, upSample);
Map<MLROIFinder, Double> meanAccuracies = new HashMap<>();
for (MLROIFinder model : accuracies.keySet()) {
double mean = 0;
......@@ -318,16 +337,29 @@ public abstract class CrossValidation {
return meanAccuracies;
}
/**
* Evaluates the specified models for mean accuracy (ratio of true results / total results).
*
* @param numRounds number of rounds of testing
* @param models models to test
* @return average(mean) accuracy for each model
* @throws Exception if an error occurs
*/
public Map<MLROIFinder, Double> eval(int numRounds, MLROIFinder[] models) throws Exception {
return eval(numRounds, models, false);
}
/**
* Evaluates the specified models and returns the most accurate (ratio of true results / total results) model.
*
* @param numRounds number of rounds of testing
* @param models models to test
* @param upSample if true, upsample training data to try to balance minority / majority classes
* @return the most accurate model found and its mean accuracy score.
* @throws Exception if an error occurs
*/
public Map.Entry<MLROIFinder, Double> findBestModel(int numRounds, MLROIFinder[] models) throws Exception {
Map<MLROIFinder, Double> meanModelAccuracies = eval(numRounds, models);
public Map.Entry<MLROIFinder, Double> findBestModel(int numRounds, MLROIFinder[] models, boolean upSample) throws Exception {
Map<MLROIFinder, Double> meanModelAccuracies = eval(numRounds, models, upSample);
MLROIFinder bestModel = null;
Double bestAccuracy = Double.MIN_VALUE;
for (MLROIFinder model : meanModelAccuracies.keySet()) {
......@@ -339,4 +371,16 @@ public abstract class CrossValidation {
}
return new AbstractMap.SimpleEntry<>(bestModel, bestAccuracy);
}
/**
* Evaluates the specified models and returns the most accurate (ratio of true results / total results) model.
*
* @param numRounds number of rounds of testing
* @param models models to test
* @return the most accurate model found and its mean accuracy score.
* @throws Exception if an error occurs
*/
public Map.Entry<MLROIFinder, Double> findBestModel(int numRounds, MLROIFinder[] models) throws Exception {
return findBestModel(numRounds, models, false);
}
}
/*
* com.emphysic.myriad.network.DataIngestorPool
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -129,6 +129,6 @@ public class DataIngestorPool extends LinkedWorkerPool {
} catch (Exception e) {
log.info(e.getMessage());
}
ingestorDemo.shutdown();
ingestorDemo.terminate();
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment