Commit 31b33c78 authored by Jorn Baayen's avatar Jorn Baayen

Vector support in extract_states/extract_controls

parent c78f3298
......@@ -1006,11 +1006,13 @@ class CollocatedIntegratedOptimizationProblem(OptimizationProblem, metaclass = A
# Extract control inputs
results = {}
offset = 0
for variable in self.controls:
n_times = len(self.times(variable))
results[variable] = np.array(self.variable_nominal(
variable) * X[offset:offset + n_times, 0]).ravel()
offset += n_times
for variable in self.dae_variables['control_inputs']:
variable_name = variable.name()
variable_numel = variable.numel()
n_times = len(self.times(variable_name))
results[variable_name] = np.array(self.variable_nominal(
variable_name) * X[offset:offset + n_times * variable_numel, 0]).reshape((n_times, variable_numel))
offset += n_times * variable_numel
# Done
return results
......@@ -1293,38 +1295,44 @@ class CollocatedIntegratedOptimizationProblem(OptimizationProblem, metaclass = A
# Extract collocated variables
offset = control_size + ensemble_member * ensemble_member_size
for variable in itertools.chain(self.differentiated_states, self.algebraic_states):
if variable in self.integrated_states:
offset += 1
for variable in itertools.chain(self.dae_variables['states'], self.dae_variables['algebraics']):
variable_name = variable.name()
variable_numel = variable.numel()
if variable_name in self.integrated_states:
offset += variable_numel
else:
n_times = len(self.times(variable))
results[variable] = np.array(self.variable_nominal(
variable) * X[offset:offset + n_times, 0]).ravel()
offset += n_times
n_times = len(self.times(variable_name))
results[variable_name] = np.array(self.variable_nominal(
variable_name) * X[offset:offset + n_times * variable_numel, 0]).reshape((n_times, variable_numel))
offset += n_times * variable_numel
# Extract constant input aliases
constant_inputs = self.constant_inputs(ensemble_member)
for variable in self.dae_variables['constant_inputs']:
variable = variable.name()
variable_name = variable.name()
variable_numel = variable.numel()
try:
constant_input = constant_inputs[variable]
constant_input = constant_inputs[variable_name]
except KeyError:
pass
else:
results[variable] = np.interp(self.times(variable), constant_input.times, constant_input.values)
results[variable_name] = np.interp(self.times(variable_name), constant_input.times, constant_input.values)
# Extract path variables
n_collocation_times = len(self.times())
for variable in self.path_variables:
variable = variable.name()
results[variable] = np.array(
X[offset:offset + n_collocation_times, 0]).ravel()
offset += n_collocation_times
variable_name = variable.name()
variable_numel = variable.numel()
results[variable_name] = np.array(
X[offset:offset + n_collocation_times * variable_numel, 0]).reshape((n_collocation_times, variable_numel))
offset += n_collocation_times * variable_numel
# Extract extra variables
for k in range(len(self.extra_variables)):
variable = self.extra_variables[k].name()
results[variable] = np.array(X[offset + k, 0]).ravel()
for variable in self.extra_variables:
variable_name = variable.name()
variable_numel = variable.numel()
results[variable_name] = np.array(X[offset:offset + variable_numel, 0]).reshape((1, variable_numel))
offset += variable_numel
# Done
return results
......
......@@ -239,9 +239,12 @@ class ControlTreeMixin(OptimizationProblem):
# Extract control inputs
results = {}
for variable in self.controls:
results[variable] = np.array(self.variable_nominal(
variable) * X[self._control_indices[variable][ensemble_member, :], 0]).ravel()
for variable in self.dae_variables['control_inputs']:
variable_name = variable.name()
variable_numel = variable.numel()
n_times = len(self.times(variable_name))
results[variable_name] = np.array(self.variable_nominal(
variable_name) * X[self._control_indices[variable][ensemble_member, :], 0]).reshape((n_times, variable_numel))
# Done
return results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment