Commit d1445ff7 authored by Chris Coughlin's avatar Chris Coughlin

Initial support for up/down sampling to address training data class imbalances

parent 9e92a038
......@@ -26,7 +26,7 @@ The [Myriad library](https://gitlab.com/ccoughlin/datareader) must be installed
[Desktop](https://gitlab.com/ccoughlin/MyriadDesktop) isn't required for Trainer. Models created with Trainer can be used in Desktop or in other Myriad-based applications.
## Licensing
Copyright 2017 Emphysic, LLC.
Copyright 2018 Emphysic, LLC.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
......@@ -95,17 +95,17 @@
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>1.3.0</version>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-data</artifactId>
<version>1.3.0</version>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-plot</artifactId>
<version>1.3.0</version>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
......
/*
* com.emphysic.myriadtrainer.controllers.MainAppController
*
* Copyright (c) 2017 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -105,6 +105,14 @@ public class MainAppController {
SOBEL
};
/**
* Class label balancing options
* NONE - leave as-is (no balancing)
* DOWN - rebalance by downsampling majority label
* UP - rebalance by upsampling (SMOTE) minority label
*/
public enum BalanceClasses{NONE, DOWN, UP};
/**
* Description of available preprocessing operations
*/
......@@ -619,7 +627,6 @@ public class MainAppController {
*
* @param positiveSamples positive samples
* @param negativeSamples negative samples
* @param balance if true, try to make ratio of pos:neg 1:1, otherwise randomly sample
* @param normalize if true, normalize data between 0 and 1 prior to running any additional preprocessing operation
* @return compiled data, or null if no positive/negative samples, couldn't
* determine model, etc.
......@@ -627,7 +634,6 @@ public class MainAppController {
public CrossValidation.Data compileData(
List<Dataset> positiveSamples,
List<Dataset> negativeSamples,
boolean balance,
boolean normalize) {
Random rnd = new Random();
List<Dataset> samples = new ArrayList<>();
......@@ -641,16 +647,19 @@ public class MainAppController {
return null;
}
NormalizeSignalOperation norm = new NormalizeSignalOperation();
if (balance) {
// try to force a 1:1 ratio so we don't overfit on one class or the other
int sampleSize = Math.min(positiveSamples.size(), negativeSamples.size());
while (positiveSamples.size() != sampleSize) {
positiveSamples.remove(rnd.nextInt(positiveSamples.size()));
}
while (negativeSamples.size() != sampleSize) {
negativeSamples.remove(rnd.nextInt(negativeSamples.size()));
}
}
// switch (balance) {
// case DOWN:
// int sampleSize = Math.min(positiveSamples.size(), negativeSamples.size());
// while (positiveSamples.size() != sampleSize) {
// positiveSamples.remove(rnd.nextInt(positiveSamples.size()));
// }
// while (negativeSamples.size() != sampleSize) {
// negativeSamples.remove(rnd.nextInt(negativeSamples.size()));
// }
// break;
// case UP:
//
// }
int posClass = (int)getCurrentModel().getModel().positiveClass();
int negClass = (int)getCurrentModel().getModel().negativeClass();
Collections.shuffle(positiveSamples, rnd);
......
/*
* com.emphysic.myriadtrainer.ui.AboutDialog
*
* Copyright (c) 2017 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -119,7 +119,7 @@ public class AboutDialog extends javax.swing.JDialog {
File urlFile = new File(url.getPath());
copyright = new String(Files.readAllBytes(urlFile.toPath()));
} catch (IOException ioe) {
copyright = "Copyright (C) 2017 Emphysic LLC.";
copyright = "Copyright (C) 2018 Emphysic LLC.";
}
return copyright;
}
......
/*
* com.emphysic.myriadtrainer.ui.AboutEmphysicDialog
*
* Copyright (c) 2017 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/*
* com.emphysic.myriadtrainer.ui.AddDataDialog
*
* Copyright (c) 2017 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -25,7 +25,7 @@ import javax.swing.UIManager;
/**
* Add folders to Myriad model training.
* Copyright (C) 2017 Emphysic LLC. All rights reserved.
* Copyright (C) 2018 Emphysic LLC. All rights reserved.
* @author ccoughlin
*/
public class AddDataDialog extends javax.swing.JDialog {
......
/*
* com.emphysic.myriadtrainer.ui.DataPanel
*
* Copyright (c) 2017 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/*
* com.emphysic.myriadtrainer.ui.LicenseDialog
*
* Copyright (c) 2017 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/*
* com.emphysic.myriadtrainer.util.TrainerWorker
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2018 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -44,6 +44,8 @@ public class TrainerWorker extends SwingWorker<MLROIFinder, Void> {
* The initial model to train and test
*/
private MLROIFinder initialModel;
private MonteCarloCV cv;
/**
* Current accuracy of the model
*/
......@@ -54,21 +56,44 @@ public class TrainerWorker extends SwingWorker<MLROIFinder, Void> {
*/
private int iters = 5;
/**
* Whether to attempt to balance minority / majority labels in the training
* set with upsampling
*/
private boolean upSample = true;
/**
* Constructor
* @param model model to train
* @param data data used to train the model
* @param trainRatio proportion of data to use for training, 1-trainRatio is used for testing.
* @param rounds number of rounds of train/test to perform
* @param upSample if true, attempt to balance training set via upsampling
*/
public TrainerWorker(MLROIFinder model,
CrossValidation.Data data,
double trainRatio,
int rounds) {
int rounds,
boolean upSample) {
this.initialModel = model;
this.data = data;
this.trainRatio = trainRatio;
this.iters = rounds;
this.cv = new MonteCarloCV(this.data, trainRatio);
}
/**
* Constructor
* @param model model to train
* @param data data used to train the model
* @param trainRatio proportion of data to use for training, 1-trainRatio is used for testing.
* @param rounds number of rounds of train/test to perform
*/
public TrainerWorker(MLROIFinder model,
CrossValidation.Data data,
double trainRatio,
int rounds) {
this(model, data, trainRatio, rounds, true);
}
/**
......@@ -77,7 +102,7 @@ public class TrainerWorker extends SwingWorker<MLROIFinder, Void> {
* @param data data used to train the model
*/
public TrainerWorker(MLROIFinder model, CrossValidation.Data data) {
this(model, data, 0.75, 5);
this(model, data, 0.75, 5, true);
}
@Override
......@@ -86,10 +111,10 @@ public class TrainerWorker extends SwingWorker<MLROIFinder, Void> {
if (isCancelled()) {
throw new InterruptedException();
}
MonteCarloCV cv = new MonteCarloCV(data, trainRatio);
// Sort of a fudge here - we're not really doing CV just looking for accuracy metrics
Map.Entry<MLROIFinder, Double> bestSGDModel = cv.findBestModel(iters,
new MLROIFinder[] {initialModel}
new MLROIFinder[] {initialModel},
upSample
);
if (isCancelled()) {
throw new InterruptedException();
......@@ -174,4 +199,7 @@ public class TrainerWorker extends SwingWorker<MLROIFinder, Void> {
this.iters = iters;
}
public CrossValidation getCV() {
return cv;
}
}
Copyright 2017 Emphysic LLC.
Copyright 2018 Emphysic LLC.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment