Commit 9634a38d authored by Christophe Gonzales's avatar Christophe Gonzales

added some thread support to databases + removed a method in BNLearner that…

added some thread support to databases + removed a method in BNLearner that could lead to potential bugs
parent af7814e2
Pipeline #23798268 failed with stages
in 26 minutes and 5 seconds
......@@ -122,7 +122,7 @@ namespace gum {
__refcount = 0;
__val = from;
} catch (std::bad_alloc &) {
} catch (std::bad_alloc&) {
if (*old_refcount == 1) {
__val = from;
delete old_val;
......
......@@ -147,8 +147,8 @@ namespace gum {
BayesNet< GUM_SCALAR > learnParameters(const DAG& dag,
bool take_into_account_score = true);
/// learns a BN (its parameters) when its structure is known
/**
// learns a BN (its parameters) when its structure is known
/*
* @param bn the structure of the Bayesian network
* @param take_into_account_score The dag passed in argument may have
* been learnt from a structure learning. In this case, if the score used
......@@ -165,8 +165,7 @@ namespace gum {
* @throw UnknownLabelInDatabase if a label is found in the databast that
* do not correpond to the variable.
*/
BayesNet< GUM_SCALAR > learnParameters(const BayesNet< GUM_SCALAR >& bn,
bool take_into_account_score = true);
//BayesNet< GUM_SCALAR > learnParameters(bool take_into_account_score = true);
private:
/// read the first line of a file to find column names
......
......@@ -139,10 +139,10 @@ namespace gum {
}
/// learns a BN (its parameters) when its structure is known
/*
template < typename GUM_SCALAR >
BayesNet< GUM_SCALAR >
BNLearner< GUM_SCALAR >::learnParameters(const BayesNet< GUM_SCALAR >& bn,
bool take_into_account_score) {
BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
// create the apriori and the estimator
__createApriori();
__createParamEstimator(take_into_account_score);
......@@ -152,6 +152,7 @@ namespace gum {
NodeProperty< NodeId > mapIds(bn.size());
auto mods = modalities();
for (auto node : bn.nodes()) {
const NodeId new_id = idFromName(bn.variable(node).name());
......@@ -168,6 +169,7 @@ namespace gum {
newDAG.addArc(mapIds[arc.tail()], mapIds[arc.head()]);
}
return DAG2BNLearner::
createBN< GUM_SCALAR, ParamEstimator<>, DBTranslatorSet<> >(
*__param_estimator,
......@@ -176,6 +178,7 @@ namespace gum {
__score_database.modalities(),
__score_database.databaseTable().translatorSet());
}
*/
template < typename GUM_SCALAR >
NodeProperty< Sequence< std::string > >
......
......@@ -223,32 +223,31 @@ namespace gum {
return false;
}
// raises an appropriate exception when encountering a type error
std::string DBCell::__typeErrorMsg ( const std::string& true_type ) const {
std::string DBCell::__typeErrorMsg(const std::string& true_type) const {
std::stringstream str;
switch ( __type ) {
case EltType::REAL:
str << "The DBCell contains a real number instead of " << true_type;
break;
case EltType::INTEGER:
str << "The DBCell contains an integer instead of " << true_type;
break;
case EltType::STRING:
str << "The DBCell contains a string instead of " << true_type;
break;
case EltType::MISSING:
str << "The DBCell contains a missing value instead of " << true_type;
break;
default:
GUM_ERROR(NotImplementedYet, "DBCell type not implemented yet");
switch (__type) {
case EltType::REAL:
str << "The DBCell contains a real number instead of " << true_type;
break;
case EltType::INTEGER:
str << "The DBCell contains an integer instead of " << true_type;
break;
case EltType::STRING:
str << "The DBCell contains a string instead of " << true_type;
break;
case EltType::MISSING:
str << "The DBCell contains a missing value instead of " << true_type;
break;
default: GUM_ERROR(NotImplementedYet, "DBCell type not implemented yet");
}
return str.str ();
return str.str();
}
} /* namespace learning */
......
......@@ -273,8 +273,8 @@ namespace gum {
typename std::conditional< sizeof(int) < sizeof(float), float, int >::type;
// raises an appropriate exception when encountering a type error
std::string __typeErrorMsg ( const std::string& real_type ) const;
std::string __typeErrorMsg(const std::string& real_type) const;
// a bijection assigning to each string index its corresponding string
static Bijection< std::string, int >& __strings();
......
......@@ -157,13 +157,13 @@ namespace gum {
/// returns the current type of the DBCell
INLINE DBCell::EltType DBCell::type() const noexcept { return __type; }
/// returns the DBcell as a float
INLINE float DBCell::real() const {
if (__type == EltType::REAL)
return __val_real;
else
GUM_ERROR(TypeError, __typeErrorMsg ("a real number") );
GUM_ERROR(TypeError, __typeErrorMsg("a real number"));
}
......@@ -188,7 +188,7 @@ namespace gum {
if (__type == EltType::INTEGER)
return __val_integer;
else
GUM_ERROR(TypeError, __typeErrorMsg ("an integer") );
GUM_ERROR(TypeError, __typeErrorMsg("an integer"));
}
......@@ -213,7 +213,7 @@ namespace gum {
if (__type == EltType::STRING)
return __strings().first(__val_index);
else
GUM_ERROR(TypeError, __typeErrorMsg ("a string") );
GUM_ERROR(TypeError, __typeErrorMsg("a string"));
}
......@@ -222,7 +222,7 @@ namespace gum {
if (__type == EltType::STRING)
return __val_index;
else
GUM_ERROR(TypeError, __typeErrorMsg ("a string") );
GUM_ERROR(TypeError, __typeErrorMsg("a string"));
}
......
......@@ -56,42 +56,42 @@ namespace gum {
deb_index = connection_string.find(delimiter, 0);
if (deb_index == std::string::npos)
GUM_ERROR(DatabaseError,
"could not determine the datasource from string " <<
connection_string);
"could not determine the datasource from string "
<< connection_string);
deb_index += std::size_t(1);
end_index = connection_string.find(delimiter, deb_index);
if (end_index == std::string::npos)
GUM_ERROR(DatabaseError,
"could not determine the datasource from string " <<
connection_string);
"could not determine the datasource from string "
<< connection_string);
std::string dataSource =
connection_string.substr(deb_index, end_index - deb_index);
deb_index = connection_string.find(delimiter, end_index + std::size_t(1));
if (deb_index == std::string::npos)
GUM_ERROR(DatabaseError,
"could not determine the database login from string " <<
connection_string);
"could not determine the database login from string "
<< connection_string);
deb_index += std::size_t(1);
end_index = connection_string.find(delimiter, deb_index);
if (end_index == std::string::npos)
GUM_ERROR(DatabaseError,
"could not determine the database login from string " <<
connection_string);
"could not determine the database login from string "
<< connection_string);
std::string login =
connection_string.substr(deb_index, end_index - deb_index);
deb_index = connection_string.find(delimiter, end_index + std::size_t(1));
if (deb_index == std::string::npos)
GUM_ERROR(DatabaseError,
"could not determine the database password from string " <<
connection_string);
"could not determine the database password from string "
<< connection_string);
deb_index += std::size_t(1);
end_index = connection_string.find(delimiter, deb_index);
if (end_index == std::string::npos)
GUM_ERROR(DatabaseError,
"could not determine the database password from string " <<
connection_string);
"could not determine the database password from string "
<< connection_string);
std::string password =
connection_string.substr(deb_index, end_index - deb_index);
......
......@@ -351,7 +351,7 @@ namespace gum {
bool isMissingValue(const DBTranslatedValue& val) const;
/// returns the translation of a missing value
virtual DBTranslatedValue missingValue () const = 0;
virtual DBTranslatedValue missingValue() const = 0;
/// @}
......
......@@ -367,7 +367,7 @@ namespace gum {
virtual const IContinuousVariable* variable() const final;
/// returns the translation of a missing value
virtual DBTranslatedValue missingValue () const final;
virtual DBTranslatedValue missingValue() const final;
/// @}
......
......@@ -373,8 +373,10 @@ namespace gum {
if (this->isMissingSymbol(str)) {
return DBTranslatedValue{std::numeric_limits< float >::max()};
} else
GUM_ERROR(TypeError, "String \"" << str <<
"\" cannot be translated because it is not a number");
GUM_ERROR(TypeError,
"String \""
<< str
<< "\" cannot be translated because it is not a number");
}
// here we know that the string is a number
......@@ -394,8 +396,9 @@ namespace gum {
// check if we are allowed to update the domain of the variable
if (!__fit_range) {
GUM_ERROR(UnknownLabelInDatabase,
"String \"" << str << "\" cannot be translated because it is "
"out of the domain of the continuous variable");
"String \"" << str
<< "\" cannot be translated because it is "
"out of the domain of the continuous variable");
}
// now, we can try to add str as a new bound of the range variable
......@@ -425,9 +428,10 @@ namespace gum {
const float miss_val = std::stof(missing.first);
if ((miss_val >= number) && (miss_val <= upper_bound)) {
GUM_ERROR(OperationNotAllowed,
"String \"" << str << "\" cannot be translated because " <<
"it would induce a new domain containing an already " <<
"translated missing symbol");
"String \""
<< str << "\" cannot be translated because "
<< "it would induce a new domain containing an already "
<< "translated missing symbol");
}
}
}
......@@ -458,9 +462,10 @@ namespace gum {
const float miss_val = std::stof(missing.first);
if ((miss_val >= lower_bound) && (miss_val <= number)) {
GUM_ERROR(OperationNotAllowed,
"String \"" << str << "\" cannot be translated because " <<
"it would induce a new domain containing an already " <<
"translated missing symbol");
"String \""
<< str << "\" cannot be translated because "
<< "it would induce a new domain containing an already "
<< "translated missing symbol");
}
}
}
......@@ -500,9 +505,10 @@ namespace gum {
if ((translated_val.cont_val < __variable.lowerBound())
|| (translated_val.cont_val > __variable.upperBound())) {
GUM_ERROR(UnknownLabelInDatabase,
"The back translation of " << translated_val.cont_val <<
" could not be found because the value is outside the " <<
"domain of the continuous variable");
"The back translation of "
<< translated_val.cont_val
<< " could not be found because the value is outside the "
<< "domain of the continuous variable");
}
char buffer[100];
......@@ -547,11 +553,11 @@ namespace gum {
return __real_variable;
}
/// returns the translation of a missing value
template < template < typename > class ALLOC >
INLINE DBTranslatedValue
DBTranslator4ContinuousVariable< ALLOC >::missingValue () const {
DBTranslator4ContinuousVariable< ALLOC >::missingValue() const {
return DBTranslatedValue{std::numeric_limits< float >::max()};
}
......
......@@ -316,7 +316,7 @@ namespace gum {
virtual const IDiscretizedVariable* variable() const final;
/// returns the translation of a missing value
virtual DBTranslatedValue missingValue () const final;
virtual DBTranslatedValue missingValue() const final;
/// @}
......
......@@ -381,8 +381,10 @@ namespace gum {
return DBTranslatedValue{this->_back_dico.first(str)};
} catch (gum::Exception&) {
if (!DBCell::isReal(str)) {
GUM_ERROR(TypeError, "String \"" << str <<
"\" cannot be translated because it is not a number");
GUM_ERROR(TypeError,
"String \""
<< str
<< "\" cannot be translated because it is not a number");
} else {
GUM_ERROR(UnknownLabelInDatabase,
"The translation of \"" << str << "\" could not be found");
......@@ -405,8 +407,8 @@ namespace gum {
return *(this->_missing_symbols.begin());
else
GUM_ERROR(UnknownLabelInDatabase,
"The back translation of \"" << translated_val.discr_val <<
"\" could not be found");
"The back translation of \"" << translated_val.discr_val
<< "\" could not be found");
}
}
......@@ -463,7 +465,7 @@ namespace gum {
/// returns the translation of a missing value
template < template < typename > class ALLOC >
INLINE DBTranslatedValue
DBTranslator4DiscretizedVariable< ALLOC >::missingValue () const {
DBTranslator4DiscretizedVariable< ALLOC >::missingValue() const {
return DBTranslatedValue{std::numeric_limits< std::size_t >::max()};
}
......
......@@ -371,7 +371,7 @@ namespace gum {
virtual const LabelizedVariable* variable() const final;
/// returns the translation of a missing value
virtual DBTranslatedValue missingValue () const final;
virtual DBTranslatedValue missingValue() const final;
/// @}
......
......@@ -260,8 +260,8 @@ namespace gum {
const std::size_t size = __variable.domainSize();
if (size >= this->_max_dico_entries)
GUM_ERROR(SizeError,
"String \"" << str << "\" cannot be translated " <<
"because the dictionary is already full" );
"String \"" << str << "\" cannot be translated "
<< "because the dictionary is already full");
__variable.addLabel(str);
this->_back_dico.insert(size, str);
return DBTranslatedValue{size};
......@@ -285,8 +285,8 @@ namespace gum {
return *(this->_missing_symbols.begin());
else
GUM_ERROR(UnknownLabelInDatabase,
"The back translation of \"" << translated_val.discr_val <<
"\" could not be found");
"The back translation of \"" << translated_val.discr_val
<< "\" could not be found");
}
}
......@@ -418,11 +418,11 @@ namespace gum {
return &__variable;
}
/// returns the translation of a missing value
template < template < typename > class ALLOC >
INLINE DBTranslatedValue
DBTranslator4LabelizedVariable< ALLOC >::missingValue () const {
DBTranslator4LabelizedVariable< ALLOC >::missingValue() const {
return DBTranslatedValue{std::numeric_limits< std::size_t >::max()};
}
......
......@@ -367,7 +367,7 @@ namespace gum {
virtual const RangeVariable* variable() const final;
/// returns the translation of a missing value
virtual DBTranslatedValue missingValue () const final;
virtual DBTranslatedValue missingValue() const final;
/// @}
......
......@@ -323,15 +323,15 @@ namespace gum {
// check if we are allowed to update the range variable
if (!this->hasEditableDictionary()) {
GUM_ERROR(UnknownLabelInDatabase,
"The translation of String \"" << str <<
"\" could not be found");
"The translation of String \"" << str
<< "\" could not be found");
}
// check if str could correspond to a bound of the range variable
if (!DBCell::isInteger(str)) {
GUM_ERROR(TypeError,
"String \"" << str << "\" cannot be translated because " <<
"it cannot be converted into an integer");
"String \"" << str << "\" cannot be translated because "
<< "it cannot be converted into an integer");
}
const long new_value = std::stol(str);
......@@ -339,8 +339,9 @@ namespace gum {
// translated, raise an exception
if (__translated_int_missing_symbols.exists(new_value)) {
GUM_ERROR(OperationNotAllowed,
"String \"" << str << "\" cannot be translated because " <<
"it corresponds to an already translated missing symbol");
"String \""
<< str << "\" cannot be translated because "
<< "it corresponds to an already translated missing symbol");
}
// now, we can try to add str as a new bound of the range variable
......@@ -353,8 +354,8 @@ namespace gum {
if (__variable.minVal() > __variable.maxVal()) {
if (this->_max_dico_entries == 0) {
GUM_ERROR(SizeError,
"String \"" << str << "\" cannot be translated because " <<
"the dictionary is already full");
"String \"" << str << "\" cannot be translated because "
<< "the dictionary is already full");
}
__variable.setMinVal(new_value);
__variable.setMaxVal(new_value);
......@@ -374,17 +375,18 @@ namespace gum {
if (new_value < __variable.minVal()) {
if (std::size_t(upper_bound - new_value + 1) > this->_max_dico_entries)
GUM_ERROR(SizeError,
"String \"" << str << "\" cannot be translated because " <<
"the dictionary is already full");
"String \"" << str << "\" cannot be translated because "
<< "the dictionary is already full");
// check that there does not already exist a translated missing
// value within the new bounds of the range variable
for (const auto& missing : __translated_int_missing_symbols) {
if ((missing >= new_value) && (missing <= upper_bound)) {
GUM_ERROR(OperationNotAllowed,
"String \"" << str << "\" cannot be translated " <<
"because it would induce a new range containing " <<
"an already translated missing symbol");
"String \""
<< str << "\" cannot be translated "
<< "because it would induce a new range containing "
<< "an already translated missing symbol");
}
}
......@@ -414,17 +416,18 @@ namespace gum {
} else {
if (std::size_t(new_value - lower_bound + 1) > this->_max_dico_entries)
GUM_ERROR(SizeError,
"String \"" << str << "\" cannot be translated because " <<
"the dictionary is already full");
"String \"" << str << "\" cannot be translated because "
<< "the dictionary is already full");
// check that there does not already exist a translated missing
// value within the new bounds of the range variable
for (const auto& missing : __translated_int_missing_symbols) {
if ((missing <= new_value) && (missing >= lower_bound)) {
GUM_ERROR(OperationNotAllowed,
"String \"" << str << "\" cannot be translated " <<
"because it would induce a new range containing " <<
"an already translated missing symbol");
"String \""
<< str << "\" cannot be translated "
<< "because it would induce a new range containing "
<< "an already translated missing symbol");
}
}
......@@ -471,8 +474,8 @@ namespace gum {
}
GUM_ERROR(UnknownLabelInDatabase,
"The back translation of \"" << translated_val.discr_val <<
"\" could not be found");
"The back translation of \"" << translated_val.discr_val
<< "\" could not be found");
}
}
......@@ -552,11 +555,11 @@ namespace gum {
return &__variable;
}
/// returns the translation of a missing value
template < template < typename > class ALLOC >
INLINE DBTranslatedValue
DBTranslator4RangeVariable< ALLOC >::missingValue () const {
DBTranslator4RangeVariable< ALLOC >::missingValue() const {
return DBTranslatedValue{std::numeric_limits< std::size_t >::max()};
}
......
......@@ -195,7 +195,7 @@ namespace gum {
template < template < template < typename > class > class Translator >
std::size_t insertTranslator(const Translator< ALLOC >& translator,
const std::size_t column,
const bool unique_column=true);
const bool unique_column = true);
/** @brief inserts a new translator for a given variable at the end of
* the translator set
......@@ -216,7 +216,7 @@ namespace gum {
const Variable& var,
const std::size_t column,
const std::vector< std::string, XALLOC< std::string > >& missing_symbols,
const bool unique_column=true);
const bool unique_column = true);
/** @brief inserts a new translator for a given variable at the end of
* the translator set
......@@ -232,7 +232,7 @@ namespace gum {
*/
std::size_t insertTranslator(const Variable& var,
const std::size_t column,
const bool unique_column=true);
const bool unique_column = true);
/** @brief erases either the kth translator or those parsing the kth
* column of the input database
......@@ -248,8 +248,7 @@ namespace gum {
* translators parse the column k, all of them are removed).
* @warning if the translator does not exists, nothing is done. In
* particular, no exception is raised. */
void eraseTranslator(const std::size_t k,
const bool k_is_input_col = false);
void eraseTranslator(const std::size_t k, const bool k_is_input_col = false);
/// returns the kth translator
/** @warning this method assumes that there are at least k translators.
......
......@@ -212,21 +212,21 @@ namespace gum {
template < template < template < typename > class > class Translator >
std::size_t DBTranslatorSet< ALLOC >::insertTranslator(
const Translator< ALLOC >& translator,
const std::size_t column,
const bool unique_column) {
const std::size_t column,
const bool unique_column) {
// if the unique_column parameter is set to true and there exists already
// another translator that parses the column, raise a DuplicateElement
// exception
const std::size_t size = __translators.size();
if ( unique_column ) {
if (unique_column) {
for (std::size_t i = std::size_t(0); i < size; ++i) {
if (__columns[i] == column)
GUM_ERROR(DuplicateElement,
"There already exists a DBTranslator that parses Column"
<< column);
<< column);
}
}
// reserve some place for the new translator
__translators.reserve(size + 1);
__columns.reserve(size + 1);
......@@ -287,20 +287,20 @@ namespace gum {
return insertTranslator(translator, column, unique_column);
}
default: GUM_ERROR(NotImplementedYet,
"The insertion of the translator for Variable " <<
var.name() << " is impossible because a translator "
"for such variable is not implemented yet");
default:
GUM_ERROR(NotImplementedYet,
"The insertion of the translator for Variable "
<< var.name()
<< " is impossible because a translator "
"for such variable is not implemented yet");
}
}
/// inserts a new translator for a given variable in the translator set
template < template < typename > class ALLOC >
INLINE std::size_t
DBTranslatorSet< ALLOC >::insertTranslator(const Variable& var,
const std::size_t column,
const bool unique_column) {
INLINE std::size_t DBTranslatorSet< ALLOC >::insertTranslator(
const Variable& var, const std::size_t column, const bool unique_column) {
const std::vector< std::string, ALLOC< std::string > > missing;
return this->insertTranslator(var, column, missing, unique_column);
}
......@@ -311,10 +311,10 @@ namespace gum {
void DBTranslatorSet< ALLOC >::eraseTranslator(const std::size_t k,
const bool k_is_input_col) {
ALLOC< DBTranslator< ALLOC > > allocator(this->getAllocator());
const std::size_t nb_trans = __translators.size();
if ( ! k_is_input_col ) {
if ( nb_trans < k ) return;
const std::size_t nb_trans = __translators.size();
if (!k_is_input_col) {
if (nb_trans < k) return;
// remove the translator and its corresponding column
allocator.destroy(__translators[k]);
......@@ -326,37 +326,34 @@ namespace gum {
// if the highest column index corresponded to the kth translator,
// we must recomput it
if ( __highest_column == colk ) {
if (__highest_column == colk) {
__highest_column = std::size_t(0);
for ( const auto col : __columns )
if ( __highest_column < col )
__highest_column = col;
}
}
else {
for (const auto col : __columns)
if (__highest_column < col) __highest_column = col;
}
} else {
// remove all the translators parsing the kth column
auto iter_trans = __translators.rbegin ();
auto iter_trans = __translators.rbegin();
bool translator_found = false;
for ( auto iter_col = __columns.rbegin();
iter_col != __columns.rend(); ++iter_col, ++iter_trans ) {
if ( *iter_col == k ) {
for (auto iter_col = __columns.rbegin(); iter_col != __columns.rend();
++iter_col, ++iter_trans) {
if (*iter_col == k) {
// remove the translator and its corresponding column
allocator.destroy( *iter_trans );
allocator.deallocate( *iter_trans, 1);