[agrum] updating API for marginal targets

parent 9f8d92c5
......@@ -346,7 +346,7 @@ namespace gum {
// however, note that the nodes that received hard evidence do not belong to
// the graph and, therefore, should not be taken into account
const auto& hard_ev_nodes = this->hardEvidenceNodes();
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true;
}
for ( const auto& nodes : this->jointTargets() ) {
......@@ -412,7 +412,7 @@ namespace gum {
// the BN without altering the inference output
if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) {
// identify the barren nodes
NodeSet target_nodes = this->allTargets();
NodeSet target_nodes = this->targets();
for ( const auto& nodeset : this->jointTargets() ) {
target_nodes += nodeset;
}
......@@ -999,7 +999,7 @@ namespace gum {
void ShaferShenoyInference<GUM_SCALAR>::__computeJoinTreeRoots() {
// get the set of cliques in which we can find the targets and joint_targets
NodeSet clique_targets;
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
try {
clique_targets.insert( __node_to_clique[node] );
} catch ( Exception& ) {
......@@ -1288,7 +1288,7 @@ namespace gum {
template <typename GUM_SCALAR>
INLINE void ShaferShenoyInference<GUM_SCALAR>::_makeInference() {
// collect messages for all single targets
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
// perform only collects in the join tree for nodes that have
// not received hard evidence (those that received hard evidence were
// not included into the join tree for speed-up reasons)
......
......@@ -385,7 +385,7 @@ namespace gum {
// however, note that the nodes that received hard evidence do not belong to
// the graph and, therefore, should not be taken into account
const auto& hard_ev_nodes = this->hardEvidenceNodes();
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true;
}
for ( const auto& nodes : this->jointTargets() ) {
......@@ -452,7 +452,7 @@ namespace gum {
// altering the inference output
if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) {
// identify the barren nodes
NodeSet target_nodes = this->allTargets();
NodeSet target_nodes = this->targets();
for ( const auto& nodeset : this->jointTargets() ) {
target_nodes += nodeset;
}
......@@ -968,7 +968,7 @@ namespace gum {
void LazyPropagation<GUM_SCALAR>::__computeJoinTreeRoots() {
// get the set of cliques in which we can find the targets and joint_targets
NodeSet clique_targets;
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
try {
clique_targets.insert( __node_to_clique[node] );
} catch ( Exception& ) {
......@@ -1347,7 +1347,7 @@ namespace gum {
template <typename GUM_SCALAR>
INLINE void LazyPropagation<GUM_SCALAR>::_makeInference() {
// collect messages for all single targets
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
// perform only collects in the join tree for nodes that have
// not received hard evidence (those that received hard evidence were
// not included into the join tree for speed-up reasons)
......
......@@ -50,21 +50,19 @@ namespace gum {
* the current state of the inference. Note that the MarginalTargetedInference
* is designed to be used in incremental inference engines.
*/
template <typename GUM_SCALAR>
template<typename GUM_SCALAR>
class MarginalTargetedInference : public virtual BayesNetInference<GUM_SCALAR> {
public:
public:
// ############################################################################
/// @name Constructors / Destructors
// ############################################################################
/// @{
/// default constructor
/** @warning By default (when the targets set is empty), all the nodes of
* the Bayes net are considered as targets. Once a first target is added to the
* set, the remaining nodes are not considered as default targets anymore.
/** @warning By default, all the nodes of the Bayes net are targets.
* @warning note that, by aGrUM's rule, the BN is not copied but only
* referenced by the inference algorithm. */
MarginalTargetedInference( const IBayesNet<GUM_SCALAR>* bn );
MarginalTargetedInference(const IBayesNet<GUM_SCALAR> *bn);
/// destructor
virtual ~MarginalTargetedInference();
......@@ -92,7 +90,7 @@ namespace gum {
*
* @throw UndefinedElement if node is not in the set of targets
*/
virtual const Potential<GUM_SCALAR>& posterior( const NodeId node );
virtual const Potential<GUM_SCALAR> &posterior(const NodeId node);
/// Computes and returns the posterior of a node.
/**
......@@ -110,7 +108,7 @@ namespace gum {
*
* @throw UndefinedElement if node is not in the set of targets
*/
virtual const Potential<GUM_SCALAR>& posterior( const std::string& nodeName );
virtual const Potential<GUM_SCALAR> &posterior(const std::string &nodeName);
/// @}
......@@ -121,62 +119,44 @@ namespace gum {
/// @{
/// Clear all previously defined targets
/// @warning this means that every node is now a target by default
virtual void eraseAllTargets();
/// adds all nodes as targets
/// @warning due to the semantic of targets, this function is an alias of
/// eraseAllTargets()
virtual void addAllTargets() final;
/// Add a marginal target to the list of targets
/**
* @throw UndefinedElement if target is not a NodeId in the Bayes net
*/
virtual void addTarget( const NodeId target ) final;
virtual void addTarget(const NodeId target) final;
/// Add a marginal target to the list of targets
/**
* @throw UndefinedElement if target is not a NodeId in the Bayes net
*/
virtual void addTarget( const std::string& nodeName ) final;
virtual void addTarget(const std::string &nodeName) final;
/// removes an existing (marginal) target
/** @warning If the target does not already exist, the method does nothing.
* In particular, it does not raise any exception.
* @warning Erasing the last target implies that every node is now a target by
* default.
* */
virtual void eraseTarget( const NodeId target ) final;
* In particular, it does not raise any exception. */
virtual void eraseTarget(const NodeId target) final;
/// removes an existing (marginal) target
/** @warning If the target does not already exist, the method does nothing.
* In particular, it does not raise any exception.
* @warning Erasing the last target implies that every node is now a target by
* default.*/
virtual void eraseTarget( const std::string& nodeName ) final;
* In particular, it does not raise any exception. */
virtual void eraseTarget(const std::string &nodeName) final;
/// return true if variable is a (marginal) target
virtual bool isTarget( const NodeId variable ) const final;
virtual bool isTarget(const NodeId variable) const final;
/// return true if variable is a (marginal) target
virtual bool isTarget( const std::string& nodeName ) const final;
virtual bool isTarget(const std::string &nodeName) const final;
/// returns the number of marginal targets.
//// @warning if the result is 0, it means that all the nodes are targets by
/// default.
/// returns the number of marginal targets
virtual const Size nbrTargets() const noexcept final;
/// returns the set of marginal targets
//// @warning if the set is empty, it means that all the nodes are targets by
/// default.
virtual const NodeSet& targets() const noexcept final;
/// return all the marginal targets.
/** Particularly, if the targetSet is empty, allTargets will send a copy of the
* nodeSet of the BN
*/
virtual NodeSet allTargets() const noexcept final;
/// returns the list of marginal targets
virtual const NodeSet &targets() const noexcept final;
/// @}
......@@ -189,13 +169,13 @@ namespace gum {
* Compute Shanon's entropy of a node given the observation
* @see http://en.wikipedia.org/wiki/Information_entropy
*/
virtual GUM_SCALAR H( const NodeId X ) final;
virtual GUM_SCALAR H(const NodeId X) final;
/** Entropy
* Compute Shanon's entropy of a node given the observation
* @see http://en.wikipedia.org/wiki/Information_entropy
*/
virtual GUM_SCALAR H( const std::string& nodeName ) final;
virtual GUM_SCALAR H(const std::string &nodeName) final;
///@}
......@@ -211,8 +191,8 @@ namespace gum {
* @param evs the vector of nodeId of the observed variables
* @return a Potential
*/
Potential<GUM_SCALAR> evidenceImpact( NodeId target,
const std::vector<NodeId>& evs );
Potential<GUM_SCALAR> evidenceImpact(NodeId target,
const std::vector<NodeId>& evs);
/**
* Create a gum::Potential for P(target|evs) (for all instanciation of target
......@@ -224,17 +204,17 @@ namespace gum {
* @param evs the nodeId of the observed variable
* @return a Potential
*/
Potential<GUM_SCALAR> evidenceImpact( const std::string& target,
const std::vector<std::string>& evs );
Potential<GUM_SCALAR> evidenceImpact(const std::string& target,
const std::vector<std::string>& evs);
protected:
protected:
/// fired after a new marginal target is inserted
/** @param id The target variable's id. */
virtual void _onMarginalTargetAdded( const NodeId id ) = 0;
virtual void _onMarginalTargetAdded(const NodeId id) = 0;
/// fired before a marginal target is removed
/** @param id The target variable's id. */
virtual void _onMarginalTargetErased( const NodeId id ) = 0;
virtual void _onMarginalTargetErased(const NodeId id) = 0;
/// fired after all the nodes of the BN are added as marginal targets
virtual void _onAllMarginalTargetsAdded() = 0;
......@@ -243,16 +223,19 @@ namespace gum {
virtual void _onAllMarginalTargetsErased() = 0;
/// fired after a new Bayes net has been assigned to the engine
virtual void _onBayesNetChanged( const IBayesNet<GUM_SCALAR>* bn );
virtual void _onBayesNetChanged(const IBayesNet<GUM_SCALAR> *bn);
/// asks derived classes for the posterior of a given variable
/** @param id The variable's id. */
virtual const Potential<GUM_SCALAR>& _posterior( const NodeId id ) = 0;
virtual const Potential<GUM_SCALAR> &_posterior(const NodeId id) = 0;
private:
/// whether the actual targets are default
bool __defaultTargets;
private:
/// the set of marginal targets
NodeSet __targetsSet;
NodeSet __targets;
/// remove all the marginal posteriors computed
......
......@@ -44,23 +44,25 @@ namespace gum_tests {
gum::BayesNet<double>::fastPrototype( "A->B->C->D;A->E->D;F->B;C->H;" );
gum::LazyPropagation<double> lazy( &bn );
TS_ASSERT(lazy.targets() == gum::NodeSet( {} ) );
TS_ASSERT( lazy.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
lazy.addTarget( "A" );
TS_ASSERT(lazy.targets() == gum::NodeSet( {0} ) );
TS_ASSERT( lazy.targets() == gum::NodeSet( {0} ) );
lazy.addTarget( "B" );
TS_ASSERT(lazy.targets() == gum::NodeSet( {0, 1} ) );
TS_ASSERT( lazy.targets() == gum::NodeSet( {0, 1} ) );
gum::ShaferShenoyInference<double> shafer( &bn );
TS_ASSERT( shafer.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
shafer.addTarget( "A" );
TS_ASSERT(shafer.targets() == gum::NodeSet( {0} ) );
TS_ASSERT( shafer.targets() == gum::NodeSet( {0} ) );
shafer.addTarget( "B" );
TS_ASSERT(shafer.targets() == gum::NodeSet( {0, 1} ) );
TS_ASSERT( shafer.targets() == gum::NodeSet( {0, 1} ) );
gum::VariableElimination<double> ve( &bn );
TS_ASSERT( ve.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
ve.addTarget( "A" );
TS_ASSERT(ve.targets() == gum::NodeSet( {0} ) );
TS_ASSERT( ve.targets() == gum::NodeSet( {0} ) );
ve.addTarget( "B" );
TS_ASSERT(ve.targets() == gum::NodeSet( {0, 1} ) );
TS_ASSERT( ve.targets() == gum::NodeSet( {0, 1} ) );
}
};
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment