[agrum] updating API for marginal targets

parent 9f8d92c5
...@@ -346,7 +346,7 @@ namespace gum { ...@@ -346,7 +346,7 @@ namespace gum {
// however, note that the nodes that received hard evidence do not belong to // however, note that the nodes that received hard evidence do not belong to
// the graph and, therefore, should not be taken into account // the graph and, therefore, should not be taken into account
const auto& hard_ev_nodes = this->hardEvidenceNodes(); const auto& hard_ev_nodes = this->hardEvidenceNodes();
for ( const auto node : this->allTargets() ) { for ( const auto node : this->targets() ) {
if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true; if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true;
} }
for ( const auto& nodes : this->jointTargets() ) { for ( const auto& nodes : this->jointTargets() ) {
...@@ -412,7 +412,7 @@ namespace gum { ...@@ -412,7 +412,7 @@ namespace gum {
// the BN without altering the inference output // the BN without altering the inference output
if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) { if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) {
// identify the barren nodes // identify the barren nodes
NodeSet target_nodes = this->allTargets(); NodeSet target_nodes = this->targets();
for ( const auto& nodeset : this->jointTargets() ) { for ( const auto& nodeset : this->jointTargets() ) {
target_nodes += nodeset; target_nodes += nodeset;
} }
...@@ -999,7 +999,7 @@ namespace gum { ...@@ -999,7 +999,7 @@ namespace gum {
void ShaferShenoyInference<GUM_SCALAR>::__computeJoinTreeRoots() { void ShaferShenoyInference<GUM_SCALAR>::__computeJoinTreeRoots() {
// get the set of cliques in which we can find the targets and joint_targets // get the set of cliques in which we can find the targets and joint_targets
NodeSet clique_targets; NodeSet clique_targets;
for ( const auto node : this->allTargets() ) { for ( const auto node : this->targets() ) {
try { try {
clique_targets.insert( __node_to_clique[node] ); clique_targets.insert( __node_to_clique[node] );
} catch ( Exception& ) { } catch ( Exception& ) {
...@@ -1288,7 +1288,7 @@ namespace gum { ...@@ -1288,7 +1288,7 @@ namespace gum {
template <typename GUM_SCALAR> template <typename GUM_SCALAR>
INLINE void ShaferShenoyInference<GUM_SCALAR>::_makeInference() { INLINE void ShaferShenoyInference<GUM_SCALAR>::_makeInference() {
// collect messages for all single targets // collect messages for all single targets
for ( const auto node : this->allTargets() ) { for ( const auto node : this->targets() ) {
// perform only collects in the join tree for nodes that have // perform only collects in the join tree for nodes that have
// not received hard evidence (those that received hard evidence were // not received hard evidence (those that received hard evidence were
// not included into the join tree for speed-up reasons) // not included into the join tree for speed-up reasons)
......
...@@ -385,7 +385,7 @@ namespace gum { ...@@ -385,7 +385,7 @@ namespace gum {
// however, note that the nodes that received hard evidence do not belong to // however, note that the nodes that received hard evidence do not belong to
// the graph and, therefore, should not be taken into account // the graph and, therefore, should not be taken into account
const auto& hard_ev_nodes = this->hardEvidenceNodes(); const auto& hard_ev_nodes = this->hardEvidenceNodes();
for ( const auto node : this->allTargets() ) { for ( const auto node : this->targets() ) {
if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true; if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true;
} }
for ( const auto& nodes : this->jointTargets() ) { for ( const auto& nodes : this->jointTargets() ) {
...@@ -452,7 +452,7 @@ namespace gum { ...@@ -452,7 +452,7 @@ namespace gum {
// altering the inference output // altering the inference output
if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) { if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) {
// identify the barren nodes // identify the barren nodes
NodeSet target_nodes = this->allTargets(); NodeSet target_nodes = this->targets();
for ( const auto& nodeset : this->jointTargets() ) { for ( const auto& nodeset : this->jointTargets() ) {
target_nodes += nodeset; target_nodes += nodeset;
} }
...@@ -968,7 +968,7 @@ namespace gum { ...@@ -968,7 +968,7 @@ namespace gum {
void LazyPropagation<GUM_SCALAR>::__computeJoinTreeRoots() { void LazyPropagation<GUM_SCALAR>::__computeJoinTreeRoots() {
// get the set of cliques in which we can find the targets and joint_targets // get the set of cliques in which we can find the targets and joint_targets
NodeSet clique_targets; NodeSet clique_targets;
for ( const auto node : this->allTargets() ) { for ( const auto node : this->targets() ) {
try { try {
clique_targets.insert( __node_to_clique[node] ); clique_targets.insert( __node_to_clique[node] );
} catch ( Exception& ) { } catch ( Exception& ) {
...@@ -1347,7 +1347,7 @@ namespace gum { ...@@ -1347,7 +1347,7 @@ namespace gum {
template <typename GUM_SCALAR> template <typename GUM_SCALAR>
INLINE void LazyPropagation<GUM_SCALAR>::_makeInference() { INLINE void LazyPropagation<GUM_SCALAR>::_makeInference() {
// collect messages for all single targets // collect messages for all single targets
for ( const auto node : this->allTargets() ) { for ( const auto node : this->targets() ) {
// perform only collects in the join tree for nodes that have // perform only collects in the join tree for nodes that have
// not received hard evidence (those that received hard evidence were // not received hard evidence (those that received hard evidence were
// not included into the join tree for speed-up reasons) // not included into the join tree for speed-up reasons)
......
...@@ -50,21 +50,19 @@ namespace gum { ...@@ -50,21 +50,19 @@ namespace gum {
* the current state of the inference. Note that the MarginalTargetedInference * the current state of the inference. Note that the MarginalTargetedInference
* is designed to be used in incremental inference engines. * is designed to be used in incremental inference engines.
*/ */
template <typename GUM_SCALAR> template<typename GUM_SCALAR>
class MarginalTargetedInference : public virtual BayesNetInference<GUM_SCALAR> { class MarginalTargetedInference : public virtual BayesNetInference<GUM_SCALAR> {
public: public:
// ############################################################################ // ############################################################################
/// @name Constructors / Destructors /// @name Constructors / Destructors
// ############################################################################ // ############################################################################
/// @{ /// @{
/// default constructor /// default constructor
/** @warning By default (when the targets set is empty), all the nodes of /** @warning By default, all the nodes of the Bayes net are targets.
* the Bayes net are considered as targets. Once a first target is added to the
* set, the remaining nodes are not considered as default targets anymore.
* @warning note that, by aGrUM's rule, the BN is not copied but only * @warning note that, by aGrUM's rule, the BN is not copied but only
* referenced by the inference algorithm. */ * referenced by the inference algorithm. */
MarginalTargetedInference( const IBayesNet<GUM_SCALAR>* bn ); MarginalTargetedInference(const IBayesNet<GUM_SCALAR> *bn);
/// destructor /// destructor
virtual ~MarginalTargetedInference(); virtual ~MarginalTargetedInference();
...@@ -92,7 +90,7 @@ namespace gum { ...@@ -92,7 +90,7 @@ namespace gum {
* *
* @throw UndefinedElement if node is not in the set of targets * @throw UndefinedElement if node is not in the set of targets
*/ */
virtual const Potential<GUM_SCALAR>& posterior( const NodeId node ); virtual const Potential<GUM_SCALAR> &posterior(const NodeId node);
/// Computes and returns the posterior of a node. /// Computes and returns the posterior of a node.
/** /**
...@@ -110,7 +108,7 @@ namespace gum { ...@@ -110,7 +108,7 @@ namespace gum {
* *
* @throw UndefinedElement if node is not in the set of targets * @throw UndefinedElement if node is not in the set of targets
*/ */
virtual const Potential<GUM_SCALAR>& posterior( const std::string& nodeName ); virtual const Potential<GUM_SCALAR> &posterior(const std::string &nodeName);
/// @} /// @}
...@@ -121,62 +119,44 @@ namespace gum { ...@@ -121,62 +119,44 @@ namespace gum {
/// @{ /// @{
/// Clear all previously defined targets /// Clear all previously defined targets
/// @warning this means that every node is now a target by default
virtual void eraseAllTargets(); virtual void eraseAllTargets();
/// adds all nodes as targets /// adds all nodes as targets
/// @warning due to the semantic of targets, this function is an alias of
/// eraseAllTargets()
virtual void addAllTargets() final; virtual void addAllTargets() final;
/// Add a marginal target to the list of targets /// Add a marginal target to the list of targets
/** /**
* @throw UndefinedElement if target is not a NodeId in the Bayes net * @throw UndefinedElement if target is not a NodeId in the Bayes net
*/ */
virtual void addTarget( const NodeId target ) final; virtual void addTarget(const NodeId target) final;
/// Add a marginal target to the list of targets /// Add a marginal target to the list of targets
/** /**
* @throw UndefinedElement if target is not a NodeId in the Bayes net * @throw UndefinedElement if target is not a NodeId in the Bayes net
*/ */
virtual void addTarget( const std::string& nodeName ) final; virtual void addTarget(const std::string &nodeName) final;
/// removes an existing (marginal) target /// removes an existing (marginal) target
/** @warning If the target does not already exist, the method does nothing. /** @warning If the target does not already exist, the method does nothing.
* In particular, it does not raise any exception. * In particular, it does not raise any exception. */
* @warning Erasing the last target implies that every node is now a target by virtual void eraseTarget(const NodeId target) final;
* default.
* */
virtual void eraseTarget( const NodeId target ) final;
/// removes an existing (marginal) target /// removes an existing (marginal) target
/** @warning If the target does not already exist, the method does nothing. /** @warning If the target does not already exist, the method does nothing.
* In particular, it does not raise any exception. * In particular, it does not raise any exception. */
* @warning Erasing the last target implies that every node is now a target by virtual void eraseTarget(const std::string &nodeName) final;
* default.*/
virtual void eraseTarget( const std::string& nodeName ) final;
/// return true if variable is a (marginal) target /// return true if variable is a (marginal) target
virtual bool isTarget( const NodeId variable ) const final; virtual bool isTarget(const NodeId variable) const final;
/// return true if variable is a (marginal) target /// return true if variable is a (marginal) target
virtual bool isTarget( const std::string& nodeName ) const final; virtual bool isTarget(const std::string &nodeName) const final;
/// returns the number of marginal targets. /// returns the number of marginal targets
//// @warning if the result is 0, it means that all the nodes are targets by
/// default.
virtual const Size nbrTargets() const noexcept final; virtual const Size nbrTargets() const noexcept final;
/// returns the set of marginal targets /// returns the list of marginal targets
//// @warning if the set is empty, it means that all the nodes are targets by virtual const NodeSet &targets() const noexcept final;
/// default.
virtual const NodeSet& targets() const noexcept final;
/// return all the marginal targets.
/** Particularly, if the targetSet is empty, allTargets will send a copy of the
* nodeSet of the BN
*/
virtual NodeSet allTargets() const noexcept final;
/// @} /// @}
...@@ -189,13 +169,13 @@ namespace gum { ...@@ -189,13 +169,13 @@ namespace gum {
* Compute Shanon's entropy of a node given the observation * Compute Shanon's entropy of a node given the observation
* @see http://en.wikipedia.org/wiki/Information_entropy * @see http://en.wikipedia.org/wiki/Information_entropy
*/ */
virtual GUM_SCALAR H( const NodeId X ) final; virtual GUM_SCALAR H(const NodeId X) final;
/** Entropy /** Entropy
* Compute Shanon's entropy of a node given the observation * Compute Shanon's entropy of a node given the observation
* @see http://en.wikipedia.org/wiki/Information_entropy * @see http://en.wikipedia.org/wiki/Information_entropy
*/ */
virtual GUM_SCALAR H( const std::string& nodeName ) final; virtual GUM_SCALAR H(const std::string &nodeName) final;
///@} ///@}
...@@ -211,8 +191,8 @@ namespace gum { ...@@ -211,8 +191,8 @@ namespace gum {
* @param evs the vector of nodeId of the observed variables * @param evs the vector of nodeId of the observed variables
* @return a Potential * @return a Potential
*/ */
Potential<GUM_SCALAR> evidenceImpact( NodeId target, Potential<GUM_SCALAR> evidenceImpact(NodeId target,
const std::vector<NodeId>& evs ); const std::vector<NodeId>& evs);
/** /**
* Create a gum::Potential for P(target|evs) (for all instanciation of target * Create a gum::Potential for P(target|evs) (for all instanciation of target
...@@ -224,17 +204,17 @@ namespace gum { ...@@ -224,17 +204,17 @@ namespace gum {
* @param evs the nodeId of the observed variable * @param evs the nodeId of the observed variable
* @return a Potential * @return a Potential
*/ */
Potential<GUM_SCALAR> evidenceImpact( const std::string& target, Potential<GUM_SCALAR> evidenceImpact(const std::string& target,
const std::vector<std::string>& evs ); const std::vector<std::string>& evs);
protected: protected:
/// fired after a new marginal target is inserted /// fired after a new marginal target is inserted
/** @param id The target variable's id. */ /** @param id The target variable's id. */
virtual void _onMarginalTargetAdded( const NodeId id ) = 0; virtual void _onMarginalTargetAdded(const NodeId id) = 0;
/// fired before a marginal target is removed /// fired before a marginal target is removed
/** @param id The target variable's id. */ /** @param id The target variable's id. */
virtual void _onMarginalTargetErased( const NodeId id ) = 0; virtual void _onMarginalTargetErased(const NodeId id) = 0;
/// fired after all the nodes of the BN are added as marginal targets /// fired after all the nodes of the BN are added as marginal targets
virtual void _onAllMarginalTargetsAdded() = 0; virtual void _onAllMarginalTargetsAdded() = 0;
...@@ -243,16 +223,19 @@ namespace gum { ...@@ -243,16 +223,19 @@ namespace gum {
virtual void _onAllMarginalTargetsErased() = 0; virtual void _onAllMarginalTargetsErased() = 0;
/// fired after a new Bayes net has been assigned to the engine /// fired after a new Bayes net has been assigned to the engine
virtual void _onBayesNetChanged( const IBayesNet<GUM_SCALAR>* bn ); virtual void _onBayesNetChanged(const IBayesNet<GUM_SCALAR> *bn);
/// asks derived classes for the posterior of a given variable /// asks derived classes for the posterior of a given variable
/** @param id The variable's id. */ /** @param id The variable's id. */
virtual const Potential<GUM_SCALAR>& _posterior( const NodeId id ) = 0; virtual const Potential<GUM_SCALAR> &_posterior(const NodeId id) = 0;
private:
/// whether the actual targets are default
bool __defaultTargets;
private:
/// the set of marginal targets /// the set of marginal targets
NodeSet __targetsSet; NodeSet __targets;
/// remove all the marginal posteriors computed /// remove all the marginal posteriors computed
......
...@@ -44,23 +44,25 @@ namespace gum_tests { ...@@ -44,23 +44,25 @@ namespace gum_tests {
gum::BayesNet<double>::fastPrototype( "A->B->C->D;A->E->D;F->B;C->H;" ); gum::BayesNet<double>::fastPrototype( "A->B->C->D;A->E->D;F->B;C->H;" );
gum::LazyPropagation<double> lazy( &bn ); gum::LazyPropagation<double> lazy( &bn );
TS_ASSERT(lazy.targets() == gum::NodeSet( {} ) ); TS_ASSERT( lazy.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
lazy.addTarget( "A" ); lazy.addTarget( "A" );
TS_ASSERT(lazy.targets() == gum::NodeSet( {0} ) ); TS_ASSERT( lazy.targets() == gum::NodeSet( {0} ) );
lazy.addTarget( "B" ); lazy.addTarget( "B" );
TS_ASSERT(lazy.targets() == gum::NodeSet( {0, 1} ) ); TS_ASSERT( lazy.targets() == gum::NodeSet( {0, 1} ) );
gum::ShaferShenoyInference<double> shafer( &bn ); gum::ShaferShenoyInference<double> shafer( &bn );
TS_ASSERT( shafer.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
shafer.addTarget( "A" ); shafer.addTarget( "A" );
TS_ASSERT(shafer.targets() == gum::NodeSet( {0} ) ); TS_ASSERT( shafer.targets() == gum::NodeSet( {0} ) );
shafer.addTarget( "B" ); shafer.addTarget( "B" );
TS_ASSERT(shafer.targets() == gum::NodeSet( {0, 1} ) ); TS_ASSERT( shafer.targets() == gum::NodeSet( {0, 1} ) );
gum::VariableElimination<double> ve( &bn ); gum::VariableElimination<double> ve( &bn );
TS_ASSERT( ve.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
ve.addTarget( "A" ); ve.addTarget( "A" );
TS_ASSERT(ve.targets() == gum::NodeSet( {0} ) ); TS_ASSERT( ve.targets() == gum::NodeSet( {0} ) );
ve.addTarget( "B" ); ve.addTarget( "B" );
TS_ASSERT(ve.targets() == gum::NodeSet( {0, 1} ) ); TS_ASSERT( ve.targets() == gum::NodeSet( {0, 1} ) );
} }
}; };
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment