from matplotlib import pyplot as plt
from scipy import arange, meshgrid, zeros, cos, sin, log10, linspace, logspace, around, exp, absolute
from scipy.optimize import fmin

from idmlin import IDMLin


def maxStringResponse(fn, idm, n, **kwargs):
    gammainv = fn(idm, n, **kwargs)
    wmax = fmin(gammainv, 10)
    gmax = -gammainv(wmax)
    # print "wmax =",  wmax
    return gmax < 1e-2 and -1 or gmax

def human(idm, tau=0):
    f, g, h = idm.f, idm.g, idm.h
    k = g+h
    def gamma(w):
        s = w * 1j
        num = f + g * s
        den = f + k * s + exp(tau * s) * s * s
        return num/den
    return gamma

def humanstring(idm, n, tau=0):
    gamma = human(idm, tau)
    def dbgammainv(w):
        db = 20 * log10(absolute(gamma(w)))
        return -n*db
    return dbgammainv

def followedrobotstring(idm, n, tau=0, alpha=0.5):#, ff, fb, gf, gb, hr, taur=0, tau=0):
    taur = 0
    fb = idm.f*alpha
    ff = idm.f*(1-alpha)
    gb = idm.g*alpha
    gf = idm.g*(1-alpha)
    hr = idm.h

    gamma = human(idm, tau)
    stringinv = humanstring(idm, n, tau)
    def dbgammainv(w):
        s = w * 1j
        num = ff + gf * s
        den = ff + fb + (gf + gb + hr) * s + exp(taur * s) * s * s - (fb + gb * s) * gamma(w)
        dbnum = 20 * log10(absolute(num))
        dbden = 20 * log10(absolute(den))
        return (dbden-dbnum)+stringinv(w)
    return dbgammainv

def teststability(idm, N):
    mr = N * idm.maxresponse()
    return max(mr, 0)

a = 0.45
t = 1.5
b = 2.0
s = 2.0
v = 33.3

idm = IDMLin(t=t, a=a, b=b, s0=s)

nrange = logspace(1, 2.5, 20)
rerange = logspace(-2, -0.2, 20)
rerange = linspace(0.1, 0.4, 20)

mgrid_h = zeros((len(nrange), len(rerange)))
mgrid_r25 = zeros((len(nrange), len(rerange)))
mgrid_r50 = zeros((len(nrange), len(rerange)))
mgrid_r75 = zeros((len(nrange), len(rerange)))

for i, n in enumerate(nrange):
    for j, re in enumerate(rerange):
        idm.go(v, re)
        mgrid_h[i,j] = maxStringResponse(humanstring, idm, n)
        mgrid_r25[i,j] = maxStringResponse(followedrobotstring, idm, n, alpha=0.25)
        mgrid_r50[i,j] = maxStringResponse(followedrobotstring, idm, n, alpha=0.5)
        mgrid_r75[i,j] = maxStringResponse(followedrobotstring, idm, n, alpha=0.75)

fig, axarr = plt.subplots(2,3, sharex=True, sharey=True)

im = axarr[0,0].imshow(mgrid_h, cmap=plt.cm.bwr,
                interpolation='none', 
                aspect='auto',
		vmin=-1, vmax=1,
                )
im = axarr[0,1].imshow(mgrid_r50, cmap=plt.cm.bwr, 
                interpolation='none', 
                aspect='auto',
		vmin=-1, vmax=1,
                )
im = axarr[0,2].imshow(mgrid_r50-mgrid_h, cmap=plt.cm.bwr, 
                interpolation='none', 
                aspect='auto',
		vmin=-1, vmax=1,
                )
im = axarr[1,0].imshow(mgrid_r25, cmap=plt.cm.bwr, 
                interpolation='none', 
                aspect='auto',
		vmin=-1, vmax=1,
                )
im = axarr[1,2].imshow(mgrid_r75, cmap=plt.cm.bwr, 
                interpolation='none', 
                aspect='auto',
		vmin=-1, vmax=1,
                )

axarr[0,0].set_title("Human cars only")
axarr[0,1].set_title("Human cars followed by BCM50 robot")
axarr[0,2].set_title("Difference")
axarr[1,0].set_title("Human cars followed by BCM25 robot")
axarr[1,2].set_title("Human cars followed by BCM75 robot")

axarr[0,0].set_yticks(range(0,len(nrange),4))
axarr[0,0].set_xticks(range(0,len(rerange),4))

axarr[0,0].set_yticklabels(around(nrange[::4], 0))
axarr[0,0].set_xticklabels(around(rerange, 2)[::4])

axarr[0,0].set_ylabel("N")
axarr[0,0].set_xlabel("R_e")

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)

print nrange
print rerange

plt.show()