[aGrUM] optimization + new comment in BIF file when saving BN

parent e2ebdad4
......@@ -31,15 +31,15 @@ namespace gum {
/* === GUM_BIF_WRITER === */
/* =========================================================================*/
// Default constructor.
template <typename GUM_SCALAR>
INLINE BIFWriter<GUM_SCALAR>::BIFWriter() {
GUM_CONSTRUCTOR( BIFWriter );
template < typename GUM_SCALAR >
INLINE BIFWriter< GUM_SCALAR >::BIFWriter() {
GUM_CONSTRUCTOR(BIFWriter);
}
// Default destructor.
template <typename GUM_SCALAR>
INLINE BIFWriter<GUM_SCALAR>::~BIFWriter() {
GUM_DESTRUCTOR( BIFWriter );
template < typename GUM_SCALAR >
INLINE BIFWriter< GUM_SCALAR >::~BIFWriter() {
GUM_DESTRUCTOR(BIFWriter);
}
//
......@@ -48,30 +48,30 @@ namespace gum {
// @param ouput The output stream.
// @param bn The Bayesian Network writen in output.
// @throws Raised if an I/O error occurs.
template <typename GUM_SCALAR>
INLINE void BIFWriter<GUM_SCALAR>::write( std::ostream& output,
const IBayesNet<GUM_SCALAR>& bn ) {
if ( !output.good() ) {
GUM_ERROR( IOError, "Stream states flags are not all unset." );
template < typename GUM_SCALAR >
INLINE void BIFWriter< GUM_SCALAR >::write(std::ostream& output,
const IBayesNet< GUM_SCALAR >& bn) {
if (!output.good()) {
GUM_ERROR(IOError, "Stream states flags are not all unset.");
}
output << __header( bn ) << std::endl;
output << __header(bn) << std::endl;
for ( auto node : bn.nodes() ) {
output << __variableBloc( bn.variable( node ) ) << std::endl;
for (auto node : bn.nodes()) {
output << __variableBloc(bn.variable(node)) << std::endl;
}
for ( auto node : bn.nodes() ) {
const Potential<GUM_SCALAR>& proba = bn.cpt( node );
output << __variableCPT( proba );
for (auto node : bn.nodes()) {
const Potential< GUM_SCALAR >& proba = bn.cpt(node);
output << __variableCPT(proba);
}
output << std::endl;
output.flush();
if ( output.fail() ) {
GUM_ERROR( IOError, "Writting in the ostream failed." );
if (output.fail()) {
GUM_ERROR(IOError, "Writting in the ostream failed.");
}
}
......@@ -82,24 +82,24 @@ namespace gum {
// @param filePath The path to the file used to write the Bayesian Network.
// @param bn The Bayesian Network writed in the file.
// @throws Raised if an I/O error occurs.
template <typename GUM_SCALAR>
INLINE void BIFWriter<GUM_SCALAR>::write( std::string filePath,
const IBayesNet<GUM_SCALAR>& bn ) {
std::ofstream output( filePath.c_str(), std::ios_base::trunc );
template < typename GUM_SCALAR >
INLINE void BIFWriter< GUM_SCALAR >::write(std::string filePath,
const IBayesNet< GUM_SCALAR >& bn) {
std::ofstream output(filePath.c_str(), std::ios_base::trunc);
if ( !output.good() ) {
GUM_ERROR( IOError, "Stream states flags are not all unset." );
if (!output.good()) {
GUM_ERROR(IOError, "Stream states flags are not all unset.");
}
output << __header( bn ) << std::endl;
output << __header(bn) << std::endl;
for ( auto node : bn.nodes() ) {
output << __variableBloc( bn.variable( node ) ) << std::endl;
for (auto node : bn.nodes()) {
output << __variableBloc(bn.variable(node)) << std::endl;
}
for ( auto node : bn.nodes() ) {
const Potential<GUM_SCALAR>& proba = bn.cpt( node );
output << __variableCPT( proba );
for (auto node : bn.nodes()) {
const Potential< GUM_SCALAR >& proba = bn.cpt(node);
output << __variableCPT(proba);
}
output << std::endl;
......@@ -107,49 +107,48 @@ namespace gum {
output.flush();
output.close();
if ( output.fail() ) {
GUM_ERROR( IOError, "Writting in the ostream failed." );
if (output.fail()) {
GUM_ERROR(IOError, "Writting in the ostream failed.");
}
}
// Returns a bloc defining a variable's CPT in the BIF format.
template <typename GUM_SCALAR>
template < typename GUM_SCALAR >
INLINE std::string
BIFWriter<GUM_SCALAR>::__variableCPT( const Potential<GUM_SCALAR>& cpt ) {
BIFWriter< GUM_SCALAR >::__variableCPT(const Potential< GUM_SCALAR >& cpt) {
std::stringstream str;
std::string tab = " "; // poor tabulation
if ( cpt.nbrDim() == 1 ) {
Instantiation inst( cpt );
str << "probability (" << cpt.variable( 0 ).name() << ") {" << std::endl;
if (cpt.nbrDim() == 1) {
Instantiation inst(cpt);
str << "probability (" << cpt.variable(0).name() << ") {" << std::endl;
str << tab << "default";
for ( inst.setFirst(); !inst.end(); ++inst ) {
for (inst.setFirst(); !inst.end(); ++inst) {
str << " " << cpt[inst];
}
str << ";" << std::endl << "}" << std::endl;
} else if ( cpt.domainSize() > 1 ) {
Instantiation inst( cpt );
} else if (cpt.domainSize() > 1) {
Instantiation inst(cpt);
Instantiation condVars; // Instantiation on the conditioning variables
const Sequence<const DiscreteVariable*>& varsSeq = cpt.variablesSequence();
str << "probability (" << ( varsSeq[(Idx)0] )->name() << " | ";
const Sequence< const DiscreteVariable* >& varsSeq = cpt.variablesSequence();
str << "probability (" << (varsSeq[(Idx)0])->name() << " | ";
for ( Idx i = 1; i < varsSeq.size() - 1; i++ ) {
for (Idx i = 1; i < varsSeq.size() - 1; i++) {
str << varsSeq[i]->name() << ", ";
condVars << *( varsSeq[i] );
condVars << *(varsSeq[i]);
}
str << varsSeq[varsSeq.size() - 1]->name() << ") {" << std::endl;
condVars << *( varsSeq[varsSeq.size() - 1] );
condVars << *(varsSeq[varsSeq.size() - 1]);
for ( inst.setFirstIn( condVars ); !inst.end(); inst.incIn( condVars ) ) {
str << tab << "(" << __variablesLabels( varsSeq, inst ) << ")";
for (inst.setFirstIn(condVars); !inst.end(); inst.incIn(condVars)) {
str << tab << "(" << __variablesLabels(varsSeq, inst) << ")";
// Writing the probabilities of the variable
for ( inst.setFirstOut( condVars ); !inst.end();
inst.incOut( condVars ) ) {
for (inst.setFirstOut(condVars); !inst.end(); inst.incOut(condVars)) {
str << " " << cpt[inst];
}
......@@ -165,56 +164,53 @@ namespace gum {
}
// Returns the header of the BIF file.
template <typename GUM_SCALAR>
template < typename GUM_SCALAR >
INLINE std::string
BIFWriter<GUM_SCALAR>::__header( const IBayesNet<GUM_SCALAR>& bn ) {
BIFWriter< GUM_SCALAR >::__header(const IBayesNet< GUM_SCALAR >& bn) {
std::stringstream str;
std::string tab = " "; // poor tabulation
str << std::endl
<< "network \"" << bn.propertyWithDefault( "name", "unnamedBN" ) << "\" {"
str << "network \"" << bn.propertyWithDefault("name", "unnamedBN") << "\" {"
<< std::endl;
str << tab << "property"
<< " software aGrUM"
<< ";" << std::endl;
str << "// written by aGrUM " << GUM_VERSION << std::endl;
str << "}" << std::endl;
return str.str();
}
// Returns a bloc defining a variable in the BIF format.
template <typename GUM_SCALAR>
template < typename GUM_SCALAR >
INLINE std::string
BIFWriter<GUM_SCALAR>::__variableBloc( const DiscreteVariable& var ) {
BIFWriter< GUM_SCALAR >::__variableBloc(const DiscreteVariable& var) {
std::stringstream str;
std::string tab = " "; // poor tabulation
str << "variable " << var.name() << " {" << std::endl;
str << tab << "type discrete[" << var.domainSize() << "] {";
for ( Idx i = 0; i < var.domainSize() - 1; i++ ) {
str << var.label( i ) << ", ";
for (Idx i = 0; i < var.domainSize() - 1; i++) {
str << var.label(i) << ", ";
}
str << var.label( var.domainSize() - 1 ) << "};" << std::endl;
str << var.label(var.domainSize() - 1) << "};" << std::endl;
str << "}" << std::endl;
return str.str();
}
// Returns the modalities labels of the variables in varsSeq
template <typename GUM_SCALAR>
INLINE std::string BIFWriter<GUM_SCALAR>::__variablesLabels(
const Sequence<const DiscreteVariable*>& varsSeq,
const Instantiation& inst ) {
template < typename GUM_SCALAR >
INLINE std::string BIFWriter< GUM_SCALAR >::__variablesLabels(
const Sequence< const DiscreteVariable* >& varsSeq,
const Instantiation& inst) {
std::stringstream str;
const DiscreteVariable* varPtr = nullptr;
for ( Idx i = 1; i < varsSeq.size() - 1; i++ ) {
for (Idx i = 1; i < varsSeq.size() - 1; i++) {
varPtr = varsSeq[i];
str << varPtr->label( inst.val( *varPtr ) ) << ", ";
str << varPtr->label(inst.val(*varPtr)) << ", ";
}
varPtr = varsSeq[varsSeq.size() - 1];
str << varPtr->label( inst.val( *varPtr ) );
str << varPtr->label(inst.val(*varPtr));
return str.str();
}
......
......@@ -62,7 +62,7 @@ namespace std {
bool hasUniqueElts(std::vector< T > const& x) {
if (x.size() <= 1) return true;
if (x.size() == 2) return x[0] != x[1];
auto refless = [](T const* l, T const* r) { return *l < *r; };
auto refeq = [](T const* l, T const* r) { return *l == *r; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment