[pyAgrum] add MarkovBlanket and tests

parent 3a386589
......@@ -32,7 +32,7 @@
namespace gum {
MarkovBlanket::MarkovBlanket( const DAGmodel& m, NodeId id )
: __dag( m )
: __model( m )
, __node( id ) {
__buildMarkovBlanket();
}
......@@ -43,24 +43,37 @@ namespace gum {
MarkovBlanket::~MarkovBlanket() {}
void MarkovBlanket::__buildMarkovBlanket() {
if ( !__dag.nodes().exists( __node ) )
if ( !__model.nodes().exists( __node ) )
GUM_ERROR( InvalidArgument, "Node " << __node << " does not exist." );
__mb.addNode( __node );
for ( const auto& parent : __dag.dag().parents( __node ) ) {
for ( const auto& parent : __model.dag().parents( __node ) ) {
__mb.addNode( parent );
__mb.addArc( parent, __node );
}
for ( const auto& child : __dag.dag().children( __node ) ) {
for ( const auto& child : __model.dag().children( __node ) ) {
__mb.addNode( child );
__mb.addArc( __node, child );
for ( const auto& opar : __dag.dag().parents( child ) ) {
for ( const auto& opar : __model.dag().parents( child ) ) {
if ( opar != __node ) {
if ( !__mb.nodes().exists( opar ) ) __mb.addNode( opar );
__mb.addArc( opar, child );
}
}
}
// we add now some arcs that are between the nodes in __mb but are not part of
// the last ones.
// For instance, an arc between a parent and a parent of children
for ( const auto node : __mb.nodes() ) {
for ( const auto child : __model.dag().children( node ) ) {
if ( __mb.existsNode( child ) && !__mb.existsArc( Arc( node, child ) ) ) {
__mb.addArc( node, child );
__specialArcs.insert( Arc( node, child ) );
}
}
}
}
bool MarkovBlanket::hasSameStructure( const DAGmodel& other ) {
......@@ -70,7 +83,7 @@ namespace gum {
for ( const auto& nid : nodes() ) {
try {
other.idFromName( __dag.variable( nid ).name() );
other.idFromName( __model.variable( nid ).name() );
} catch ( NotFound ) {
return false;
}
......@@ -78,8 +91,8 @@ namespace gum {
for ( const auto& arc : arcs() ) {
if ( !other.arcs().exists(
Arc( other.idFromName( __dag.variable( arc.tail() ).name() ),
other.idFromName( __dag.variable( arc.head() ).name() ) ) ) )
Arc( other.idFromName( __model.variable( arc.tail() ).name() ),
other.idFromName( __model.variable( arc.head() ).name() ) ) ) )
return false;
}
......@@ -89,7 +102,7 @@ namespace gum {
std::string MarkovBlanket::toDot( void ) const {
std::stringstream output;
std::stringstream nodeStream;
std::stringstream edgeStream;
std::stringstream arcStream;
List<NodeId> treatedNodes;
output << "digraph \""
<< "no_name\" {" << std::endl;
......@@ -98,19 +111,24 @@ namespace gum {
for ( const auto node : __mb.nodes() ) {
nodeStream << tab << node << "[label=\"" << __dag.variable( node ).name()
nodeStream << tab << node << "[label=\"" << __model.variable( node ).name()
<< "\"";
if ( node == __node ) {
nodeStream << ", color=red";
}
nodeStream << "];";
nodeStream << "];" << std::endl;
for ( const auto chi : __mb.children( node ) )
edgeStream << tab << node << " -> " << chi << ";" << std::endl;
for ( const auto chi : __mb.children( node ) ) {
arcStream << tab << node << " -> " << chi;
if ( __specialArcs.exists( Arc( node, chi ) ) ) {
arcStream << " [color=grey]";
}
arcStream << ";" << std::endl;
}
}
output << nodeStream.str() << std::endl
<< edgeStream.str() << std::endl
<< arcStream.str() << std::endl
<< "}" << std::endl;
return output.str();
......
......@@ -53,7 +53,10 @@ namespace gum {
/// @return a copy of the mixed graph
DiGraph mb();
/// @return a dot representation of this essentialGraph
// @return a dot representation of this MarkovBlanket
// node of interest is in red
// special arcs (not used during the construction of the Markov Blanket) are in
// grey
std::string toDot( void ) const;
/// wrapping @ref DiGraph::parents(id)
......@@ -62,9 +65,6 @@ namespace gum {
/// wrapping @ref DiGraph::parents(id)
const NodeSet& children( const NodeId id ) const;
/// wrapping @ref DiGraph::parents(id)
const NodeSet& neighbours( const NodeId id ) const;
/// wrapping @ref DiGraph::sizeArcs()
Size sizeArcs() const;
......@@ -73,6 +73,7 @@ namespace gum {
/// wrapping @ref DiGraph::sizeNodes()
Size sizeNodes() const;
/// wrapping @ref DiGraph::size()
Size size() const;
......@@ -86,9 +87,10 @@ namespace gum {
private:
void __buildMarkovBlanket();
const DAGmodel& __dag;
const DAGmodel& __model;
DiGraph __mb;
const NodeId __node;
ArcSet __specialArcs;
};
} // namespace gum
......
......@@ -97,6 +97,14 @@ namespace gum_tests {
TS_ASSERT( gum::MarkovBlanket( bn, "i" ).hasSameStructure(
gum::BayesNet<int>::fastPrototype( "d->g;h->i->g;;" ) ) );
}
void testMarkovBlanketSpecialArcs() {
const auto bn = gum::BayesNet<int>::fastPrototype(
"aa->bb->cc->dd->ee;ff->dd->gg;hh->ii->gg;ff->ii;ff->gg" );
const auto mb = gum::BayesNet<int>::fastPrototype(
"cc->dd->ee;ff->dd->gg;ff->gg;ff->ii->gg" );
TS_ASSERT( gum::MarkovBlanket( bn, "dd" ).hasSameStructure( mb ) );
}
};
} // gum_tests
......@@ -38,7 +38,7 @@ from .functions import *
# selection of imports extracted from dir(.pyAgrum)
from .pyAgrum import statsObj
from .pyAgrum import Arc, Edge, DiGraph, UndiGraph, MixedGraph, DAG, CliqueGraph
from .pyAgrum import BayesNet,EssentialGraph
from .pyAgrum import BayesNet, EssentialGraph, MarkovBlanket
from .pyAgrum import DiscretizedVariable, LabelizedVariable, RangeVariable, DiscreteVariable
from .pyAgrum import Potential, Instantiation, UtilityTable
from .pyAgrum import BruteForceKL, GibbsKL
......
......@@ -7915,6 +7915,18 @@ class DAGmodel(_object):
"""
return _pyAgrum.DAGmodel_log10DomainSize(self)
def hasSameStructure(self, other):
"""
hasSameStructure(DAGmodel self, DAGmodel other) -> bool
Parameters
----------
other: gum::DAGmodel const &
"""
return _pyAgrum.DAGmodel_hasSameStructure(self, other)
DAGmodel_swigregister = _pyAgrum.DAGmodel_swigregister
DAGmodel_swigregister(DAGmodel)
cvar = _pyAgrum.cvar
......@@ -8081,6 +8093,150 @@ class EssentialGraph(_object):
EssentialGraph_swigregister = _pyAgrum.EssentialGraph_swigregister
EssentialGraph_swigregister(EssentialGraph)
class MarkovBlanket(_object):
"""Proxy of C++ gum::MarkovBlanket class."""
__swig_setmethods__ = {}
__setattr__ = lambda self, name, value: _swig_setattr(self, MarkovBlanket, name, value)
__swig_getmethods__ = {}
__getattr__ = lambda self, name: _swig_getattr(self, MarkovBlanket, name)
__repr__ = _swig_repr
def __init__(self, *args):
"""
__init__(gum::MarkovBlanket self, DAGmodel m, gum::NodeId n) -> MarkovBlanket
Parameters
----------
m: gum::DAGmodel const &
n: gum::NodeId
__init__(gum::MarkovBlanket self, DAGmodel m, std::string const & name) -> MarkovBlanket
Parameters
----------
m: gum::DAGmodel const &
name: std::string const &
"""
this = _pyAgrum.new_MarkovBlanket(*args)
try:
self.this.append(this)
except __builtin__.Exception:
self.this = this
__swig_destroy__ = _pyAgrum.delete_MarkovBlanket
__del__ = lambda self: None
def mb(self):
"""
mb(MarkovBlanket self) -> DiGraph
Parameters
----------
self: gum::MarkovBlanket *
"""
return _pyAgrum.MarkovBlanket_mb(self)
def toDot(self):
"""
toDot(MarkovBlanket self) -> std::string
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_toDot(self)
def parents(self, id):
"""
parents(MarkovBlanket self, gum::NodeId const id) -> gum::NodeSet const &
Parameters
----------
id: gum::NodeId const
"""
return _pyAgrum.MarkovBlanket_parents(self, id)
def children(self, id):
"""
children(MarkovBlanket self, gum::NodeId const id) -> gum::NodeSet const &
Parameters
----------
id: gum::NodeId const
"""
return _pyAgrum.MarkovBlanket_children(self, id)
def sizeArcs(self):
"""
sizeArcs(MarkovBlanket self) -> gum::Size
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_sizeArcs(self)
def arcs(self):
"""
arcs(MarkovBlanket self) -> gum::ArcSet const &
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_arcs(self)
def sizeNodes(self):
"""
sizeNodes(MarkovBlanket self) -> gum::Size
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_sizeNodes(self)
def size(self):
"""
size(MarkovBlanket self) -> gum::Size
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_size(self)
def hasSameStructure(self, other):
"""
hasSameStructure(MarkovBlanket self, DAGmodel other) -> bool
Parameters
----------
other: gum::DAGmodel const &
"""
return _pyAgrum.MarkovBlanket_hasSameStructure(self, other)
MarkovBlanket_swigregister = _pyAgrum.MarkovBlanket_swigregister
MarkovBlanket_swigregister(MarkovBlanket)
class ApproximationScheme(_object):
"""Proxy of C++ gum::ApproximationScheme class."""
......
......@@ -7915,6 +7915,18 @@ class DAGmodel(_object):
"""
return _pyAgrum.DAGmodel_log10DomainSize(self)
def hasSameStructure(self, other: 'DAGmodel') -> "bool":
"""
hasSameStructure(DAGmodel self, DAGmodel other) -> bool
Parameters
----------
other: gum::DAGmodel const &
"""
return _pyAgrum.DAGmodel_hasSameStructure(self, other)
DAGmodel_swigregister = _pyAgrum.DAGmodel_swigregister
DAGmodel_swigregister(DAGmodel)
cvar = _pyAgrum.cvar
......@@ -8081,6 +8093,150 @@ class EssentialGraph(_object):
EssentialGraph_swigregister = _pyAgrum.EssentialGraph_swigregister
EssentialGraph_swigregister(EssentialGraph)
class MarkovBlanket(_object):
"""Proxy of C++ gum::MarkovBlanket class."""
__swig_setmethods__ = {}
__setattr__ = lambda self, name, value: _swig_setattr(self, MarkovBlanket, name, value)
__swig_getmethods__ = {}
__getattr__ = lambda self, name: _swig_getattr(self, MarkovBlanket, name)
__repr__ = _swig_repr
def __init__(self, *args):
"""
__init__(gum::MarkovBlanket self, DAGmodel m, gum::NodeId n) -> MarkovBlanket
Parameters
----------
m: gum::DAGmodel const &
n: gum::NodeId
__init__(gum::MarkovBlanket self, DAGmodel m, std::string const & name) -> MarkovBlanket
Parameters
----------
m: gum::DAGmodel const &
name: std::string const &
"""
this = _pyAgrum.new_MarkovBlanket(*args)
try:
self.this.append(this)
except __builtin__.Exception:
self.this = this
__swig_destroy__ = _pyAgrum.delete_MarkovBlanket
__del__ = lambda self: None
def mb(self) -> "gum::DiGraph":
"""
mb(MarkovBlanket self) -> DiGraph
Parameters
----------
self: gum::MarkovBlanket *
"""
return _pyAgrum.MarkovBlanket_mb(self)
def toDot(self) -> "std::string":
"""
toDot(MarkovBlanket self) -> std::string
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_toDot(self)
def parents(self, id: 'gum::NodeId const') -> "gum::NodeSet const &":
"""
parents(MarkovBlanket self, gum::NodeId const id) -> gum::NodeSet const &
Parameters
----------
id: gum::NodeId const
"""
return _pyAgrum.MarkovBlanket_parents(self, id)
def children(self, id: 'gum::NodeId const') -> "gum::NodeSet const &":
"""
children(MarkovBlanket self, gum::NodeId const id) -> gum::NodeSet const &
Parameters
----------
id: gum::NodeId const
"""
return _pyAgrum.MarkovBlanket_children(self, id)
def sizeArcs(self) -> "gum::Size":
"""
sizeArcs(MarkovBlanket self) -> gum::Size
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_sizeArcs(self)
def arcs(self) -> "gum::ArcSet const &":
"""
arcs(MarkovBlanket self) -> gum::ArcSet const &
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_arcs(self)
def sizeNodes(self) -> "gum::Size":
"""
sizeNodes(MarkovBlanket self) -> gum::Size
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_sizeNodes(self)
def size(self) -> "gum::Size":
"""
size(MarkovBlanket self) -> gum::Size
Parameters
----------
self: gum::MarkovBlanket const *
"""
return _pyAgrum.MarkovBlanket_size(self)
def hasSameStructure(self, other: 'DAGmodel') -> "bool":
"""
hasSameStructure(MarkovBlanket self, DAGmodel other) -> bool
Parameters
----------
other: gum::DAGmodel const &
"""
return _pyAgrum.MarkovBlanket_hasSameStructure(self, other)
MarkovBlanket_swigregister = _pyAgrum.MarkovBlanket_swigregister
MarkovBlanket_swigregister(MarkovBlanket)
class ApproximationScheme(_object):
"""Proxy of C++ gum::ApproximationScheme class."""
......
......@@ -673,3 +673,4 @@ gum.MixedGraph._repr_html_ = lambda self: getDot(self.toDot())
gum.DAG._repr_html_ = lambda self: getDot(self.toDot())
gum.CliqueGraph._repr_html_ = lambda self: getDot(self.toDot())
gum.EssentialGraph._repr_html_ = lambda self: getDot(self.toDot())
gum.MarkovBlanket._repr_html_ = lambda self: getDot(self.toDot())
......@@ -17,6 +17,7 @@ import pyAgrum as gum
import unittest
from tests import MarkovBlanket
from tests import EssentialGraphTestSuite
from tests import VariablesTestSuite
from tests import BayesNetTestSuite
......@@ -36,6 +37,7 @@ from tests import LoopyBeliefPropagationTestSuite
import time
tests = list()
tests.append(MarkovBlanket.ts)
tests.append(EssentialGraphTestSuite.ts)
tests.append(VariablesTestSuite.ts)
tests.append(BayesNetTestSuite.ts)
......
# -*- encoding: UTF-8 -*-
import unittest
import pyAgrum as gum
from pyAgrumTestSuite import pyAgrumTestCase, addTests
class TestMarkovBlanket(pyAgrumTestCase):
def testChain(self):
bn = gum.fastBN("a->b->c")
eg = gum.MarkovBlanket(bn, "a")
eg = gum.MarkovBlanket(bn, 1)
def testMarkovBlanketSpecialArcs(self):
bn = gum.fastBN("aa->bb->cc->dd->ee;ff->dd->gg;hh->ii->gg;ff->ii;ff->gg")
mb = gum.fastBN("cc->dd->ee;ff->dd->gg;ff->gg;ff->ii->gg")
self.assertTrue(gum.MarkovBlanket(bn, "dd").hasSameStructure(mb))
def testMarkovBlanketStructure(self):
bn = gum.fastBN("a->b->c->d->e;f->d->g;h->i->g");
self.assertFalse(gum.MarkovBlanket(bn, "a").hasSameStructure(
gum.fastBN("b->a")))
self.assertTrue(gum.MarkovBlanket(bn, "a").hasSameStructure(
gum.fastBN("a->b")))
self.assertTrue(gum.MarkovBlanket(bn, "b").hasSameStructure(
gum.fastBN("a->b->c")))
self.assertTrue(gum.MarkovBlanket(bn, "c").hasSameStructure(
gum.fastBN("b->c->d;f->d")))
self.assertTrue(gum.MarkovBlanket(bn, "d").hasSameStructure(
gum.fastBN("c->d->e;f->d->g;i->g")))
self.assertTrue(gum.MarkovBlanket(bn, "e").hasSameStructure(
gum.fastBN("d->e")))
self.assertTrue(gum.MarkovBlanket(bn, "f").hasSameStructure(
gum.fastBN("c->d;f->d;")))
self.assertTrue(gum.MarkovBlanket(bn, "g").hasSameStructure(
gum.fastBN("d->g;i->g;")))
self.assertTrue(gum.MarkovBlanket(bn, "h").hasSameStructure(
gum.fastBN("h->i;")))
self.assertTrue(gum.MarkovBlanket(bn, "i").hasSameStructure(
gum.fastBN("d->g;h->i->g;;")))
ts = unittest.TestSuite()
addTests(ts, TestMarkovBlanket)
......@@ -62,6 +62,7 @@
#include <agrum/BN/BayesNet.h>
#include <agrum/BN/algorithms/essentialGraph.h>
#include <agrum/BN/algorithms/MarkovBlanket.h>
#include <agrum/BN/io/BIF/BIFReader.h>
#include <agrum/BN/io/BIF/BIFWriter.h>
......@@ -331,6 +332,7 @@ ADD_APPROXIMATIONSCHEME_API(gum::learning::genericBNLearner,gum::learning::BNLea
%include <agrum/BN/BayesNet.h>
%include <agrum/BN/algorithms/essentialGraph.h>
%include <agrum/BN/algorithms/MarkovBlanket.h>
%import <agrum/core/approximations/IApproximationSchemeConfiguration.h>
%include <agrum/core/approximations/approximationScheme.h>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment