[aGrUM] fixing cross-validation example. Updating apps/CMakemList.txt. adding example.stdout

parent eef6fdd9
......@@ -4,7 +4,6 @@ cmake_minimum_required(VERSION 2.8)
# -DCMAKE_BUILD_TYPE=DEBUG|RELEASE
# -G "MinGW Makefiles"
# cmake -DCMAKE_BUILD_TYPE=DEBUG
# or
# cmake -DCMAKE_BUILD_TYPE=RELEASE
......@@ -32,12 +31,12 @@ endif (aGrUM_FOUND)
file(GLOB HOWTO_BUILD_SIMPLE_BN_SOURCE ${HOWTO_BUILD_SIMPLE_BN_SOURCE_DIR}/*.cpp)
file(GLOB HOWTO_BUILD_SIMPLE_BN_INCLUDE ${HOWTO_BUILD_SIMPLE_BN_SOURCE_DIR}/*.h)
add_executable (buildSimpleBN ${HOWTO_BUILD_SIMPLE_BN_SOURCE})
add_executable (example ${HOWTO_BUILD_SIMPLE_BN_SOURCE})
if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release
target_link_libraries(buildSimpleBN agrum)
target_link_libraries(example agrum)
else() # debug : act install debug
target_link_libraries(buildSimpleBN agrum-dbg)
target_link_libraries(example agrum-dbg)
endif()
install (TARGETS buildSimpleBN DESTINATION bin)
install (TARGETS example DESTINATION bin)
== Created BN ==
================
BN{nodes: 3, arcs: 2, domainSize: 8, parameters: 12, compression ratio: -50% }
== dot version of the BN ==
===========================
digraph "no_name" {
graph [bgcolor=transparent,label="no_name"];
node [style=filled fillcolor="#ffffaa"];
"foo" [comment="0:foo<0,1>"];
"bar" [comment="1:bar<0,1>"];
"qux" [comment="2:qux<0,1>"];
"foo" -> "bar";
"qux" -> "bar";
}
== cpt for 'bar' variable ==
============================
<bar:0|foo:0|qux:0> :: 0.2 /<bar:1|foo:0|qux:0> :: 0.8 /<bar:0|foo:1|qux:0> :: 0.3 /<bar:1|foo:1|qux:0> :: 0.7 /<bar:0|foo:0|qux:1> :: 0.9 /<bar:1|foo:0|qux:1> :: 0.1 /<bar:0|foo:1|qux:1> :: 0.5 /<bar:1|foo:1|qux:1> :: 0.5
......@@ -33,12 +33,12 @@ endif (aGrUM_FOUND)
file(GLOB EXAMPLE_SOURCE ${EXAMPLE_SOURCE_DIR}/*.cpp)
file(GLOB EXAMPLE_INCLUDE ${EXAMPLE_SOURCE_DIR}/*.h)
add_executable (crossvalid ${EXAMPLE_SOURCE})
add_executable (example ${EXAMPLE_SOURCE})
if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release
target_link_libraries(crossvalid agrum)
target_link_libraries(example agrum)
else() # debug : act install debug
target_link_libraries(crossvalid agrum-dbg)
target_link_libraries(example agrum-dbg)
endif()
install (TARGETS crossvalid DESTINATION bin)
install (TARGETS example DESTINATION bin)
#include <iostream>
#include <agrum/BN/BayesNet.h>
#include <iostream>
#include <agrum/learning/database/databaseVectInRAM.h>
#include <agrum/learning/database/databaseFromCSV.h>
#include <agrum/learning/database/databaseVectInRAM.h>
#include <agrum/learning/database/DBCellTranslator.h>
#include <agrum/learning/database/DBCellTranslators/cellTranslatorCompactIntId.h>
#include <agrum/learning/database/DBRowTranslatorSet.h>
#include <agrum/learning/database/filteredRowGenerators/rowGeneratorIdentity.h>
......@@ -12,89 +14,100 @@
#include <agrum/learning/aprioris/aprioriSmoothing.h>
#include <agrum/learning/constraints/structuralConstraintDiGraph.h>
#include <agrum/learning/constraints/structuralConstraintDAG.h>
#include <agrum/learning/constraints/structuralConstraintDiGraph.h>
#include <agrum/learning/structureUtils/graphChangesSelector4DiGraph.h>
#include <agrum/learning/structureUtils/graphChangesGenerator4DiGraph.h>
#include <agrum/learning/structureUtils/graphChangesSelector4DiGraph.h>
#include <agrum/learning/paramUtils/paramEstimatorML.h>
#include <agrum/learning/greedyHillClimbing.h>
#include <agrum/learning/paramUtils/paramEstimatorML.h>
int main(int argc, char *argv[]) {
std::cout<<"Simple K-Cross-Validation with aGrUM"
<<std::endl<<std::endl;
int main(int argc, char* argv[]) {
std::cout << "Simple K-Cross-Validation with aGrUM" << std::endl << std::endl;
if (argc<2) {
std::cout<<"Call : example K"
<<std::endl<<std::endl;
exit(0);
gum::Idx k;
if (argc < 2) {
std::cout << "Please type 'example K' for a K-fold cross validation"
<< std::endl;
std::cout << "...using K=3 by default" << std::endl << std::endl;
k = 3;
} else {
k = atoi(argv[1]);
}
int k = atoi(argv[1]);
std::string csvfilename("../asia.csv");
std::string csvfilename("../asia.csv");
gum::learning::DatabaseFromCSV database(csvfilename);
// K-fold Cross Validation Start
int n = database.content().size();
std::cout<<" K="<<k<<" on "<<csvfilename<<" (size:"<<n<<")"
<<std::endl<<std::endl;
std::cout << " K=" << k << " on " << csvfilename << " (size:" << n << ")"
<< std::endl
<< std::endl;
// K-fold Cross Validation
try {
// Structure Learning
const int nbCol = 8; // <-- has to be changed for each csv file (nb of variables in the csv)
const int nbCol =
8; // <-- has to be changed for each csv file (nb of variables in the csv)
// will parse the database once
auto translators = gum::learning::make_translators(
gum::learning::Create<gum::learning::CellTranslatorCompactIntId,
gum::learning::Col<0>, nbCol>());
auto generators =
gum::learning::make_generators(gum::learning::RowGeneratorIdentity());
gum::learning::DBRowTranslatorSet< gum::learning::CellTranslatorCompactIntId >
translators;
translators.insertTranslator(0, nbCol);
gum::learning::FilteredRowGeneratorSet< gum::learning::RowGeneratorIdentity >
generators;
generators.insertGenerator();
auto filter =
gum::learning::make_DB_row_filter(database, translators, generators);
std::vector<gum::Size> modalities = filter.modalities();
gum::learning::make_DB_row_filter(database, translators, generators);
std::vector< gum::Size > modalities = filter.modalities();
gum::learning::AprioriSmoothing<> apriori;
gum::learning::StructuralConstraintSetStatic<
gum::learning::StructuralConstraintDAG> struct_constraint;
gum::learning::StructuralConstraintDAG >
struct_constraint;
gum::learning::GraphChangesGenerator4DiGraph<decltype(struct_constraint)>
op_set(struct_constraint);
gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) >
op_set(struct_constraint);
gum::learning::GreedyHillClimbing search;
int foldSize = database.content().size() / k;
for (int fold = 0; fold < k; fold++) {
Idx fold_deb = fold * foldSize;
Idx fold_end = fold_deb + foldSize - 1;
gum::Idx fold_deb = fold * foldSize;
gum::Idx fold_end = fold_deb + foldSize - 1;
std::cout << "+ LEARNING on [" << fold_deb << "," << fold_end << "] : ";
// LEARNING
filter.handler().setRange(fold_deb, fold_end);
gum::learning::ScoreBDeu<> score(filter, modalities, apriori);
gum::learning::ParamEstimatorML<> estimator(filter, modalities, apriori,
score.internalApriori());
gum::learning::GraphChangesSelector4DiGraph<
decltype(score), decltype(struct_constraint), decltype(op_set)>
selector(score, struct_constraint, op_set);
gum::Timer timer;
gum::BayesNet<float> bn =
search.learnBN(selector, estimator, database.variableNames(),
modalities, filter.translatorSet());
gum::learning::ScoreBDeu<> score(filter, modalities, apriori);
gum::learning::ParamEstimatorML<> estimator(
filter, modalities, apriori, score.internalApriori());
gum::learning::GraphChangesSelector4DiGraph< decltype(score),
decltype(struct_constraint),
decltype(op_set) >
selector(score, struct_constraint, op_set);
gum::Timer timer;
gum::BayesNet< double > bn = search.learnBN(selector,
estimator,
database.variableNames(),
modalities,
filter.translatorSet());
std::cout << timer.step() << "s ";
std::cout << bn.arcs().size() << " arcs" << std::endl;
// TESTING
gum::Instantiation I;
for (auto &name : filter.variableNames()) {
for (auto& name : filter.variableNames()) {
I.add(bn.variableFromName(name));
}
......@@ -117,8 +130,8 @@ int main(int argc, char *argv[]) {
}
if (fold_end + 1 < database.content().size() - 1) {
if (LL!=0.0) {
std::cout<<" U ";
if (LL != 0.0) {
std::cout << " U ";
}
std::cout << "[" << fold_end + 1 << "," << database.content().size() - 1
......@@ -138,7 +151,7 @@ int main(int argc, char *argv[]) {
std::cout << " : LL=" << LL << std::endl << std::endl;
}
} catch (const gum::Exception &ex) {
} catch (const gum::Exception& ex) {
std::cout << ex.errorContent() << std::endl;
}
}
Simple K-Cross-Validation with aGrUM
Please call 'example K' for a K-fold cross validation
...using K=3 by default
K=3 on ../asia.csv (size:10000)
+ LEARNING on [0,3332] : 0.107717s 11 arcs
TESTING on [3333,9999] : LL=-21656.8
+ LEARNING on [3333,6665] : 0.0877568s 11 arcs
TESTING on [0,3332] U [6666,9999] : LL=-21795.6
+ LEARNING on [6666,9998] : 0.0881687s 11 arcs
TESTING on [0,6665] : LL=-21617.2
......@@ -31,12 +31,12 @@ endif (aGrUM_FOUND)
file(GLOB OPERATIONS_WITH_POTENTIAL_SOURCE ${OPERATIONS_WITH_POTENTIAL_SOURCE_DIR}/*.cpp)
file(GLOB OPERATIONS_WITH_POTENTIAL_INCLUDE ${OPERATIONS_WITH_POTENTIAL_SOURCE_DIR}/*.h)
add_executable (ops_potentials ${OPERATIONS_WITH_POTENTIAL_SOURCE})
add_executable (example ${OPERATIONS_WITH_POTENTIAL_SOURCE})
if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release
target_link_libraries(ops_potentials agrum)
target_link_libraries(example agrum)
else() # debug : act install debug
target_link_libraries(ops_potentials agrum-dbg)
target_link_libraries(example agrum-dbg)
endif()
install (TARGETS ops_potentials DESTINATION bin)
install (TARGETS example DESTINATION bin)
p1: <a:0|b:0> :: 0.1 /<a:1|b:0> :: 0.2 /<a:0|b:1> :: 0.3 /<a:1|b:1> :: 0.4
p2: <b:0|c:0> :: 0.3 /<b:1|c:0> :: 0.4 /<b:0|c:1> :: 0.1 /<b:1|c:1> :: 0.2
p1*p2: <b:0|c:0|a:0> :: 0.03 /<b:1|c:0|a:0> :: 0.12 /<b:0|c:1|a:0> :: 0.01 /<b:1|c:1|a:0> :: 0.06 /<b:0|c:0|a:1> :: 0.06 /<b:1|c:0|a:1> :: 0.16 /<b:0|c:1|a:1> :: 0.02 /<b:1|c:1|a:1> :: 0.08
p1+p2: <b:0|c:0|a:0> :: 0.4 /<b:1|c:0|a:0> :: 0.7 /<b:0|c:1|a:0> :: 0.2 /<b:1|c:1|a:0> :: 0.5 /<b:0|c:0|a:1> :: 0.5 /<b:1|c:0|a:1> :: 0.8 /<b:0|c:1|a:1> :: 0.3 /<b:1|c:1|a:1> :: 0.6
p1/p2: <b:0|c:0|a:0> :: 0.333333 /<b:1|c:0|a:0> :: 0.75 /<b:0|c:1|a:0> :: 1 /<b:1|c:1|a:0> :: 1.5 /<b:0|c:0|a:1> :: 0.666667 /<b:1|c:0|a:1> :: 1 /<b:0|c:1|a:1> :: 2 /<b:1|c:1|a:1> :: 2
p1-p2: <b:0|c:0|a:0> :: -0.2 /<b:1|c:0|a:0> :: -0.1 /<b:0|c:1|a:0> :: 0 /<b:1|c:1|a:0> :: 0.1 /<b:0|c:0|a:1> :: -0.1 /<b:1|c:0|a:1> :: 0 /<b:0|c:1|a:1> :: 0.1 /<b:1|c:1|a:1> :: 0.2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment