[aGrUM] fixing cross-validation example. Updating apps/CMakemList.txt. adding example.stdout

parent eef6fdd9
...@@ -4,7 +4,6 @@ cmake_minimum_required(VERSION 2.8) ...@@ -4,7 +4,6 @@ cmake_minimum_required(VERSION 2.8)
# -DCMAKE_BUILD_TYPE=DEBUG|RELEASE # -DCMAKE_BUILD_TYPE=DEBUG|RELEASE
# -G "MinGW Makefiles" # -G "MinGW Makefiles"
# cmake -DCMAKE_BUILD_TYPE=DEBUG # cmake -DCMAKE_BUILD_TYPE=DEBUG
# or # or
# cmake -DCMAKE_BUILD_TYPE=RELEASE # cmake -DCMAKE_BUILD_TYPE=RELEASE
...@@ -32,12 +31,12 @@ endif (aGrUM_FOUND) ...@@ -32,12 +31,12 @@ endif (aGrUM_FOUND)
file(GLOB HOWTO_BUILD_SIMPLE_BN_SOURCE ${HOWTO_BUILD_SIMPLE_BN_SOURCE_DIR}/*.cpp) file(GLOB HOWTO_BUILD_SIMPLE_BN_SOURCE ${HOWTO_BUILD_SIMPLE_BN_SOURCE_DIR}/*.cpp)
file(GLOB HOWTO_BUILD_SIMPLE_BN_INCLUDE ${HOWTO_BUILD_SIMPLE_BN_SOURCE_DIR}/*.h) file(GLOB HOWTO_BUILD_SIMPLE_BN_INCLUDE ${HOWTO_BUILD_SIMPLE_BN_SOURCE_DIR}/*.h)
add_executable (buildSimpleBN ${HOWTO_BUILD_SIMPLE_BN_SOURCE}) add_executable (example ${HOWTO_BUILD_SIMPLE_BN_SOURCE})
if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release
target_link_libraries(buildSimpleBN agrum) target_link_libraries(example agrum)
else() # debug : act install debug else() # debug : act install debug
target_link_libraries(buildSimpleBN agrum-dbg) target_link_libraries(example agrum-dbg)
endif() endif()
install (TARGETS buildSimpleBN DESTINATION bin) install (TARGETS example DESTINATION bin)
== Created BN ==
================
BN{nodes: 3, arcs: 2, domainSize: 8, parameters: 12, compression ratio: -50% }
== dot version of the BN ==
===========================
digraph "no_name" {
graph [bgcolor=transparent,label="no_name"];
node [style=filled fillcolor="#ffffaa"];
"foo" [comment="0:foo<0,1>"];
"bar" [comment="1:bar<0,1>"];
"qux" [comment="2:qux<0,1>"];
"foo" -> "bar";
"qux" -> "bar";
}
== cpt for 'bar' variable ==
============================
<bar:0|foo:0|qux:0> :: 0.2 /<bar:1|foo:0|qux:0> :: 0.8 /<bar:0|foo:1|qux:0> :: 0.3 /<bar:1|foo:1|qux:0> :: 0.7 /<bar:0|foo:0|qux:1> :: 0.9 /<bar:1|foo:0|qux:1> :: 0.1 /<bar:0|foo:1|qux:1> :: 0.5 /<bar:1|foo:1|qux:1> :: 0.5
...@@ -33,12 +33,12 @@ endif (aGrUM_FOUND) ...@@ -33,12 +33,12 @@ endif (aGrUM_FOUND)
file(GLOB EXAMPLE_SOURCE ${EXAMPLE_SOURCE_DIR}/*.cpp) file(GLOB EXAMPLE_SOURCE ${EXAMPLE_SOURCE_DIR}/*.cpp)
file(GLOB EXAMPLE_INCLUDE ${EXAMPLE_SOURCE_DIR}/*.h) file(GLOB EXAMPLE_INCLUDE ${EXAMPLE_SOURCE_DIR}/*.h)
add_executable (crossvalid ${EXAMPLE_SOURCE}) add_executable (example ${EXAMPLE_SOURCE})
if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release
target_link_libraries(crossvalid agrum) target_link_libraries(example agrum)
else() # debug : act install debug else() # debug : act install debug
target_link_libraries(crossvalid agrum-dbg) target_link_libraries(example agrum-dbg)
endif() endif()
install (TARGETS crossvalid DESTINATION bin) install (TARGETS example DESTINATION bin)
#include <iostream>
#include <agrum/BN/BayesNet.h> #include <agrum/BN/BayesNet.h>
#include <iostream>
#include <agrum/learning/database/databaseVectInRAM.h>
#include <agrum/learning/database/databaseFromCSV.h> #include <agrum/learning/database/databaseFromCSV.h>
#include <agrum/learning/database/databaseVectInRAM.h>
#include <agrum/learning/database/DBCellTranslator.h>
#include <agrum/learning/database/DBCellTranslators/cellTranslatorCompactIntId.h> #include <agrum/learning/database/DBCellTranslators/cellTranslatorCompactIntId.h>
#include <agrum/learning/database/DBRowTranslatorSet.h>
#include <agrum/learning/database/filteredRowGenerators/rowGeneratorIdentity.h> #include <agrum/learning/database/filteredRowGenerators/rowGeneratorIdentity.h>
...@@ -12,89 +14,100 @@ ...@@ -12,89 +14,100 @@
#include <agrum/learning/aprioris/aprioriSmoothing.h> #include <agrum/learning/aprioris/aprioriSmoothing.h>
#include <agrum/learning/constraints/structuralConstraintDiGraph.h>
#include <agrum/learning/constraints/structuralConstraintDAG.h> #include <agrum/learning/constraints/structuralConstraintDAG.h>
#include <agrum/learning/constraints/structuralConstraintDiGraph.h>
#include <agrum/learning/structureUtils/graphChangesSelector4DiGraph.h>
#include <agrum/learning/structureUtils/graphChangesGenerator4DiGraph.h> #include <agrum/learning/structureUtils/graphChangesGenerator4DiGraph.h>
#include <agrum/learning/structureUtils/graphChangesSelector4DiGraph.h>
#include <agrum/learning/paramUtils/paramEstimatorML.h>
#include <agrum/learning/greedyHillClimbing.h> #include <agrum/learning/greedyHillClimbing.h>
#include <agrum/learning/paramUtils/paramEstimatorML.h>
int main(int argc, char *argv[]) { int main(int argc, char* argv[]) {
std::cout<<"Simple K-Cross-Validation with aGrUM" std::cout << "Simple K-Cross-Validation with aGrUM" << std::endl << std::endl;
<<std::endl<<std::endl;
if (argc<2) { gum::Idx k;
std::cout<<"Call : example K" if (argc < 2) {
<<std::endl<<std::endl; std::cout << "Please type 'example K' for a K-fold cross validation"
exit(0); << std::endl;
std::cout << "...using K=3 by default" << std::endl << std::endl;
k = 3;
} else {
k = atoi(argv[1]);
} }
int k = atoi(argv[1]); std::string csvfilename("../asia.csv");
std::string csvfilename("../asia.csv");
gum::learning::DatabaseFromCSV database(csvfilename); gum::learning::DatabaseFromCSV database(csvfilename);
// K-fold Cross Validation Start // K-fold Cross Validation Start
int n = database.content().size(); int n = database.content().size();
std::cout<<" K="<<k<<" on "<<csvfilename<<" (size:"<<n<<")" std::cout << " K=" << k << " on " << csvfilename << " (size:" << n << ")"
<<std::endl<<std::endl; << std::endl
<< std::endl;
// K-fold Cross Validation // K-fold Cross Validation
try { try {
// Structure Learning // Structure Learning
const int nbCol = 8; // <-- has to be changed for each csv file (nb of variables in the csv) const int nbCol =
8; // <-- has to be changed for each csv file (nb of variables in the csv)
// will parse the database once // will parse the database once
auto translators = gum::learning::make_translators( gum::learning::DBRowTranslatorSet< gum::learning::CellTranslatorCompactIntId >
gum::learning::Create<gum::learning::CellTranslatorCompactIntId, translators;
gum::learning::Col<0>, nbCol>()); translators.insertTranslator(0, nbCol);
auto generators =
gum::learning::make_generators(gum::learning::RowGeneratorIdentity()); gum::learning::FilteredRowGeneratorSet< gum::learning::RowGeneratorIdentity >
generators;
generators.insertGenerator();
auto filter = auto filter =
gum::learning::make_DB_row_filter(database, translators, generators); gum::learning::make_DB_row_filter(database, translators, generators);
std::vector<gum::Size> modalities = filter.modalities();
std::vector< gum::Size > modalities = filter.modalities();
gum::learning::AprioriSmoothing<> apriori; gum::learning::AprioriSmoothing<> apriori;
gum::learning::StructuralConstraintSetStatic< gum::learning::StructuralConstraintSetStatic<
gum::learning::StructuralConstraintDAG> struct_constraint; gum::learning::StructuralConstraintDAG >
struct_constraint;
gum::learning::GraphChangesGenerator4DiGraph<decltype(struct_constraint)> gum::learning::GraphChangesGenerator4DiGraph< decltype(struct_constraint) >
op_set(struct_constraint); op_set(struct_constraint);
gum::learning::GreedyHillClimbing search; gum::learning::GreedyHillClimbing search;
int foldSize = database.content().size() / k; int foldSize = database.content().size() / k;
for (int fold = 0; fold < k; fold++) { for (int fold = 0; fold < k; fold++) {
Idx fold_deb = fold * foldSize; gum::Idx fold_deb = fold * foldSize;
Idx fold_end = fold_deb + foldSize - 1; gum::Idx fold_end = fold_deb + foldSize - 1;
std::cout << "+ LEARNING on [" << fold_deb << "," << fold_end << "] : "; std::cout << "+ LEARNING on [" << fold_deb << "," << fold_end << "] : ";
// LEARNING // LEARNING
filter.handler().setRange(fold_deb, fold_end); filter.handler().setRange(fold_deb, fold_end);
gum::learning::ScoreBDeu<> score(filter, modalities, apriori); gum::learning::ScoreBDeu<> score(filter, modalities, apriori);
gum::learning::ParamEstimatorML<> estimator(filter, modalities, apriori, gum::learning::ParamEstimatorML<> estimator(
score.internalApriori()); filter, modalities, apriori, score.internalApriori());
gum::learning::GraphChangesSelector4DiGraph< gum::learning::GraphChangesSelector4DiGraph< decltype(score),
decltype(score), decltype(struct_constraint), decltype(op_set)> decltype(struct_constraint),
selector(score, struct_constraint, op_set); decltype(op_set) >
gum::Timer timer; selector(score, struct_constraint, op_set);
gum::BayesNet<float> bn = gum::Timer timer;
search.learnBN(selector, estimator, database.variableNames(), gum::BayesNet< double > bn = search.learnBN(selector,
modalities, filter.translatorSet()); estimator,
database.variableNames(),
modalities,
filter.translatorSet());
std::cout << timer.step() << "s "; std::cout << timer.step() << "s ";
std::cout << bn.arcs().size() << " arcs" << std::endl; std::cout << bn.arcs().size() << " arcs" << std::endl;
// TESTING // TESTING
gum::Instantiation I; gum::Instantiation I;
for (auto &name : filter.variableNames()) { for (auto& name : filter.variableNames()) {
I.add(bn.variableFromName(name)); I.add(bn.variableFromName(name));
} }
...@@ -117,8 +130,8 @@ int main(int argc, char *argv[]) { ...@@ -117,8 +130,8 @@ int main(int argc, char *argv[]) {
} }
if (fold_end + 1 < database.content().size() - 1) { if (fold_end + 1 < database.content().size() - 1) {
if (LL!=0.0) { if (LL != 0.0) {
std::cout<<" U "; std::cout << " U ";
} }
std::cout << "[" << fold_end + 1 << "," << database.content().size() - 1 std::cout << "[" << fold_end + 1 << "," << database.content().size() - 1
...@@ -138,7 +151,7 @@ int main(int argc, char *argv[]) { ...@@ -138,7 +151,7 @@ int main(int argc, char *argv[]) {
std::cout << " : LL=" << LL << std::endl << std::endl; std::cout << " : LL=" << LL << std::endl << std::endl;
} }
} catch (const gum::Exception &ex) { } catch (const gum::Exception& ex) {
std::cout << ex.errorContent() << std::endl; std::cout << ex.errorContent() << std::endl;
} }
} }
Simple K-Cross-Validation with aGrUM
Please call 'example K' for a K-fold cross validation
...using K=3 by default
K=3 on ../asia.csv (size:10000)
+ LEARNING on [0,3332] : 0.107717s 11 arcs
TESTING on [3333,9999] : LL=-21656.8
+ LEARNING on [3333,6665] : 0.0877568s 11 arcs
TESTING on [0,3332] U [6666,9999] : LL=-21795.6
+ LEARNING on [6666,9998] : 0.0881687s 11 arcs
TESTING on [0,6665] : LL=-21617.2
...@@ -31,12 +31,12 @@ endif (aGrUM_FOUND) ...@@ -31,12 +31,12 @@ endif (aGrUM_FOUND)
file(GLOB OPERATIONS_WITH_POTENTIAL_SOURCE ${OPERATIONS_WITH_POTENTIAL_SOURCE_DIR}/*.cpp) file(GLOB OPERATIONS_WITH_POTENTIAL_SOURCE ${OPERATIONS_WITH_POTENTIAL_SOURCE_DIR}/*.cpp)
file(GLOB OPERATIONS_WITH_POTENTIAL_INCLUDE ${OPERATIONS_WITH_POTENTIAL_SOURCE_DIR}/*.h) file(GLOB OPERATIONS_WITH_POTENTIAL_INCLUDE ${OPERATIONS_WITH_POTENTIAL_SOURCE_DIR}/*.h)
add_executable (ops_potentials ${OPERATIONS_WITH_POTENTIAL_SOURCE}) add_executable (example ${OPERATIONS_WITH_POTENTIAL_SOURCE})
if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release if (${CMAKE_BUILD_TYPE} STREQUAL "RELEASE") # release : act install release
target_link_libraries(ops_potentials agrum) target_link_libraries(example agrum)
else() # debug : act install debug else() # debug : act install debug
target_link_libraries(ops_potentials agrum-dbg) target_link_libraries(example agrum-dbg)
endif() endif()
install (TARGETS ops_potentials DESTINATION bin) install (TARGETS example DESTINATION bin)
p1: <a:0|b:0> :: 0.1 /<a:1|b:0> :: 0.2 /<a:0|b:1> :: 0.3 /<a:1|b:1> :: 0.4
p2: <b:0|c:0> :: 0.3 /<b:1|c:0> :: 0.4 /<b:0|c:1> :: 0.1 /<b:1|c:1> :: 0.2
p1*p2: <b:0|c:0|a:0> :: 0.03 /<b:1|c:0|a:0> :: 0.12 /<b:0|c:1|a:0> :: 0.01 /<b:1|c:1|a:0> :: 0.06 /<b:0|c:0|a:1> :: 0.06 /<b:1|c:0|a:1> :: 0.16 /<b:0|c:1|a:1> :: 0.02 /<b:1|c:1|a:1> :: 0.08
p1+p2: <b:0|c:0|a:0> :: 0.4 /<b:1|c:0|a:0> :: 0.7 /<b:0|c:1|a:0> :: 0.2 /<b:1|c:1|a:0> :: 0.5 /<b:0|c:0|a:1> :: 0.5 /<b:1|c:0|a:1> :: 0.8 /<b:0|c:1|a:1> :: 0.3 /<b:1|c:1|a:1> :: 0.6
p1/p2: <b:0|c:0|a:0> :: 0.333333 /<b:1|c:0|a:0> :: 0.75 /<b:0|c:1|a:0> :: 1 /<b:1|c:1|a:0> :: 1.5 /<b:0|c:0|a:1> :: 0.666667 /<b:1|c:0|a:1> :: 1 /<b:0|c:1|a:1> :: 2 /<b:1|c:1|a:1> :: 2
p1-p2: <b:0|c:0|a:0> :: -0.2 /<b:1|c:0|a:0> :: -0.1 /<b:0|c:1|a:0> :: 0 /<b:1|c:1|a:0> :: 0.1 /<b:0|c:0|a:1> :: -0.1 /<b:1|c:0|a:1> :: 0 /<b:0|c:1|a:1> :: 0.1 /<b:1|c:1|a:1> :: 0.2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment