Commit 8f743998 authored by Anne Hommelberg's avatar Anne Hommelberg

Add io accessor for all DataStore access

To avoid method naming conflicts, all access methods for
the internal data store have been put into an self.io accessor.
For example, call self.io.get_times to get the times stored in the
internal data store.
parent f2aae576
......@@ -10,10 +10,11 @@ from rtctools._internal.alias_tools import AliasDict, AliasRelation
logger = logging.getLogger("rtctools")
class DataStore(metaclass=ABCMeta):
class DataStoreAccessor(metaclass=ABCMeta):
"""
Base class for all problems.
Adds an internal data store where which timeseries, parameters and initial states can be stored and read from.
Adds an internal data store where timeseries and parameters can be stored.
Access to the internal data store is always done through the io accessor.
:cvar timeseries_import_basename:
Import file basename. Default is ``timeseries_import``.
......@@ -35,6 +36,33 @@ class DataStore(metaclass=ABCMeta):
logger.debug("Expecting input files to be located in '" + self._input_folder + "'.")
logger.debug("Writing output files to '" + self._output_folder + "'.")
self.io = DataStore(self)
@property
@abstractmethod
def alias_relation(self) -> AliasRelation:
raise NotImplementedError
@property
def initial_time(self) -> float:
"""
The initial time in seconds.
"""
times = self.io.get_times()
if times is None:
raise RuntimeError("Attempting to access initial_time before setting times")
return times[self.io.get_forecast_index()]
class DataStore(metaclass=ABCMeta):
"""
DataStore class used by the DataStoreAccessor.
Contains all methods needed to access the internal data store.
"""
def __init__(self, accessor):
self.__accessor = accessor
# Should all be set by subclass via setters
self.__forecast_index = 0
self.__timeseries_times_sec = None
......@@ -86,7 +114,7 @@ class DataStore(metaclass=ABCMeta):
.format(len(values), len(self.__timeseries_times_sec)))
while ensemble_member >= len(self.__timeseries_values):
self.__timeseries_values.append(AliasDict(self.alias_relation))
self.__timeseries_values.append(AliasDict(self.__accessor.alias_relation))
if check_duplicates and variable in self.__timeseries_values[ensemble_member].keys():
logger.warning("Attempting to set time series values for ensemble member {} and variable {} twice. "
......@@ -151,7 +179,7 @@ class DataStore(metaclass=ABCMeta):
If False, existing values can be silently overwritten with new values.
"""
while ensemble_member >= len(self.__parameters):
self.__parameters.append(AliasDict(self.alias_relation))
self.__parameters.append(AliasDict(self.__accessor.alias_relation))
if check_duplicates and parameter_name in self.__parameters[ensemble_member].keys():
logger.warning("Attempting to set parameter value for ensemble member {} and name {} twice. "
......@@ -178,15 +206,6 @@ class DataStore(metaclass=ABCMeta):
return set()
return self.__parameters[ensemble_member].keys()
@property
def initial_time(self) -> float:
"""
The initial time in seconds.
"""
if self.__timeseries_times_sec is None:
raise RuntimeError("Attempting to access initial_time before setting times")
return self.__timeseries_times_sec[self.__forecast_index]
@staticmethod
def datetime_to_sec(d: Union[Iterable[datetime], datetime], t0: datetime) -> Union[Iterable[float], float]:
"""
......@@ -212,8 +231,3 @@ class DataStore(metaclass=ABCMeta):
return [t0 + timedelta(seconds=t) for t in s]
else:
return t0 + timedelta(seconds=s)
@property
@abstractmethod
def alias_relation(self) -> AliasRelation:
raise NotImplementedError
......@@ -88,14 +88,14 @@ class CSVMixin(IOMixin):
with_time=True,
)
self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]
self.set_times(
self.datetime_to_sec(
self.io.set_times(
self.io.datetime_to_sec(
self.__timeseries_times,
self.__timeseries_times[self.get_forecast_index()]
self.__timeseries_times[self.io.get_forecast_index()]
)
)
for key in _timeseries.dtype.names[1:]:
self.set_timeseries_values(
self.io.set_timeseries_values(
key,
np.asarray(_timeseries[key], dtype=np.float64),
ensemble_member_index
......@@ -107,7 +107,7 @@ class CSVMixin(IOMixin):
_parameters = csv.load(os.path.join(
self._input_folder, ensemble_member_name, 'parameters.csv'), delimiter=self.csv_delimiter)
for key in _parameters.dtype.names:
self.set_parameter(key, float(_parameters[key]), ensemble_member_index)
self.io.set_parameter(key, float(_parameters[key]), ensemble_member_index)
except IOError:
pass
logger.debug("CSVMixin: Read parameters.")
......@@ -131,14 +131,14 @@ class CSVMixin(IOMixin):
with_time=True,
)
self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]
self.set_times(
self.datetime_to_sec(
self.io.set_times(
self.io.datetime_to_sec(
self.__timeseries_times,
self.__timeseries_times[self.get_forecast_index()]
self.__timeseries_times[self.io.get_forecast_index()]
)
)
for key in _timeseries.dtype.names[1:]:
self.set_timeseries_values(key, np.asarray(_timeseries[key], dtype=np.float64))
self.io.set_timeseries_values(key, np.asarray(_timeseries[key], dtype=np.float64))
logger.debug("CSVMixin: Read timeseries.")
try:
......@@ -146,7 +146,7 @@ class CSVMixin(IOMixin):
self._input_folder, 'parameters.csv'), delimiter=self.csv_delimiter)
logger.debug("CSVMixin: Read parameters.")
for key in _parameters.dtype.names:
self.set_parameter(key, float(_parameters[key]))
self.io.set_parameter(key, float(_parameters[key]))
except IOError:
pass
......@@ -160,7 +160,7 @@ class CSVMixin(IOMixin):
_initial_state = {}
self.__initial_state.append(AliasDict(self.alias_relation, _initial_state))
timeseries_times_sec = self.get_times()
timeseries_times_sec = self.io.get_times()
# Timestamp check
if self.csv_validate_timeseries:
......@@ -206,7 +206,7 @@ class CSVMixin(IOMixin):
for parameter in self.dae_variables['parameters']:
parameter = parameter.name()
try:
parameters[parameter] = self.get_parameter(parameter, ensemble_member)
parameters[parameter] = self.io.get_parameter(parameter, ensemble_member)
except KeyError:
pass
else:
......@@ -244,7 +244,7 @@ class CSVMixin(IOMixin):
formats = ['O'] + (len(names) - 1) * ['f8']
dtype = {'names': names, 'formats': formats}
data = np.zeros(len(times), dtype=dtype)
data['time'] = [self.__timeseries_times[self.get_forecast_index()] + timedelta(seconds=s) for s in times]
data['time'] = [self.__timeseries_times[self.io.get_forecast_index()] + timedelta(seconds=s) for s in times]
for output_variable in self.output_variables:
output_variable = output_variable.name()
try:
......
......@@ -60,10 +60,10 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
:param variable:
"""
return self.get_times()[self.get_forecast_index():]
return self.io.get_times()[self.io.get_forecast_index():]
def get_timeseries(self, variable: str, ensemble_member: int = 0) -> Timeseries:
return Timeseries(self.get_times(), self.get_timeseries_values(variable, ensemble_member))
return Timeseries(self.io.get_times(), self.io.get_timeseries_values(variable, ensemble_member))
def set_timeseries(
self,
......@@ -75,7 +75,7 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
def stretch_values(values, t_pos):
# Construct a values range with preceding and possibly following nans
new_values = np.full_like(self.__timeseries_times_sec, np.nan)
new_values = np.full_like(self.io.__timeseries_times_sec, np.nan)
new_values[t_pos:] = values
return new_values
......@@ -88,7 +88,9 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
'different length (lengths of {} and {}, respectively).'
.format(variable, len(timeseries.times), len(timeseries.values)))
if not np.array_equal(self.__timeseries_times_sec, timeseries.times):
timeseries_times_sec = self.io.get_times()
if not np.array_equal(timeseries_times_sec, timeseries.times):
if check_consistency:
raise ValueError(
'IOMixin: Trying to set timeseries {} with different times '
......@@ -101,7 +103,7 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
# import times. For this we assume that both time ranges are ordered,
# and that the times of the added series is a subset of the import
# times.
t_pos = bisect.bisect_left(self.__timeseries_times_sec, timeseries.times[0])
t_pos = bisect.bisect_left(timeseries_times_sec, timeseries.times[0])
# Construct a new values range with length of self.__timeseries_times_sec
values = stretch_values(timeseries.values, t_pos)
......@@ -121,12 +123,12 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
# If times is not supplied with the timeseries, we add the
# forecast times range to a new Timeseries object. Hereby
# we assume that the supplied values stretch from T0 to end.
t_pos = self.get_forecast_index()
t_pos = self.io.get_forecast_index()
# Construct a new values range with length of self.__timeseries_times_sec
values = stretch_values(timeseries, t_pos)
self.set_timeseries_values(variable, values, ensemble_member)
self.io.set_timeseries_values(variable, values, ensemble_member)
def min_timeseries_id(self, variable: str) -> str:
"""
......@@ -149,7 +151,7 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
# Call parent class first for default values.
bounds = super().bounds()
forecast_index = self.get_forecast_index()
forecast_index = self.io.get_forecast_index()
# Load bounds from timeseries
for variable in self.dae_variables['free_variables']:
......@@ -159,7 +161,7 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
timeseries_id = self.min_timeseries_id(variable_name)
try:
m = self.get_timeseries_values(timeseries_id, 0)[forecast_index:]
m = self.io.get_timeseries_values(timeseries_id, 0)[forecast_index:]
except KeyError:
pass
else:
......@@ -168,7 +170,7 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
timeseries_id = self.max_timeseries_id(variable_name)
try:
M = self.get_timeseries_values(timeseries_id, 0)[forecast_index:]
M = self.io.get_timeseries_values(timeseries_id, 0)[forecast_index:]
except KeyError:
pass
else:
......@@ -178,10 +180,10 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
# Replace NaN with +/- inf, and create Timeseries objects
if m is not None:
m[np.isnan(m)] = np.finfo(m.dtype).min
m = Timeseries(self.get_times()[forecast_index:], m)
m = Timeseries(self.io.get_times()[forecast_index:], m)
if M is not None:
M[np.isnan(M)] = np.finfo(M.dtype).max
M = Timeseries(self.get_times()[forecast_index:], M)
M = Timeseries(self.io.get_times()[forecast_index:], M)
# Store
if m is not None or M is not None:
......@@ -193,7 +195,7 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
# Load history
history = AliasDict(self.alias_relation)
end_index = self.get_forecast_index() + 1
end_index = self.io.get_forecast_index() + 1
variable_list = self.dae_variables['states'] + self.dae_variables['algebraics'] + \
self.dae_variables['control_inputs'] + self.dae_variables['constant_inputs']
......@@ -201,8 +203,8 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
variable = variable.name()
try:
history[variable] = Timeseries(
self.get_times()[:end_index],
self.get_timeseries_values(variable, ensemble_member)[:end_index])
self.io.get_times()[:end_index],
self.io.get_timeseries_values(variable, ensemble_member)[:end_index])
except KeyError:
pass
else:
......@@ -220,8 +222,8 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
variable = variable.name()
try:
s = Timeseries(
self.get_times(),
self.get_timeseries_values(variable, ensemble_member)
self.io.get_times(),
self.io.get_timeseries_values(variable, ensemble_member)
)
except KeyError:
pass
......@@ -243,13 +245,13 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
variable = variable.name()
try:
timeseries = Timeseries(
self.get_times(),
self.get_timeseries_values(variable, ensemble_member)
self.io.get_times(),
self.io.get_timeseries_values(variable, ensemble_member)
)
except KeyError:
pass
else:
if np.any(np.isnan(timeseries.values[self.get_forecast_index():])):
if np.any(np.isnan(timeseries.values[self.io.get_forecast_index():])):
raise Exception("IOMixin: Constant input {} contains NaN".format(variable))
constant_inputs[variable] = timeseries
if logger.getEffectiveLevel() == logging.DEBUG:
......@@ -257,7 +259,7 @@ class IOMixin(OptimizationProblem, metaclass=ABCMeta):
return constant_inputs
def timeseries_at(self, variable, t, ensemble_member=0):
return self.interpolate(t, self.get_times(), self.get_timeseries_values(variable, ensemble_member))
return self.interpolate(t, self.io.get_times(), self.io.get_timeseries_values(variable, ensemble_member))
@property
def output_variables(self):
......
......@@ -7,14 +7,14 @@ import casadi as ca
import numpy as np
from rtctools._internal.alias_tools import AliasDict
from rtctools.data.storage import DataStore
from rtctools.data.storage import DataStoreAccessor
from .timeseries import Timeseries
logger = logging.getLogger("rtctools")
class OptimizationProblem(DataStore, metaclass=ABCMeta):
class OptimizationProblem(DataStoreAccessor, metaclass=ABCMeta):
"""
Base class for all optimization problems.
"""
......
import logging
import warnings
from datetime import timedelta
import numpy as np
......@@ -86,13 +87,13 @@ class PIMixin(IOMixin):
parameter = parameter_id
if self.pi_check_for_duplicate_parameters:
if parameter in self.get_parameter_names():
if parameter in self.io.get_parameter_names():
logger.warning(
'PIMixin: parameter {} defined in file {} was already '
'present in another or this parameterConfig file. Using value {}.'.format(
parameter, parameter_config.path, value))
self.set_parameter(parameter, value)
self.io.set_parameter(parameter, value)
try:
self.__timeseries_import = pi.Timeseries(
......@@ -107,11 +108,11 @@ class PIMixin(IOMixin):
binary=self.pi_binary_timeseries, pi_validate_times=False, make_new_file=True)
# Convert timeseries timestamps to seconds since t0 for internal use
timeseries_import_times = np.asarray(self.datetime_to_sec(
timeseries_import_times = np.asarray(self.io.datetime_to_sec(
self.__timeseries_import.times,
self.__timeseries_import.forecast_datetime
))
self.set_times(timeseries_import_times)
self.io.set_times(timeseries_import_times)
# Timestamp check
if self.pi_validate_timeseries:
......@@ -135,10 +136,10 @@ class PIMixin(IOMixin):
# Offer input timeseries to IOMixin
for ensemble_member in range(self.ensemble_size):
for variable, values in self.__timeseries_import.items(ensemble_member):
self.set_timeseries_values(variable, values, ensemble_member)
self.io.set_timeseries_values(variable, values, ensemble_member)
# Set the forecast index to the read index
self.set_forecast_index(self.__timeseries_import.forecast_index)
self.io.set_forecast_index(self.__timeseries_import.forecast_index)
@property
def equidistant(self):
......@@ -167,8 +168,8 @@ class PIMixin(IOMixin):
# Call parent class first for default values.
parameters = super().parameters(ensemble_member)
for parameter in self.get_parameter_names():
parameters[parameter] = self.get_parameter(parameter)
for parameter in self.io.get_parameter_names():
parameters[parameter] = self.io.get_parameter(parameter)
# Done
return parameters
......@@ -269,7 +270,7 @@ class PIMixin(IOMixin):
The time stamps are in seconds since t0, and may be negative.
"""
# todo replace usage of this property with calls to get_times()?
return self.get_times()
return self.io.get_times()
@property
def timeseries_export(self):
......@@ -277,3 +278,11 @@ class PIMixin(IOMixin):
:class:`pi.Timeseries` object for holding the output data.
"""
return self.__timeseries_export
def get_forecast_index(self):
"""
Deprecated, use io.get_forecast_index instead.
"""
warnings.warn('get_forecast_index() is deprecated and will be removed in the future, '
'use io.get_forecast_index() instead.', FutureWarning)
return self.io.get_forecast_index()
......@@ -44,7 +44,7 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
def initialize(self, config_file=None):
# Set up experiment
timeseries_import_times = self.get_times()
timeseries_import_times = self.io.get_times()
self.__dt = timeseries_import_times[1] - timeseries_import_times[0]
self.setup_experiment(0, timeseries_import_times[-1], self.__dt)
......@@ -52,9 +52,9 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
logger.debug("Model parameters are {}".format(parameter_variables))
for parameter in self.get_parameter_names():
for parameter in self.io.get_parameter_names():
if parameter in parameter_variables:
value = self.get_parameter(parameter)
value = self.io.get_parameter(parameter)
logger.debug("IOMixin: Setting parameter {} = {}".format(parameter, value))
self.set_var(parameter, value)
......@@ -62,13 +62,13 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
self.__input_variables = set(self.get_input_variables().keys())
# Set input values
self.__set_input_variables(self.get_forecast_index())
self.__set_input_variables(self.io.get_forecast_index())
logger.debug("Model inputs are {}".format(self.__input_variables))
# Empty output
self.__output_variables = self.get_output_variables()
n_times = len(self.get_times())
n_times = len(self.io.get_times())
self.__output = AliasDict(self.alias_relation)
self.__output.update({variable: np.full(n_times, np.nan) for variable in self.__output_variables})
......@@ -77,12 +77,12 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
# Extract consistent t0 values
for variable in self.__output_variables:
self.__output[variable][self.get_forecast_index()] = self.get_var(variable)
self.__output[variable][self.io.get_forecast_index()] = self.get_var(variable)
def __set_input_variables(self, t_idx):
for variable in self.get_variables():
if variable in self.__input_variables:
value = self.get_timeseries_values(variable)[t_idx]
value = self.io.get_timeseries_values(variable)[t_idx]
if np.isfinite(value):
self.set_var(variable, value)
else:
......@@ -98,7 +98,7 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
t = self.get_current_time()
# Get current time index
t_idx = bisect.bisect_left(self.get_times(), t + dt)
t_idx = bisect.bisect_left(self.io.get_times(), t + dt)
# Set input values
self.__set_input_variables(t_idx)
......@@ -129,11 +129,11 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
parameters = super().parameters()
# Load parameters from input files (stored in internal data store)
for parameter_name in self.get_parameter_names():
parameters[parameter_name] = self.get_parameter(parameter_name)
for parameter_name in self.io.get_parameter_names():
parameters[parameter_name] = self.io.get_parameter(parameter_name)
if logger.getEffectiveLevel() == logging.DEBUG:
for parameter_name in self.get_parameter_names():
for parameter_name in self.io.get_parameter_names():
logger.debug("IOMixin: Read parameter {}".format(parameter_name))
return parameters
......@@ -146,7 +146,7 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
:returns: List of all the timesteps in seconds.
"""
return self.get_times()[self.get_forecast_index():]
return self.io.get_times()[self.io.get_forecast_index():]
def timeseries_at(self, variable, t):
"""
......@@ -159,8 +159,8 @@ class IOMixin(SimulationProblem, metaclass=ABCMeta):
:raises: KeyError
"""
values = self.get_timeseries_values(variable)
timeseries_times_sec = self.get_times()
values = self.io.get_timeseries_values(variable)
timeseries_times_sec = self.io.get_times()
t_idx = bisect.bisect_left(timeseries_times_sec, t)
if timeseries_times_sec[t_idx] == t:
return values[t_idx]
......
......@@ -69,7 +69,7 @@ class PIMixin(IOMixin):
parameter = self.__data_config.parameter(parameter_id, location_id, model_id)
except KeyError:
parameter = parameter_id
self.set_parameter(parameter, value)
self.io.set_parameter(parameter, value)
try:
self.__timeseries_import = pi.Timeseries(
......@@ -84,12 +84,12 @@ class PIMixin(IOMixin):
binary=self.pi_binary_timeseries, pi_validate_times=False, make_new_file=True)
# Convert timeseries timestamps to seconds since t0 for internal use
self.set_forecast_index(self.__timeseries_import.forecast_index)
timeseries_import_times = np.asarray(self.datetime_to_sec(
self.io.set_forecast_index(self.__timeseries_import.forecast_index)
timeseries_import_times = np.asarray(self.io.datetime_to_sec(
self.__timeseries_import.times,
self.__timeseries_import.forecast_datetime
))
self.set_times(timeseries_import_times)
self.io.set_times(timeseries_import_times)
# Timestamp check
if self.pi_validate_timeseries:
......@@ -112,7 +112,7 @@ class PIMixin(IOMixin):
# Stick timeseries into an AliasDict
debug = logger.getEffectiveLevel() == logging.DEBUG
for variable, values in self.__timeseries_import.items():
self.set_timeseries_values(variable, values, check_duplicates=False)
self.io.set_timeseries_values(variable, values, check_duplicates=False)
if debug and variable in self.get_variables():
logger.debug('PIMixin: Timeseries {} replaced another aliased timeseries.'.format(variable))
......@@ -122,7 +122,7 @@ class PIMixin(IOMixin):
# Start of write output
# Write the time range for the export file.
self.__timeseries_export.times = self.__timeseries_import.times[self.get_forecast_index():]
self.__timeseries_export.times = self.__timeseries_import.times[self.io.get_forecast_index():]
# Write other time settings
self.__timeseries_export.forecast_datetime = self.__timeseries_import.forecast_datetime
......@@ -166,7 +166,7 @@ class PIMixin(IOMixin):
The time stamps are in seconds since t0, and may be negative.
"""
return self.get_times()
return self.io.get_times()
@property
def timeseries_export(self):
......@@ -198,10 +198,10 @@ class PIMixin(IOMixin):
self.__timeseries_export.set(variable, values, unit=unit)
self.__timeseries_import.set(variable, values, unit=unit)
self.set_timeseries_values(variable, values)
self.io.set_timeseries_values(variable, values)
def get_timeseries(self, variable):
return self.get_timeseries_values(variable)
return self.io.get_timeseries_values(variable)
def extract_results(self):
"""
......
......@@ -15,12 +15,12 @@ import pymoca.backends.casadi.api
from rtctools._internal.alias_tools import AliasDict, AliasRelation
from rtctools._internal.caching import cached
from rtctools.data.storage import DataStore
from rtctools.data.storage import DataStoreAccessor
logger = logging.getLogger("rtctools")
class SimulationProblem(DataStore):
class SimulationProblem(DataStoreAccessor):
"""
Implements the `BMI <http://csdms.colorado.edu/wiki/BMI_Description>`_ Interface.
......
......@@ -5,19 +5,13 @@ import numpy as np
from pymoca.backends.casadi.alias_relation import AliasRelation
from rtctools.data.storage import DataStore
from rtctools.data.storage import DataStoreAccessor
logger = logging.getLogger("rtctools")
logger.setLevel(logging.WARNING)
class DummyDataStore(DataStore):
def read(self):
pass
def write(self):
pass
class DummyDataStore(DataStoreAccessor):
@property
def alias_relation(self):
return AliasRelation()
......@@ -31,88 +25,88 @@ class TestDummyDataStore(TestCase):
def test_times(self):
expected_times = np.array([-7200, -3600, 0, 3600, 7200, 9800])
self.datastore.set_times(expected_times)
actual_times = self.datastore.get_times()
self.datastore.io.set_times(expected_times)
actual_times = self.datastore.io.get_times()
self.assertTrue(np.array_equal(actual_times, expected_times))
def test_forecast_index(self):
forecast_index = self.datastore.get_forecast_index()
forecast_index = self.datastore.io.get_forecast_index()
self.assertEqual(forecast_index, 0) # default forecast_index should be 0
times = np.array([-7200, -3600, 0, 3600, 7200, 9800])
self.datastore.set_times(times)
self.datastore.io.set_times(times)
initial_time = self.datastore.initial_time
self.assertEqual(initial_time, -7200)
self.datastore.set_forecast_index(3)
self.assertEqual(self.datastore.get_forecast_index(), 3)
self.datastore.io.set_forecast_index(3)
self.assertEqual(self.datastore.io.get_forecast_index(), 3)
self.assertEqual(self.datastore.initial_time, 3600)
def test_timeseries(self):
# expect a KeyError when getting a timeseries that has not been set
with self.assertRaises(KeyError):
self.datastore.get_timeseries_values('someNoneExistentVariable')
self.datastore.io.get_timeseries_values('someNoneExistentVariable')
# expect a RunTimeError when setting timeseries values before setting times
with self.assertRaises(RuntimeError):
self.datastore.set_timeseries_values('myNewVariable', np.array([3.1, 2.4, 2.5]))
self.datastore.io.set_timeseries_values('myNewVariable', np.array([3.1, 2.4, 2.5]))
self.datastore.set_times(np.array([-3600, 0, 7200]))
self.datastore.io.set_times(np.array([-3600, 0, 7200]))
expected_values = np.array([3.1, 2.4, 2.5])
self.datastore.set_timeseries_values('myNewVariable', expected_values)
actual_values = self.datastore.get_timeseries_values('myNewVariable')
self.datastore.io.set_timeseries_values('myNewVariable', expected_values)
actual_values = self.datastore.io.get_timeseries_values('myNewVariable')
self.assertTrue(np.array_equal(actual_values, expected_values))
# expect a KeyError when getting timeseries for an ensemble member that doesn't exist
with self.assertRaises(KeyError):
self.datastore.get_timeseries_values('myNewVariable', 1)
self.datastore.io.get_timeseries_values('myNewVariable', 1)
expected_values = np.array([1.1, 1.4, 1.5])
self.datastore.set_timeseries_values('ensembleVariable', expected_values, ensemble_member=1)
self.datastore.io.set_timeseries_values('ensembleVariable', expected_values, ensemble_member=1)
with self.assertRaises(KeyError):
self.datastore.get_timeseries_values('ensembleVariable', 0)
self.assertTrue(np.array_equal(self.datastore.get_timeseries_values('ensembleVariable', 1), expected_values))
self.datastore.io.get_timeseries_values('ensembleVariable', 0)
self.assertTrue(np.array_equal(self.datastore.io.get_timeseries_values('ensembleVariable', 1), expected_values))
# expect a warning when overwriting a timeseries with check_duplicates=True (default)
new_values = np.array([2.1, 1.1, 0.1])
with self.assertLogs(logger, level='WARN') as cm:
self.datastore.set_timeseries_values('myNewVariable', new_values)
self.datastore.io.set_timeseries_values('myNewVariable', new_values)
self.assertEqual(cm.output,
['WARNING:rtctools:Attempting to set time series values for ensemble member 0 '
'and variable myNewVariable twice. Ignoring second set of values.'])
self.assertFalse(np.array_equal(self.datastore.get_timeseries_values('myNewVariable'), new_values))
self.assertFalse(np.array_equal(self.datastore.io.get_timeseries_values('myNewVariable'), new_values))
# disable check to allow overwriting old values
self.datastore.set_timeseries_values('myNewVariable', new_values, check_duplicates=False)
self.assertTrue(np.array_equal(self.datastore.get_timeseries_values('myNewVariable'), new_values))
self.datastore.io.set_timeseries_values('myNewVariable', new_values, check_duplicates=False)
self.assertTrue(np.array_equal(self.datastore.io.get_timeseries_values('myNewVariable'), new_values))
def test_parameters(self):
# expect a KeyError when getting a parameter that has not been set
with self.assertRaises(KeyError):
self.datastore.get_parameter('someNoneExistentParameter')
self.datastore.io.get_parameter('someNoneExistentParameter')
self.datastore.set_parameter('myNewParameter', 1.4)
self.assertEqual(self.datastore.get_parameter('myNewParameter'), 1.4)
self.datastore.io.set_parameter('myNewParameter', 1.4)
self.assertEqual(self.datastore.io.get_parameter('myNewParameter'), 1.4)
# expect a KeyError when getting parameters for an ensemble member that doesn't exist
with self.assertRaises(KeyError):
self.datastore.get_parameter('myNewParameter', 1)
self.datastore.io.get_parameter('myNewParameter', 1)
self.datastore.set_parameter('ensembleParameter', 1.2, ensemble_member=1)
self.datastore.io.set_parameter('ensembleParameter', 1.2, ensemble_member=1)
with self.assertRaises(KeyError):
self.datastore.get_parameter('ensembleParameter', 0)
self.assertEqual(self.datastore.get_parameter('ensembleParameter', 1), 1.2)