Commit 0b4c7cf9 authored by Jesse VanderWees's avatar Jesse VanderWees 🐘 Committed by Tjerk Vreeken

Properly handle LookupTable __call__() parameters

The handling of LookupTable __call__() parameters was somewhat sloppy,
without clear error codes if parameters were not supported. Furthermore,
calling with np.nan would expose undefined behaviour. This commit
improves error messages and does explicit handling of np.nan parameters.
parent 9ceda84d
......@@ -3,7 +3,7 @@ import glob
import logging
import os
import pickle
from typing import List, Tuple, Union
from typing import Iterable, List, Tuple, Union
import casadi as ca
......@@ -79,7 +79,18 @@ class LookupTable:
"""
return self.__function
def __call__(self, *args: List[Union[float, Timeseries]]) -> Union[float, Timeseries]:
@property
@cached
def __numeric_function_evaluator(self):
return np.vectorize(
lambda *args: np.nan
if np.any(np.isnan(args))
else np.float(self.function(*args))
)
def __call__(
self, *args: Union[float, Iterable, Timeseries]
) -> Union[float, np.ndarray, Timeseries]:
"""
Evaluate the lookup table.
......@@ -93,42 +104,95 @@ class LookupTable:
[y1, y2] = lookup_table([1.0, 2.0])
"""
if isinstance(args[0], Timeseries):
return Timeseries(args[0].times, self(args[0].values))
evaluator = self.__numeric_function_evaluator
if len(args) == 1:
arg = args[0]
if isinstance(arg, Timeseries):
return Timeseries(arg.times, self(arg.values))
else:
if hasattr(arg, "__iter__"):
arg = np.fromiter(arg, dtype=float)
return evaluator(arg)
else:
arg = float(arg)
return evaluator(arg).item()
else:
if hasattr(args[0], '__iter__'):
evaluator = np.vectorize(
lambda v: float(self.function(v)))
return evaluator(args[0])
if any(isinstance(arg, Timeseries) for arg in args):
raise TypeError(
"Higher-order LookupTable calls do not yet support Timeseries parameters"
)
elif any(hasattr(arg, "__iter__") for arg in args):
raise TypeError(
"Higher-order LookupTable calls do not yet support vector parameters"
)
else:
return float(self.function(*args))
def reverse_call(self, y, domain=(None, None), detect_range_error=True):
"""
use scipy brentq optimizer to do an inverted call to this lookuptable
args = np.fromiter(args, dtype=float)
return evaluator(*args)
def reverse_call(
self,
y: Union[float, Iterable, Timeseries],
domain: Tuple[float, float] = (None, None),
detect_range_error: bool = True,
) -> Union[float, np.ndarray, Timeseries]:
"""Do an inverted call on this LookupTable
Uses SciPy brentq optimizer to simulate a reversed call.
Note: Method does not work with higher-order LookupTables
"""
if isinstance(y, Timeseries):
# Recurse and return
return Timeseries(y.times, self.reverse_call(y.values))
# Get domain information
l_d, u_d = domain
if l_d is None:
l_d = self.domain[0]
if u_d is None:
u_d = self.domain[1]
if detect_range_error:
l_r, u_r = self.range
# Cast y to array of float
if hasattr(y, "__iter__"):
y_array = np.fromiter(y, dtype=float)
else:
y_array = np.array([y], dtype=float)
def function(y_target):
if detect_range_error and (y_target < l_r or y_target > u_r):
raise ValueError('Value {} is not in lookup table range ({}, {})'.format(y_target, l_r, u_r))
return brentq(lambda x: self(x) - y_target, l_d, u_d)
# Find not np.nan
is_not_nan = ~np.isnan(y_array)
y_array_not_nan = y_array[is_not_nan]
if isinstance(y, Timeseries):
return Timeseries(y.times, self.reverse_call(y.values))
# Detect if there is a range violation
if detect_range_error:
l_r, u_r = self.range
lb_viol = y_array_not_nan < l_r
ub_viol = y_array_not_nan > u_r
all_viol = y_array_not_nan[lb_viol | ub_viol]
if all_viol:
raise ValueError(
"Values {} are not in lookup table range ({}, {})".format(
all_viol, l_r, u_r
)
)
# Construct function to do inverse evaluation
evaluator = self.__numeric_function_evaluator
def inv_evaluator(y_target):
"""inverse evaluator function"""
return brentq(lambda x: evaluator(x) - y_target, l_d, u_d)
inv_evaluator = np.vectorize(inv_evaluator)
# Calculate x_array
x_array = np.full_like(y_array, np.nan, dtype=float)
if y_array_not_nan.size != 0:
x_array[is_not_nan] = inv_evaluator(y_array_not_nan)
# Return x
if hasattr(y, "__iter__"):
return x_array
else:
if hasattr(y, '__iter__'):
evaluator = np.vectorize(function)
return evaluator(y)
else:
return function(y)
return x_array.item()
class CSVLookupTableMixin(OptimizationProblem):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment