Commit 86fb7fb6 authored by Charles Vernerey's avatar Charles Vernerey
Browse files

Add OR constraint for items in the rule

parent ea427915
Loading
Loading
Loading
Loading
+21 −9
Original line number Diff line number Diff line
@@ -74,6 +74,9 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
    @Option(names = "--0c", description = "Items to exclude in the consequent (path of a file where each line " +
            "represents an item to exclude)")
    private String zeroItemsConsequentPath;
    @Option(names = "--or", description = "Items to include in the antecedent or the consequent (path of a file where" +
            "each line represents an item to include")
    private String orItemsPath;
    @Option(names = "--lab", description = "File path with the label of items (each line corresponds to one item)")
    private String labelsPath;
    private String[] labels;
@@ -85,6 +88,12 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
    private Map<Set<Integer>, Set<Integer>> closedPatterns;
    private int[] zeroItemsAntecedent = new int[0];
    private int[] zeroItemsConsequent = new int[0];
    private int[] orItems = new int[0];

    private int[] readItems(String path) throws IOException {
        Map<Integer, Integer> itemsMap = database.getItemsMap();
        return Files.readAllLines(Paths.get(path), StandardCharsets.UTF_8).stream().mapToInt(s -> itemsMap.get(Integer.parseInt(s))).toArray();
    }

    @Override
    public void parseArgs() throws SetUpException {
@@ -102,18 +111,14 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
            if (skyPath != null) {
                closedPatterns = getClosedPatterns();
            }
            Map<Integer, Integer> itemsMap = database.getItemsMap();
            if (zeroItemsAntecedentPath != null) {
                zeroItemsAntecedent = Files.readAllLines(Paths.get(zeroItemsAntecedentPath), StandardCharsets.UTF_8)
                        .stream()
                        .mapToInt(s -> itemsMap.get(Integer.parseInt(s)))
                        .toArray();
                zeroItemsAntecedent = readItems(zeroItemsAntecedentPath);
            }
            if (zeroItemsConsequentPath != null) {
                zeroItemsConsequent = Files.readAllLines(Paths.get(zeroItemsConsequentPath), StandardCharsets.UTF_8)
                        .stream()
                        .mapToInt(s -> itemsMap.get(Integer.parseInt(s)))
                        .toArray();
                zeroItemsConsequent = readItems(zeroItemsConsequentPath);
            }
            if (orItemsPath != null) {
                orItems = readItems(orItemsPath);
            }
            if (labelsPath != null) {
                labels = Files.readAllLines(Paths.get(labelsPath), StandardCharsets.UTF_8).toArray(new String[0]);
@@ -192,6 +197,12 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
        Arrays.stream(zeroItems).forEach(i -> items[i].eq(0).post());
    }

    private void orItemsConstraint(BoolVar[] z) {
        if (orItems.length == 0) return;
        BoolVar[] orItemVars = Arrays.stream(orItems).mapToObj(i -> z[i]).toArray(BoolVar[]::new);
        model.or(orItemVars).post();
    }

    @Override
    public void buildModel() {
        BoolVar[] x = model.boolVarArray("x", database.getNbItems());
@@ -203,6 +214,7 @@ public class AssociationRuleMining extends ChocoProblem<AssociationRule, ArMeasu
            model.arithm(x[i], "+", y[i], "<=", 1).post();
            model.addClausesBoolOrEqVar(x[i], y[i], z[i]);
        }
        orItemsConstraint(z);
        model.addClausesBoolOrArrayEqualTrue(x);
        model.addClausesBoolOrArrayEqualTrue(y);
        IntVar freqZ = model.intVar("freqZ", minFreq, database.getNbTransactions());