Commit 0fff7a82 authored by Charles Vernerey's avatar Charles Vernerey
Browse files

Add zero and required items constraints for antecedent and consequent of the rules

parent 165a70f9
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -16,7 +16,8 @@ import io.gitlab.chaver.mining.rules.problems.AssociationRuleMining;
import picocli.CommandLine;
import picocli.CommandLine.Command;

@Command(name = "mine", subcommands = {ClosedSky.class, CpSky.class, AssociationRuleMining.class})
@Command(subcommands = {AssociationRuleMining.class, ClosedSky.class, CpSky.class}, mixinStandardHelpOptions = true,
        version = "1.0.0")
public class MainCommand {

    public static void main(String[] args) {
+31 −1
Original line number Diff line number Diff line
@@ -39,13 +39,16 @@ import picocli.CommandLine.Command;
import picocli.CommandLine.Option;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;

import static io.gitlab.chaver.mining.patterns.util.PatternUtil.findClosedPattern;
import static org.chocosolver.solver.search.strategy.Search.intVarSearch;

@Command(name = "arm", description = "Association rule mining")
@Command(name = "arm", description = "Association rule mining", mixinStandardHelpOptions = true)
public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasuresView> {

    @Option(names = {"-d", "--data"}, required = true, description = "Datafile to use")
@@ -62,10 +65,18 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
    private double minConf;
    @Option(names = "--sky", description = "Skypatterns file (impose constraint)")
    private String skyPath;
    @Option(names = "--0a", description = "Items to exclude in the antecedent (path of a file where each line " +
            "represents an item to exclude)")
    private String zeroItemsAntecedentPath;
    @Option(names = "--0c", description = "Items to exclude in the consequent (path of a file where each line " +
            "represents an item to exclude)")
    private String zeroItemsConsequentPath;

    private Database database;
    private ArMonitor arMonitor;
    private Map<Set<Integer>, Set<Integer>> closedPatterns;
    private int[] zeroItemsAntecedent = new int[0];
    private int[] zeroItemsConsequent = new int[0];

    @Override
    public void parseArgs() throws SetUpException {
@@ -83,6 +94,19 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
            if (skyPath != null) {
                closedPatterns = getClosedPatterns();
            }
            Map<Integer, Integer> itemsMap = database.getItemsMap();
            if (zeroItemsAntecedentPath != null) {
                zeroItemsAntecedent = Files.readAllLines(Paths.get(zeroItemsAntecedentPath), StandardCharsets.UTF_8)
                        .stream()
                        .mapToInt(s -> itemsMap.get(Integer.parseInt(s)))
                        .toArray();
            }
            if (zeroItemsConsequentPath != null) {
                zeroItemsConsequent = Files.readAllLines(Paths.get(zeroItemsConsequentPath), StandardCharsets.UTF_8)
                        .stream()
                        .mapToInt(s -> itemsMap.get(Integer.parseInt(s)))
                        .toArray();
            }
        }
        catch (IOException e) {
            throw new SetUpException(e.getMessage());
@@ -153,10 +177,16 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
        return skyVar;
    }

    private void zeroItemsConstraint(BoolVar[] items, int[] zeroItems) {
        Arrays.stream(zeroItems).forEach(i -> items[i].eq(0).post());
    }

    @Override
    public void buildModel() {
        BoolVar[] x = model.boolVarArray("x", database.getNbItems());
        BoolVar[] y = model.boolVarArray("y", database.getNbItems());
        zeroItemsConstraint(x, zeroItemsAntecedent);
        zeroItemsConstraint(y, zeroItemsConsequent);
        BoolVar[] z = model.boolVarArray("z", database.getNbItems());
        for (int i = 0; i < database.getNbItems(); i++) {
            model.arithm(x[i], "+", y[i], "<=", 1).post();