from scipy import log, sqrt, linspace, logspace, meshgrid, zeros, absolute
from scipy.optimize import minimize

def isstable(model):
    return 2*model[1]*model[2] + model[2]*model[2] - 2*model[0]

def tf((f,g,h), w):
    """
    Return the complex value of the linearized transfer function T(jw)
    """
    return (g*1j*w + f)/(-w*w + (g+h)*1j*w + f)

def tflogmag(model, w):
    """
    Return the log of the magnitude of the linearized transfer function
        log[ |T(jw)| ]
    """
    return log(absolute(tf(model, w)))

def htflogmag(model, w):
    """
    Return the log of the magnitude of the linearized headway transfer function
        log[ |1-T(jw)| ]
    """
    return log(absolute(1-tf(model,w)))

def btf(human, robot, gamma, w):
    """
    Return the complex value of the linearized transfer function T(jw)
    of a bilateral robot controller [ followed by a human ]:
        Y(robot) = gamma Tr Y(forward) + (1-gamma) Tr Y(behind)
        Y(behind) = Th Y(robot)
        Y(robot) = gamma Tr Y(forward) + (1-gamma) Tr Th Y(robot)
        Y(robot) = [ gamma Tr ] / [ 1 - (1-gamma) Tr Th ] Y(forward)
    """
    tr = tf(robot, w)
    th = tf(human, w)
    t = gamma * tr / (1 - (1-gamma)*tr*th)
    return t

def btflogmag(human, robot, gamma, w):
    """
    Return the log of the magnitude of the linearized transfer function
        log[ |T(jw)| ]
    of a bilateral robot controller [ followed by a human ]
    """
    return log(absolute(btf(human, robot, gamma, w)))

def bhtflogmag(human, robot, gamma, w):
    """
    Return the log of the magnitude of the linearized headway transfer function
        log[ |1-T(jw)| ]
    of a bilateral robot controller [ followed by a human ]
    """
    return log(absolute(1-btf(human, robot, gamma, w)))

def kstable(human, robot, w, eta=2):
    """
    Return the maximum number of human cars 
        that can be string-stabilized by a single robot car
        for an oscillatory perturbation of frequency w
    """
    return -tflogmag(robot,w)/tflogmag(human,w)

def ksafe(human, robot, w, eta):
    """
    Return the maximum number of human cars 
        that can be "safely" followed by a single robot car
        for an oscillatory perturbation of frequency w
        by 
    """
    return (log(eta)-htflogmag(robot,w)) / tflogmag(human,w)

TOL = 1e-6

def maxkfn(human, robot, eta, kfn):
    f,g,h = human
    maxw = sqrt(-isstable(human))
    return minimize(lambda w : kfn(human, robot, w, eta), 1, bounds=[(TOL,maxw-TOL)])#.fun[0]

def maxk(human, robot, eta):
    return min(maxkfn(human, robot, eta, kstable), maxkfn(human, robot, eta, ksafe))

'''
def bestrobotstable(human):
    cons = ({'type': 'ineq', 'fun': lambda x: 2*x[1]*x[2] + x[2]*x[2] - 2*x[0] },)
    return minimize(lambda robot : -maxkstable(human, robot), 
            (0.1, 0, 0.1), 
            bounds=[(0, 0.2), (0, 0), (0, 0.2)],
            constraints=cons)

def bestrobotsafe(human, eta):
    cons = ({'type': 'ineq', 'fun': lambda x: 2*x[1]*x[2] + x[2]*x[2] - 2*x[0] },)
    return minimize(lambda robot : -maxksafe(human, robot, eta), 
            (0.1, 0, 0.1), 
            bounds=[(0, 0.2), (0, 0), (0, 0.2)],
            constraints=cons)
'''

def bestrobot(human, eta):
    #cons = ({'type': 'ineq', 'fun': lambda x: 2*x[1]*x[2] + x[2]*x[2] - 2*x[0] },)
    return minimize(lambda robot : -maxk(human, robot, eta).fun[0], 
            [0.001, 0, 0.001], 
            #constraints=cons,
            bounds=[(0, 0.2), (0, 0), (0, 0.2)])

if __name__ == "__main__":
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib import pyplot as plt
    from cfm import IDM, CFM

    hmodel = IDM(a=0.3,b=3,t=1.5,s0=2,v0=30)
    hmodel.go(0.5)
    human = (hmodel.f, hmodel.g, hmodel.h)

    maxw = sqrt(-isstable(human))

    eta=2

    print bestrobot(human, eta)

    '''
    ws = linspace(TOL, maxw-TOL, 101)
    kst = kstable(human, robot, ws)
    ksf = ksafe(human, robot, ws, eta)
    # print "kstable = ", kstable(human, robot, ws)
    plt.plot(ws, kst)
    plt.plot(ws, ksf)
    plt.axis([0, maxw, 0, 100])
    plt.show()

    frng = linspace(0, 0.1, 21)
    hrng = linspace(0, 0.1, 21)
    fs, hs = meshgrid(frng, hrng)
    ks = zeros(fs.shape)
    for fi, f in enumerate(frng):
        for hi, h in enumerate(hrng):
            ks[fi, hi] = maxk(human, (f, 0, h), eta)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(fs, hs, ks)
    plt.show()

    # print maxkstable(human, robot)
    # print maxksafe(human, robot, eta)
    print maxk(human, robot, eta)

    print 
    print "***"
    print 

    print bestrobotstable(human)
    print bestrobotsafe(human, eta)
    # print bestrobot(human, eta)
    '''