...
  View open merge request
Commits (6)
  • Anne Hommelberg's avatar
    Add general DataStore and IOMixin classes · 00c7de3f
    Anne Hommelberg authored
    Large refactoring of all IO mixins (PIMixin and CSVMixin).
    Adds a general DataStore class extended by OptimizationProblem
    and SimulationProblem, which is used to store all data read by
    the IO mixins.
    Adds an optimization IOMixin class which contains methods that
    were previously duplicated in the optimization PIMixin and
    CSVMixin.
    Adds a simulation IOMixin class which does the same for the
    simulation PIMixin and IOMixin.
    00c7de3f
  • Anne Hommelberg's avatar
    Fix unit tests · f2aae576
    Anne Hommelberg authored
    The DataStore should allow the input and output folder to not be
    specified, in case the user chooses to not use any of the provided
    IOMixins.
    f2aae576
  • Anne Hommelberg's avatar
    Add io accessor for all DataStore access · 8f743998
    Anne Hommelberg authored
    To avoid method naming conflicts, all access methods for
    the internal data store have been put into an self.io accessor.
    For example, call self.io.get_times to get the times stored in the
    internal data store.
    8f743998
  • Anne Hommelberg's avatar
    Fix duplicate checks in DataStore · bc0b8320
    Anne Hommelberg authored
    To keep backwards compatibility with the old duplicate parameters
    check in PIMixin, the DataStore should overwrite the old values with
    new values when duplicates occur. If check_duplicates is True, a
    warning will be given each time this happens.
    bc0b8320
  • Anne Hommelberg's avatar
    Add unit tests for new io classes · eb5e4b7f
    Anne Hommelberg authored
    Adds unit tests for the two IOMixin and DataStore classes.
    Also fixes some minor bugs in the set_timeseries method.
    Also adds the get_parameter_ensemble_size method to make parameter
    access in the DataStore consistent with access to the stored time
    series.
    eb5e4b7f
  • Anne Hommelberg's avatar
    Add NetCDFMixin · b8da58b3
    Anne Hommelberg authored
    Adds a NetCDFMixin to import and export data to
    and from NetCDF files.
    b8da58b3
import os
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Iterable, List, Set, Union
from netCDF4 import Dataset, Variable, chartostring, num2date
import numpy as np
class Stations:
def __init__(self, dataset: Dataset, station_variable: Variable):
self.__station_variable = station_variable
station_dimension = station_variable.dimensions[0]
# todo make this a bit smarter, right now variables like station_name would be forgotten
self.__attribute_variables = {}
for variable_name in dataset.variables:
variable = dataset.variables[variable_name]
if variable != station_variable and variable.dimensions == (station_dimension,):
self.__attribute_variables[variable_name] = variable
self.__attributes = OrderedDict()
for i in range(station_variable.shape[0]):
id = str(chartostring(station_variable[i]))
values = {}
for variable_name in self.__attribute_variables.keys():
values[variable_name] = dataset.variables[variable_name][i]
self.__attributes[id] = values
@property
def station_ids(self) -> Iterable:
"""
:return: An ordered iterable of the station ids (location ids) for which station data is available.
"""
return self.__attributes.keys()
@property
def attributes(self) -> OrderedDict:
"""
:return: An OrderedDict containing dicts containing the values for all station attributes of the input dataset.
"""
return self.__attributes
@property
def attribute_variables(self) -> dict:
"""
:return: A dict containing the station attribute variables of the input dataset.
"""
return self.__attribute_variables
class ImportDataset:
"""
A class used to open and import the data from a NetCDF file.
Uses the NetCDF4 library. Contains various methods for reading the data in the file.
"""
def __init__(self, folder: str, basename: str):
"""
:param folder: Folder the file is located in.
:param basename: Basename of the file, extension ".nc" will be appended to this
"""
# Load the content of a NetCDF file into a Dataset.
self.__filename = os.path.join(
folder,
basename + ".nc"
)
self.__dataset = Dataset(self.__filename)
# Find the time and station id variables
self.__time_variable = self.__find_time_variable()
if self.__time_variable is None:
raise Exception('No time variable found in file ' + self.__filename + '. '
'Please ensure the file contains a time variable with standard_name "time" and axis "T".')
self.__station_variable = self.__find_station_variable()
if self.__station_variable is None:
raise Exception('No station variable found in file ' + self.__filename + '. '
'Please ensure the file contains a variable with cf_role "timeseries_id".')
def __str__(self):
return self.__filename
def __find_time_variable(self) -> Union[Variable, None]:
"""
Find the variable containing the times in the given Dataset.
:param dataset: The Dataset to be searched.
:return: a netCDF4.Variable object of the time variable (or None if none found)
"""
for variable in self.__dataset.variables.values():
if ('standard_name' in variable.ncattrs() and 'axis' in variable.ncattrs()
and variable.standard_name == 'time' and variable.axis == 'T'):
return variable
return None
def __find_station_variable(self) -> Union[Variable, None]:
"""
Find the variable containing station id's (location id's) in the given Dataset.
:param dataset: The Dataset to be searched.
:return: a netCDF4.Variable object of the station id variable (or None if none found)
"""
for variable in self.__dataset.variables.values():
if 'cf_role' in variable.ncattrs() and variable.cf_role == 'timeseries_id':
return variable
return None
def read_import_times(self) -> np.ndarray:
"""
Reads the import times in the time variable of the dataset.
:param time_variable: The time variable containing input times
:return: an array containing the input times as datetime objects
"""
time_values = self.__time_variable[:]
time_unit = self.__time_variable.units
try:
time_calendar = self.__time_variable.calendar
except AttributeError:
time_calendar = u'gregorian'
return num2date(time_values, units=time_unit, calendar=time_calendar)
def read_station_data(self) -> Stations:
return Stations(self.__dataset, self.__station_variable)
def find_timeseries_variables(self) -> List[str]:
"""
Find the keys of all 2d variables with dimensions (station, time) or (time, station),
where station is the dimension of the station_variable and time the dimension of the time_variable.
:param dataset: The Dataset to be searched.
:param station_variable: The station id variable.
:param time_variable: The time variable.
:return: a list of strings containing all keys found.
"""
station_dim = self.__station_variable.dimensions[0]
time_dim = self.__time_variable.dimensions[0]
expected_dims = [(station_dim, time_dim), (time_dim, station_dim)]
timeseries_variables = []
for var_key, variable in self.__dataset.variables.items():
if variable.dimensions in expected_dims:
timeseries_variables.append(var_key)
return timeseries_variables
def read_timeseries_values(self, station_index: int, variable_name: str) -> np.ndarray:
"""
Reads the specified timeseries from the input file.
:param station_index: The index of the station for which the values should be read
:param variable_name: The name of the variable for which the values should be read
:return: an array of values
"""
station_dim = self.__station_variable.dimensions[0]
timeseries_variable = self.__dataset.variables[variable_name]
assert timeseries_variable.dimensions[0] == station_dim or timeseries_variable.dimensions[1] == station_dim
if timeseries_variable.dimensions[0] == station_dim:
values = timeseries_variable[station_index, :]
else:
values = timeseries_variable[:, station_index]
# NetCDF4 reads the values as a numpy masked array,
# convert to a normal array with nan where mask == True
return np.ma.filled(values, np.nan)
@property
def time_variable(self):
return self.__time_variable
@property
def station_variable(self):
return self.__station_variable
class ExportDataset:
"""
A class used to write data to a NetCDF file.
Creates a new file or overwrites an old file. The file metadata will be written upon initialization. Data such
as times, station data and timeseries data should be presented to the ExportDataset through the various methods.
When all data has been written, the close method must be called to flush the changes from local memory to the
actual file on disk.
"""
def __init__(self, folder: str, basename: str):
"""
:param folder: Folder the file will be located in.
:param basename: Basename of the file, extension ".nc" will be appended to this
"""
# Create the file and open a Dataset to access it
self.__filename = os.path.join(
folder,
basename + ".nc"
)
# use same write format as FEWS
self.__dataset = Dataset(self.__filename, mode='w', format='NETCDF3_CLASSIC')
# write metadata to the file
self.__dataset.title = 'RTC-Tools Output Data'
self.__dataset.institution = 'Deltares'
self.__dataset.source = 'RTC-Tools'
self.__dataset.history = 'Generated on {}'.format(datetime.now())
self.__dataset.Conventions = 'CF-1.6'
self.__dataset.featureType = 'timeseries'
# dimensions are created when writing times and station data, must be created before writing variables
self.__time_dim = None
self.__station_dim = None
self.__station_id_to_index_mapping = None
self.__timeseries_variables = {}
def __str__(self):
return self.__filename
def write_times(self, times: np.ndarray, forecast_time: float, forecast_date: datetime) -> None:
"""
Writes a time variable to the given dataset.
:param dataset: The NetCDF4.Dataset object that the times will be written to (must have write permission)
:param times: The times that are to be written in seconds.
:param forecast_time: The forecast time in seconds corresponding to the forecast date
:param forecast_date: The datetime corresponding with time in seconds at the forecast index.
"""
# in a NetCDF file times are written with respect to a reference date
# the written values for the times may never be negative, so use the earliest time as the reference date
reference_date = forecast_date
minimum_time = np.min(times)
if minimum_time < 0:
times = times - minimum_time
reference_date = reference_date - timedelta(seconds=forecast_time - minimum_time)
self.__time_dim = self.__dataset.createDimension('time', None)
time_var = self.__dataset.createVariable('time', 'f8', ('time',))
time_var.standard_name = 'time'
time_var.units = 'seconds since {}'.format(reference_date)
time_var.axis = 'T'
time_var[:] = times
def write_station_data(self, stations: Stations, output_station_ids: Set[str]) -> None:
"""
Writes the station ids and additional station information to the given dataset.
:param stations: The stations data read from the input file.
:param output_station_ids: The set of station ids for which output will be written.
"""
self.__station_dim = self.__dataset.createDimension('station', len(output_station_ids))
# first write the ids
max_id_length = max(len(id) for id in output_station_ids)
self.__dataset.createDimension('char_leng_id', max_id_length)
station_id_var = self.__dataset.createVariable('station_id', 'c', ('station', 'char_leng_id'))
station_id_var.long_name = 'station identification code'
station_id_var.cf_role = 'timeseries_id'
# we must store the index we use for each station id, to be able to write the data at the correct index later
self.__station_id_to_index_mapping = {}
for i, id in enumerate(output_station_ids):
station_id_var[i, :] = list(id)
self.__station_id_to_index_mapping[id] = i
# now write the stored attributes
for var_name, attr_var in stations.attribute_variables.items():
variable = self.__dataset.createVariable(var_name, attr_var.datatype, ('station',))
# copy all attributes from the original input variable
variable.setncatts(attr_var.__dict__)
for station_id in output_station_ids:
if station_id in stations.attributes:
station_index = self.__station_id_to_index_mapping[station_id]
variable[station_index] = stations.attributes[station_id][var_name]
def create_variables(self, variable_names: Set[str]) -> None:
"""
Creates variables in the dataset for each of the provided parameter ids.
The write_times and write_station_data methods must be called first, to ensure the necessary dimensions have
already been created in the output NetCDF file.
:param variable_names: The parameter ids for which variables must be created.
"""
assert self.__time_dim is not None, 'First call write_times to ensure the time dimension has been created.'
assert self.__station_dim is not None, 'First call write_station_data to ensure ' \
'the station dimension has been created'
assert self.__station_id_to_index_mapping is not None # should also be created in write_station_data
for variable_name in variable_names:
self.__dataset.createVariable(variable_name, 'f8', ('time', 'station'), fill_value=np.nan)
def write_output_values(self, station_id: str, variable_name: str, values: np.ndarray) -> None:
"""
Writes the given data to the dataset. The variable must have already been created through the
create_variables method. After all calls to write_output_values, the close method must be called to flush all
changes.
:param station_id: The id of the station the data is written for.
:param variable_name: The name of the variable the data is written to (must have already been created).
:param values: The values that are to be written to the file
"""
assert self.__station_id_to_index_mapping is not None, 'First call write_station_data and create_variables.'
station_index = self.__station_id_to_index_mapping[station_id]
self.__dataset.variables[variable_name][:, station_index] = values
def close(self) -> None:
"""
Closes the NetCDF4 Dataset to ensure all changes made are written to the file.
This method must be called after writing all data through the various write method.
"""
self.__dataset.close()
import logging
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import Iterable, Set, Union
import numpy as np
from rtctools._internal.alias_tools import AliasDict, AliasRelation
logger = logging.getLogger("rtctools")
class DataStoreAccessor(metaclass=ABCMeta):
"""
Base class for all problems.
Adds an internal data store where timeseries and parameters can be stored.
Access to the internal data store is always done through the io accessor.
:cvar timeseries_import_basename:
Import file basename. Default is ``timeseries_import``.
:cvar timeseries_export_basename:
Export file basename. Default is ``timeseries_export``.
"""
#: Import file basename
timeseries_import_basename = 'timeseries_import'
#: Export file basename
timeseries_export_basename = 'timeseries_export'
def __init__(self, **kwargs):
# Save arguments
self._input_folder = kwargs['input_folder'] if 'input_folder' in kwargs else 'input'
self._output_folder = kwargs['output_folder'] if 'output_folder' in kwargs else 'output'
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("Expecting input files to be located in '" + self._input_folder + "'.")
logger.debug("Writing output files to '" + self._output_folder + "'.")
self.io = DataStore(self)
@property
@abstractmethod
def alias_relation(self) -> AliasRelation:
raise NotImplementedError
@property
def initial_time(self) -> float:
"""
The initial time in seconds.
"""
times = self.io.get_times()
if times is None:
raise RuntimeError("Attempting to access initial_time before setting times")
return times[self.io.get_forecast_index()]
class DataStore(metaclass=ABCMeta):
"""
DataStore class used by the DataStoreAccessor.
Contains all methods needed to access the internal data store.
"""
def __init__(self, accessor):
self.__accessor = accessor
# Should all be set by subclass via setters
self.__forecast_index = 0
self.__timeseries_times_sec = None
self.__timeseries_values = []
self.__parameters = []
# todo add support for storing initial states
# self.__initial_state = []
def get_times(self) -> np.ndarray:
""""
Returns the timeseries times in seconds.
:return timseries times in seconds, or None if there has been no call to set_times
"""
return self.__timeseries_times_sec
def set_times(self, times_in_sec: np.ndarray) -> None:
"""
Sets the timeseries times in seconds in the internal data store.
Must be called in .read() to store the times in the IOMixin before calling set_timeseries_values
to store the values for an input timeseries.
:param times_in_sec: np.ndarray containing the times in seconds
"""
if self.__timeseries_times_sec is not None and not np.array_equal(times_in_sec, self.__timeseries_times_sec):
raise RuntimeError("Attempting to overwrite the input time series times with different values. "
"Please ensure all input time series have the same times.")
self.__timeseries_times_sec = times_in_sec
def set_timeseries_values(self,
variable: str,
values: np.ndarray,
ensemble_member: int = 0,
check_duplicates: bool = True) -> None:
"""
Stores input time series values in the internal data store.
:param variable: Variable name.
:param values: The values to be stored.
:param ensemble_member: The ensemble member index.
:param check_duplicates: If True, a warning will be given when overwriting values.
If False, existing values can be silently overwritten with new values.
"""
if self.__timeseries_times_sec is None:
raise RuntimeError("First call set_times before calling set_timeseries_values")
if len(self.__timeseries_times_sec) != len(values):
raise ValueError("Length of values ({}) must be the same as length of times ({})"
.format(len(values), len(self.__timeseries_times_sec)))
while ensemble_member >= len(self.__timeseries_values):
self.__timeseries_values.append(AliasDict(self.__accessor.alias_relation))
if check_duplicates and variable in self.__timeseries_values[ensemble_member].keys():
logger.warning("Time series values for ensemble member {} and variable {} set twice. "
"Overwriting old values.".format(ensemble_member, variable))
self.__timeseries_values[ensemble_member][variable] = values
def get_timeseries_values(self, variable: str, ensemble_member: int = 0) -> np.ndarray:
"""
Looks up the time series values in the internal data store.
"""
if ensemble_member >= len(self.__timeseries_values):
raise KeyError("ensemble_member {} does not exist".format(ensemble_member))
return self.__timeseries_values[ensemble_member][variable]
def get_variables(self, ensemble_member: int = 0) -> Set:
"""
Returns a set of variables for which timeseries values are stored in the internal data store
:param ensemble_member: The ensemble member index.
"""
if ensemble_member >= len(self.__timeseries_values):
return set()
return self.__timeseries_values[ensemble_member].keys()
def get_ensemble_size(self):
"""
Returns the number of ensemble members for which timeseries are stored in the internal data store
"""
return len(self.__timeseries_values)
def get_forecast_index(self) -> int:
""""
Looks up the forecast index from the internal data store
:return: Current forecast index, values before this index will be considered "history".
"""
return self.__forecast_index
def set_forecast_index(self, forecast_index: int) -> None:
"""
Sets the forecast index in the internal data store.
Values (and times) before this index will be considered "history"
:param forecast_index: New forecast index.
"""
self.__forecast_index = forecast_index
def set_parameter(self,
parameter_name: str,
value: float,
ensemble_member: int = 0,
check_duplicates: bool = True) -> None:
"""
Stores the parameter value in the internal data store.
:param parameter_name: Parameter name.
:param value: The values to be stored.
:param ensemble_member: The ensemble member index.
:param check_duplicates: If True, a warning will be given when overwriting values.
If False, existing values can be silently overwritten with new values.
"""
while ensemble_member >= len(self.__parameters):
self.__parameters.append(AliasDict(self.__accessor.alias_relation))
if check_duplicates and parameter_name in self.__parameters[ensemble_member].keys():
logger.warning("Attempting to set parameter value for ensemble member {} and name {} twice. "
"Using new value of {}.".format(ensemble_member, parameter_name, value))
self.__parameters[ensemble_member][parameter_name] = value
def get_parameter(self, parameter_name: str, ensemble_member: int = 0) -> float:
"""
Looks up the parameter value in the internal data store.
"""
if ensemble_member >= len(self.__parameters):
raise KeyError("ensemble_member {} does not exist".format(ensemble_member))
return self.__parameters[ensemble_member][parameter_name]
def get_parameter_names(self, ensemble_member: int = 0) -> Set:
"""
Returns a set of variables for which timeseries values are stored in the internal data store
:param ensemble_member: The ensemble member index.
"""
if ensemble_member >= len(self.__parameters):
return set()
return self.__parameters[ensemble_member].keys()
def get_parameter_ensemble_size(self):
"""
Returns the number of ensemble members for which parameters are stored in the internal data store
"""
return len(self.__parameters)
@staticmethod
def datetime_to_sec(d: Union[Iterable[datetime], datetime], t0: datetime) -> Union[Iterable[float], float]:
"""
Returns the date/timestamps in seconds since t0.
:param d: Iterable of datetimes or a single datetime object.
:param t0: Reference datetime.
"""
if hasattr(d, '__iter__'):
return np.array([(t - t0).total_seconds() for t in d])
else:
return (d - t0).total_seconds()
@staticmethod
def sec_to_datetime(s: Union[Iterable[float], float], t0: datetime) -> Union[Iterable[datetime], datetime]:
"""
Return the date/timestamps in seconds since t0 as datetime objects.
:param s: Iterable of ints or a single int (number of seconds before or after t0).
:param t0: Reference datetime.
"""
if hasattr(s, '__iter__'):
return [t0 + timedelta(seconds=t) for t in s]
else:
return t0 + timedelta(seconds=s)
import itertools
import logging
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
import casadi as ca
......@@ -78,17 +78,6 @@ class CollocatedIntegratedOptimizationProblem(OptimizationProblem, metaclass=ABC
# Call super
super().__init__(**kwargs)
@abstractmethod
def times(self, variable=None):
"""
List of time stamps for variable.
:param variable: Variable name.
:returns: A list of time stamps for the given variable.
"""
pass
def interpolation_method(self, variable=None):
"""
Interpolation method for variable.
......
......@@ -2,21 +2,17 @@ import logging
import os
from datetime import timedelta
import casadi as ca
import numpy as np
import rtctools.data.csv as csv
from rtctools._internal.alias_tools import AliasDict
from rtctools._internal.caching import cached
from .optimization_problem import OptimizationProblem
from .timeseries import Timeseries
from rtctools.optimization.io_mixin import IOMixin
logger = logging.getLogger("rtctools")
class CSVMixin(OptimizationProblem):
class CSVMixin(IOMixin):
"""
Adds reading and writing of CSV timeseries and parameters to your optimization problem.
......@@ -38,10 +34,6 @@ class CSVMixin(OptimizationProblem):
Whether or not to use ensembles. Default is ``False``.
:cvar csv_validate_timeseries:
Check consistency of timeseries. Default is ``True``.
:cvar timeseries_import_basename:
Import file basename. Default is ``timeseries_import``.
:cvar timeseries_export_basename:
Export file basename. Default is ``timeseries_export``.
"""
#: Column delimiter used in CSV files
......@@ -56,29 +48,13 @@ class CSVMixin(OptimizationProblem):
#: Check consistency of timeseries
csv_validate_timeseries = True
#: Import file basename
timeseries_import_basename = "timeseries_import"
#: Export file basename
timeseries_export_basename = "timeseries_export"
def __init__(self, **kwargs):
# Check arguments
assert('input_folder' in kwargs)
assert('output_folder' in kwargs)
# Save arguments
self.__input_folder = kwargs['input_folder']
self.__output_folder = kwargs['output_folder']
# Additional output variables
self.__output_timeseries = set()
# Call parent class first for default behaviour.
super().__init__(**kwargs)
def pre(self):
def read(self):
# Call parent class first for default behaviour.
super().pre()
super().read()
# Helper function to check if initial state array actually defines
# only the initial state
......@@ -90,23 +66,21 @@ class CSVMixin(OptimizationProblem):
raise Exception(
'CSVMixin: Initial state file {} contains more than one row of data. '
'Please remove the data row(s) that do not describe the initial state.'.format(
os.path.join(self.__input_folder, 'initial_state.csv')))
os.path.join(self._input_folder, 'initial_state.csv')))
# Read CSV files
self.__timeseries = []
self.__parameters = []
self.__initial_state = []
if self.csv_ensemble_mode:
self.__ensemble = np.genfromtxt(
os.path.join(self.__input_folder, 'ensemble.csv'),
os.path.join(self._input_folder, 'ensemble.csv'),
delimiter=",", deletechars='', dtype=None, names=True, encoding=None)
logger.debug("CSVMixin: Read ensemble description")
for ensemble_member_name in self.__ensemble['name']:
for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble['name']):
_timeseries = csv.load(
os.path.join(
self.__input_folder,
self._input_folder,
ensemble_member_name,
self.timeseries_import_basename + ".csv",
),
......@@ -114,26 +88,34 @@ class CSVMixin(OptimizationProblem):
with_time=True,
)
self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]
self.__timeseries.append(
AliasDict(
self.alias_relation,
{key: np.asarray(_timeseries[key], dtype=np.float64) for key in _timeseries.dtype.names[1:]}))
self.io.set_times(
self.io.datetime_to_sec(
self.__timeseries_times,
self.__timeseries_times[self.io.get_forecast_index()]
)
)
for key in _timeseries.dtype.names[1:]:
self.io.set_timeseries_values(
key,
np.asarray(_timeseries[key], dtype=np.float64),
ensemble_member_index
)
logger.debug("CSVMixin: Read timeseries")
for ensemble_member_name in self.__ensemble['name']:
for ensemble_member_index, ensemble_member_name in enumerate(self.__ensemble['name']):
try:
_parameters = csv.load(os.path.join(
self.__input_folder, ensemble_member_name, 'parameters.csv'), delimiter=self.csv_delimiter)
_parameters = {key: float(_parameters[key]) for key in _parameters.dtype.names}
self._input_folder, ensemble_member_name, 'parameters.csv'), delimiter=self.csv_delimiter)
for key in _parameters.dtype.names:
self.io.set_parameter(key, float(_parameters[key]), ensemble_member_index)
except IOError:
_parameters = {}
self.__parameters.append(AliasDict(self.alias_relation, _parameters))
pass
logger.debug("CSVMixin: Read parameters.")
for ensemble_member_name in self.__ensemble['name']:
try:
_initial_state = csv.load(os.path.join(
self.__input_folder, ensemble_member_name, 'initial_state.csv'), delimiter=self.csv_delimiter)
self._input_folder, ensemble_member_name, 'initial_state.csv'), delimiter=self.csv_delimiter)
check_initial_state_array(_initial_state)
_initial_state = {key: float(_initial_state[key]) for key in _initial_state.dtype.names}
except IOError:
......@@ -143,30 +125,34 @@ class CSVMixin(OptimizationProblem):
else:
_timeseries = csv.load(
os.path.join(
self.__input_folder, self.timeseries_import_basename + ".csv"
self._input_folder, self.timeseries_import_basename + ".csv"
),
delimiter=self.csv_delimiter,
with_time=True,
)
self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]
self.__timeseries.append(
AliasDict(
self.alias_relation,
{key: np.asarray(_timeseries[key], dtype=np.float64) for key in _timeseries.dtype.names[1:]}))
self.io.set_times(
self.io.datetime_to_sec(
self.__timeseries_times,
self.__timeseries_times[self.io.get_forecast_index()]
)
)
for key in _timeseries.dtype.names[1:]:
self.io.set_timeseries_values(key, np.asarray(_timeseries[key], dtype=np.float64))
logger.debug("CSVMixin: Read timeseries.")
try:
_parameters = csv.load(os.path.join(
self.__input_folder, 'parameters.csv'), delimiter=self.csv_delimiter)
self._input_folder, 'parameters.csv'), delimiter=self.csv_delimiter)
logger.debug("CSVMixin: Read parameters.")
_parameters = {key: float(_parameters[key]) for key in _parameters.dtype.names}
for key in _parameters.dtype.names:
self.io.set_parameter(key, float(_parameters[key]))
except IOError:
_parameters = {}
self.__parameters.append(AliasDict(self.alias_relation, _parameters))
pass
try:
_initial_state = csv.load(os.path.join(
self.__input_folder, 'initial_state.csv'), delimiter=self.csv_delimiter)
self._input_folder, 'initial_state.csv'), delimiter=self.csv_delimiter)
logger.debug("CSVMixin: Read initial state.")
check_initial_state_array(_initial_state)
_initial_state = {key: float(_initial_state[key]) for key in _initial_state.dtype.names}
......@@ -174,31 +160,26 @@ class CSVMixin(OptimizationProblem):
_initial_state = {}
self.__initial_state.append(AliasDict(self.alias_relation, _initial_state))
self.__timeseries_times_sec = self.__datetime_to_sec(
self.__timeseries_times)
timeseries_times_sec = self.io.get_times()
# Timestamp check
if self.csv_validate_timeseries:
for i in range(len(self.__timeseries_times_sec) - 1):
if self.__timeseries_times_sec[i] >= self.__timeseries_times_sec[i + 1]:
for i in range(len(timeseries_times_sec) - 1):
if timeseries_times_sec[i] >= timeseries_times_sec[i + 1]:
raise Exception(
'CSVMixin: Time stamps must be strictly increasing.')
if self.csv_equidistant:
# Check if the timeseries are truly equidistant
if self.csv_validate_timeseries:
dt = self.__timeseries_times_sec[
1] - self.__timeseries_times_sec[0]
for i in range(len(self.__timeseries_times_sec) - 1):
if self.__timeseries_times_sec[i + 1] - self.__timeseries_times_sec[i] != dt:
dt = timeseries_times_sec[1] - timeseries_times_sec[0]
for i in range(len(timeseries_times_sec) - 1):
if timeseries_times_sec[i + 1] - timeseries_times_sec[i] != dt:
raise Exception(
'CSVMixin: Expecting equidistant timeseries, the time step towards '
'{} is not the same as the time step(s) before. Set csv_equidistant = False '
'if this is intended.'.format(self.__timeseries_times[i + 1]))
def times(self, variable=None):
return self.__timeseries_times_sec
@property
def equidistant(self):
return self.csv_equidistant
......@@ -225,7 +206,7 @@ class CSVMixin(OptimizationProblem):
for parameter in self.dae_variables['parameters']:
parameter = parameter.name()
try:
parameters[parameter] = self.__parameters[ensemble_member][parameter]
parameters[parameter] = self.io.get_parameter(parameter, ensemble_member)
except KeyError:
pass
else:
......@@ -233,72 +214,6 @@ class CSVMixin(OptimizationProblem):
logger.debug("CSVMixin: Read parameter {} ".format(parameter))
return parameters
@cached
def constant_inputs(self, ensemble_member):
# Call parent class first for default values.
constant_inputs = super(
CSVMixin, self).constant_inputs(ensemble_member)
# Load bounds from timeseries
for variable in self.dae_variables['constant_inputs']:
variable = variable.name()
try:
constant_inputs[variable] = Timeseries(
self.__timeseries_times_sec, self.__timeseries[ensemble_member][variable])
except (KeyError, ValueError):
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("CSVMixin: Read constant input {}".format(variable))
return constant_inputs
@cached
def bounds(self):
# Call parent class first for default values.
bounds = super().bounds()
# Load bounds from timeseries
for variable in self.dae_variables['free_variables']:
variable = variable.name()
m, M = None, None
timeseries_id = self.min_timeseries_id(variable)
try:
m = self.__timeseries[0][timeseries_id]
except (KeyError, ValueError):
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("CSVMixin: Read lower bound for variable {}".format(variable))
timeseries_id = self.max_timeseries_id(variable)
try:
M = self.__timeseries[0][timeseries_id]
except (KeyError, ValueError):
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("CSVMixin: Read upper bound for variable {}".format(variable))
# Replace NaN with +/- inf, and create Timeseries objects
if m is not None:
m[np.isnan(m)] = np.finfo(m.dtype).min
m = Timeseries(self.__timeseries_times_sec, m)
if M is not None:
M[np.isnan(M)] = np.finfo(M.dtype).max
M = Timeseries(self.__timeseries_times_sec, M)
# Store
if m is not None or M is not None:
bounds[variable] = (m, M)
return bounds
@property
def initial_time(self):
return 0.0
@cached
def initial_state(self, ensemble_member):
# Call parent class first for default values.
......@@ -316,29 +231,9 @@ class CSVMixin(OptimizationProblem):
logger.debug("CSVMixin: Read initial state {}".format(variable))
return initial_state
@cached
def seed(self, ensemble_member):
# Call parent class first for default values.
seed = super().seed(ensemble_member)
# Load seed values from CSV
for variable in self.dae_variables['free_variables']:
variable = variable.name()
try:
s = Timeseries(self.__timeseries_times_sec, self.__timeseries[ensemble_member][variable])
except (KeyError, ValueError):
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("CSVMixin: Seeded free variable {}".format(variable))
# A seeding of NaN means no seeding
s.values[np.isnan(s.values)] = 0.0
seed[variable] = s
return seed
def post(self):
def write(self):
# Call parent class first for default behaviour.
super().post()
super().write()
# Write output
times = self.times()
......@@ -349,7 +244,7 @@ class CSVMixin(OptimizationProblem):
formats = ['O'] + (len(names) - 1) * ['f8']
dtype = {'names': names, 'formats': formats}
data = np.zeros(len(times), dtype=dtype)
data['time'] = [self.__timeseries_times[0] + timedelta(seconds=s) for s in times]
data['time'] = [self.__timeseries_times[self.io.get_forecast_index()] + timedelta(seconds=s) for s in times]
for output_variable in self.output_variables:
output_variable = output_variable.name()
try:
......@@ -378,66 +273,6 @@ class CSVMixin(OptimizationProblem):
if self.csv_ensemble_mode:
for ensemble_member, ensemble_member_name in enumerate(self.__ensemble['name']):
write_output(ensemble_member, os.path.join(
self.__output_folder, ensemble_member_name))
else:
write_output(0, self.__output_folder)
def __datetime_to_sec(self, d):
# Return the date/timestamps in seconds since t0.
if hasattr(d, '__iter__'):
return np.array([(t - self.__timeseries_times[0]).total_seconds() for t in d])
else:
return (d - self.__timeseries_times[0]).total_seconds()
def __sec_to_datetime(self, s):
# Return the date/timestamps in seconds since t0 as datetime objects.
if hasattr(s, '__iter__'):
return [self.__timeseries_times[0] + timedelta(seconds=t) for t in s]
else:
return self.__timeseries_times[0] + timedelta(seconds=s)
def get_timeseries(self, variable, ensemble_member=0):
return Timeseries(self.__timeseries_times_sec, self.__timeseries[ensemble_member][variable])
def set_timeseries(self, variable, timeseries, ensemble_member=0, output=True, check_consistency=True):
if output:
self.__output_timeseries.add(variable)
if isinstance(timeseries, Timeseries):
# TODO: add better check on timeseries.times?
if check_consistency:
if not np.array_equal(self.times(), timeseries.times):
raise Exception(
'CSV: Trying to set/append timeseries {} with different times '
'(in seconds) than the imported timeseries. Please make sure the '
'timeseries covers startDate through endData of the longest '
'imported timeseries.'.format(variable))
self._output_folder, ensemble_member_name))
else:
timeseries = Timeseries(self.times(), timeseries)
assert(len(timeseries.times) == len(timeseries.values))
self.__timeseries[ensemble_member][variable] = timeseries.values
def timeseries_at(self, variable, t, ensemble_member=0):
return self.interpolate(t, self.__timeseries_times_sec, self.__timeseries[ensemble_member][variable])
@property
def output_variables(self):
variables = super().output_variables
variables.extend([ca.MX.sym(variable)
for variable in self.__output_timeseries])
return variables
def min_timeseries_id(self, variable: str) -> str:
"""
Returns the name of the lower bound timeseries for the specified variable.
:param variable: Variable name.
"""
return '_'.join((variable, 'Min'))
def max_timeseries_id(self, variable: str) -> str:
"""
Returns the name of the upper bound timeseries for the specified variable.
:param variable: Variable name.
"""
return '_'.join((variable, 'Max'))
write_output(0, self._output_folder)
import bisect
import logging
from abc import ABCMeta, abstractmethod
import casadi as ca
import numpy as np
from rtctools._internal.alias_tools import AliasDict
from rtctools._internal.caching import cached
from rtctools.optimization.optimization_problem import OptimizationProblem
from rtctools.optimization.timeseries import Timeseries
logger = logging.getLogger("rtctools")
class IOMixin(OptimizationProblem, metaclass=ABCMeta):
"""
Base class for all IO methods of optimization problems.
"""
def __init__(self, **kwargs):
# Call parent class first for default behaviour.
super().__init__(**kwargs)
# Additional output variables
self.__output_timeseries = set()
def pre(self) -> None:
# Call parent class first for default behaviour.
super().pre()
# Call read method to read all input
self.read()
@abstractmethod
def read(self) -> None:
"""
Reads input data from files
"""
pass
def post(self) -> None:
# Call parent class first for default behaviour.
super().post()
# Call write method to write all output
self.write()
@abstractmethod
def write(self) -> None:
""""
Writes output data to files
"""
pass
def times(self, variable=None) -> np.ndarray:
"""
Returns the times in seconds from the forecast index and onwards
:param variable:
"""
return self.io.get_times()[self.io.get_forecast_index():]
def get_timeseries(self, variable: str, ensemble_member: int = 0) -> Timeseries:
return Timeseries(self.io.get_times(), self.io.get_timeseries_values(variable, ensemble_member))
def set_timeseries(
self,
variable: str,
timeseries: Timeseries,
ensemble_member: int = 0,
output: bool = True,
check_consistency: bool = True):
def stretch_values(values, t_pos):
# Construct a values range with preceding and possibly following nans
new_values = np.full(self.io.get_times().shape, np.nan)
new_values[t_pos:] = values
return new_values
if output:
self.__output_timeseries.add(variable)
if isinstance(timeseries, Timeseries):
if len(timeseries.values) != len(timeseries.times):
raise ValueError('IOMixin: Trying to set timeseries {} with times and values that are of '
'different length (lengths of {} and {}, respectively).'
.format(variable, len(timeseries.times), len(timeseries.values)))
timeseries_times_sec = self.io.get_times()
if not np.array_equal(timeseries_times_sec, timeseries.times):
if check_consistency:
raise ValueError(
'IOMixin: Trying to set timeseries {} with different times '
'(in seconds) than the imported timeseries. Please make sure the '
'timeseries covers all timesteps of the longest '
'imported timeseries.'.format(variable)
)
# Determine position of first times of added timeseries within the
# import times. For this we assume that both time ranges are ordered,
# and that the times of the added series is a subset of the import
# times.
t_pos = bisect.bisect_left(timeseries_times_sec, timeseries.times[0])
# Construct a new values range with length of self.io.get_times()
values = stretch_values(timeseries.values, t_pos)
else:
values = timeseries.values
else:
if check_consistency and len(self.times()) != len(timeseries):
raise ValueError('IOMixin: Trying to set values for {} with a different '
'length ({}) than the forecast length. Please make sure the '
'values covers all timesteps of the longest imported timeseries (length {}).'
.format(variable, len(timeseries), len(self.times())))
# If times is not supplied with the timeseries, we add the
# forecast times range to a new Timeseries object. Hereby
# we assume that the supplied values stretch from T0 to end.
t_pos = self.io.get_forecast_index()
# Construct a new values range with length of self.io.get_times()
values = stretch_values(timeseries, t_pos)
self.io.set_timeseries_values(variable, values, ensemble_member)
def min_timeseries_id(self, variable: str) -> str:
"""
Returns the name of the lower bound timeseries for the specified variable.
:param variable: Variable name.
"""
return '_'.join((variable, 'Min'))
def max_timeseries_id(self, variable: str) -> str:
"""
Returns the name of the upper bound timeseries for the specified variable.
:param variable: Variable name.
"""
return '_'.join((variable, 'Max'))
@cached
def bounds(self):
# Call parent class first for default values.
bounds = super().bounds()
forecast_index = self.io.get_forecast_index()
# Load bounds from timeseries
for variable in self.dae_variables['free_variables']:
variable_name = variable.name()
m, M = None, None
timeseries_id = self.min_timeseries_id(variable_name)
try:
m = self.io.get_timeseries_values(timeseries_id, 0)[forecast_index:]
except KeyError:
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("Read lower bound for variable {}".format(variable_name))
timeseries_id = self.max_timeseries_id(variable_name)
try:
M = self.io.get_timeseries_values(timeseries_id, 0)[forecast_index:]
except KeyError:
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("Read upper bound for variable {}".format(variable_name))
# Replace NaN with +/- inf, and create Timeseries objects
if m is not None:
m[np.isnan(m)] = np.finfo(m.dtype).min
m = Timeseries(self.io.get_times()[forecast_index:], m)
if M is not None:
M[np.isnan(M)] = np.finfo(M.dtype).max
M = Timeseries(self.io.get_times()[forecast_index:], M)
# Store
if m is not None or M is not None:
bounds[variable_name] = (m, M)
return bounds
@cached
def history(self, ensemble_member):
# Load history
history = AliasDict(self.alias_relation)
end_index = self.io.get_forecast_index() + 1
variable_list = self.dae_variables['states'] + self.dae_variables['algebraics'] + \
self.dae_variables['control_inputs'] + self.dae_variables['constant_inputs']
for variable in variable_list:
variable = variable.name()
try:
history[variable] = Timeseries(
self.io.get_times()[:end_index],
self.io.get_timeseries_values(variable, ensemble_member)[:end_index])
except KeyError:
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("IOMixin: Read history for state {}".format(variable))
return history
@cached
def seed(self, ensemble_member):
# Call parent class first for default values.
seed = super().seed(ensemble_member)
# Load seeds
for variable in self.dae_variables['free_variables']:
variable = variable.name()
try:
s = Timeseries(
self.io.get_times(),
self.io.get_timeseries_values(variable, ensemble_member)
)
except KeyError:
pass
else:
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("IOMixin: Seeded free variable {}".format(variable))
# A seeding of NaN means no seeding
s.values[np.isnan(s.values)] = 0.0
seed[variable] = s
return seed
@cached
def constant_inputs(self, ensemble_member):
# Call parent class first for default values.
constant_inputs = super().constant_inputs(ensemble_member)
# Load inputs from timeseries
for variable in self.dae_variables['constant_inputs']:
variable = variable.name()
try:
timeseries = Timeseries(
self.io.get_times(),
self.io.get_timeseries_values(variable, ensemble_member)
)
except KeyError:
pass
else:
if np.any(np.isnan(timeseries.values[self.io.get_forecast_index():])):
raise Exception("IOMixin: Constant input {} contains NaN".format(variable))
constant_inputs[variable] = timeseries
if logger.getEffectiveLevel() == logging.DEBUG:
logger.debug("IOMixin: Read constant input {}".format(variable))
return constant_inputs
def timeseries_at(self, variable, t, ensemble_member=0):
return self.interpolate(t, self.io.get_times(), self.io.get_timeseries_values(variable, ensemble_member))
@property
def output_variables(self):
variables = super().output_variables
variables.extend([ca.MX.sym(variable) for variable in self.__output_timeseries])
return variables
......@@ -279,8 +279,14 @@ class ModelicaMixin(OptimizationProblem):
M_ = float(M_)
# We take the intersection of all provided bounds
m = max(m, m_)
M = min(M, M_)
def intersect(old_bound, new_bound, intersecter):
if isinstance(old_bound, Timeseries):
return Timeseries(old_bound.times, intersecter(old_bound.values, new_bound))
else:
return intersecter(old_bound, new_bound)
m = intersect(m, m_, np.maximum)
M = intersect(M, M_, np.minimum)
bounds[sym_name] = (m, M)
......@@ -297,7 +303,7 @@ class ModelicaMixin(OptimizationProblem):
# Load seeds
for var in itertools.chain(self.__pymoca_model.states, self.__pymoca_model.alg_states):
if var.fixed:
if var.fixed or var.symbol.name() in seed.keys():
# Values will be set from import timeseries
continue
......
import logging
import os
import rtctools.data.netcdf as netcdf
from rtctools.data import rtc
from rtctools.optimization.io_mixin import IOMixin
logger = logging.getLogger("rtctools")
# todo add support for ensembles
class NetCDFMixin(IOMixin):
"""
Adds NetCDF I/O to your optimization problem.
During preprocessing, a file named timeseries_import.nc is read from the ``input`` subfolder.
During postprocessing a file named timeseries_export.nc is written to the ``output`` subfolder.
Both the input and output nc files are expected to follow the FEWS format for scalar data in a Netcdf file, i.e.:
- They must contain a variable with the station id's (location id's) which can be recognized by the attribute
'cf_role' set to 'timeseries_id'.
- They must contain a time variable with attributes 'standard_name' = 'time' and 'axis' = 'T'
From the input file, all 2d variables with dimensions equal to the station id's and time variable are read.
To determine the rtc-tools variable name, the NetCDF mixin uses the station id (location id) and name of the
timeseries variable in the file (parameter). An rtcDataConfig.xml file can be given in the input folder to
configure variable names for specific location and parameter combinations. If this file is present, and contains
a configured variable name for a read timeseries, this variable name will be used. If the file is present, but does
not contain a configured variable name, a default variable name is constructed and a warning is given to alert the
user that the current rtcDataConfig may contain a mistake. To suppress this warning if this is intentional, set the
check_missing_variable_names attribute to False. Finally, if no file is present, the default variable name will
always be used, and no warnings will be given. With debug logging enabled, the NetCDF mixin will report the chosen
variable name for each location and parameter combination.
To construct the default variable name, the station id is concatenated with the name of the variable in the NetCDF
file, separted by the location_parameter_delimeter (set to a double underscore - '__' - by default). For example,
if a NetCDF file contains two stations 'loc_1' and 'loc_2', and a timeseries variable called 'water_level', this
will result in two rtc-tools variables called 'loc_1__water_level' and 'loc_2__water_level' (with the default
location_parameter_delimiter of '__').
:cvar location_parameter_delimiter:
Delimiter used between location and parameter id when constructing the variable name.
:cvar check_missing_variable_names:
Warn if an rtcDataConfig.xml file is given but does not contain a variable name for a read timeseries.
Default is ``True``
:cvar netcdf_validate_timeseries:
Check consistency of timeseries. Default is ``True``
"""
#: Delimiter used between location and parameter id when constructing the variable name.
location_parameter_delimiter = '__'
#: Warn if an rtcDataConfig.xml file is given but does not contain a variable name for a read timeseries.
check_missing_variable_names = True
#: Check consistency of timeseries.
netcdf_validate_timeseries = True
def __init__(self, **kwargs):
# call parent class for default behaviour
super().__init__(**kwargs)
path = os.path.join(self._input_folder, "rtcDataConfig.xml")
self.__data_config = rtc.DataConfig(self._input_folder) if os.path.isfile(path) else None
def read(self):
# Call parent class first for default behaviour
super().read()
dataset = netcdf.ImportDataset(self._input_folder, self.timeseries_import_basename)
# convert and store the import times
self.__import_datetimes = dataset.read_import_times()
times = self.io.datetime_to_sec(self.__import_datetimes, self.__import_datetimes[self.io.get_forecast_index()])
self.io.set_times(times)
if self.netcdf_validate_timeseries:
# check if strictly increasing
for i in range(len(times) - 1):
if times[i] >= times[i + 1]:
raise Exception('NetCDFMixin: Time stamps must be strictly increasing.')
self.__dt = times[1] - times[0] if len(times) >= 2 else 0
for i in range(len(times) - 1):
if times[i + 1] - times[i] != self.__dt:
self.__dt = None
break
# store the station data for later use
self.__stations = dataset.read_station_data()
# read all available timeseries from the dataset
timeseries_var_keys = dataset.find_timeseries_variables()
# todo add support for ensembles
for parameter in timeseries_var_keys:
for i, location_id in enumerate(self.__stations.station_ids):
default_name = location_id + self.location_parameter_delimiter + parameter
if self.__data_config is not None:
try:
name = self.__data_config.parameter(parameter, location_id)
except KeyError:
if self.check_missing_variable_names:
logger.warning('No configured variable name found in rtcDataConfig.xml for location id "{}"'
' and parameter id "{}", using default variable name "{}" instead. '
'(To suppress this warning set check_missing_variable_names to False.)'
.format(location_id, parameter, default_name))
name = default_name
else:
name = default_name
values = dataset.read_timeseries_values(i, parameter)
self.io.set_timeseries_values(name, values)
logger.debug('Read timeseries data for location id "{}" and parameter "{}", '
'stored under variable name "{}"'
.format(location_id, parameter, name))
logger.debug("NetCDFMixin: Read timeseries")
def write(self):
dataset = netcdf.ExportDataset(self._output_folder, self.timeseries_export_basename)
times = self.times()
forecast_index = self.io.get_forecast_index()
dataset.write_times(times, self.initial_time, self.__import_datetimes[forecast_index])
output_variables = [sym.name() for sym in self.output_variables]
output_location_parameter_ids = {var_name: self.extract_station_id(var_name) for var_name in output_variables}
output_station_ids = {loc_par[0] for loc_par in output_location_parameter_ids.values()}
dataset.write_station_data(self.__stations, output_station_ids)
output_parameter_ids = {loc_par[1] for loc_par in output_location_parameter_ids.values()}
dataset.create_variables(output_parameter_ids)
for ensemble_member in range(self.ensemble_size):
results = self.extract_results(ensemble_member)
for var_name in output_variables:
# determine the output values
try:
values = results[var_name]
if len(values) != len(times):
values = self.interpolate(
times, self.times(var_name), values, self.interpolation_method(var_name))
except KeyError:
try:
ts = self.get_timeseries(var_name, ensemble_member)
if len(ts.times) != len(times):
values = self.interpolate(
times, ts.times, ts.values)
else:
values = ts.values
except KeyError:
logger.error(
'NetCDFMixin: Output requested for non-existent variable {}. '
'Will not be in output file.'.format(var_name))
continue
# determine where to put this output
location_parameter_id = output_location_parameter_ids[var_name]
location_id = location_parameter_id[0]
parameter_id = location_parameter_id[1]
dataset.write_output_values(location_id, parameter_id, values)
dataset.close()
def extract_station_id(self, variable_name: str) -> tuple:
"""
Returns the station id corresponding to the given RTC-Tools variable name.
:param variable_name: The name of the RTC-Tools variable
:return: the station id
"""
try:
return self.__data_config.pi_variable_ids(variable_name)[:2]
except KeyError:
return tuple(variable_name.split(self.location_parameter_delimiter))
@property
def equidistant(self):
return self.__dt is not None
......@@ -6,19 +6,23 @@ import casadi as ca
import numpy as np
from rtctools._internal.alias_tools import AliasDict, AliasRelation
from rtctools._internal.alias_tools import AliasDict
from rtctools.data.storage import DataStoreAccessor
from .timeseries import Timeseries
logger = logging.getLogger("rtctools")
class OptimizationProblem(metaclass=ABCMeta):
class OptimizationProblem(DataStoreAccessor, metaclass=ABCMeta):
"""
Base class for all optimization problems.
"""
def __init__(self, **kwargs):
# Call parent class first for default behaviour.
super().__init__(**kwargs)
self.__mixed_integer = False
def optimize(self, preprocessing: bool = True, postprocessing: bool = True,
......@@ -406,10 +410,6 @@ class OptimizationProblem(metaclass=ABCMeta):
{variable: Timeseries(np.array([self.initial_time]), np.array([state]))
for variable, state in initial_state.items()})
@abstractproperty
def alias_relation(self) -> AliasRelation:
raise NotImplementedError
def variable_is_discrete(self, variable: str) -> bool: