Commit 00dbba91 authored by Jorn Baayen's avatar Jorn Baayen

Clean up casadi helpers.

parent 425b6925
from casadi import MX, Function, jacobian, vertcat, reshape, mtimes, substitute, interpolant, transpose, repmat
import casadi as ca
import numpy as np
import logging
......@@ -6,7 +6,7 @@ logger = logging.getLogger("rtctools")
def is_affine(e, v):
Af = Function('f', [v], [jacobian(e, v)])
Af = ca.Function('f', [v], [ca.jacobian(e, v)])
return (Af.sparsity_jac(0, 0).nnz() == 0)
......@@ -15,9 +15,9 @@ def nullvertcat(*L):
Like vertcat, but creates an MX with consistent dimensions even if L is empty.
"""
if len(L) == 0:
return MX(0, 1)
return ca.MX(0, 1)
else:
return vertcat(*L)
return ca.vertcat(*L)
def reduce_matvec(e, v):
......@@ -26,13 +26,13 @@ def reduce_matvec(e, v):
This reduces the number of nodes required to represent the linear operations.
"""
Af = Function('Af', [MX()], [jacobian(e, v)])
A = Af(MX())
return reshape(mtimes(A, v), e.shape)
Af = ca.Function('Af', [ca.MX()], [ca.jacobian(e, v)])
A = Af(ca.MX())
return ca.reshape(ca.mtimes(A, v), e.shape)
def substitute_in_external(expr, symbols, values):
f = Function('f', symbols, expr)
f = ca.Function('f', symbols, expr)
return f.call(values, True, False)
......@@ -40,18 +40,18 @@ def interpolate(ts, xs, t, equidistant, mode=0):
if False: # TODO mode == 0:
print(ts)
print(xs)
return interpolant('interpolant', 'linear', [ts], xs, {'lookup_mode': 'exact'})(t)
return ca.interpolant('interpolant', 'linear', [ts], xs, {'lookup_mode': 'exact'})(t)
else:
if mode == 1:
xs = xs[:-1] # block-forward
else:
xs = xs[1:] # block-backward
t = MX(t)
t = ca.MX(t)
if t.size1() > 1:
t_ = MX.sym('t')
xs_ = MX.sym('xs', xs.size1())
f = Function('interpolant', [t_, xs_], [mtimes(transpose((t_ >= ts[:-1]) * (t_ < ts[1:])), xs_)])
t_ = ca.MX.sym('t')
xs_ = ca.MX.sym('xs', xs.size1())
f = ca.Function('interpolant', [t_, xs_], [ca.mtimes(ca.transpose((t_ >= ts[:-1]) * (t_ < ts[1:])), xs_)])
f = f.map(t.size1(), 'serial')
return transpose(f(transpose(t), repmat(xs, 1, t.size1())))
return ca.transpose(f(ca.transpose(t), ca.repmat(xs, 1, t.size1())))
else:
return mtimes(transpose((t >= ts[:-1]) * (t < ts[1:])), xs)
\ No newline at end of file
return ca.mtimes(ca.transpose((t >= ts[:-1]) * (t < ts[1:])), xs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment