csv_mixin.py 6.49 KB
Newer Older
Jesse VanderWees's avatar
Jesse VanderWees committed
1
import logging
2
import os
Jesse VanderWees's avatar
Jesse VanderWees committed
3 4

import numpy as np
5 6

import rtctools.data.csv as csv
Jesse VanderWees's avatar
Jesse VanderWees committed
7
from rtctools._internal.caching import cached
8
from rtctools.simulation.io_mixin import IOMixin
9 10 11 12

logger = logging.getLogger("rtctools")


13
class CSVMixin(IOMixin):
14 15 16
    """
    Adds reading and writing of CSV timeseries and parameters to your simulation problem.

Jesse VanderWees's avatar
Jesse VanderWees committed
17 18
    During preprocessing, files named ``timeseries_import.csv``, ``initial_state.csv``,
    and ``parameters.csv`` are read from the ``input`` subfolder.
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33

    During postprocessing, a file named ``timeseries_export.csv`` is written to the ``output`` subfolder.

    :cvar csv_delimiter:           Column delimiter used in CSV files.  Default is ``,``.
    :cvar csv_validate_timeseries: Check consistency of timeseries.  Default is ``True``.
    """

    #: Column delimiter used in CSV files
    csv_delimiter = ','

    #: Check consistency of timeseries
    csv_validate_timeseries = True

    def __init__(self, **kwargs):
        # Call parent class first for default behaviour.
34
        super().__init__(**kwargs)
35

36
    def read(self):
37
        # Call parent class first for default behaviour.
38
        super().read()
39

Jesse VanderWees's avatar
Jesse VanderWees committed
40
        # Helper function to check if initial state array actually defines
41 42 43 44 45 46
        # only the initial state
        def check_initial_state_array(initial_state):
            """
            Check length of initial state array, throw exception when larger than 1.
            """
            if initial_state.shape:
Jesse VanderWees's avatar
Jesse VanderWees committed
47 48 49
                raise Exception(
                    'CSVMixin: Initial state file {} contains more than one row of data. '
                    'Please remove the data row(s) that do not describe the initial '
50
                    'state.'.format(os.path.join(self._input_folder, 'initial_state.csv')))
51 52

        # Read CSV files
Jesse VanderWees's avatar
Jesse VanderWees committed
53
        _timeseries = csv.load(
54
            os.path.join(self._input_folder, self.timeseries_import_basename + '.csv'),
Jesse VanderWees's avatar
Jesse VanderWees committed
55
            delimiter=self.csv_delimiter, with_time=True)
56
        self.__timeseries_times = _timeseries[_timeseries.dtype.names[0]]
57 58 59 60 61 62 63
        timeseries_times_sec = np.array(self.datetime_to_sec(
            self.__timeseries_times,
            self.__timeseries_times[self.get_forecast_index()]
        ))
        self.set_times(timeseries_times_sec)
        for key in _timeseries.dtype.names[1:]:
            self.set_timeseries_values(key, np.asarray(_timeseries[key], dtype=np.float64))
Jesse VanderWees's avatar
Jesse VanderWees committed
64

65 66 67
        logger.debug("CSVMixin: Read timeseries.")

        try:
Jesse VanderWees's avatar
Jesse VanderWees committed
68
            _parameters = csv.load(
69
                os.path.join(self._input_folder, 'parameters.csv'),
Jesse VanderWees's avatar
Jesse VanderWees committed
70
                delimiter=self.csv_delimiter)
71 72
            for key in _parameters.dtype.names:
                self.set_parameter(key, float(_parameters[key]))
73 74
            logger.debug("CSVMixin: Read parameters.")
        except IOError:
75
            pass
76 77

        try:
Jesse VanderWees's avatar
Jesse VanderWees committed
78
            _initial_state = csv.load(
79
                os.path.join(self._input_folder, 'initial_state.csv'),
Jesse VanderWees's avatar
Jesse VanderWees committed
80
                delimiter=self.csv_delimiter)
81
            logger.debug("CSVMixin: Read initial state.")
82
            check_initial_state_array(_initial_state)
Jesse VanderWees's avatar
Jesse VanderWees committed
83 84
            self.__initial_state = {
                key: float(_initial_state[key]) for key in _initial_state.dtype.names}
85
        except IOError:
86
            self.__initial_state = {}
87

88
        # Check for collisions in __initial_state and __timeseries
89 90
        for collision in set(self.__initial_state) & self.get_variables():
            if self.__initial_state[collision] == self.get_timeseries_values(collision)[0]:
91 92
                continue
            else:
Jesse VanderWees's avatar
Jesse VanderWees committed
93 94 95
                logger.warning(
                    'CSVMixin: Entry {} in initial_state.csv conflicts with '
                    'timeseries_import.csv'.format(collision))
96

97 98
        # Timestamp check
        if self.csv_validate_timeseries:
99 100
            for i in range(len(timeseries_times_sec) - 1):
                if timeseries_times_sec[i] >= timeseries_times_sec[i + 1]:
101 102 103
                    raise Exception(
                        'CSVMixin: Time stamps must be strictly increasing.')

104
        dt = timeseries_times_sec[1] - timeseries_times_sec[0]
105 106 107

        # Check if the timeseries are truly equidistant
        if self.csv_validate_timeseries:
108 109
            for i in range(len(timeseries_times_sec) - 1):
                if timeseries_times_sec[i + 1] - timeseries_times_sec[i] != dt:
Jesse VanderWees's avatar
Jesse VanderWees committed
110 111 112 113 114
                    raise Exception(
                        'CSVMixin: Expecting equidistant timeseries, the time step '
                        'towards {} is not the same as the time step(s) before. '
                        'Set equidistant=False if this is intended.'.format(
                            self.__timeseries_times[i + 1]))
115

116
    def write(self):
117
        # Call parent class first for default behaviour.
118
        super().write()
119

Jesse VanderWees's avatar
Jesse VanderWees committed
120
        # Write output
121
        names = ['time'] + sorted(set(self.__output.keys()))
122
        formats = ['O'] + (len(names) - 1) * ['f8']
Jesse VanderWees's avatar
Jesse VanderWees committed
123
        dtype = {'names': names, 'formats': formats}
124 125
        data = np.zeros(len(self.__timeseries_times), dtype=dtype)
        data['time'] = self.__timeseries_times
126
        for variable, values in self.output.items():
127 128
            data[variable] = values

129
        fname = os.path.join(self._output_folder, self.timeseries_export_basename + '.csv')
130 131
        csv.save(fname, data, delimiter=self.csv_delimiter, with_time=True)

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    @cached
    def initial_state(self):
        """
        The initial state. Includes entries from parent classes and initial_state.csv

        :returns: A dictionary of variable names and initial state (t0) values.
        """
        # Call parent class first for default values.
        initial_state = super().initial_state()

        # Set of model vars that are allowed to have an initial state
        valid_model_vars = set(self.get_state_variables()) | set(self.get_input_variables())

        # Load initial states from __initial_state
        for variable, value in self.__initial_state.items():

            # Get the cannonical vars and signs
            canonical_var, sign = self.alias_relation.canonical_signed(variable)

            # Only store variables that are allowed to have an initial state
            if canonical_var in valid_model_vars:
                initial_state[canonical_var] = value * sign

                if logger.getEffectiveLevel() == logging.DEBUG:
                        logger.debug("CSVMixin: Read initial state {} = {}".format(variable, value))
            else:
                logger.warning("CSVMixin: In initial_state.csv, {} is not an input or state variable.".format(variable))
        return initial_state