Commit 8df58161 by Christophe Gonzales

adding missing values support in BNLearner

parent 43563102
Pipeline #21673148 passed with stages
in 77 minutes 25 seconds
......@@ -459,6 +459,14 @@ namespace gum {
*/
class UnknownLabelInDatabase;
/**
* @class gum::MissingValueInDatabase agrum/core/exceptions.h
* @extends gum::LearningError
* Error: The database contains some missing values
*/
class MissingValueInDatabase;
/**
* @class gum::SyntaxError exceptions.h <agrum/core/exceptions.h>
* @extends gum::IOError
......@@ -531,6 +539,9 @@ namespace gum {
GUM_MAKE_ERROR(MissingVariableInDatabase,
LearningError,
"Missing variable name in database")
GUM_MAKE_ERROR(MissingValueInDatabase,
LearningError,
"The database contains some missing values")
GUM_MAKE_ERROR(UnknownLabelInDatabase,
LearningError,
"Unknown label found in database")
......
......@@ -43,7 +43,7 @@ namespace gum {
genericBNLearner::Database::Database(const DatabaseTable<>& db)
genericBNLearner::Database::Database( const DatabaseTable<>& db)
: __database(db) {
// get the variables names
const auto& var_names = __database.variableNames ();
......@@ -61,8 +61,10 @@ namespace gum {
}
genericBNLearner::Database::Database(const std::string& filename)
: Database(genericBNLearner::__readFile(filename)) {}
genericBNLearner::Database::Database(
const std::string& filename,
const std::vector<std::string>& missing_symbols)
: Database(genericBNLearner::__readFile(filename, missing_symbols) ) {}
......@@ -97,9 +99,11 @@ namespace gum {
*/
genericBNLearner::Database::Database(const std::string& filename,
Database& apriori_database)
: __database(genericBNLearner::__readFile(filename)) {
genericBNLearner::Database::Database(
const std::string& filename,
Database& apriori_database,
const std::vector<std::string>& missing_symbols)
: __database(genericBNLearner::__readFile(filename,missing_symbols)) {
// check that there are at least as many variables in the a priori
// database as those in the score_database
if (__database.nbVariables() < apriori_database.__database.nbVariables()) {
......@@ -199,8 +203,10 @@ namespace gum {
// ===========================================================================
genericBNLearner::genericBNLearner(const std::string& filename)
: __score_database ( filename ) {
genericBNLearner::genericBNLearner(
const std::string& filename,
const std::vector<std::string>& missing_symbols)
: __score_database ( filename, missing_symbols ) {
// for debugging purposes
GUM_CONSTRUCTOR(genericBNLearner);
}
......@@ -464,7 +470,9 @@ namespace gum {
DatabaseTable<> genericBNLearner::__readFile(const std::string& filename) {
DatabaseTable<>
genericBNLearner::__readFile(const std::string& filename,
const std::vector<std::string>& missing_symbols) {
// get the extension of the file
Size filename_size = Size(filename.size());
......@@ -491,15 +499,21 @@ namespace gum {
const std::size_t nb_vars = var_names.size ();
DBTranslatorSet<> translator_set;
DBTranslator4LabelizedVariable<> translator;
DBTranslator4LabelizedVariable<> translator ( missing_symbols );
for ( std::size_t i = 0; i < nb_vars; ++i ) {
translator_set.insertTranslator ( translator, i );
}
DatabaseTable<> database ( translator_set );
DatabaseTable<> database ( missing_symbols, translator_set );
database.setVariableNames( initializer.variableNames () );
initializer.fillDatabase ( database );
// check that the database does not contain any missing value
if ( database.hasMissingValues () )
GUM_ERROR ( MissingValueInDatabase,
"For the moment, the BNLearaner is unable to cope "
"with missing values in databases" );
database.reorder ();
return database;
......@@ -527,7 +541,8 @@ namespace gum {
}
if (__user_modalities.empty()) {
__apriori_database = new Database(__apriori_dbname, __score_database);
__apriori_database = new Database(__apriori_dbname, __score_database,
__score_database.missingSymbols () );
} else {
GUM_ERROR(OperationNotAllowed, "not implemented" );
//__apriori_database =
......
......@@ -126,7 +126,8 @@ namespace gum {
/// @{
/// default constructor
explicit Database(const std::string& file);
explicit Database(const std::string& file,
const std::vector<std::string>& missing_symbols );
explicit Database(const DatabaseTable<>& db);
/// default constructor with defined modalities for some variables
......@@ -154,7 +155,9 @@ namespace gum {
* apriori database is the same as in the score/parameter database
* read before creating the apriori. This is compulsory to have
* aprioris that make sense. */
Database(const std::string& filename, Database& score_database);
Database(const std::string& filename,
Database& score_database,
const std::vector<std::string>& missing_symbols );
/// default constructor for the aprioris
/** We must ensure that, when reading the apriori database, if the
......@@ -175,12 +178,14 @@ namespace gum {
*/
template < typename GUM_SCALAR >
Database( const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& bn);
const gum::BayesNet< GUM_SCALAR >& bn,
const std::vector<std::string>& missing_symbols );
template < typename GUM_SCALAR >
Database( const std::string& filename,
Database& score_database,
const gum::BayesNet< GUM_SCALAR >& bn);
const gum::BayesNet< GUM_SCALAR >& bn,
const std::vector<std::string>& missing_symbols );
/// copy constructor
Database(const Database& from);
......@@ -229,6 +234,9 @@ namespace gum {
/// returns the internal database table
const DatabaseTable<>& databaseTable () const;
/// returns the set of missing symbols taken into account
const std::vector<std::string>& missingSymbols () const;
/// @}
protected:
......@@ -279,7 +287,8 @@ namespace gum {
* read the database file for the score / parameter estimation and var
* names
*/
genericBNLearner(const std::string& db);
genericBNLearner(const std::string& filename,
const std::vector<std::string>& missing_symbols );
genericBNLearner(const DatabaseTable<>& db);
/**
......@@ -302,8 +311,9 @@ namespace gum {
* if we find other values in the database, an exception will be raised
* during learning). */
template < typename GUM_SCALAR >
genericBNLearner(const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& src );
genericBNLearner(const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& src,
const std::vector<std::string>& missing_symbols );
/// copy constructor
genericBNLearner(const genericBNLearner&);
......@@ -582,7 +592,9 @@ namespace gum {
const ApproximationScheme* __current_algorithm{nullptr};
/// reads a file and returns a databaseVectInRam
static DatabaseTable<> __readFile(const std::string& filename);
static DatabaseTable<>
__readFile(const std::string& filename,
const std::vector<std::string>& missing_symbols);
/// checks whether the extension of a CSV filename is correct
static void __checkFileName(const std::string& filename);
......
......@@ -73,6 +73,14 @@ namespace gum {
return __database;
}
/// returns the set of missing symbols taken into account
INLINE const std::vector<std::string>&
genericBNLearner::Database::missingSymbols () const {
return __database.missingSymbols ();
}
// ===========================================================================
......
......@@ -5,8 +5,9 @@ namespace gum {
template < typename GUM_SCALAR >
genericBNLearner::Database::Database(
const std::string& filename,
const BayesNet< GUM_SCALAR >& bn ) {
const std::string& filename,
const BayesNet< GUM_SCALAR >& bn,
const std::vector<std::string>& missing_symbols ) {
// assign to each column name in the database its position
genericBNLearner::__checkFileName( filename );
DBInitializerFromCSV<> initializer ( filename );
......@@ -20,7 +21,8 @@ namespace gum {
try {
for ( auto node : bn.dag () ) {
const Variable& var = bn.variable(node);
__database.insertTranslator ( var, var_names[var.name()] );
__database.insertTranslator ( var, var_names[var.name()],
missing_symbols );
}
}
catch ( NotFound& ) {
......@@ -30,6 +32,12 @@ namespace gum {
// fill the database
initializer.fillDatabase ( __database );
// check that the database does not contain any missing value
if ( __database.hasMissingValues () )
GUM_ERROR ( MissingValueInDatabase,
"For the moment, the BNLearaner is unable to cope "
"with missing values in databases" );
// get the domain sizes of the variables
for ( auto dom : __database.domainSizes () )
......@@ -50,8 +58,9 @@ namespace gum {
genericBNLearner::Database::Database(
const std::string& filename,
Database& score_database,
const BayesNet< GUM_SCALAR >& bn)
: __database(genericBNLearner::__readFile(filename,bn)) {
const BayesNet< GUM_SCALAR >& bn,
const std::vector<std::string>& missing_symbols)
: __database(genericBNLearner::__readFile(filename,bn,missing_symbols)) {
}
......@@ -72,8 +81,9 @@ namespace gum {
template < typename GUM_SCALAR >
genericBNLearner::genericBNLearner(
const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& bn )
: __score_database ( filename, bn ) {
const gum::BayesNet< GUM_SCALAR >& bn,
const std::vector<std::string>& missing_symbols )
: __score_database ( filename, bn, missing_symbols ) {
GUM_CONSTRUCTOR(genericBNLearner);
}
......
......@@ -65,7 +65,8 @@ namespace gum {
* read the database file for the score / parameter estimation and var
* names
*/
BNLearner(const std::string& filename);
BNLearner(const std::string& filename,
const std::vector<std::string>& missing_symbols = { "?" } );
BNLearner(const DatabaseTable<>& db);
/**
......@@ -102,7 +103,8 @@ namespace gum {
* to find those modalities and nodeids.
**/
BNLearner(const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& src );
const gum::BayesNet< GUM_SCALAR >& src,
const std::vector<std::string>& missing_symbols = { "?" } );
/// copy constructor
BNLearner(const BNLearner&);
......
......@@ -38,8 +38,10 @@ namespace gum {
namespace learning {
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename)
: genericBNLearner(filename) {
BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
const std::vector<std::string>& missing_symbols)
: genericBNLearner(filename,missing_symbols) {
GUM_CONSTRUCTOR(BNLearner);
}
......@@ -50,9 +52,11 @@ namespace gum {
}
template < typename GUM_SCALAR >
BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& bn )
: genericBNLearner(filename, bn ) {
BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
const gum::BayesNet< GUM_SCALAR >& bn,
const std::vector<std::string>& missing_symbols )
: genericBNLearner(filename, bn, missing_symbols ) {
GUM_CONSTRUCTOR (BNLearner)
}
......
......@@ -276,7 +276,7 @@ namespace gum {
using MissingValType = std::vector<std::string,XALLOC<std::string>>;
enum IsMissing : char { True, False };
enum IsMissing : char { False, True };
/** @class Handler
......@@ -977,6 +977,13 @@ namespace gum {
/** @brief returns the set of columns of the original dataset that are
* present in the IDatabaseTable */
virtual const DBVector<std::size_t> inputColumns () const = 0;
/// indicates whether the database contains some missing values
bool hasMissingValues () const;
/// indicates whether the kth row contains some missing values
bool hasMissingValues ( const std::size_t k ) const;
using IDatabaseTableInsert4DBCell<ALLOC,
!std::is_same<T_DATA,DBCell>::value>::insertRow;
......@@ -1076,6 +1083,9 @@ namespace gum {
/// returns the allocator of the database
ALLOC<T_DATA> getAllocator () const;
/// returns the set of missing symbols
const DBVector<std::string>& missingSymbols () const;
/// @}
......@@ -1086,6 +1096,9 @@ namespace gum {
/// returns the content of the database
Matrix<T_DATA>& _content() noexcept;
/// returns the vector indicating whether a row contains missing values
DBVector<IsMissing>& _hasRowMissingVal () noexcept;
/// returns the set of symbols for the missing values
const DBVector<std::string>& _missingSymbols () const;
......
......@@ -742,6 +742,32 @@ namespace gum {
return __data;
}
/// returns the vector indicating whether a row contains missing values
template <typename T_DATA, template<typename> class ALLOC>
INLINE typename IDatabaseTable<T_DATA,ALLOC>::template
DBVector<typename IDatabaseTable<T_DATA,ALLOC>::IsMissing>&
IDatabaseTable<T_DATA,ALLOC>::_hasRowMissingVal () noexcept {
return __has_row_missing_val;
}
/// indicates whether the database contains some missing values
template <typename T_DATA, template<typename> class ALLOC>
bool IDatabaseTable<T_DATA,ALLOC>::hasMissingValues () const {
for ( const auto& status : __has_row_missing_val )
if ( status == IsMissing::True ) return true;
return false;
}
/// indicates whether the kth row contains some missing values
template <typename T_DATA, template<typename> class ALLOC>
INLINE bool
IDatabaseTable<T_DATA,ALLOC>::hasMissingValues ( const std::size_t k ) const {
return __has_row_missing_val[k] == IsMissing::True;
}
// returns the variable names for all the columns
template <typename T_DATA, template<typename> class ALLOC>
......@@ -1149,6 +1175,14 @@ namespace gum {
}
// returns the set of symbols for the missing values
template <typename T_DATA, template<typename> class ALLOC>
INLINE const std::vector<std::string,ALLOC<std::string>>&
IDatabaseTable<T_DATA,ALLOC>::missingSymbols () const {
return __missing_symbols;
}
/// insert new rows at the end of the database
template <template<typename> class ALLOC>
void IDatabaseTableInsert4DBCell<ALLOC,true>::insertRows (
......
......@@ -436,17 +436,34 @@ namespace gum {
// as well as the resulting column in __data and the _variable_names.
// Note that if there remains no more variable left, __data should
// become empty
__translators.eraseTranslator ( pos, false );
this->_variable_names.erase ( this->_variable_names.begin() + pos );
if ( this->_variable_names.empty () ) {
IDatabaseTable<DBTranslatedValue,ALLOC>::eraseAllRows();
__translators.eraseTranslator ( pos, false );
}
else {
for ( auto& xrow :
IDatabaseTable<DBTranslatedValue,ALLOC>::_content() ) {
auto& row = xrow.row ();
auto& rows =
IDatabaseTable<DBTranslatedValue,ALLOC>::_content();
auto& has_row_missing_val =
IDatabaseTable<DBTranslatedValue,ALLOC>::_hasRowMissingVal ();
const std::size_t nb_trans = __translators.size ();
const std::size_t nb_rows = rows.size ();
for ( std::size_t i = std::size_t(0); i < nb_rows; ++i ) {
auto& row = rows[i].row ();
if ( has_row_missing_val[i] == IsMissing::True ) {
bool has_missing_val = false;
for ( std::size_t j = std::size_t (0); j < nb_trans; ++j ) {
if ( ( j != pos ) && __translators.isMissingValue(row[j], j) ) {
has_missing_val = true;
break;
}
}
if ( ! has_missing_val )
has_row_missing_val[i] = IsMissing::False;
}
row.erase ( row.begin() + pos );
}
__translators.eraseTranslator ( pos, false );
}
}
}
......@@ -643,7 +660,9 @@ namespace gum {
for ( std::size_t i = std::size_t(0); i < row_size; ++i ) {
switch ( translators[i]->getValType () ) {
case DBTranslatedValueType::DISCRETE:
if ( row[i].discr_val >= translators[i]->domainSize () ) return false;
if ( ( row[i].discr_val >= translators[i]->domainSize () ) &&
( row[i].discr_val != std::numeric_limits<std::size_t>::max() ) )
return false;
break;
case DBTranslatedValueType::CONTINUOUS:
......@@ -651,8 +670,10 @@ namespace gum {
const ContinuousVariable<float>* var =
static_cast<const ContinuousVariable<float>*>
( translators[i]->variable () );
if ( ( var->lowerBound () > row[i].cont_val ) ||
( var->upperBound () < row[i].cont_val ) ) return false;
if ( ( ( var->lowerBound () > row[i].cont_val ) ||
( var->upperBound () < row[i].cont_val ) ) &&
( row[i].cont_val != std::numeric_limits<float>::max() ) )
return false;
break;
}
......
......@@ -261,9 +261,28 @@ namespace gum {
IDatabaseTable<DBCell,ALLOC>::eraseAllRows ();
}
else {
for ( auto& xrow : IDatabaseTable<DBCell,ALLOC>::_content () ) {
auto& row = xrow.row ();
row.erase ( row.begin() + col );
auto& rows =
IDatabaseTable<DBCell,ALLOC>::_content();
auto& has_row_missing_val =
IDatabaseTable<DBCell,ALLOC>::_hasRowMissingVal ();
const std::size_t nb_rows = rows.size ();
if ( nb_rows != std::size_t (0) ) {
const std::size_t nb_cols = rows[0].size ();
for ( std::size_t i = std::size_t(0); i < nb_rows; ++i ) {
auto& row = rows[i].row ();
if ( has_row_missing_val[i] == IsMissing::True ) {
bool has_missing_val = false;
for ( std::size_t j = std::size_t (0); j < nb_cols; ++j ) {
if ( ( j != col ) && row[j].isMissing () ) {
has_missing_val = true;
break;
}
}
if ( ! has_missing_val )
has_row_missing_val[i] = IsMissing::False;
}
row.erase ( row.begin() + col );
}
}
}
}
......
......@@ -820,6 +820,23 @@ namespace gum_tests {
TS_ASSERT_EQUALS(p1[I1], p3[I3]); // same probabilities
}
}
void test_asia_with_missing_values() {
int nb = 0;
try {
gum::learning::BNLearner< double >
learner(GET_RESSOURCES_PATH("asia3-faulty.csv"),
std::vector<std::string>{"BEURK"});
learner.useK2(std::vector< gum::NodeId >{1, 5, 2, 6, 0, 3, 4, 7});
}
catch (gum::MissingValueInDatabase& e) {
nb = 1;
}
TS_ASSERT ( nb == 1 );
}
};
} /* namespace gum_tests */
......@@ -1107,7 +1107,95 @@ namespace gum_tests {
}
void test_missing_vals () {
std::vector<std::string> missing { "?", "N/A", "???" };
gum::learning::DatabaseTable<> database;
gum::LabelizedVariable var ( "var0", "", 0 );
var.addLabel ( "L1" );
var.addLabel ( "L2" );
var.addLabel ( "L0" );
database.insertTranslator<> ( var, 0, missing );
gum::LabelizedVariable var1 ( "var1", "", 0 );
var1.addLabel ( "L0" );
var1.addLabel ( "L1" );
var1.addLabel ( "L2" );
database.insertTranslator<> ( var1, 1, missing );
var.setName ( "var2" );
database.insertTranslator<> ( var, 2, missing );
var.setName ( "var3" );
database.insertTranslator<> ( var, 3, missing );
const auto& vnames = database.variableNames();
TS_ASSERT( vnames.size() == 4 );
TS_ASSERT( vnames[0] == "var0" );
TS_ASSERT( vnames[1] == "var1" );
TS_ASSERT( vnames[2] == "var2" );
TS_ASSERT( vnames[3] == "var3" );
std::vector<std::string> row { "L0", "L1", "L2", "L0" };
database.insertRow( row );
row[0] = "?";
database.insertRow( row );
row[0] = "L0";
row[1] = "?";
database.insertRow( row );
row[2] = "N/A";
database.insertRow( row );
row[0] = "???";
database.insertRow( row );
row[0] = "L0";
row[1] = "L0";
row[2] = "L0";
TS_ASSERT ( database.hasMissingValues () );
TS_ASSERT ( database.hasMissingValues ( 0 ) == false );
TS_ASSERT ( database.hasMissingValues ( 1 ) == true );
TS_ASSERT ( database.hasMissingValues ( 2 ) == true );
TS_ASSERT ( database.hasMissingValues ( 3 ) == true );
TS_ASSERT ( database.hasMissingValues ( 4 ) == true );
TS_ASSERT ( database.hasMissingValues ( 5 ) == false );
database.ignoreColumn ( 1 );
TS_ASSERT ( database.hasMissingValues () );
TS_ASSERT ( database.hasMissingValues ( 0 ) == false );
TS_ASSERT ( database.hasMissingValues ( 1 ) == true );
TS_ASSERT ( database.hasMissingValues ( 2 ) == false );
TS_ASSERT ( database.hasMissingValues ( 3 ) == true );
TS_ASSERT ( database.hasMissingValues ( 4 ) == true );
TS_ASSERT ( database.hasMissingValues ( 5 ) == false );
database.ignoreColumn ( 2 );
TS_ASSERT ( database.hasMissingValues () );
TS_ASSERT ( database.hasMissingValues ( 0 ) == false );
TS_ASSERT ( database.hasMissingValues ( 1 ) == true );
TS_ASSERT ( database.hasMissingValues ( 2 ) == false );
TS_ASSERT ( database.hasMissingValues ( 3 ) == false );
TS_ASSERT ( database.hasMissingValues ( 4 ) == true );
TS_ASSERT ( database.hasMissingValues ( 5 ) == false );
database.ignoreColumn ( 0 );
TS_ASSERT ( database.hasMissingValues () == false );
TS_ASSERT ( database.hasMissingValues ( 0 ) == false );
TS_ASSERT ( database.hasMissingValues ( 1 ) == false );
TS_ASSERT ( database.hasMissingValues ( 2 ) == false );
TS_ASSERT ( database.hasMissingValues ( 3 ) == false );
TS_ASSERT ( database.hasMissingValues ( 4 ) == false );
TS_ASSERT ( database.hasMissingValues ( 5 ) == false );
database.ignoreColumn ( 3 );
TS_ASSERT ( database.hasMissingValues () == false );
}
private:
......
......@@ -825,6 +825,81 @@ namespace gum_tests {
}
void test_missing_vals () {
std::vector<std::string> missing { "?", "N/A", "???" };
gum::learning::RawDatabaseTable<> database ( missing );
TS_ASSERT( database.content().size() == 0 );
TS_ASSERT( database.variableNames().size() == 0 );
std::vector<std::string> vect { "v0", "v1", "v2", "v3" };
database.setVariableNames( vect );
TS_ASSERT( database.variableNames().size() == 4 );
TS_ASSERT( database.nbVariables() == 4 );
std::vector<std::string> row { "L0", "L1", "L2", "L0" };
database.insertRow( row );
row[0] = "?";
database.insertRow( row );
row[0] = "L0";
row[1] = "?";
database.insertRow( row );
row[2] = "N/A";
database.insertRow( row );
row[0] = "???";
database.insertRow( row );
row[0] = "L0";
row[1] = "L0";
row[2] = "L0";
TS_ASSERT ( database.hasMissingValues () );
TS_ASSERT ( database.hasMissingValues ( 0 ) == false );
TS_ASSERT ( database.hasMissingValues ( 1 ) == true );
TS_ASSERT ( database.hasMissingValues ( 2 ) == true );
TS_ASSERT ( database.hasMissingValues ( 3 ) == true );
TS_ASSERT ( database.hasMissingValues ( 4 ) == true );
TS_ASSERT ( database.hasMissingValues ( 5 ) == false );
database.ignoreColumn ( 1 );
TS_ASSERT ( database.hasMissingValues () );
TS_ASSERT ( database.hasMissingValues ( 0 ) == false );
TS_ASSERT ( database.hasMissingValues ( 1 ) == true );
TS_ASSERT ( database.hasMissingValues ( 2 ) == false );
TS_ASSERT ( database.hasMissingValues ( 3 ) == true );
TS_ASSERT ( database.hasMissingValues ( 4 ) == true );
TS_ASSERT ( database.hasMissingValues ( 5 ) == false );
database.ignoreColumn ( 2 );
TS_ASSERT ( database.hasMissingValues () );
TS_ASSERT ( database.hasMissingValues ( 0 ) == false );
TS_ASSERT ( database.hasMissingValues ( 1 ) == true );