[aGrUM] still refactoring approximate inference

parent 889ec795
......@@ -30,44 +30,45 @@
namespace gum {
/// default constructor
template <typename GUM_SCALAR>
GibbsSampling<GUM_SCALAR>::GibbsSampling( const IBayesNet<GUM_SCALAR>* BN )
: ApproximateInference<GUM_SCALAR>( BN )
, GibbsOperator<GUM_SCALAR>( *BN ) {
GUM_CONSTRUCTOR( GibbsSampling );
template < typename GUM_SCALAR >
GibbsSampling< GUM_SCALAR >::GibbsSampling(const IBayesNet< GUM_SCALAR >* BN)
: ApproximateInference< GUM_SCALAR >(BN)
, GibbsOperator< GUM_SCALAR >(*BN) {
GUM_CONSTRUCTOR(GibbsSampling);
}
/// destructor
template <typename GUM_SCALAR>
GibbsSampling<GUM_SCALAR>::~GibbsSampling() {
GUM_DESTRUCTOR( GibbsSampling );
template < typename GUM_SCALAR >
GibbsSampling< GUM_SCALAR >::~GibbsSampling() {
GUM_DESTRUCTOR(GibbsSampling);
}
template <typename GUM_SCALAR>
Instantiation GibbsSampling<GUM_SCALAR>::_monteCarloSample() {
return GibbsOperator<GUM_SCALAR>::monteCarloSample( samplingBN() );
template < typename GUM_SCALAR >
Instantiation GibbsSampling< GUM_SCALAR >::_monteCarloSample() {
return GibbsOperator< GUM_SCALAR >::monteCarloSample();
}
template <typename GUM_SCALAR>
Instantiation GibbsSampling<GUM_SCALAR>::_burnIn() {
template < typename GUM_SCALAR >
Instantiation GibbsSampling< GUM_SCALAR >::_burnIn() {
gum::Instantiation Ip;
if ( this->burnIn() == 0 ) return Ip;
if (this->burnIn() == 0) return Ip;
float w = 1.;
Ip = _monteCarloSample();
for ( Size i = 1; i < this->burnIn(); i++ )
Ip = this->_draw( &w, Ip );
for (Size i = 1; i < this->burnIn(); i++)
Ip = this->_draw(&w, Ip);
return Ip;
}
/// draws next sample for gibbs sampling
template <typename GUM_SCALAR>
Instantiation GibbsSampling<GUM_SCALAR>::_draw( float* w, Instantiation prev ) {
return this->drawNextInstance( w, prev,samplingBN() );
template < typename GUM_SCALAR >
Instantiation GibbsSampling< GUM_SCALAR >::_draw(float* w, Instantiation prev) {
*w = 1.0;
return this->nextSample(prev);
}
}
......@@ -19,7 +19,8 @@
***************************************************************************/
/**
* @file
* @brief Implementation of Monte Carlo Sampling for inference in Bayesian Networks.
* @brief Implementation of Monte Carlo Sampling for inference in Bayesian
* Networks.
*
* @author Paul ALAM & Pierre-Henri WUILLEMIN
*/
......@@ -30,60 +31,54 @@
namespace gum {
/// Default constructor
template <typename GUM_SCALAR>
MonteCarloSampling<GUM_SCALAR>::MonteCarloSampling(const IBayesNet<GUM_SCALAR>* BN)
: ApproximateInference<GUM_SCALAR>(BN) {
/// Default constructor
template < typename GUM_SCALAR >
MonteCarloSampling< GUM_SCALAR >::MonteCarloSampling(
const IBayesNet< GUM_SCALAR >* BN)
: ApproximateInference< GUM_SCALAR >(BN) {
this->setBurnIn(0);
this->setBurnIn(0);
GUM_CONSTRUCTOR(MonteCarloSampling);
}
/// Destructor
template <typename GUM_SCALAR>
MonteCarloSampling<GUM_SCALAR>::~MonteCarloSampling() {
GUM_DESTRUCTOR(MonteCarloSampling);
template < typename GUM_SCALAR >
MonteCarloSampling< GUM_SCALAR >::~MonteCarloSampling() {
GUM_DESTRUCTOR(MonteCarloSampling);
}
/// no burn in needed for Monte Carlo sampling
template <typename GUM_SCALAR>
Instantiation MonteCarloSampling<GUM_SCALAR>::_burnIn(){
gum::Instantiation I;
return I;
template < typename GUM_SCALAR >
Instantiation MonteCarloSampling< GUM_SCALAR >::_burnIn() {
gum::Instantiation I;
return I;
}
template <typename GUM_SCALAR>
Instantiation MonteCarloSampling<GUM_SCALAR>::_draw(float* w, Instantiation prev, const IBayesNet<GUM_SCALAR>& bn, const NodeSet& hardEvNodes, const NodeProperty<Idx>& hardEv){
*w = 1.;
bool wrong_value = false;
do{
wrong_value = false; prev.clear();
for (auto nod: this->BN().topologicalOrder()){
this->_addVarSample(nod, &prev, this->BN());
if (this->hardEvidenceNodes().contains(nod) and prev.val(this->BN().variable(nod)) != this->hardEvidence()[nod]) {
wrong_value = true;
break;
}
}
} while (wrong_value);
return prev;
}
template < typename GUM_SCALAR >
Instantiation
MonteCarloSampling< GUM_SCALAR >::_draw(float* w,
Instantiation prev,
const IBayesNet< GUM_SCALAR >& bn,
const NodeSet& hardEvNodes,
const NodeProperty< Idx >& hardEv) {
*w = 1.;
bool wrong_value = false;
do {
wrong_value = false;
prev.clear();
for (auto nod : this->BN().topologicalOrder()) {
this->_addVarSample(nod, &prev);
if (this->hardEvidenceNodes().contains(nod) and
prev.val(this->BN().variable(nod)) != this->hardEvidence()[nod]) {
wrong_value = true;
break;
}
}
} while (wrong_value);
return prev;
}
}
......@@ -50,7 +50,6 @@ namespace gum {
*/
template <typename GUM_SCALAR, template <typename> class APPROX>
class HybridApproxInference : public APPROX<GUM_SCALAR> {
public:
......
......@@ -33,37 +33,31 @@
namespace gum {
template <typename GUM_SCALAR, template <typename> class APPROX>
HybridApproxInference<GUM_SCALAR, APPROX>::HybridApproxInference(
const IBayesNet<GUM_SCALAR>* BN )
: APPROX<GUM_SCALAR>( BN ) {
template < typename GUM_SCALAR, template < typename > class APPROX >
HybridApproxInference< GUM_SCALAR, APPROX >::HybridApproxInference(
const IBayesNet< GUM_SCALAR >* BN)
: APPROX< GUM_SCALAR >(BN) {
GUM_CONSTRUCTOR( HybridApproxInference );
GUM_CONSTRUCTOR(HybridApproxInference);
}
template <typename GUM_SCALAR, template <typename> class APPROX>
HybridApproxInference<GUM_SCALAR, APPROX>::~HybridApproxInference() {
template < typename GUM_SCALAR, template < typename > class APPROX >
HybridApproxInference< GUM_SCALAR, APPROX >::~HybridApproxInference() {
GUM_DESTRUCTOR( HybridApproxInference );
GUM_DESTRUCTOR(HybridApproxInference);
}
template <typename GUM_SCALAR, template <typename> class APPROX>
void HybridApproxInference<GUM_SCALAR, APPROX>::_makeInference() {
template < typename GUM_SCALAR, template < typename > class APPROX >
void HybridApproxInference< GUM_SCALAR, APPROX >::_makeInference() {
LoopyBeliefPropagation<GUM_SCALAR> lbp( &this->BN() );
lbp.setMaxIter( DEFAULT_LBP_MAX_ITER );
LoopyBeliefPropagation< GUM_SCALAR > lbp(&this->BN());
lbp.setMaxIter(DEFAULT_LBP_MAX_ITER);
lbp.makeInference();
const auto& bn = this->BN();
auto hardEv = this->hardEvidence();
auto hardEvNodes = this->hardEvidenceNodes();
if ( !this->isContextualized ) this->contextualize();
if ( !this->isSetEstimator )
this->_setEstimatorFromLBP( &lbp, this->hardEvidenceNodes() );
if (!this->isSetEstimator)
this->_setEstimatorFromLBP(&lbp);
this->initApproximationScheme();
gum::Instantiation Ip;
......@@ -74,12 +68,11 @@ namespace gum {
do {
Ip = this->_draw( &w, Ip );
this->__estimator.update( Ip, w );
Ip = this->_draw(&w, Ip);
this->__estimator.update(Ip, w);
this->updateApproximationScheme();
} while (
this->continueApproximationScheme( this->__estimator.confidence() ) );
} while (this->continueApproximationScheme(this->__estimator.confidence()));
this->isSetEstimator = false;
}
......
......@@ -31,108 +31,97 @@
namespace gum {
/// default constructor
template <typename GUM_SCALAR>
ImportanceSampling<GUM_SCALAR>::ImportanceSampling(
const IBayesNet<GUM_SCALAR>* BN )
: ApproximateInference<GUM_SCALAR>( BN ) {
template < typename GUM_SCALAR >
ImportanceSampling< GUM_SCALAR >::ImportanceSampling(
const IBayesNet< GUM_SCALAR >* BN)
: ApproximateInference< GUM_SCALAR >(BN) {
this->setBurnIn( 0 );
GUM_CONSTRUCTOR( ImportanceSampling );
this->setBurnIn(0);
GUM_CONSTRUCTOR(ImportanceSampling);
}
/// destructor
template <typename GUM_SCALAR>
ImportanceSampling<GUM_SCALAR>::~ImportanceSampling() {
template < typename GUM_SCALAR >
ImportanceSampling< GUM_SCALAR >::~ImportanceSampling() {
GUM_DESTRUCTOR( ImportanceSampling );
GUM_DESTRUCTOR(ImportanceSampling);
}
/// no burn in needed for Importance sampling
template <typename GUM_SCALAR>
Instantiation ImportanceSampling<GUM_SCALAR>::_burnIn() {
template < typename GUM_SCALAR >
Instantiation ImportanceSampling< GUM_SCALAR >::_burnIn() {
Instantiation I;
return I;
}
template <typename GUM_SCALAR>
Instantiation ImportanceSampling<GUM_SCALAR>::_draw( float* w,
Instantiation prev ) {
template < typename GUM_SCALAR >
Instantiation ImportanceSampling< GUM_SCALAR >::_draw(float* w,
Instantiation prev) {
GUM_SCALAR pSurQ = 1.;
do {
prev.clear();
pSurQ = 1.;
for ( auto ev : this->hardEvidenceNodes()) {
prev.add( BN().variable( ev ) );
prev.chgVal( BN().variable( ev ), this->evidence()[ev] );
for (auto ev : this->hardEvidenceNodes()) {
prev.add(this->BN().variable(ev));
prev.chgVal(this->BN().variable(ev), this->hardEvidence()[ev]);
}
for ( auto nod : this->BN().topologicalOrder() ) {
this->_addVarSample( nod, &prev);
auto probaQ = BN().cpt( nod ).get( prev );
auto probaP = _frag->cpt( nod ).get( prev );
if ( ( probaP == 0 ) || ( probaQ == 0 ) ) {
for (auto nod : this->samplingBN().topologicalOrder()) {
this->_addVarSample(nod, &prev);
auto probaP = this->BN().cpt(nod).get(prev);
auto probaQ = this->samplingBN().cpt(nod).get(prev);
if ((probaP == 0) || (probaQ == 0)) {
pSurQ = 0;
continue;
} else {
pSurQ = probaP / probaQ;
}
pSurQ = probaP / probaQ;
}
if ( pSurQ > 0.0 ) {
for ( auto ev = hardEvNodes.beginSafe(); ev != hardEvNodes.endSafe();
++ev ) {
pSurQ *= bn.cpt( *ev ).get( prev );
if (pSurQ > 0.0) {
for (auto ev : this->hardEvidenceNodes()) {
pSurQ *= this->samplingBN().cpt(ev).get(prev);
}
}
} while ( pSurQ == 0 );
} while (pSurQ == 0);
*w = pSurQ;
return prev;
}
template <typename GUM_SCALAR>
void
ImportanceSampling<GUM_SCALAR>::_unsharpenBN( BayesNetFragment<GUM_SCALAR>* bn,
float epsilon ) {
for ( auto nod : bn->nodes().asNodeSet() ) {
Potential<GUM_SCALAR>* p = new Potential<GUM_SCALAR>();
*p = bn->cpt( nod ).isNonZeroMap().scale( epsilon ) + bn->cpt( nod );
template < typename GUM_SCALAR >
void ImportanceSampling< GUM_SCALAR >::_unsharpenBN(
BayesNetFragment< GUM_SCALAR >* bn, float epsilon) {
for (auto nod : bn->nodes().asNodeSet()) {
Potential< GUM_SCALAR >* p = new Potential< GUM_SCALAR >();
*p = bn->cpt(nod).isNonZeroMap().scale(epsilon) + bn->cpt(nod);
p->normalizeAsCPT();
bn->installCPT( nod, p );
bn->installCPT(nod, p);
}
}
template <typename GUM_SCALAR>
void ImportanceSampling<GUM_SCALAR>::_onContextualize(
BayesNetFragment<GUM_SCALAR>* bn ) {
/*
Sequence<NodeId> sid;
for ( NodeSet::iterator ev = hardEvNodes.begin(); ev != hardEvNodes.end();
++ev )
sid << *ev;
*/
auto hardEvNodes = this->hardEvidenceNodes();
auto hardEv = this->hardEvidence();
auto targets = this->targets();
GUM_CHECKPOINT;
for ( auto ev : hardEvNodes ) {
template < typename GUM_SCALAR >
void ImportanceSampling< GUM_SCALAR >::_onContextualize(
BayesNetFragment< GUM_SCALAR >* bn) {
GUM_TRACE_VAR(this->hardEvidenceNodes());
for (auto ev : this->hardEvidenceNodes()) {
GUM_CHECKPOINT;
bn->uninstallCPT( ev );
GUM_TRACE_VAR( ev );
GUM_TRACE_VAR( this->hardEvidenceNodes() );
bn->installCPT( ev, new Potential<GUM_SCALAR>( *this->evidence()[ev] ) );
bn->uninstallCPT(ev);
GUM_TRACE_VAR(ev);
GUM_TRACE_VAR(this->hardEvidenceNodes());
bn->installCPT(ev, new Potential< GUM_SCALAR >(*this->evidence()[ev]));
GUM_CHECKPOINT;
// we keep the variables with hard evidence but alone
// bn->uninstallNode( sid[i] );
}
GUM_CHECKPOINT;
for ( auto targ = targets.begin(); targ != targets.end(); ++targ ) {
if ( this->BN().dag().exists( *targ ) ) this->addTarget( *targ );
for (auto targ : this->targets()) {
if (this->BN().dag().exists(targ)) this->addTarget(targ);
}
GUM_CHECKPOINT;
......@@ -140,6 +129,6 @@ namespace gum {
auto minAccepted = this->epsilon() / bn->maxVarDomainSize();
GUM_CHECKPOINT;
if ( minParam < minAccepted ) this->_unsharpenBN( bn, minAccepted );
if (minParam < minAccepted) this->_unsharpenBN(bn, minAccepted);
}
}
......@@ -73,14 +73,17 @@ namespace gum {
}
template < typename GUM_SCALAR >
const IBayesNet< GUM_SCALAR >& ApproximateInference< GUM_SCALAR >::samplingBN() {
INLINE const IBayesNet< GUM_SCALAR >&
ApproximateInference< GUM_SCALAR >::samplingBN() {
this->prepareInference();
return *__samplingBN;
if (__samplingBN == nullptr)
return this->BN();
else
return *__samplingBN;
}
template < typename GUM_SCALAR >
void ApproximateInference< GUM_SCALAR >::_setEstimatorFromBN() {
__estimator.setFromBN(__samplingBN, this->hardEvidenceNodes());
__estimator.setFromBN(&samplingBN(), this->hardEvidenceNodes());
this->isSetEstimator = true;
}
......@@ -160,8 +163,6 @@ namespace gum {
//@todo This should be in __prepareInference
if (!isContextualized) {
this->contextualize();
} else {
__samplingBN = &(this->BN());
}
if (!isSetEstimator) this->_setEstimatorFromBN();
......@@ -187,11 +188,11 @@ namespace gum {
template < typename GUM_SCALAR >
void ApproximateInference< GUM_SCALAR >::_addVarSample(NodeId nod,
Instantiation* I) {
gum::Instantiation Itop = gum::Instantiation(__samplingBN->cpt(nod));
Itop.erase(__samplingBN->variable(nod));
gum::Instantiation Itop = gum::Instantiation(samplingBN().cpt(nod));
Itop.erase(samplingBN().variable(nod));
I->add(__samplingBN->variable(nod));
I->chgVal(__samplingBN->variable(nod),
__samplingBN->cpt(nod).extract(*I).draw());
I->add(samplingBN().variable(nod));
I->chgVal(samplingBN().variable(nod),
samplingBN().cpt(nod).extract(*I).draw());
}
}
......@@ -90,7 +90,7 @@ namespace gum {
}
this->_addVarSample(nod, &prev, this->BN());
this->_addVarSample(nod, &prev);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment