from matplotlib import pyplot as plt
from scipy import signal
from scipy import sqrt, linspace, exp, pi, polymul, sinh, cosh, arange, ones, size, array, diff, insert

from idmlin import IDMLin


def linfty(re, n=25, tmax=100, dt=0.1, timeplot=False):
    def go(t, a, b, s, v, fmt=None):
        idm = IDMLin(t=t, a=a, b=b, s0=s)
        idm.go(v, re)

        sys = idm.sys()
        T = linspace(0, tmax, num=tmax/dt)
        U = ones(size(T))
        A = [0]*n
        Us = ones((len(T), n))

        for i in range(n):
            _, U, _ = signal.lsim(sys, U, T)
            A[i] = max(U)
            Us[:,i] = U

        if timeplot:
            if fmt is None:
                return T, array(Us)
            else:
                return T, array(Us), fmt
        else:
            if fmt is None:
                return arange(n)+1, A
            else:
                return arange(n)+1, A, fmt
    return go


def idmgo(t, a, b, s, v, fmt=None):
    idm = IDMLin(t=t, a=a, b=b, s0=s)
    res = arange(100)/100.
    mr = ones(size(res))
    for i, re in enumerate(res):
        idm.go(v, re)
        mr[i] = idm.maxresponse()

    if fmt is None:
        return res, mr
    else:
        return res, mr, fmt

def sweep(fn, rngs, axis=None, separate=True, logx=False):
    if logx:
        plot = plt.semilogx
    else:
        plot = plt.plot

    defaults = [x["default"] for x in rngs]

    if not separate:
        plot(*fn(*defaults, fmt='k'))

    for i, r in enumerate(rngs):
        if "values" in r and r["values"] is not None:
            if separate:
                plot(*fn(*defaults, fmt='k'))

            print r["name"], defaults[i], r["values"]
            args = list(defaults)
            for dot, arg in zip(r["dots"], r["values"]):
                args[i] = arg
                plot(*fn(*args, fmt=r["color"]+dot))

            if axis is not None:
                plt.axis(axis)
            if separate:
                plt.show()
    if not separate:
        plt.show()

if __name__ == "__main__":

    ### What happens when you vary a?
    ### ( enhance )

    avals = linspace(0.6, 1.5, 10)
    bvals = linspace(0.6, 1.5, 10)
    dots = ["-"] * len(avals)

    arng = {"name": "a",   "color": "", "dots": dots, "default": 0.5, "values": avals }
    trng = {"name": "T",   "default": 1.2}
    brng = {"name": "b",   "color": "", "dots": dots, "default": 0.5, "values": bvals }
    srng = {"name": "s_0", "default": 1.5}
    vrng = {"name": "v_0", "default": 16.1}

    rngs = [trng, arng, brng, srng, vrng]

    sweep(idmgo, rngs)#, axis=[0,0.1,-.05,.45])
    # go = linfty(0.5, n=50, tmax=500)
    # sweep(go, rngs)

    '''
    ### What happens when you vary IDM parameters?
    dots = (':', '--', '-', '.-')

    trng = {"name": "T",   "color": "r", "dots": dots, "default": 1.5, "values": (0.5, 1, 2, 2.5) }
    arng = {"name": "a",   "color": "g", "dots": dots, "default": 1.0, "values": (0.3, 0.5, 0.8, 1.5) }
    brng = {"name": "b",   "color": "b", "dots": dots, "default": 3.0, "values": (1, 2, 4, 5) }
    srng = {"name": "s_0", "color": "m", "dots": dots, "default": 2.0, "values": (0.5, 1, 2.5, 3) }
    vrng = {"name": "v_0", "color": "c", "dots": dots, "default": 30., "values": (10, 20, 40, 50) }

    rngs = [trng, arng, brng, srng, vrng]

    sweep(idmgo, rngs, axis=[0,1,-.2,.2])
    '''

    '''
    ### What happens when you vary IDM parameters?
    dots = (':', '--', '-', '.-')

    trng = {"name": "T",   "color": "r", "dots": dots, "default": 1.5, "values": (0.8, 0.9, 1, 1.1) }
    arng = {"name": "a",   "color": "g", "dots": dots, "default": 1.0, "values": (0.3, 0.4, 0.5, 0.6) }
    brng = {"name": "b",   "color": "b", "dots": dots, "default": 3.0, "values": (1, 5, 10, 20) }
    srng = {"name": "s_0", "color": "m", "dots": dots, "default": 2.0, "values": (3, 5, 10, 20) }
    vrng = {"name": "v_0", "color": "c", "dots": dots, "default": 30., "values": (10, 20, 40, 50) }

    rngs = [trng, arng, brng, srng, vrng]

    go = linfty(0.5, n=50, tmax=500)
    sweep(go, rngs)
    '''

    '''
    go = linfty(0.05, n=80, tmax=500, timeplot=True)
    T, Us = go(1.5, 0.3, 3, 2, 30)

    f, axarr = plt.subplots(2, sharex=True)
    axarr[0].plot(T, Us)
    axarr[1].plot(T, diff(insert(Us, 0, 1, axis=1)))
    plt.show()
    '''

    '''
    go = linfty(0.5, n=500, tmax=2000, dt=1, timeplot=False)
    plt.plot(*go(1.5, 0.57, 3, 2, 30))
    plt.show()
    '''