from matplotlib import pyplot as plt
from scipy import arange, meshgrid, zeros, cos, sin, log10, linspace, logspace, around, exp, absolute
from scipy.optimize import fmin

from idmlin import IDMLin


def maxStringResponse(idm, n, tau = 0):
    gammainv = humanstring(idm, n, tau)
    wmax = fmin(gammainv, 10)
    # print "wmax =",  wmax
    return -gammainv(wmax)

def human(idm, tau=0):
    f, g, h = idm.f, idm.g, idm.h
    k = g+h
    def gamma(w):
        s = w * 1j
        num = f + g * s
        den = f + k * s + exp(tau * s) * s * s
        return num/den
    return gamma

def humanstring(idm, n, tau = 0):
    gamma = human(idm, tau)
    def dbgammainv(w):
        db = 20 * log10(absolute(gamma(w)))
        return -n*db
    return dbgammainv

def followedrobot(idm):#, ff, fb, gf, gb, hr, taur=0, tau=0):
    tau = 0
    taur = 0
    ff = idm.f/2
    fb = idm.f/2
    gf = idm.g/2
    gb = idm.g/2
    hr = idm.h

    gamma = human(idm, tau)
    def dbgammainv(w):
        s = w * 1j
        num = ff + gf * s
        den = ff + fb + (gf + gb + hr) * s + exp(taur * s) * s * s - (fb + gb * s) * gamma(w)
        dbnum = 20 * log10(absolute(num))
        dbden = 20 * log10(absolute(den))
        return (dbden-dbnum)
    return dbgammainv

def teststability(idm, N):
    mr = N * idm.maxresponse()
    return max(mr, 0)

a = 0.45
t = 1.5
b = 2.0
s = 2.0
v = 33.3

idm = IDMLin(t=t, a=a, b=b, s0=s)

nrange = logspace(1, 2.5, 20)
rerange = logspace(-2, -0.2, 20)

mgrid = zeros((len(nrange), len(rerange)))
mgrid2 = zeros((len(nrange), len(rerange)))

for i, n in enumerate(nrange):
    for j, re in enumerate(rerange):
        idm.go(v, re)
        mgrid[i,j] = teststability(idm, n)
        mgrid2[i,j] = maxStringResponse(idm, n)

fig, axarr = plt.subplots(1,3, sharex=True, sharey=True)

im = axarr[2].imshow(mgrid, cmap=plt.cm.Reds, 
                interpolation='none', 
                aspect='auto',
		vmin=0, vmax=1,
                )
im = axarr[1].imshow(mgrid2, cmap=plt.cm.Reds, 
                interpolation='none', 
                aspect='auto',
		vmin=0, vmax=1,
                )
im = axarr[0].imshow((mgrid2-mgrid)*100, cmap=plt.cm.Reds, 
                interpolation='none', 
                aspect='auto',
		vmin=0, vmax=1,
                )

axarr[0].set_yticks(range(len(nrange)))
axarr[1].set_yticks(range(len(nrange)))
axarr[1].set_xticks(range(len(rerange)))

axarr[0].set_yticklabels(nrange)
axarr[1].set_yticklabels(nrange)
axarr[1].set_xticklabels(around(rerange, 2))

axarr[0].set_ylabel("N")
axarr[1].set_ylabel("N")
axarr[1].set_xlabel("R_e")

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)

plt.show()