Commit 0c1dab4d by Chris Coughlin

Initial implementation of SMOTE oversampling and minor unit test fix

parent 6abee726
/*
* com.emphysic.myriad.core.ml.CrossValidation
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,6 +19,10 @@
package com.emphysic.myriad.core.ml;
import com.emphysic.myriad.core.data.roi.MLROIFinder;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;
import smile.validation.Accuracy;
import java.util.*;
......@@ -101,6 +105,153 @@ public abstract class CrossValidation {
}
}
/**
* Tallies the frequencies of each label in a dataset.
* @param data data to tally
* @return Map where keys are each label encountered and values are total number of samples in data observed with
* that label.
*/
public static Map<Integer, Integer> labelFrequencies(Data data) {
HashMap<Integer, Integer> labelFreqs = new HashMap<>();
for (int i=0; i<data.numSamples(); i++) {
int label = data.labels[i];
int value = 0;
if (labelFreqs.containsKey(label)) {
value = labelFreqs.get(label);
}
value++;
labelFreqs.put(label, value);
}
return labelFreqs;
}
/**
* Finds the minority label in a sample set, i.e. the label for which the fewest observations are found.
* @param labelFreqs label frequencies map
* @return label with fewest number of samples
*/
public static Integer findMinorityLabel(Map<Integer, Integer> labelFreqs) {
Set<Integer> keys = labelFreqs.keySet();
int minorityLabel = keys.iterator().next();
int minValue = labelFreqs.get(minorityLabel);
for (Integer label : labelFreqs.keySet()) {
int labelCount = labelFreqs.get(label);
if (labelCount < minValue) {
minValue = labelCount;
minorityLabel = label;
}
}
return minorityLabel;
}
/**
* Finds the minority label in a sample set, i.e. the label for which the fewest observations are found.
* @param data data to examine
* @return label with fewest number of samples
*/
public static Integer findMinorityLabel(Data data) {
return findMinorityLabel(labelFrequencies(data));
}
/**
* Finds the majority label in a sample set, i.e. the label for which the most observations are found.
* @param labelFreqs label frequencies map
* @return label with the most number of samples
*/
public static Integer findMajorityLabel(Map<Integer, Integer> labelFreqs) {
Set<Integer> keys = labelFreqs.keySet();
int majorityLabel = keys.iterator().next();
int maxValue = labelFreqs.get(majorityLabel);
for (Integer label : labelFreqs.keySet()) {
int labelCount = labelFreqs.get(label);
if (labelCount > maxValue) {
maxValue = labelCount;
majorityLabel = label;
}
}
return majorityLabel;
}
/**
* Finds the majority label in a sample set, i.e. the label for which the most observations are found.
* @param data data to examine
* @return label for which the most samples in data were found
*/
public static Integer findMajorityLabel(Data data) {
return findMajorityLabel(labelFrequencies(data));
}
/**
* Balances imbalanced data with the SMOTE (Synthetic Minority Over-sampling TEchnique) as per
* https://www.cs.cmu.edu/afs/cs/project/jair/pub/volume16/chawla02a-html/node6.html . The data are examined and
* the majority and minority labels are identified as the label found for the most and fewest samples respectively.
* The ratio of these sample subsets determines how many synthetic samples are generated.
* @param orig Original (possibly) imbalanced data
* @param distanceMetric distance metric to use for finding nearest neighbors
* @param nn number of nearest neighbors to find
* @return new data consisting of both the original and synthetic data
*/
public static Data balanceUp(Data orig, Distance<double[]> distanceMetric, int nn) {
LinearSearch<double[]> ls = new LinearSearch<>(orig.samples, distanceMetric);
List<double[]> oversamples = new ArrayList<>();
List<Integer> oversampleLabels = new ArrayList<>();
Map<Integer, Integer> labelFreqs = labelFrequencies(orig);
int minorityLabel = findMinorityLabel(labelFreqs);
int majorityLabel = findMajorityLabel(labelFreqs);
int sampleFactor = labelFreqs.get(majorityLabel) / labelFreqs.get(minorityLabel);
Random rnd = new Random();
for (int i=0; i<orig.numSamples(); i++) {
double[] sample = orig.samples[i];
oversamples.add(sample);
oversampleLabels.add(orig.labels[i]);
if (orig.labels[i] == minorityLabel) {
Neighbor<double[], double[]>[] neighbors = ls.knn(sample, nn);
for (int j=0; j<sampleFactor; j++) {
int rndIdx = rnd.nextInt(neighbors.length);
double[] neighbor = neighbors[rndIdx].value;
double[] synthSample = new double[orig.numFeatures()];
double featureMult = rnd.nextDouble();
for (int feature=0; feature<orig.numFeatures(); feature++) {
double delta = neighbor[feature] - sample[feature];
synthSample[feature] = sample[feature] + featureMult * delta;
}
oversamples.add(synthSample);
oversampleLabels.add(minorityLabel);
}
}
}
double[][] samples = oversamples.toArray(new double[0][]);
int[] labels = oversampleLabels.stream().mapToInt(i -> i).toArray();
return new Data(samples, labels);
}
/**
* Balances imbalanced data with the SMOTE (Synthetic Minority Over-sampling TEchnique) as per
* https://www.cs.cmu.edu/afs/cs/project/jair/pub/volume16/chawla02a-html/node6.html . The data are examined and
* the majority and minority labels are identified as the label found for the most and fewest samples respectively.
* The ratio of these sample subsets determines how many synthetic samples are generated. Neighbors are determined
* with simple linear (Euclidean) distance.
* @param orig data to balance
* @param nn number of nearest neighbors to use
* @return new data consisting of both the original and synthetic data
*/
public static Data balanceUp(Data orig, int nn) {
return balanceUp(orig, new EuclideanDistance(), nn);
}
/**
* Balances imbalanced data with the SMOTE (Synthetic Minority Over-sampling TEchnique) as per
* https://www.cs.cmu.edu/afs/cs/project/jair/pub/volume16/chawla02a-html/node6.html . The data are examined and
* the majority and minority labels are identified as the label found for the most and fewest samples respectively.
* The ratio of these sample subsets determines how many synthetic samples are generated. Neighbors are determined
* with simple linear (Euclidean) distance, and as per the original paper five nearest neighbors are chosen.
* @param orig data to balance
* @return new data consisting of both the original and synthetic data
*/
public static Data balanceUp(Data orig) {
return balanceUp(orig,5);
}
/**
* Helper class to partition input data
*/
......
/*
* com.emphysic.myriad.core.ml.MonteCarloCV
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -26,7 +26,7 @@ import java.util.Arrays;
*/
public class MonteCarloCV extends CrossValidation {
private double ratio; // Ratio of training size: test size
Integer[] indices;
private Integer[] indices;
/**
* Create a new Monte Carlo cross validator.
......
/*
* com.emphysic.myriad.core.ml.CrossValidationTest
*
* Copyright (c) 2016 Emphysic LLC.
* Copyright (c) 2017 Emphysic LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -25,6 +25,7 @@ import org.junit.Test;
import org.sgdtk.HingeLoss;
import org.sgdtk.LogLoss;
import org.sgdtk.Loss;
import smile.math.distance.EuclideanDistance;
import java.io.BufferedReader;
import java.io.File;
......@@ -34,7 +35,7 @@ import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
* CrossValidationTest - tests the CrossValidation class.
......@@ -43,8 +44,7 @@ import static org.junit.Assert.assertNotNull;
public class CrossValidationTest {
private CrossValidation cv;
@Before
public void setUp() throws Exception {
public CrossValidation.Data getData() throws Exception{
URL dataURL = Thread.currentThread().getContextClassLoader().getResource("data/ml/enchilada.txt");
File samplesFile = new File(dataURL.getPath());
BufferedReader br = new BufferedReader(new FileReader(samplesFile));
......@@ -61,8 +61,65 @@ public class CrossValidationTest {
double val = Double.parseDouble(tokens[i + 1]);
samples[sampleno][i] = val;
}
sampleno++;
}
return new CrossValidation.Data(samples, labels);
}
@Before
public void setUp() throws Exception {
CrossValidation.Data data = getData();
cv = new MonteCarloCV(data.samples, data.labels);
}
@Test
public void findMinorityLabel() throws Exception {
CrossValidation.Data data = getData();
assertTrue(CrossValidation.findMinorityLabel(data) == 1);
}
@Test
public void findMajorityLabel() throws Exception {
CrossValidation.Data data = getData();
assertTrue(CrossValidation.findMajorityLabel(data) == -1);
}
@Test
public void findLabelFrequencies() throws Exception {
CrossValidation.Data data = getData();
int numPos = 0;
int numNeg = 0;
for (int i=0; i<data.numSamples(); i++) {
if (data.labels[i] == -1) {
numNeg++;
} else if (data.labels[i] == 1) {
numPos++;
}
}
Map<Integer, Integer> labelFreqs = CrossValidation.labelFrequencies(data);
assertTrue(labelFreqs.get(-1) == numNeg);
assertTrue(labelFreqs.get(1) == numPos);
}
@Test
public void balanceUp() throws Exception {
CrossValidation.Data data = getData();
int numNeighbors = 5;
CrossValidation.Data upSampled = CrossValidation.balanceUp(data, new EuclideanDistance(), numNeighbors);
Map<Integer, Integer> origLabelFreqs = CrossValidation.labelFrequencies(data);
int minorityLabel = CrossValidation.findMinorityLabel(origLabelFreqs);
int majorityLabel = CrossValidation.findMajorityLabel(origLabelFreqs);
int factor = origLabelFreqs.get(majorityLabel) / origLabelFreqs.get(minorityLabel);
assertTrue(origLabelFreqs.get(minorityLabel) * factor +
data.numSamples() == upSampled.numSamples());
Map<Integer, Integer> newLabelFreqs = CrossValidation.labelFrequencies(upSampled);
assertTrue(origLabelFreqs.keySet().equals(newLabelFreqs.keySet()));
assertTrue(origLabelFreqs.get(minorityLabel) * (factor + 1) == newLabelFreqs.get(minorityLabel));
for (Integer key : origLabelFreqs.keySet()) {
if (key != minorityLabel) {
assertTrue(origLabelFreqs.get(key).equals(newLabelFreqs.get(key)));
}
}
cv = new MonteCarloCV(samples, labels);
}
@Test
......@@ -80,33 +137,30 @@ public class CrossValidationTest {
@Test
public void eval() throws Exception {
SGDROIFinder[] models = genModels();
Map<MLROIFinder, List<Double>> evaluation = cv.evalModels(2, models);
Map<MLROIFinder, Double> meanEvaluation = cv.eval(2, models);
int numRounds = 3;
Map<MLROIFinder, Double> meanEvaluation = cv.eval(numRounds, models);
assertEquals(meanEvaluation.keySet().size(), models.length);
for (MLROIFinder model : evaluation.keySet()) {
List<Double> accuracies = evaluation.get(model);
Double expected = mean(accuracies);
Double meanAccuracy = meanEvaluation.get(model);
assertEquals(expected, meanAccuracy, 0.05 * expected);
for (int i=0; i<models.length; i++) {
MLROIFinder model = models[i];
assertTrue(meanEvaluation.containsKey(model));
double acc = meanEvaluation.get(model);
assertTrue(acc >= 0 && acc <= 1);
}
}
@Test
public void findBestModel() throws Exception {
SGDROIFinder[] models = genModels();
Map<MLROIFinder, Double> meanEvaluation = cv.eval(2, models);
MLROIFinder best = null;
Double bestAccuracy = Double.MIN_VALUE;
for (MLROIFinder model : meanEvaluation.keySet()) {
Double meanAccuracy = meanEvaluation.get(model);
if (meanAccuracy > bestAccuracy) {
bestAccuracy = meanAccuracy;
best = model;
Map.Entry<MLROIFinder, Double> returnedBest = cv.findBestModel(2, models);
boolean found = false;
for (SGDROIFinder model : models) {
if (model == returnedBest.getKey()) {
found = true;
}
}
assertNotNull(best);
Map.Entry<MLROIFinder, Double> returnedBest = cv.findBestModel(2, models);
assertEquals(bestAccuracy, returnedBest.getValue(), 0.05 * bestAccuracy);
assertTrue(found);
double acc = returnedBest.getValue();
assertTrue(acc >= 0 && acc <= 1);
}
/**
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment