[pyAgrum] testing sampling inference leas to many changes

parent 09293318
......@@ -23,6 +23,7 @@
* BayesNetInference.
*/
#include <agrum/BN/inference/tools/BayesNetInference.h>
namespace gum {
......
......@@ -33,7 +33,7 @@
namespace gum {
template <typename GUM_SCALAR>
template < typename GUM_SCALAR >
class Estimator {
public:
......@@ -52,7 +52,7 @@ namespace gum {
/**
* Constructor with Bayesian Network
*/
Estimator( const IBayesNet<GUM_SCALAR>* bn );
Estimator(const IBayesNet< GUM_SCALAR >* bn);
/* Destructor */
~Estimator();
......@@ -65,14 +65,14 @@ namespace gum {
* sets the estimator object with 0-filled vectors corresponding
* to each non evidence node
*/
void setFromBN( const IBayesNet<GUM_SCALAR>* bn, const NodeSet& hardEvidence );
void setFromBN(const IBayesNet< GUM_SCALAR >* bn, const NodeSet& hardEvidence);
/**
* sets the estimatoor object with posteriors obtained by LoopyBeliefPropagation
*/
void setFromLBP( LoopyBeliefPropagation<GUM_SCALAR>* lbp,
const NodeSet& hardEvidence,
GUM_SCALAR virtualLBPSize );
void setFromLBP(LoopyBeliefPropagation< GUM_SCALAR >* lbp,
const NodeSet& hardEvidence,
GUM_SCALAR virtualLBPSize);
/** @} */
/// computes the maximum length of confidence interval for each possible value
......@@ -89,7 +89,7 @@ namespace gum {
*
* adds the sample weight to each node's given value in the estimator
*/
void update( Instantiation I, GUM_SCALAR w );
void update(Instantiation I, GUM_SCALAR w);
/// returns the posterior of a node
/**
......@@ -101,12 +101,18 @@ namespace gum {
*
* @throw NotFound if variable node is not in estimator.
*/
const Potential<GUM_SCALAR>& posterior( const DiscreteVariable& var );
const Potential< GUM_SCALAR >& posterior(const DiscreteVariable& var);
/// refresh the estimator state as empty
/** this function remove all the statistics in order to restart the
* computations.
*/
void clear();
private:
/// estimator represented by hashtable between each variable name and a vector
/// of cumulative sample weights
HashTable<std::string, std::vector<GUM_SCALAR>> _estimator;
HashTable< std::string, std::vector< GUM_SCALAR > > _estimator;
/// cumulated weights of all samples
GUM_SCALAR _wtotal;
......@@ -115,7 +121,7 @@ namespace gum {
Size _ntotal;
/// bayesian network on which approximation is done
const IBayesNet<GUM_SCALAR>* _bn;
const IBayesNet< GUM_SCALAR >* _bn;
/// returns expected value of Bernouilli variable (called by it's name) of
/// given parameter
......@@ -128,7 +134,7 @@ namespace gum {
* computes the amount of cumulative weights for paramater val over the amount
* of total cumulative weights
*/
GUM_SCALAR EV( std::string name, int val );
GUM_SCALAR EV(std::string name, int val);
/// returns variance of Bernouilli variable (called by it's name) of given
/// parameter
......@@ -140,16 +146,16 @@ namespace gum {
*
* computes variance for Bernouilli law using EV(name, val)
*/
GUM_SCALAR variance( std::string name, int val ); // variance corrigée
GUM_SCALAR variance(std::string name, int val); // variance corrigée
private:
/// the set of single posteriors computed during the last inference
/** the posteriors are owned by LazyPropagation. */
HashTable<std::string, Potential<GUM_SCALAR>*> __target_posteriors;
HashTable< std::string, Potential< GUM_SCALAR >* > __target_posteriors;
};
extern template class Estimator<float>;
extern template class Estimator<double>;
extern template class Estimator< float >;
extern template class Estimator< double >;
}
#include <agrum/BN/inference/tools/estimator_tpl.h>
......
......@@ -58,8 +58,7 @@ namespace gum {
INLINE Estimator< GUM_SCALAR >::~Estimator() {
GUM_DESTRUCTOR(Estimator);
// remove all the posteriors computed
for (const auto& pot : __target_posteriors)
delete pot.second;
clear();
}
......@@ -88,8 +87,8 @@ namespace gum {
}
}
/// we multiply the posteriors obtained by LoopyBeliefPropagation by the it's
/// number of iterations
// we multiply the posteriors obtained by LoopyBeliefPropagation by the it's
// number of iterations
template < typename GUM_SCALAR >
void
Estimator< GUM_SCALAR >::setFromLBP(LoopyBeliefPropagation< GUM_SCALAR >* lbp,
......@@ -191,4 +190,14 @@ namespace gum {
return ic_max;
}
template < typename GUM_SCALAR >
void Estimator< GUM_SCALAR >::clear() {
_estimator.clear();
_wtotal = (GUM_SCALAR)0;
_ntotal = Size(0);
for (const auto& pot : __target_posteriors)
delete pot.second;
__target_posteriors.clear();
}
}
......@@ -232,6 +232,12 @@ namespace gum {
virtual void _onAllMarginalTargetsErased(){};
virtual void _onStateChanged() {
if (this->isInferenceReady()) {
__estimator.clear();
}
};
private:
BayesNetFragment< GUM_SCALAR >* __samplingBN;
};
......
......@@ -419,5 +419,29 @@ namespace gum_tests {
TS_ASSERT(false);
}
}
void testMultipleInferenceWithSameEngine() {
auto bn = gum::BayesNet< float >::fastPrototype("a->b->c;a->d->c", 3);
unsharpen(bn);
try {
gum::GibbsSampling< float > inf(&bn);
inf.addEvidence(bn.idFromName("d"), 0);
inf.setVerbosity(false);
inf.setEpsilon(EPSILON_FOR_GIBBS);
inf.makeInference();
inf.eraseAllEvidence();
inf.addEvidence(bn.idFromName("d"), 0);
inf.setVerbosity(false);
inf.setEpsilon(EPSILON_FOR_GIBBS);
inf.makeInference();
} catch (gum::Exception& e) {
GUM_SHOWERROR(e);
TS_ASSERT(false);
}
}
};
} // namespace gum_tests
......@@ -24,9 +24,9 @@
#include <cxxtest/testsuite_utils.h>
#include <agrum/BN/BayesNet.h>
#include <agrum/BN/inference/loopySamplingInference.h>
#include <agrum/BN/inference/lazyPropagation.h>
#include <agrum/BN/inference/loopyBeliefPropagation.h>
#include <agrum/BN/inference/loopySamplingInference.h>
#include <agrum/multidim/multiDimArray.h>
#include <agrum/variables/labelizedVariable.h>
......@@ -67,7 +67,7 @@ namespace gum_tests {
class loopySamplingInferenceTestSuite : public CxxTest::TestSuite {
public:
void /*test*/HybridBinaryTreeWithoutEvidence() {
void /*test*/ HybridBinaryTreeWithoutEvidence() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
......@@ -92,7 +92,7 @@ namespace gum_tests {
}
void /*test*/HybridBinaryTreeWithEvidenceOnRoot() {
void /*test*/ HybridBinaryTreeWithEvidenceOnRoot() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
......@@ -117,7 +117,7 @@ namespace gum_tests {
}
}
void /*test*/HybridBinaryTreeWithEvidenceOnLeaf() {
void /*test*/ HybridBinaryTreeWithEvidenceOnLeaf() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
std::string ev = "h";
......@@ -141,7 +141,7 @@ namespace gum_tests {
}
}
void /*test*/HybridBinaryTreeWithEvidenceOnMid() {
void /*test*/ HybridBinaryTreeWithEvidenceOnMid() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
std::string ev = "e";
......@@ -166,7 +166,7 @@ namespace gum_tests {
}
}
void /*test*/HybridBinaryTreeWithMultipleEvidence() {
void /*test*/ HybridBinaryTreeWithMultipleEvidence() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e;i->j->h");
......@@ -196,7 +196,7 @@ namespace gum_tests {
}
void /*test*/HybridNaryTreeWithMultipleEvidence() {
void /*test*/ HybridNaryTreeWithMultipleEvidence() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a[4]->d[8]->f[3];b->d->g[5];b->e[4]->h;c->e;i[10]->j[3]->h");
......@@ -226,7 +226,7 @@ namespace gum_tests {
}
void /*test*/HybridSimpleBN() {
void /*test*/ HybridSimpleBN() {
auto bn = gum::BayesNet< float >::fastPrototype("a->b->c;a->d->c", 3);
try {
......@@ -289,7 +289,7 @@ namespace gum_tests {
}
void /*test*/HybridCplxBN() {
void /*test*/ HybridCplxBN() {
auto bn = gum::BayesNet< float >::fastPrototype(
"a->d->f;b->d->g;b->e->h;c->e->g;i->j->h;c->j;x->c;x->j;", 3);
......@@ -352,7 +352,7 @@ namespace gum_tests {
}
void /*test*/HybridAsia() {
void /*test*/ HybridAsia() {
gum::BayesNet< float > bn;
gum::BIFReader< float > reader(&bn, GET_RESSOURCES_PATH("asia.bif"));
......@@ -381,7 +381,7 @@ namespace gum_tests {
}
void /*test*/HybridAlarm() {
void /*test*/ HybridAlarm() {
gum::BayesNet< float > bn;
gum::BIFReader< float > reader(&bn, GET_RESSOURCES_PATH("alarm.bif"));
......@@ -409,29 +409,29 @@ namespace gum_tests {
}
}
void testMultipleInferenceWithSameEngine() {
auto bn = gum::BayesNet< float >::fastPrototype("a->b->c;a->d->c", 3);
unsharpen(bn);
void testMultipleInferenceWithSameEngine() {
auto bn = gum::BayesNet< float >::fastPrototype("a->b->c;a->d->c", 3);
unsharpen(bn);
try {
gum::LoopyGibbsSampling< float > inf(&bn);
inf.addEvidence(bn.idFromName("d"), 0);
inf.setVerbosity(false);
inf.setEpsilon(EPSILON_FOR_GIBBS);
inf.makeInference();
try {
gum::LoopySamplingInference< float, gum::WeightedSampling > inf(&bn);
inf.addEvidence(bn.idFromName("d"), 0);
inf.setVerbosity(false);
inf.setEpsilon(EPSILON_FOR_HARD_HYBRID);
inf.makeInference();
inf.eraseAllEvidence();
inf.addEvidence(bn.idFromName("d"), 0);
inf.setVerbosity(false);
inf.setEpsilon(EPSILON_FOR_GIBBS);
inf.makeInference();
inf.eraseAllEvidence();
inf.addEvidence(bn.idFromName("d"), 0);
inf.setVerbosity(false);
inf.setEpsilon(EPSILON_FOR_HARD_HYBRID);
inf.makeInference();
} catch (gum::Exception& e) {
} catch (gum::Exception& e) {
GUM_SHOWERROR(e);
TS_ASSERT(false);
}
GUM_SHOWERROR(e);
TS_ASSERT(false);
}
}
};
} /// gum_tests
......@@ -6375,7 +6375,7 @@ class BayesNetInference_double(_object):
__repr__ = _swig_repr
StateOfInference_OutdatedBNStructure = _pyAgrum.BayesNetInference_double_StateOfInference_OutdatedBNStructure
StateOfInference_OutdatedBNPotentials = _pyAgrum.BayesNetInference_double_StateOfInference_OutdatedBNPotentials
StateOfInference_InferenceReady = _pyAgrum.BayesNetInference_double_StateOfInference_InferenceReady
StateOfInference_ReadyForInference = _pyAgrum.BayesNetInference_double_StateOfInference_ReadyForInference
StateOfInference_Done = _pyAgrum.BayesNetInference_double_StateOfInference_Done
__swig_destroy__ = _pyAgrum.delete_BayesNetInference_double
def __del__(self):
......@@ -6401,6 +6401,21 @@ class BayesNetInference_double(_object):
return _pyAgrum.BayesNetInference_double_isInferenceReady(self)
def isInferenceOutdatedBNStructure(self) -> "bool":
"""isInferenceOutdatedBNStructure(self) -> bool"""
return _pyAgrum.BayesNetInference_double_isInferenceOutdatedBNStructure(self)
def isInferenceOutdatedBNPotentials(self) -> "bool":
"""isInferenceOutdatedBNPotentials(self) -> bool"""
return _pyAgrum.BayesNetInference_double_isInferenceOutdatedBNPotentials(self)
def isInferenceDone(self) -> "bool":
"""isInferenceDone(self) -> bool"""
return _pyAgrum.BayesNetInference_double_isInferenceDone(self)
def isDone(self) -> "bool":
"""isDone(self) -> bool"""
return _pyAgrum.BayesNetInference_double_isDone(self)
......
......@@ -6,17 +6,19 @@ from pyAgrumTestSuite import pyAgrumTestCase, addTests
class SamplingTestCase(pyAgrumTestCase):
def iterTest(self, proto, ie, target, evs, seuil=0.1, nbr=10):
def iterTest(self, goalPotential, inferenceEngine, target, evs, seuil=0.1, nbr=10):
min = 1000
for i in range(nbr):
ie.eraseAllEvidence()
ie.setEvidence(evs)
ie.makeInference()
result = ie.posterior(target)
diff = (proto - result).abs().max()
inferenceEngine.eraseAllEvidence()
inferenceEngine.setEvidence(evs)
inferenceEngine.makeInference()
result = inferenceEngine.posterior(target)
diff = (goalPotential - result).abs().max()
if diff <= seuil:
return ""
return None
else:
print("!", end="")
if min > diff:
min = diff
......@@ -28,11 +30,11 @@ class SamplingTestCase(pyAgrumTestCase):
def setUp(self):
self.bn = gum.fastBN("c->s{no|yes}->w{no|yes};c->r->w")
GibbsTestCase.unsharpen(self.bn)
SamplingTestCase.unsharpen(self.bn)
self.c, self.s, self.w, self.r = [self.bn.idFromName(s) for s in "cswr"]
self.bn2 = gum.fastBN("r2->s2->w2;r2->w2")
GibbsTestCase.unsharpen(self.bn2)
SamplingTestCase.unsharpen(self.bn2)
self.r2, self.s2, self.w2 = [self.bn2.idFromName(s) for s in ["s2", "w2", "r2"]]
......@@ -48,13 +50,17 @@ class TestDictFeature(SamplingTestCase):
ie.setVerbosity(False)
ie.setEpsilon(0.05)
ie.setMinEpsilonRate(0.001)
self.iterTest(proto, ie, self.r, {'s': [0, 1], 'w': (1, 0)})
msg = self.iterTest(proto, ie, self.r, {'s': [0, 1], 'w': (1, 0)})
if msg is not None:
self.fail(msg)
ie = gum.LoopyImportanceSampling(self.bn)
ie.setVerbosity(False)
ie.setEpsilon(0.05)
ie.setMinEpsilonRate(0.001)
self.iterTest(proto, ie, self.r, ({'s': 1, 'w': 0}))
msg = self.iterTest(proto, ie, self.r, ({'s': 1, 'w': 0}))
if msg is not None:
self.fail(msg)
def testDictOfLabels(self):
protoie = gum.LazyPropagation(self.bn)
......@@ -67,19 +73,17 @@ class TestDictFeature(SamplingTestCase):
ie.setVerbosity(False)
ie.setEpsilon(0.05)
ie.setMinEpsilonRate(0.001)
ie.setEvidence({'s': 0, 'w': 1})
ie.makeInference()
result = ie.posterior(self.r)
self.assertGreaterEqual(0.1, (proto - result).abs().max())
msg = self.iterTest(proto, ie, self.r, {'s': 0, 'w': 1})
if msg is not None:
self.fail(msg)
ie2 = gum.LoopyGibbsSampling(self.bn)
ie2.setVerbosity(False)
ie2.setEpsilon(0.05)
ie2.setMinEpsilonRate(0.001)
ie2.setEvidence({'s': 'no', 'w': 'yes'})
ie2.makeInference()
result2 = ie2.posterior(self.r)
self.assertGreaterEqual(0.1, (proto - result2).abs().max())
ie = gum.LoopyGibbsSampling(self.bn)
ie.setVerbosity(False)
ie.setEpsilon(0.05)
ie.setMinEpsilonRate(0.001)
msg = self.iterTest(proto, ie, self.r, {'s': 'no', 'w': 'yes'})
if msg is not None:
self.fail(msg)
def testDictOfLabelsWithId(self):
protoie = gum.LazyPropagation(self.bn)
......@@ -92,47 +96,47 @@ class TestDictFeature(SamplingTestCase):
ie.setVerbosity(False)
ie.setEpsilon(0.05)
ie.setMinEpsilonRate(0.01)
ie.setEvidence({self.s: 0, self.w: 1})
ie.makeInference()
result = ie.posterior(self.r)
self.assertGreaterEqual(0.1, (proto - result).abs().max())
msg = self.iterTest(proto, ie, self.r, {self.s: 0, self.w: 1})
if msg is not None:
self.fail(msg)
ie2 = gum.LoopyGibbsSampling(self.bn)
ie2.setVerbosity(False)
ie2.setEpsilon(0.05)
ie2.setMinEpsilonRate(0.01)
ie2.setEvidence({self.s: 'no', self.w: 'yes'})
ie2.makeInference()
result2 = ie2.posterior(self.r)
self.assertGreaterEqual(0.1, (proto - result2).abs().max())
ie = gum.LoopyGibbsSampling(self.bn)
ie.setVerbosity(False)
ie.setEpsilon(0.05)
ie.setMinEpsilonRate(0.01)
msg = self.iterTest(proto, ie, self.r, {self.s: 'no', self.w: 'yes'})
if msg is not None:
self.fail(msg)
def testWithDifferentVariables(self):
protoie = gum.LazyPropagation(self.bn)
protoie.addEvidence('s', 0)
protoie.addEvidence('w', 1)
protoie.makeInference()
proto = protoie.posterior(self.r)
proto = protoie.posterior(self.s)
ie = gum.LoopyWeightedSampling(self.bn)
ie.setVerbosity(False)
ie.setEpsilon(0.1)
ie.setMinEpsilonRate(0.01)
ie.setEvidence({'r': [0, 1], 'w': (1, 0)})
ie.makeInference()
result = ie.posterior(self.s)
self.assertGreaterEqual(0.1, (proto - result).abs().max())
# msg = self.iterTest(proto, ie, self.s, {'r': [0, 1], 'w': (1, 0)})
# if msg is not None:
# self.fail(msg)
ie = gum.LoopyWeightedSampling(self.bn)
ie.setVerbosity(False)
ie.setEpsilon(0.1)
ie.setMinEpsilonRate(0.01)
ie.setEvidence({'r': 1, 'w': 0})
ie.makeInference()
result2 = ie.posterior(self.s)
self.assertGreaterEqual(0.1, (proto - result2).abs().max())
ie2 = gum.LoopyGibbsSampling(self.bn)
ie2.setVerbosity(False)
ie2.setEpsilon(0.1)
ie2.setMinEpsilonRate(0.01)
ie2.setEvidence({'r': 1, 'w': 0})
ie2.makeInference()
print(ie.posterior('s'))
msg = self.iterTest(proto, ie2, self.s, {'r': 1, 'w': 0})
if msg is not None:
self.fail(msg)
class TestInferenceResults(GibbsTestCase):
class TestInferenceResults(SamplingTestCase):
def testOpenBayesSiteExamples(self):
protoie = gum.LazyPropagation(self.bn)
protoie.makeInference()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment