from matplotlib import pyplot as plt
from scipy import arange, meshgrid, zeros, cos, sin, log10, linspace, logspace, around
from scipy.optimize import fmin

from idmlin import IDMLin


def maxStringResponse(idm, n, tau = 0):
    gammainv = stringResponse(idm, n, tau)
    wmax = fmin(gammainv, 10)
    # print "wmax =",  wmax
    return -gammainv(wmax)

def stringResponse(idm, n, tau = 0):
    f, g, h = idm.f, idm.g, idm.h
    # print "f, g, h = ", f, g, h
    k = g+h

    f2 = f**2
    g2 = g**2
    k2 = k**2
    def gammainv(w):
        w2 = w**2
        w4 = w**4
        c  = cos(tau * w)
        s  = sin(tau * w)
        num = f2 + g2 * w2
        den = f2 + k2 * w2 + w4 - w2 * (2*f*c + 2*k*w*s)
        dbn = 20 * log10(num)
        dbd = 20 * log10(den)
        return (dbd-dbn)*(n/2)
    return gammainv

def teststability(idm, N):
    mr = N * idm.maxresponse()
    return max(mr, 0)


a = 0.45
t = 1.5
b = 2.0
s = 2.0
v = 33.3

idm = IDMLin(t=t, a=a, b=b, s0=s)

nrange = [1,2,5,10,20,50,100]
rerange = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5]

nrange = logspace(1, 2.5, 20)
rerange = logspace(-2, -0.2, 20)

mgrid = zeros((len(nrange), len(rerange)))
mgrid2 = zeros((len(nrange), len(rerange)))

for i, n in enumerate(nrange):
    for j, re in enumerate(rerange):
        # print n, re
        idm.go(v, re)
        mgrid[i,j] = teststability(idm, n)

for i, n in enumerate(nrange):
    for j, re in enumerate(rerange):
        # print n, re
        idm.go(v, re)
        mgrid2[i,j] = maxStringResponse(idm, n)

# print mgrid
# print mgrid2

fig, axarr = plt.subplots(1,3, sharex=True, sharey=True)

im = axarr[2].imshow(mgrid, cmap=plt.cm.Reds, 
                interpolation='none', 
                aspect='auto',
		vmin=0, vmax=1,
                )
im = axarr[1].imshow(mgrid2, cmap=plt.cm.Reds, 
                interpolation='none', 
                aspect='auto',
		vmin=0, vmax=1,
                )
im = axarr[0].imshow((mgrid2-mgrid)*100, cmap=plt.cm.Reds, 
                interpolation='none', 
                aspect='auto',
		vmin=0, vmax=1,
                )
axarr[0].set_yticks(range(len(nrange)))
axarr[1].set_yticks(range(len(nrange)))
axarr[1].set_xticks(range(len(rerange)))

axarr[0].set_yticklabels(nrange)
axarr[1].set_yticklabels(nrange)
axarr[1].set_xticklabels(around(rerange, 2))

axarr[0].set_ylabel("N")
axarr[1].set_ylabel("N")
axarr[1].set_xlabel("R_e")

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)

plt.show()