from scipy import log, sqrt, linspace, logspace, meshgrid, zeros, absolute
from scipy.optimize import minimize
import numpy as np


def isstable(model):
    """

    :param model: (F,G,H)
    :return: 2GH + H^2 - 2F, > 0 if stable; < 0 if unstable
    """
    return 2 * model[1] * model[2] + model[2] * model[2] - 2 * model[0]


def tf((f, g, h), w):
    """
    Return the complex value of the linearized transfer function T(jw)
    """
    return (g * 1j * w + f) / (-w * w + (g + h) * 1j * w + f)


def tflogmag(model, w):
    """
    Return the log of the magnitude of the linearized transfer function
        log[ |T(jw)| ]
    """
    return log(absolute(tf(model, w)))


def htflogmag(model, w):
    """
    Return the log of the magnitude of the linearized headway transfer function
        log[ |1-T(jw)| ]
    """
    return log(absolute(1 - tf(model, w)))


def btf(human, robot, gamma, w):
    """
    Return the complex value of the linearized transfer function T(jw)
    of a bilateral robot controller [ followed by a human ]:
        Y(robot) = gamma Tr Y(forward) + (1-gamma) Tr Y(behind)
        Y(behind) = Th Y(robot)
        Y(robot) = gamma Tr Y(forward) + (1-gamma) Tr Th Y(robot)
        Y(robot) = [ gamma Tr ] / [ 1 - (1-gamma) Tr Th ] Y(forward)
    """
    tr = tf(robot, w)
    th = tf(human, w)
    t = gamma * tr / (1 - (1 - gamma) * tr * th)
    return t


def btflogmag(human, robot, gamma, w):
    """
    Return the log of the magnitude of the linearized transfer function
        log[ |T(jw)| ]
    of a bilateral robot controller [ followed by a human ]
    """
    return log(absolute(btf(human, robot, gamma, w)))


def bhtflogmag(human, robot, gamma, w):
    """
    Return the log of the magnitude of the linearized headway transfer function
        log[ |1-T(jw)| ]
    of a bilateral robot controller [ followed by a human ]
    """
    return log(absolute(1 - btf(human, robot, gamma, w)))


def kstable(human, robot, w, eta=2):
    """
    Return the maximum number of human cars 
        that can be string-stabilized by a single robot car
        for an oscillatory perturbation of frequency w
    """
    return -tflogmag(robot, w) / tflogmag(human, w)


def ksafe(human, robot, w, eta):
    """
    Return the maximum number of human cars
        that can be "safely" followed by a single robot car
        for an oscillatory perturbation of frequency w
        by a safety parameter eta
    :param human:
    :param robot:
    :param w:
    :param eta: safety margin
    :return:
    """
    return (log(eta) - htflogmag(robot, w)) / tflogmag(human, w)


TOL = 1e-6


def maxkfn(human, robot, eta, kfn):
    """
    Find the maximum number of humans cars that can be handled by a
    single robot car for any frequency (all w)
    :param human:
    :param robot:
    :param eta:
    :param kfn: function { ksafe for safety | kstable for string stability }
    :return:
    """
    f, g, h = human
    maxw = sqrt(-isstable(human))
    return minimize(lambda w: kfn(human, robot, w, eta), 1,
                    bounds=[(TOL, maxw - TOL)])  # .fun[0]


def maxk(human, robot, eta):
    """
    Min of maxksafe and maxkstable
    :param human:
    :param robot:
    :param eta:
    :return:
    """
    return min(maxkfn(human, robot, eta, kstable),
               maxkfn(human, robot, eta, ksafe))


def best_robot_at_w(human, w, optfun, eta=2, fbound=(0, 0.2), gbound=(0, 0.2),
                    hbound=(0, 0.2)):
    cons = ({'type': 'ineq', 'fun': lambda x: isstable(x)},)
    return minimize(lambda robot: -optfun(human, robot, w, eta),
                    [0.001, 0.001, 0.001], constraints=cons,
                    bounds=[fbound, gbound, hbound])


def bestrobot(human, eta=2, fbound=(0, 0.2), gbound=(0, 0.2), hbound=(0, 0.2)):
    cons = ({'type': 'ineq', 'fun': lambda x: isstable(x)},)
    mink = minimize(lambda robot: -maxk(human, robot, eta).fun[0],
                    [0.001, 0.001, 0.001], constraints=cons,
                    bounds=[fbound, gbound, hbound])
    # robot = mink.x
    # print maxksafe(human, robot, eta=eta)
    # print maxkstable(human, robot, eta=eta)
    return mink


def plot_kmax_vs_w(human, eta=2):
    """
    I think this plot is not meaningful (after the fact)
    :param human:
    :param eta:
    :return:
    """
    # ws = linspace(0,1,10)

    maxw = sqrt(-isstable(human))

    ws = linspace(TOL, maxw - TOL, 101)
    kstablerobots = [-best_robot_at_w(human, w, kstable).fun for w in ws]
    ksaferobots = [-best_robot_at_w(human, w, ksafe, eta).fun for w in ws]
    # kst = kstable(human, robot, ws)
    # ksf = ksafe(human, robot, ws, eta)
    # print "kstable = ", kstable(human, robot, ws)
    print ksaferobots, kstablerobots
    plt.plot(ws, kstablerobots, 'k.')
    plt.plot(ws, ksaferobots, 'r.')
    plt.axis([0, maxw, 0, 200])
    plt.show()


def plot_kmaxs_vs_FGH(human, robotopt, eta=2, res=100, fbound=(0, 0.2),
                      gbound=(0, 0.2), hbound=(0, 0.3)):
    """
    How do the two optimality conditions (stability and safety) vary as the
    robot parameters are deviated from the optimal robot controller?
    :param human:
    :param robotopt: (F,G,H) for the optimal robot controller given human
    :param eta:
    :param res: resolution for plotting, how fine to plot along the x-axis
    :param fbound:
    :param gbound:
    :param hbound:
    :return:
    """
    robotFs = linspace(fbound[0], fbound[1], res)
    robotGs = linspace(gbound[0], gbound[1], res)
    robotHs = linspace(hbound[0], hbound[1], res)

    for i, robotXs in enumerate([robotFs, robotGs, robotHs]):
        robots = np.tile(robotopt, (res, 1))
        robots[:, i] = robotXs
        kstablerobots = np.array(
            [maxkfn(human, robot, eta, kstable).fun[0] for robot in robots])
        ksaferobots = [maxkfn(human, robot, eta, ksafe).fun[0] for robot in
                       robots]
        kstablerobots[kstablerobots < 0] = 0
        plt.plot(robotXs, kstablerobots, 'k.')
        plt.plot(robotXs, ksaferobots, 'r.')
        plt.vlines(robotopt[i], 0, max(max(kstablerobots), max(ksaferobots)))
        plt.axis([0, max(robotXs), 0, 100])
        plt.show()


if __name__ == "__main__":
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib import pyplot as plt
    from cfm import IDM, CFM

    hmodel = IDM(a=0.3, b=3, t=1.5, s0=2, v0=30)
    hmodel.go(0.5)
    human = (hmodel.f, hmodel.g, hmodel.h)

    maxw = sqrt(-isstable(human))

    eta = 2

    robotopt = bestrobot(human, eta)
    plot_kmaxs_vs_FGH(human, robotopt.x)