Commit 14a76840 authored by Pierre-Henri Wuillemin's avatar Pierre-Henri Wuillemin

refreshing documentation/pyAgrum

parents 81498ff4 4990782a
......@@ -346,7 +346,7 @@ namespace gum {
// however, note that the nodes that received hard evidence do not belong to
// the graph and, therefore, should not be taken into account
const auto& hard_ev_nodes = this->hardEvidenceNodes();
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true;
}
for ( const auto& nodes : this->jointTargets() ) {
......@@ -412,7 +412,7 @@ namespace gum {
// the BN without altering the inference output
if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) {
// identify the barren nodes
NodeSet target_nodes = this->allTargets();
NodeSet target_nodes = this->targets();
for ( const auto& nodeset : this->jointTargets() ) {
target_nodes += nodeset;
}
......@@ -999,7 +999,7 @@ namespace gum {
void ShaferShenoyInference<GUM_SCALAR>::__computeJoinTreeRoots() {
// get the set of cliques in which we can find the targets and joint_targets
NodeSet clique_targets;
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
try {
clique_targets.insert( __node_to_clique[node] );
} catch ( Exception& ) {
......@@ -1288,7 +1288,7 @@ namespace gum {
template <typename GUM_SCALAR>
INLINE void ShaferShenoyInference<GUM_SCALAR>::_makeInference() {
// collect messages for all single targets
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
// perform only collects in the join tree for nodes that have
// not received hard evidence (those that received hard evidence were
// not included into the join tree for speed-up reasons)
......
......@@ -385,7 +385,7 @@ namespace gum {
// however, note that the nodes that received hard evidence do not belong to
// the graph and, therefore, should not be taken into account
const auto& hard_ev_nodes = this->hardEvidenceNodes();
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
if ( !__graph.exists( node ) && !hard_ev_nodes.exists( node ) ) return true;
}
for ( const auto& nodes : this->jointTargets() ) {
......@@ -452,7 +452,7 @@ namespace gum {
// altering the inference output
if ( __barren_nodes_type == FindBarrenNodesType::FIND_BARREN_NODES ) {
// identify the barren nodes
NodeSet target_nodes = this->allTargets();
NodeSet target_nodes = this->targets();
for ( const auto& nodeset : this->jointTargets() ) {
target_nodes += nodeset;
}
......@@ -968,7 +968,7 @@ namespace gum {
void LazyPropagation<GUM_SCALAR>::__computeJoinTreeRoots() {
// get the set of cliques in which we can find the targets and joint_targets
NodeSet clique_targets;
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
try {
clique_targets.insert( __node_to_clique[node] );
} catch ( Exception& ) {
......@@ -1347,7 +1347,7 @@ namespace gum {
template <typename GUM_SCALAR>
INLINE void LazyPropagation<GUM_SCALAR>::_makeInference() {
// collect messages for all single targets
for ( const auto node : this->allTargets() ) {
for ( const auto node : this->targets() ) {
// perform only collects in the join tree for nodes that have
// not received hard evidence (those that received hard evidence were
// not included into the join tree for speed-up reasons)
......
......@@ -50,21 +50,19 @@ namespace gum {
* the current state of the inference. Note that the MarginalTargetedInference
* is designed to be used in incremental inference engines.
*/
template <typename GUM_SCALAR>
template<typename GUM_SCALAR>
class MarginalTargetedInference : public virtual BayesNetInference<GUM_SCALAR> {
public:
public:
// ############################################################################
/// @name Constructors / Destructors
// ############################################################################
/// @{
/// default constructor
/** @warning By default (when the targets set is empty), all the nodes of
* the Bayes net are considered as targets. Once a first target is added to the
* set, the remaining nodes are not considered as default targets anymore.
/** @warning By default, all the nodes of the Bayes net are targets.
* @warning note that, by aGrUM's rule, the BN is not copied but only
* referenced by the inference algorithm. */
MarginalTargetedInference( const IBayesNet<GUM_SCALAR>* bn );
MarginalTargetedInference(const IBayesNet<GUM_SCALAR> *bn);
/// destructor
virtual ~MarginalTargetedInference();
......@@ -92,7 +90,7 @@ namespace gum {
*
* @throw UndefinedElement if node is not in the set of targets
*/
virtual const Potential<GUM_SCALAR>& posterior( const NodeId node );
virtual const Potential<GUM_SCALAR> &posterior(const NodeId node);
/// Computes and returns the posterior of a node.
/**
......@@ -110,7 +108,7 @@ namespace gum {
*
* @throw UndefinedElement if node is not in the set of targets
*/
virtual const Potential<GUM_SCALAR>& posterior( const std::string& nodeName );
virtual const Potential<GUM_SCALAR> &posterior(const std::string &nodeName);
/// @}
......@@ -121,62 +119,44 @@ namespace gum {
/// @{
/// Clear all previously defined targets
/// @warning this means that every node is now a target by default
virtual void eraseAllTargets();
/// adds all nodes as targets
/// @warning due to the semantic of targets, this function is an alias of
/// eraseAllTargets()
virtual void addAllTargets() final;
/// Add a marginal target to the list of targets
/**
* @throw UndefinedElement if target is not a NodeId in the Bayes net
*/
virtual void addTarget( const NodeId target ) final;
virtual void addTarget(const NodeId target) final;
/// Add a marginal target to the list of targets
/**
* @throw UndefinedElement if target is not a NodeId in the Bayes net
*/
virtual void addTarget( const std::string& nodeName ) final;
virtual void addTarget(const std::string &nodeName) final;
/// removes an existing (marginal) target
/** @warning If the target does not already exist, the method does nothing.
* In particular, it does not raise any exception.
* @warning Erasing the last target implies that every node is now a target by
* default.
* */
virtual void eraseTarget( const NodeId target ) final;
* In particular, it does not raise any exception. */
virtual void eraseTarget(const NodeId target) final;
/// removes an existing (marginal) target
/** @warning If the target does not already exist, the method does nothing.
* In particular, it does not raise any exception.
* @warning Erasing the last target implies that every node is now a target by
* default.*/
virtual void eraseTarget( const std::string& nodeName ) final;
* In particular, it does not raise any exception. */
virtual void eraseTarget(const std::string &nodeName) final;
/// return true if variable is a (marginal) target
virtual bool isTarget( const NodeId variable ) const final;
virtual bool isTarget(const NodeId variable) const final;
/// return true if variable is a (marginal) target
virtual bool isTarget( const std::string& nodeName ) const final;
virtual bool isTarget(const std::string &nodeName) const final;
/// returns the number of marginal targets.
//// @warning if the result is 0, it means that all the nodes are targets by
/// default.
/// returns the number of marginal targets
virtual const Size nbrTargets() const noexcept final;
/// returns the set of marginal targets
//// @warning if the set is empty, it means that all the nodes are targets by
/// default.
virtual const NodeSet& targets() const noexcept final;
/// return all the marginal targets.
/** Particularly, if the targetSet is empty, allTargets will send a copy of the
* nodeSet of the BN
*/
virtual NodeSet allTargets() const noexcept final;
/// returns the list of marginal targets
virtual const NodeSet &targets() const noexcept final;
/// @}
......@@ -189,13 +169,13 @@ namespace gum {
* Compute Shanon's entropy of a node given the observation
* @see http://en.wikipedia.org/wiki/Information_entropy
*/
virtual GUM_SCALAR H( const NodeId X ) final;
virtual GUM_SCALAR H(const NodeId X) final;
/** Entropy
* Compute Shanon's entropy of a node given the observation
* @see http://en.wikipedia.org/wiki/Information_entropy
*/
virtual GUM_SCALAR H( const std::string& nodeName ) final;
virtual GUM_SCALAR H(const std::string &nodeName) final;
///@}
......@@ -211,8 +191,8 @@ namespace gum {
* @param evs the vector of nodeId of the observed variables
* @return a Potential
*/
Potential<GUM_SCALAR> evidenceImpact( NodeId target,
const std::vector<NodeId>& evs );
Potential<GUM_SCALAR> evidenceImpact(NodeId target,
const std::vector<NodeId>& evs);
/**
* Create a gum::Potential for P(target|evs) (for all instanciation of target
......@@ -224,17 +204,17 @@ namespace gum {
* @param evs the nodeId of the observed variable
* @return a Potential
*/
Potential<GUM_SCALAR> evidenceImpact( const std::string& target,
const std::vector<std::string>& evs );
Potential<GUM_SCALAR> evidenceImpact(const std::string& target,
const std::vector<std::string>& evs);
protected:
protected:
/// fired after a new marginal target is inserted
/** @param id The target variable's id. */
virtual void _onMarginalTargetAdded( const NodeId id ) = 0;
virtual void _onMarginalTargetAdded(const NodeId id) = 0;
/// fired before a marginal target is removed
/** @param id The target variable's id. */
virtual void _onMarginalTargetErased( const NodeId id ) = 0;
virtual void _onMarginalTargetErased(const NodeId id) = 0;
/// fired after all the nodes of the BN are added as marginal targets
virtual void _onAllMarginalTargetsAdded() = 0;
......@@ -243,16 +223,19 @@ namespace gum {
virtual void _onAllMarginalTargetsErased() = 0;
/// fired after a new Bayes net has been assigned to the engine
virtual void _onBayesNetChanged( const IBayesNet<GUM_SCALAR>* bn );
virtual void _onBayesNetChanged(const IBayesNet<GUM_SCALAR> *bn);
/// asks derived classes for the posterior of a given variable
/** @param id The variable's id. */
virtual const Potential<GUM_SCALAR>& _posterior( const NodeId id ) = 0;
virtual const Potential<GUM_SCALAR> &_posterior(const NodeId id) = 0;
private:
/// whether the actual targets are default
bool __defaultTargets;
private:
/// the set of marginal targets
NodeSet __targetsSet;
NodeSet __targets;
/// remove all the marginal posteriors computed
......
......@@ -44,23 +44,25 @@ namespace gum_tests {
gum::BayesNet<double>::fastPrototype( "A->B->C->D;A->E->D;F->B;C->H;" );
gum::LazyPropagation<double> lazy( &bn );
TS_ASSERT(lazy.targets() == gum::NodeSet( {} ) );
TS_ASSERT( lazy.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
lazy.addTarget( "A" );
TS_ASSERT(lazy.targets() == gum::NodeSet( {0} ) );
TS_ASSERT( lazy.targets() == gum::NodeSet( {0} ) );
lazy.addTarget( "B" );
TS_ASSERT(lazy.targets() == gum::NodeSet( {0, 1} ) );
TS_ASSERT( lazy.targets() == gum::NodeSet( {0, 1} ) );
gum::ShaferShenoyInference<double> shafer( &bn );
TS_ASSERT( shafer.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
shafer.addTarget( "A" );
TS_ASSERT(shafer.targets() == gum::NodeSet( {0} ) );
TS_ASSERT( shafer.targets() == gum::NodeSet( {0} ) );
shafer.addTarget( "B" );
TS_ASSERT(shafer.targets() == gum::NodeSet( {0, 1} ) );
TS_ASSERT( shafer.targets() == gum::NodeSet( {0, 1} ) );
gum::VariableElimination<double> ve( &bn );
TS_ASSERT( ve.targets() == gum::NodeSet( {0, 1, 2, 3, 4, 5, 6} ) );
ve.addTarget( "A" );
TS_ASSERT(ve.targets() == gum::NodeSet( {0} ) );
TS_ASSERT( ve.targets() == gum::NodeSet( {0} ) );
ve.addTarget( "B" );
TS_ASSERT(ve.targets() == gum::NodeSet( {0, 1} ) );
TS_ASSERT( ve.targets() == gum::NodeSet( {0, 1} ) );
}
};
}
This diff is collapsed.
......@@ -130,7 +130,7 @@ class CSVGenerator:
bn = name_in
seq = bn.topologicalOrder()
writer = csv.writer(open(name_out, 'w'))
writer = csv.writer(open(name_out, 'w',newline=''))
if visible:
sys.stdout.flush()
......
#!/usr/bin/python
# -*- coding: utf-8 -*-
# (c) Copyright by Pierre-Henri Wuillemin, UPMC, 2017 (pierre-henri.wuillemin@lip6.fr)
# (c) Copyright by Pierre-Henri Wuillemin, UPMC, 2017
# (pierre-henri.wuillemin@lip6.fr)
# Permission to use, copy, modify, and distribute this
# software and its documentation for any purpose and
......@@ -88,7 +89,8 @@ def BN2dot(bn, size="4", arcvals=None, vals=None, cmap=None, showValues=None):
minarcs = min(arcvals.values())
maxarcs = max(arcvals.values())
graph = dot.Dot(graph_type='digraph')
graph = dot.Dot(graph_type='digraph',bgcolor="transparent")
for n in bn.names():
if vals is None or n not in vals:
bgcol = "#444444"
......@@ -232,24 +234,24 @@ def BNinference2dot(bn, size="4",engine=None, evs={}, targets={}, format='png',
from tempfile import mkdtemp
temp_dir = mkdtemp("", "tmp", None) # with TemporaryDirectory() as temp_dir:
dotstr = "digraph structs {\n"
dotstr = "digraph structs {\n bgcolor=\"transparent\";"
dotstr += " label=\"Inference in {:6.2f}ms\";\n".format(1000 * (stopTime - startTime))
dotstr += " node [fillcolor=floralwhite, style=filled,color=grey];\n"
for n in bn.ids():
name = bn.variable(n).name()
for nid in bn.ids():
name = bn.variable(nid).name()
if vals is None or name not in vals:
bgcol = "sandybrown" if name in evs else "#FFFFFF"
if vals is None or name not in vals or nid not in vals:
bgcol = "sandybrown" if name in evs or nid in evs else "#FFFFFF"
fgcol = "#000000"
else:
bgcol = _proba2bgcolor(vals[name], cmap)
fgcol = _proba2fgcolor(vals[name], cmap)
colorattribute = 'fillcolor="{}", fontcolor="{}", color="#000000"'.format(bgcol, fgcol)
if len(targets) == 0 or name in targets:
if len(targets) == 0 or name in targets or nid in targets:
filename = temp_dir + name + "." + format
_saveFigProba(ie.posterior(n), filename, format=format)
_saveFigProba(ie.posterior(name), filename, format=format)
fill = ", " + colorattribute
dotstr += ' "{0}" [shape=rectangle,image="{1}",label="" {2}];\n'.format(name, filename, fill)
else:
......
......@@ -144,7 +144,9 @@ def getDot(dotstring, size="4", format="png"):
:param format: render as "png" or "svg"
:return: the HTML representation of the graph
"""
return getGraph(dot.graph_from_dot_data(dotstring), size, format)
g=dot.graph_from_dot_data(dotstring)
g.set_bgcolor("transparent")
return getGraph(g, size, format)
def showJunctionTree(bn, withNames=True, size="4", format="png"):
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment