[aGrUM] apparently 2 bugs to fix in BNLearner with DiscretizedVariable

parent 00a69bbb
Pipeline #22855153 passed with stages
in 122 minutes and 25 seconds
......@@ -151,7 +151,7 @@ namespace gum {
DiscretizedVariable< T_TICKS >&
DiscretizedVariable< T_TICKS >::addTick(const T_TICKS& aTick) {
if (isTick(aTick)) {
GUM_ERROR(DefaultInLabel, "Tick already used for this variable ");
GUM_ERROR(DefaultInLabel, "Tick '"<<aTick<<"' already used for variable "<<name());
}
if (__ticks_size == __ticks.size()) { // streching __ticks if necessary
......
......@@ -19,6 +19,8 @@
***************************************************************************/
// floating point env
#include <cfenv>
#include <vector>
#include <string>
#include <cxxtest/AgrumTestSuite.h>
#include <cxxtest/testsuite_utils.h>
......@@ -803,6 +805,151 @@ namespace gum_tests {
TS_ASSERT(nb == 1);
}
};
void testBugDoumenc() {
gum::BayesNet< double > templ;
std::vector< std::string > varBool{"S",
"DEP",
"TM",
"TE",
"TV",
"PSY",
"AL",
"PT",
"HYP",
"FRE",
"PC",
"C",
"MN",
"AM",
"PR",
"AR",
"DFM"}; // les vraibles booléennes du RB
std::vector< std::string > varTer{
"NBC",
"MED",
"DEM",
"SP"}; // les variables pouvant prendre 3 valeurs possibles du RB
std::vector< std::string > varContinuous{
"A", "ADL"}; // les variables continues du RB
std::vector< gum::NodeId > nodeList; // Liste des noeuds du RB
for (auto var : varBool)
nodeList.push_back(templ.add(gum::LabelizedVariable(
var, var, 2))); // Ajout des variables booléennes à la liste des noeuds
for (auto var : varTer)
nodeList.push_back(templ.add(gum::LabelizedVariable(
var, var, 3))); // Ajout des variables ternaires à la liste des noeuds
gum::DiscretizedVariable< double > A("A", "A");
for (int i = 60; i <= 105; i += 5) {
A.addTick(double(i));
}
nodeList.push_back(templ.add(A)); // Ajout de la variable Age allant de 60
// à 100 ans à la liste des noeuds
// Ajout de la variable ADL allant de 0 à 6 à la liste des noeuds
nodeList.push_back(templ.add(gum::RangeVariable("ADL", "ADL", 0, 6)));
// Création du noeud central NRC (niveau de risque de chute)
gum::LabelizedVariable NRC("NRC", "NRC", 0);
NRC.addLabel("faible");
NRC.addLabel("modere");
NRC.addLabel("eleve");
auto iNRC = templ.add(NRC);
// Création des arcs partant du noeud NRC vers les autres noeuds
for (auto node : nodeList) {
templ.addArc(iNRC, node);
}
GUM_TRACE("building learner");
try {
gum::learning::BNLearner< double > learner(
GET_RESSOURCES_PATH("bugDoumenc.csv"), templ);
GUM_TRACE("learning initiated");
auto bn = learner.learnParameters(templ);
GUM_TRACE("learning done");
} catch (gum::Exception& e) { GUM_SHOWERROR(e); }
}
void testBugDoumencWithInt() {
gum::BayesNet< double > templ;
std::vector< std::string > varBool{"S",
"DEP",
"TM",
"TE",
"TV",
"PSY",
"AL",
"PT",
"HYP",
"FRE",
"PC",
"C",
"MN",
"AM",
"PR",
"AR",
"DFM"}; // les vraibles booléennes du RB
std::vector< std::string > varTer{
"NBC",
"MED",
"DEM",
"SP"}; // les variables pouvant prendre 3 valeurs possibles du RB
std::vector< std::string > varContinuous{
"A", "ADL"}; // les variables continues du RB
std::vector< gum::NodeId > nodeList; // Liste des noeuds du RB
for (auto var : varBool)
nodeList.push_back(templ.add(gum::LabelizedVariable(
var, var, 2))); // Ajout des variables booléennes à la liste des noeuds
for (auto var : varTer)
nodeList.push_back(templ.add(gum::LabelizedVariable(
var, var, 3))); // Ajout des variables ternaires à la liste des noeuds
gum::DiscretizedVariable< int > A("A", "A");
for (int i = 60; i <= 105; i += 5) {
A.addTick(i);
}
nodeList.push_back(templ.add(A)); // Ajout de la variable Age allant de 60
// à 100 ans à la liste des noeuds
// Ajout de la variable ADL allant de 0 à 6 à la liste des noeuds
nodeList.push_back(templ.add(gum::RangeVariable("ADL", "ADL", 0, 6)));
// Création du noeud central NRC (niveau de risque de chute)
gum::LabelizedVariable NRC("NRC", "NRC", 0);
NRC.addLabel("faible");
NRC.addLabel("modere");
NRC.addLabel("eleve");
auto iNRC = templ.add(NRC);
// Création des arcs partant du noeud NRC vers les autres noeuds
for (auto node : nodeList) {
templ.addArc(iNRC, node);
}
GUM_TRACE("building learner");
try {
gum::learning::BNLearner< double > learner(
GET_RESSOURCES_PATH("bugDoumenc.csv"), templ);
GUM_TRACE("learning initiated");
auto bn = learner.learnParameters(templ);
GUM_TRACE("learning done");
} catch (gum::Exception& e) { GUM_SHOWERROR(e); }
}
};
} /* namespace gum_tests */
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 3.0.12
* Version 4.0.0
*
* This file is not intended to be easily readable and contains a number of
* coding conventions designed to improve portability and efficiency. Do not make
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment