[pyAgrum] Adding all sampling classes

parent 9b8584cb
......@@ -32,12 +32,12 @@ namespace gum {
/// default constructor
template <typename GUM_SCALAR>
ImportanceSampling<GUM_SCALAR>::ImportanceSampling(const IBayesNet<GUM_SCALAR>* BN)
: ApproximateInference<GUM_SCALAR>(BN) {
this->setBurnIn(0);
GUM_CONSTRUCTOR(ImportanceSampling);
ImportanceSampling<GUM_SCALAR>::ImportanceSampling(
const IBayesNet<GUM_SCALAR>* BN )
: ApproximateInference<GUM_SCALAR>( BN ) {
this->setBurnIn( 0 );
GUM_CONSTRUCTOR( ImportanceSampling );
}
......@@ -45,93 +45,100 @@ namespace gum {
template <typename GUM_SCALAR>
ImportanceSampling<GUM_SCALAR>::~ImportanceSampling() {
GUM_DESTRUCTOR(ImportanceSampling);
GUM_DESTRUCTOR( ImportanceSampling );
}
/// no burn in needed for Importance sampling
template <typename GUM_SCALAR>
Instantiation ImportanceSampling<GUM_SCALAR>::_burnIn(){
Instantiation I;
return I;
Instantiation ImportanceSampling<GUM_SCALAR>::_burnIn() {
Instantiation I;
return I;
}
template <typename GUM_SCALAR>
Instantiation ImportanceSampling<GUM_SCALAR>::_draw(float* w, Instantiation prev, const IBayesNet<GUM_SCALAR>& bn, const NodeSet& hardEvNodes, const NodeProperty<Idx>& hardEv){
float probaP = 1.; float probaQ = 1.;
do {
prev.clear(); probaP = 1. ; probaQ = 1.;
for (auto ev = hardEvNodes.beginSafe(); ev != hardEvNodes.endSafe(); ++ev) {
prev.add(bn.variable(*ev));
prev.chgVal(bn.variable(*ev), hardEv[*ev]);
}
for (auto nod: this->BN().topologicalOrder()){
this->_addVarSample(nod, &prev, this->BN());
probaQ *= this->BN().cpt(nod).get(prev);
probaP *= bn.cpt(nod).get(prev);
}
for (auto ev = hardEvNodes.beginSafe(); ev != hardEvNodes.endSafe(); ++ev)
probaP *= bn.cpt(*ev).get(prev);
} while(probaP == 0);
*w = probaP/probaQ;
return prev;
}
Instantiation
ImportanceSampling<GUM_SCALAR>::_draw( float* w,
Instantiation prev,
const IBayesNet<GUM_SCALAR>& bn,
const NodeSet& hardEvNodes,
const NodeProperty<Idx>& hardEv ) {
GUM_SCALAR pSurQ = 1.;
do {
prev.clear();
pSurQ = 1.;
for ( auto ev = hardEvNodes.beginSafe(); ev != hardEvNodes.endSafe();
++ev ) {
prev.add( bn.variable( *ev ) );
prev.chgVal( bn.variable( *ev ), hardEv[*ev] );
}
for ( auto nod : this->BN().topologicalOrder() ) {
this->_addVarSample( nod, &prev, this->BN() );
auto probaQ = this->BN().cpt( nod ).get( prev );
auto probaP = bn.cpt( nod ).get( prev );
if ( ( probaP == 0 ) || ( probaQ == 0 ) ) {
pSurQ = 0;
continue;
}
pSurQ = probaP / probaQ;
}
if ( pSurQ > 0.0 ) {
for ( auto ev = hardEvNodes.beginSafe(); ev != hardEvNodes.endSafe();
++ev ) {
pSurQ *= bn.cpt( *ev ).get( prev );
}
}
} while ( pSurQ == 0 );
*w = pSurQ;
return prev;
}
template <typename GUM_SCALAR>
void ImportanceSampling<GUM_SCALAR>::_unsharpenBN (BayesNetFragment<GUM_SCALAR>* bn, float epsilon ){
for (auto nod: bn->nodes().asNodeSet()) {
Potential<GUM_SCALAR> *p = new Potential<GUM_SCALAR>();
*p = bn->cpt(nod).isNonZeroMap().scale(epsilon) + bn->cpt(nod);
p->normalizeAsCPT();
bn->installCPT(nod, p);
}
}
void
ImportanceSampling<GUM_SCALAR>::_unsharpenBN( BayesNetFragment<GUM_SCALAR>* bn,
float epsilon ) {
GUM_CHECKPOINT;
for ( auto nod : bn->nodes().asNodeSet() ) {
Potential<GUM_SCALAR>* p = new Potential<GUM_SCALAR>();
*p = bn->cpt( nod ).isNonZeroMap().scale( epsilon ) + bn->cpt( nod );
p->normalizeAsCPT();
bn->installCPT( nod, p );
}
}
template <typename GUM_SCALAR>
void ImportanceSampling<GUM_SCALAR>::_onContextualize(BayesNetFragment<GUM_SCALAR>* bn, const NodeSet& targets ,const NodeSet& hardEvNodes, const NodeProperty<Idx>& hardEv){
Sequence<NodeId> sid;
for (NodeSet::iterator ev = hardEvNodes.begin(); ev != hardEvNodes.end(); ++ev)
sid << *ev;
for (Size i = 0; i < sid.size(); i++){
void ImportanceSampling<GUM_SCALAR>::_onContextualize(
BayesNetFragment<GUM_SCALAR>* bn,
const NodeSet& targets,
const NodeSet& hardEvNodes,
const NodeProperty<Idx>& hardEv ) {
bn->uninstallCPT(sid[i]);
bn->uninstallNode(sid[i]);
GUM_CHECKPOINT;
Sequence<NodeId> sid;
for ( NodeSet::iterator ev = hardEvNodes.begin(); ev != hardEvNodes.end();
++ev )
sid << *ev;
}
for ( Size i = 0; i < sid.size(); i++ ) {
for (auto targ = targets.begin(); targ != targets.end(); ++targ) {
bn->uninstallCPT( sid[i] );
bn->uninstallNode( sid[i] );
}
if (this->BN().dag().exists(*targ))
this->addTarget(*targ);
for ( auto targ = targets.begin(); targ != targets.end(); ++targ ) {
}
if ( this->BN().dag().exists( *targ ) ) this->addTarget( *targ );
}
auto minParam = bn->minNonZeroParam();
auto minAccepted = this->epsilon() / bn->maxVarDomainSize();
if (minParam < minAccepted)
this->_unsharpenBN(bn, minAccepted);
auto minParam = bn->minNonZeroParam();
auto minAccepted = this->epsilon() / bn->maxVarDomainSize();
if ( minParam < minAccepted ) this->_unsharpenBN( bn, minAccepted );
}
}
......@@ -386,6 +386,28 @@ namespace gum_tests {
}
}
void testImportanceDiabetes() {
gum::BayesNet<double> bn;
gum::BIFReader<double> reader( &bn, GET_RESSOURCES_PATH( "Diabetes.bif" ) );
int nbrErr = 0;
TS_GUM_ASSERT_THROWS_NOTHING( nbrErr = reader.proceed() );
TS_ASSERT( nbrErr == 0 );
try {
gum::ImportanceSampling<double> inf( &bn );
inf.setVerbosity( false );
inf.setMaxTime( 5 );
inf.setEpsilon( EPSILON_FOR_IMPORTANCE );
inf.makeInference();
} catch ( gum::Exception& e ) {
GUM_SHOWERROR( e );
TS_ASSERT( false );
}
TS_ASSERT( true );
}
void testImportanceInfListener() {
gum::BayesNet<float> bn;
......@@ -395,7 +417,7 @@ namespace gum_tests {
TS_ASSERT( nbrErr == 0 );
gum::ImportanceSampling<float> inf( &bn );
aSimpleImportanceListener agsl( inf );
aSimpleImportanceListener agsl( inf );
inf.setVerbosity( true );
try {
......@@ -413,10 +435,10 @@ namespace gum_tests {
private:
template <typename GUM_SCALAR>
bool __compareInference( const gum::BayesNet<GUM_SCALAR>& bn,
gum::LazyPropagation<GUM_SCALAR>& lazy,
gum::ImportanceSampling<GUM_SCALAR>& inf,
double errmax = 5e-2 ) {
bool __compareInference( const gum::BayesNet<GUM_SCALAR>& bn,
gum::LazyPropagation<GUM_SCALAR>& lazy,
gum::ImportanceSampling<GUM_SCALAR>& inf,
double errmax = 5e-2 ) {
GUM_SCALAR err = static_cast<GUM_SCALAR>( 0 );
std::string argstr = "";
......
......@@ -44,7 +44,7 @@ from .pyAgrum import DiscretizedVariable, LabelizedVariable, RangeVariable, Disc
from .pyAgrum import Potential, Instantiation, UtilityTable
from .pyAgrum import BruteForceKL, GibbsSampling
from .pyAgrum import LazyPropagation, ShaferShenoyInference, VariableElimination
from .pyAgrum import LoopyBeliefPropagation, GibbsSampling
from .pyAgrum import LoopyBeliefPropagation, GibbsSampling, MonteCarloSampling, ImportanceSampling, WeightedSampling
from .pyAgrum import PythonApproximationListener, PythonBNListener, PythonLoadListener
from .pyAgrum import BNGenerator, IDGenerator, JTGenerator
from .pyAgrum import BNLearner
......@@ -67,7 +67,8 @@ __all__=[
'DiscretizedVariable','LabelizedVariable','RangeVariable','DiscreteVariable',
'Potential','Instantiation','UtilityTable',
'BruteForceKL','GibbsKL',
'LoopyBeliefPropagation','GibbsSampling','LazyPropagation','ShaferShenoyInference','VariableElimination',
'LoopyBeliefPropagation','GibbsSampling','MonteCarloSampling', 'ImportanceSampling', 'WeightedSampling',
'LazyPropagation','ShaferShenoyInference','VariableElimination',
'PythonApproximationListener','PythonBNListener','PythonLoadListener',
'BNGenerator','IDGenerator','JTGenerator',
'BNLearner',
......
......@@ -52,6 +52,9 @@ IMPROVE_INFERENCE_API(LazyPropagation)
IMPROVE_INFERENCE_API(ShaferShenoyInference)
IMPROVE_INFERENCE_API(VariableElimination)
IMPROVE_INFERENCE_API(GibbsSampling)
IMPROVE_INFERENCE_API(ImportanceSampling)
IMPROVE_INFERENCE_API(WeightedSampling)
IMPROVE_INFERENCE_API(MonteCarloSampling)
IMPROVE_INFERENCE_API(LoopyBeliefPropagation)
......@@ -113,4 +116,7 @@ IMPROVE_EXACT_INFERENCE_API(VariableElimination)
%}
%enddef
IMPROVE_APPROX_INFERENCE_API(GibbsSampling)
IMPROVE_APPROX_INFERENCE_API(ImportanceSampling)
IMPROVE_APPROX_INFERENCE_API(WeightedSampling)
IMPROVE_APPROX_INFERENCE_API(MonteCarloSampling)
IMPROVE_APPROX_INFERENCE_API(LoopyBeliefPropagation)
......@@ -82,6 +82,10 @@ ShaferShenoyInference = ShaferShenoyInference_double
VariableElimination = VariableElimination_double
GibbsSampling = GibbsSampling_double
ImportanceSampling=ImportanceSampling_double
WeightedSampling=WeightedSampling_double
MonteCarloSampling=MonteCarloSampling_double
LoopyBeliefPropagation = LoopyBeliefPropagation_double
BruteForceKL = BruteForceKL_double
......
......@@ -87,6 +87,10 @@
#include <agrum/BN/inference/variableElimination.h>
#include <agrum/BN/inference/GibbsSampling.h>
#include <agrum/BN/inference/importanceSampling.h>
#include <agrum/BN/inference/weightedSampling.h>
#include <agrum/BN/inference/MonteCarloSampling.h>
#include <agrum/BN/inference/loopyBeliefPropagation.h>
#include <agrum/BN/algorithms/divergence/KL.h>
......@@ -171,8 +175,14 @@ namespace std {
}
%enddef
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::GibbsSampling<double>)
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::ImportanceSampling<double>)
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::WeightedSampling<double>)
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::MonteCarloSampling<double>)
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::LoopyBeliefPropagation<double>)
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::GibbsKL<double>)
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::credal::CNMonteCarloSampling<double>)
ADD_APPROXIMATIONSCHEME_API(gum::ApproximationScheme,gum::credal::CNLoopyPropagation<double>)
ADD_APPROXIMATIONSCHEME_API(gum::learning::genericBNLearner,gum::learning::BNLearner<double>)
......@@ -303,6 +313,10 @@ ADD_APPROXIMATIONSCHEME_API(gum::learning::genericBNLearner,gum::learning::BNLea
%include <agrum/BN/inference/variableElimination.h>
%include <agrum/BN/inference/GibbsSampling.h>
%include <agrum/BN/inference/importanceSampling.h>
%include <agrum/BN/inference/weightedSampling.h>
%include <agrum/BN/inference/MonteCarloSampling.h>
%include <agrum/BN/inference/loopyBeliefPropagation.h>
%import <agrum/BN/algorithms/divergence/KL.h>
......@@ -352,6 +366,10 @@ ADD_APPROXIMATIONSCHEME_API(gum::learning::genericBNLearner,gum::learning::BNLea
%template ( VariableElimination_double ) gum::VariableElimination<double>;
%template ( GibbsSampling_double ) gum::GibbsSampling<double>;
%template ( ImportanceSampling_double ) gum::ImportanceSampling<double>;
%template ( WeightedSampling_double ) gum::WeightedSampling<double>;
%template ( MonteCarloSampling_double ) gum::MonteCarloSampling<double>;
%template ( LoopyBeliefPropagation_double ) gum::LoopyBeliefPropagation<double>;
%template ( BruteForceKL_double ) gum::BruteForceKL<double>;
......
......@@ -345,6 +345,9 @@ ADD_INFERENCE_API(gum::LazyPropagation<double>)
ADD_INFERENCE_API(gum::ShaferShenoyInference<double>)
ADD_INFERENCE_API(gum::VariableElimination<double>)
ADD_INFERENCE_API(gum::GibbsSampling<double>)
ADD_INFERENCE_API(gum::MonteCarloSampling<double>)
ADD_INFERENCE_API(gum::WeightedSampling<double>)
ADD_INFERENCE_API(gum::ImportanceSampling<double>)
ADD_INFERENCE_API(gum::LoopyBeliefPropagation<double>)
%define ADD_JOINT_INFERENCE_API(classname)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment