[aGrUM] still refactofing approximate inference

parent bed29d41
......@@ -74,7 +74,6 @@ namespace gum {
}
}
} while (wrong_value);
GUM_TRACE(prev);
return prev;
}
}
......@@ -139,7 +139,6 @@ namespace gum {
for (auto elmt : nonRequisite)
__samplingBN->uninstallNode(elmt);
GUM_TRACE(__samplingBN->toDot());
for (auto hard : this->hardEvidenceNodes()) {
gum::Instantiation I;
I.add(this->BN().variable(hard));
......@@ -177,8 +176,6 @@ namespace gum {
Ip = this->_draw(&w, Ip);
__estimator.update(Ip, w);
updateApproximationScheme();
std::cout << Ip << __estimator.posterior(this->BN().variableFromName("h"))
<< " " << __estimator.confidence() << std::endl;
} while (continueApproximationScheme(__estimator.confidence()));
this->isSetEstimator = false;
......@@ -188,12 +185,9 @@ namespace gum {
template < typename GUM_SCALAR >
void ApproximateInference< GUM_SCALAR >::_addVarSample(NodeId nod,
Instantiation* I) {
gum::Instantiation Itop = gum::Instantiation(samplingBN().cpt(nod));
Itop.forgetMaster();
Itop.erase(samplingBN().variable(nod));
gum::Instantiation Itop = gum::Instantiation(*I);
I->add(samplingBN().variable(nod));
GUM_TRACE(samplingBN().cpt(nod).extract(Itop));
I->chgVal(samplingBN().variable(nod),
samplingBN().cpt(nod).extract(Itop).draw());
}
......
#ifndef AGRUM_AGRUMAPPROXIMATIONUTILS_H_H
#define AGRUM_AGRUMAPPROXIMATIONUTILS_H_H
template < typename GUM_SCALAR >
void unsharpen(const gum::BayesNet< GUM_SCALAR >& bn) {
for (const auto nod : bn.nodes().asNodeSet()) {
bn.cpt(nod).translate(1);
bn.cpt(nod).normalizeAsCPT();
}
}
template < typename GUM_SCALAR, template < typename > class INFERENCE >
bool __compareInference(const gum::BayesNet< GUM_SCALAR >& bn,
gum::LazyPropagation< GUM_SCALAR >& lazy,
INFERENCE< GUM_SCALAR >& inf,
double errmax = 5e-2) {
GUM_SCALAR err = static_cast< GUM_SCALAR >(0);
std::string argstr = "";
for (const auto& node : bn.nodes()) {
if (!inf.BN().dag().exists(node)) continue;
GUM_SCALAR e;
try {
e = lazy.posterior(node).KL(inf.posterior(node));
} catch (gum::FatalError) {
// 0 in a proba
e = std::numeric_limits< GUM_SCALAR >::infinity();
}
catch (gum::NotFound e) {
continue;
}
if (e > err) {
err = e;
argstr =
bn.variable(node).name() + " (err=" + std::to_string(err) + ") : \n";
argstr += " lazy : " + lazy.posterior(node).toString() + "\n";
argstr += " inf : " + inf.posterior(node).toString() + " \n";
}
}
if (err > errmax) GUM_TRACE(argstr)
return err <= errmax;
}
#endif // AGRUM_AGRUMAPPROXIMATIONUTILS_H_H
......@@ -851,7 +851,7 @@ namespace gum_tests {
TS_ASSERT_EQUALS(total, 3);
}
void testEliminatationOffAllVariables() {
void /*test*/EliminatationOffAllVariables() {
auto a = gum::LabelizedVariable("a", "afoo", 3);
auto b = gum::LabelizedVariable("b", "bfoo", 3);
......
#include <iostream>
#include <string>
#include <agrum/BN/BayesNet.h>
#include <cxxtest/AgrumTestSuite.h>
#include <cxxtest/testsuite_utils.h>
#include <agrum/BN/BayesNet.h>
#include <agrum/BN/inference/MonteCarloSampling.h>
#include <agrum/BN/inference/lazyPropagation.h>
#include <agrum/variables/labelizedVariable.h>
......@@ -12,6 +12,8 @@
#include <agrum/BN/io/BIF/BIFReader.h>
#include <agrum/core/approximations/approximationSchemeListener.h>
#include <cxxtest/AgrumApproximationUtils.h> // must be last include
#define EPSILON_FOR_MONTECARLO_SIMPLE_TEST 9e-2
#define EPSILON_FOR_MONTECARLO 9e-2
......@@ -46,32 +48,53 @@ namespace gum_tests {
class MonteCarloSamplingTestSuite : public CxxTest::TestSuite {
public:
void testMCbasic() {
auto bn = gum::BayesNet< float >::fastPrototype("a->h->c");
bn.cpt("a").fillWith({0.2f, 0.8f});
bn.cpt("h").fillWith({0.4f, 0.6f, 0.7f, 0.3f});
bn.cpt("c").fillWith({0.2f, 0.8f, 0.9f, 0.1f});
gum::LazyPropagation< float > lazy(&bn);
lazy.makeInference();
try {
APPROXINFERENCE_TEST_BEGIN_ITERATION
gum::MonteCarloSampling< float > inf(&bn);
inf.setEpsilon(EPSILON_FOR_MONTECARLO);
inf.makeInference();
APPROXINFERENCE_TEST_END_ITERATION(EPSILON_FOR_MONTECARLO_SIMPLE_TEST)
} catch (gum::Exception& e) {
GUM_SHOWERROR(e);
TS_ASSERT(false);
}
}
void testMCBinaryTreeWithoutEvidence() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
unsharpen(bn);
gum::LazyPropagation< float > lazy(&bn);
lazy.makeInference();
GUM_TRACE(lazy.posterior("h"));
try {
APPROXINFERENCE_TEST_BEGIN_ITERATION
gum::MonteCarloSampling< float > inf(&bn);
inf.setEpsilon(EPSILON_FOR_MONTECARLO);
inf.makeInference();
GUM_TRACE(inf.messageApproximationScheme());
APPROXINFERENCE_TEST_END_ITERATION(EPSILON_FOR_MONTECARLO_SIMPLE_TEST)
} catch (gum::Exception& e) {
GUM_SHOWERROR(e);
TS_ASSERT(false);
}
GUM_TRACE(lazy.posterior("h"));
}
void /*test*/ MCBinaryTreeWithEvidenceOnRoot() {
void testMCBinaryTreeWithEvidenceOnRoot() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
unsharpen(bn);
std::string ev = "b";
try {
......@@ -93,9 +116,10 @@ namespace gum_tests {
}
}
void /*test*/ MCBinaryTreeWithEvidenceOnLeaf() {
void testMCBinaryTreeWithEvidenceOnLeaf() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
unsharpen(bn);
std::string ev = "h";
try {
......@@ -117,9 +141,10 @@ namespace gum_tests {
}
}
void /*test*/ MCBinaryTreeWithEvidenceOnMid() {
void testMCBinaryTreeWithEvidenceOnMid() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
unsharpen(bn);
std::string ev = "e";
try {
......@@ -142,9 +167,10 @@ namespace gum_tests {
}
}
void /*test*/ MCBinaryTreeWithMultipleEvidence() {
void testMCBinaryTreeWithMultipleEvidence() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
unsharpen(bn);
try {
......@@ -171,9 +197,10 @@ namespace gum_tests {
}
}
void /*test*/ MCNaryTreeWithMultipleEvidence() {
void testMCNaryTreeWithMultipleEvidence() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a[4]->d[8]->f[3];b->d->g[5];b->e[4]->h;c->e;i[10]->j[3]->h");
unsharpen(bn);
try {
......@@ -201,8 +228,9 @@ namespace gum_tests {
}
void /*test*/ MCSimpleBN() {
void testMCSimpleBN() {
auto bn = gum::BayesNet< float >::fastPrototype("a->b->c;a->d->c", 3);
unsharpen(bn);
try {
......@@ -264,7 +292,7 @@ namespace gum_tests {
}
void /*test*/ MCCplxBN() {
void testMCCplxBN() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e->g;i->j->h;c->j;x->c;x->j;", 3);
......@@ -326,7 +354,7 @@ namespace gum_tests {
}
}
void /*test*/ MCAsia() {
void testMCAsia() {
gum::BayesNet< float > bn;
gum::BIFReader< float > reader(&bn, GET_RESSOURCES_PATH("asia.bif"));
int nbrErr = 0;
......@@ -353,7 +381,7 @@ namespace gum_tests {
}
void /*test*/ MCAlarm() {
void testMCAlarm() {
gum::BayesNet< float > bn;
gum::BIFReader< float > reader(&bn, GET_RESSOURCES_PATH("alarm.bif"));
int nbrErr = 0;
......@@ -380,7 +408,7 @@ namespace gum_tests {
}
void /*test*/ MCInfListener() {
void testMCInfListener() {
gum::BayesNet< float > bn;
gum::BIFReader< float > reader(&bn, GET_RESSOURCES_PATH("alarm.bif"));
int nbrErr = 0;
......@@ -406,7 +434,7 @@ namespace gum_tests {
TS_ASSERT_DIFFERS(agsl.getMess(), std::string(""));
}
void /*test*/ Constructor() {
void testConstructor() {
gum::BayesNet< float > bn;
gum::BIFReader< float > reader(&bn, GET_RESSOURCES_PATH("alarm.bif"));
int nbrErr = 0;
......@@ -422,9 +450,10 @@ namespace gum_tests {
}
void /*test*/ EvidenceAsTargetOnCplxBN() {
void testEvidenceAsTargetOnCplxBN() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e->g;i->j->h;c->j;x->c;x->j;", 3);
unsharpen(bn);
try {
gum::MonteCarloSampling< float > inf(&bn);
......@@ -441,46 +470,5 @@ namespace gum_tests {
TS_ASSERT(false);
}
}
private:
template < typename GUM_SCALAR >
bool __compareInference(const gum::BayesNet< GUM_SCALAR >& bn,
gum::LazyPropagation< GUM_SCALAR >& lazy,
gum::MonteCarloSampling< GUM_SCALAR >& inf,
double errmax = 5e-2) {
GUM_SCALAR err = static_cast< GUM_SCALAR >(0);
std::string argstr = "";
for (const auto& node : bn.nodes()) {
if (!inf.BN().dag().exists(node)) continue;
GUM_SCALAR e;
try {
e = lazy.posterior(node).KL(inf.posterior(node));
} catch (gum::FatalError) {
// 0 in a proba
e = std::numeric_limits< GUM_SCALAR >::infinity();
}
catch (gum::NotFound e) {
continue;
}
if (e > err) {
err = e;
argstr =
bn.variable(node).name() + " (err=" + std::to_string(err) + ") : \n";
argstr += " lazy : " + lazy.posterior(node).toString() + "\n";
argstr += " inf : " + inf.posterior(node).toString() + " \n";
}
}
if (err > errmax) GUM_TRACE(argstr)
return err <= errmax;
}
};
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment