Commit ea427915 authored by Charles Vernerey's avatar Charles Vernerey
Browse files

Add enum for rule measures

parent 1ff1b7e6
Loading
Loading
Loading
Loading
+25 −38
Original line number Diff line number Diff line
@@ -10,10 +10,12 @@
package io.gitlab.chaver.mining.rules.io;

import io.gitlab.chaver.mining.patterns.io.Database;
import io.gitlab.chaver.mining.rules.measure.RuleMeasure;
import lombok.AllArgsConstructor;
import lombok.Getter;

import java.text.DecimalFormat;
import java.util.List;
import java.util.Map;

/**
@@ -22,59 +24,31 @@ import java.util.Map;
@AllArgsConstructor
public class AssociationRule {

    public static DecimalFormat df = new DecimalFormat("0.000");

    /**
     * Antecedent of the rule
     * Items in the antecedent of the rule
     */
    private @Getter int[] x;

    /**
     * Conclusion of the rule
     * Items in the consequence of the rule
     */
    private @Getter int[] y;

    /**
     * Frequency of x
     * Frequency of x (antecedent of the rule)
     */
    private @Getter int freqX;

    /**
     * Frequency of y
     * Frequency of y (consequence of the rule)
     */
    private @Getter int freqY;

    /**
     * Frequency of z
     * Frequency of z (union between the antecedent and the consequence of the rule)
     */
    private @Getter int freqZ;

    /**
     * Compute the confidence of the rule
     * @return confidence of the rule
     */
    public double conf() {
        return (double) freqZ / freqX;
    }

    /**
     * Compute relative support of the rule
     * @param nbTransactions number of transactions in the database
     * @return relative support of the rule
     */
    public double support(int nbTransactions) {
        return (double) freqZ / nbTransactions;
    }

    /**
     * Compute lift of the rule
     * @param nbTransactions number of transactions in the database
     * @return lift of the rule
     */
    public double lift(int nbTransactions) {
        return conf() * nbTransactions / freqY;
    }

    private String convertToString(int[] pattern) {
        if (pattern.length == 0) return "{}";
        StringBuilder str = new StringBuilder("{").append(pattern[0]);
@@ -107,20 +81,33 @@ public class AssociationRule {
                '}';
    }

    private String computeMeasures(List<RuleMeasure> measures, int nbTransactions, DecimalFormat measureFormat) {
        StringBuilder str = new StringBuilder("{");
        boolean begin = true;
        for (RuleMeasure measure : measures) {
            if (!begin) {
                str.append(", ");
            }
            begin = false;
            str.append(measure.getName()).append("=").append(measureFormat.format(measure.compute(this, nbTransactions)));
        }
        str.append("}");
        return str.toString();
    }

    /**
     * Convert rule to String
     * @param database database to consider
     * @param labels label of each item
     * @return corresponding string
     */
    public String toString(Database database, String[] labels) {
    public String toString(Database database, String[] labels, List<RuleMeasure> measures, DecimalFormat measureFormat) {
        int nbTransactions = database.getNbTransactions();
        if (labels == null) {
            return convertToString(x) + " => " + convertToString(y) + ", supZ=" + df.format(support(nbTransactions)) +
                    ", conf=" + df.format(conf()) + ", lift=" + df.format(lift(nbTransactions));
            return convertToString(x) + " => " + convertToString(y) + ", measures=" +
                    computeMeasures(measures, nbTransactions, measureFormat);
        }
        return convertToString(x, labels, database) + " => " + convertToString(y, labels, database) +
                ", supZ=" + df.format(support(nbTransactions)) +
                ", conf=" + df.format(conf()) + ", lift=" + df.format(lift(nbTransactions));
                ", measures=" + computeMeasures(measures, nbTransactions, measureFormat);
    }
}
+32 −0
Original line number Diff line number Diff line
/*
 * This file is part of io.gitlab.chaver:data-mining (https://gitlab.com/chaver/data-mining)
 *
 * Copyright (c) 2022, IMT Atlantique
 *
 * Licensed under the MIT license.
 *
 * See LICENSE file in the project root for full license information.
 */
package io.gitlab.chaver.mining.rules.measure;

import io.gitlab.chaver.mining.rules.io.AssociationRule;

/**
 * Measure to compute for an association rule
 */
public interface RuleMeasure {

    /**
     * Name of the measure
     * @return the name of the measure
     */
    String getName();

    /**
     * Compute the measure
     * @param rule given association rule
     * @param nbTransactions number of transactions in the database
     * @return the value of the measure for this association rule
     */
    double compute(AssociationRule rule, int nbTransactions);
}
+62 −0
Original line number Diff line number Diff line
/*
 * This file is part of io.gitlab.chaver:data-mining (https://gitlab.com/chaver/data-mining)
 *
 * Copyright (c) 2022, IMT Atlantique
 *
 * Licensed under the MIT license.
 *
 * See LICENSE file in the project root for full license information.
 */
package io.gitlab.chaver.mining.rules.measure;

import io.gitlab.chaver.mining.rules.io.AssociationRule;

/**
 * Classic rule measures
 */
public enum SimpleRuleMeasures implements RuleMeasure {

    sup {
        @Override
        public String getName() {
            return "support";
        }
        @Override
        public double compute(AssociationRule rule, int nbTransactions) {
            return rule.getFreqZ();
        }
    },
    rsup {
        @Override
        public String getName() {
            return "relative support";
        }

        @Override
        public double compute(AssociationRule rule, int nbTransactions) {
            return (double) rule.getFreqZ() / nbTransactions;
        }
    },
    conf {
        @Override
        public String getName() {
            return "confidence";
        }
        @Override
        public double compute(AssociationRule rule, int nbTransactions) {
            return (double) rule.getFreqZ() / rule.getFreqX();
        }
    },
    lift {
        @Override
        public String getName() {
            return "lift";
        }
        @Override
        public double compute(AssociationRule rule, int nbTransactions) {
            return conf.compute(rule, nbTransactions) * nbTransactions / rule.getFreqY();
        }
    }


}
+6 −1
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ import io.gitlab.chaver.mining.patterns.io.PatternProblemProperties;
import io.gitlab.chaver.mining.rules.io.ArMeasuresView;
import io.gitlab.chaver.mining.rules.io.AssociationRule;
import io.gitlab.chaver.mining.rules.io.RuleType;
import io.gitlab.chaver.mining.rules.measure.RuleMeasure;
import io.gitlab.chaver.mining.rules.search.loop.monitors.ArMonitor;
import org.chocosolver.solver.Model;
import org.chocosolver.solver.Settings;
@@ -42,10 +43,12 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.DecimalFormat;
import java.util.*;
import java.util.stream.Collectors;

import static io.gitlab.chaver.mining.patterns.util.PatternUtil.findClosedPattern;
import static io.gitlab.chaver.mining.rules.measure.SimpleRuleMeasures.*;
import static org.chocosolver.solver.search.strategy.Search.intVarSearch;

@Command(name = "arm", description = "Association rule mining", mixinStandardHelpOptions = true)
@@ -74,6 +77,8 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
    @Option(names = "--lab", description = "File path with the label of items (each line corresponds to one item)")
    private String labelsPath;
    private String[] labels;
    private List<RuleMeasure> measures = Arrays.asList(sup, rsup, conf, lift);
    private DecimalFormat measureFormat = new DecimalFormat("0.000");

    private Database database;
    private ArMonitor arMonitor;
@@ -242,7 +247,7 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu

    @Override
    protected void printSolutions() {
        getSolutions().forEach(s -> System.out.println(s.toString(database, labels)));
        getSolutions().forEach(s -> System.out.println(s.toString(database, labels, measures, measureFormat)));
    }

    public static void main(String[] args) throws Exception {