import numpy as np
import json

from inputs import Input
from MMAE.mmae import MMAE
from System.system_simulator import SystemSimulator

class ParameterEstimationPipeline:
    def __init__(self, λ, λs, k, b, dt, H, Q, R, x0, true_system_noisy, estimator_noisy, max_time, max_steps, amplitude):
        # Synthetic system simulator initialization
        self.TrueSystem = SystemSimulator(λ, k, b, dt, H, Q, R, x0, true_system_noisy)

        # Input initialization
        self.input_signal = Input(self.TrueSystem.model, max_time).step_function(max_steps, amplitude)

        # MMAE initialization
        self.MMAE = MMAE(λs, k, b, dt, H, Q, R, x0, estimator_noisy)


    def update(self, t: int) -> float:
        u = self.input_signal[t, :].reshape(-1, 1)
        _, z = self.TrueSystem.update(u)
        λ_hat = self.MMAE.update(u, z)

        return λ_hat
    
# Load configuration from JSON file
def load_config(config_path):
    with open(config_path, 'r') as f:
        return json.load(f)

if __name__ == "__main__":
    config_path = "config.json"
    config = load_config(config_path)

    start = config["model_variants_start"]
    end = config["model_variants_end"]
    step = config["model_variants_step"]

    # Generate model variants
    λs = np.arange(start, end + step, step).tolist()

    λ = config['true_mass']
    k = config['k']
    b = config['b']
    dt = config["dt"]
    H = np.array(config["H"])
    Q = np.eye(H.shape[1]) * config["Q"]
    R = np.eye(H.shape[0]) * config["R"]
    x0 = np.array(config["initial_state"])
    max_time = config['max_time']
    max_steps = int(config['max_time'] / dt)
    amplitude = config['amplitude']

    ParameterEstimationPipeline = ParameterEstimationPipeline(λ, λs, k, b, dt, H, Q, R, x0, True, False, max_time, max_steps, amplitude)

    for step_counter in range(1, max_steps):
        λ_hat = ParameterEstimationPipeline.update(step_counter)
        print(f"Step {step_counter}: λ_hat = {λ_hat}")