__author__ = 'Joseph'

from svggen.api.CodeComponent import CodeComponent
from svggen.api.ports.DataInputPort import DataInputPort
from svggen.api.ports.DataOutputPort import DataOutputPort
from collections import OrderedDict
from svggen import SVGGEN_DIR
import os

class StateMachine(CodeComponent):
  portFilePostfix = '.smv'
  behaviorFilePostfix = '.aut'
  stateMachineDir = os.path.join(SVGGEN_DIR, 'library', 'stateMachines')

  def define(self):
    CodeComponent.define(self)
    self.addParameter('stateMachineName', '')
    # Note: input and output ports will be added when the state machine name is set

  def setParameter(self, n, v):
    CodeComponent.setParameter(self, n, v)
    if n == 'stateMachineName' and len(self.interfaces) == 0:
      inputs = StateMachine.getInputs(v)
      outputs = StateMachine.getOutputs(v)
      for i in range(len(inputs)):
        name = inputs.keys()[i]
        self.addInterface(name, DataInputPort(parent=self, name=name))
        self.setParameter(name + '.autoPoll', 'true')
      for i in range(len(outputs)):
        name = outputs.keys()[i]
        self.addInterface(name, DataOutputPort(parent=self, name=name))

  def assemble(self):
    CodeComponent.assemble(self)
    #print 'see parameters ', self.parameters
    #print 'see connections ', self.connections
    if len(self.getParameter('stateMachineName')) > 0:
      self.addCode(StateMachine.getBehaviorCode(self.getParameter('stateMachineName')))
    self.setAutoPolls()

  @staticmethod
  def createConnections(component, stateMachineComponentName, stateMachineName):
    """
    Adds input / output connections to the state machine according to the port definition file
    :param component: The component to which the state machine is being added
    :param stateMachineComponentName: The name of the StateMachine component that will need the connections
    :param stateMachineName: The name of the state machine
    :return: The given component with the new connections added
    """
    inputs = StateMachine.getInputs(stateMachineName)
    outputs = StateMachine.getOutputs(stateMachineName)
    # Add connections to state machine inputs
    for i in range(len(inputs)):
      inputName = inputs.keys()[i]
      subcomponent = inputName[0:inputName.find('.')]
      interface = inputName[inputName.find('.')+1:]
      component.addConnection((subcomponent, interface), (stateMachineComponentName, 'input' + str(i)))
    # Add connections to state machine outputs
    for i in range(len(outputs)):
      outputName = outputs.keys()[i]
      subcomponent = outputName[0:outputName.find('.')]
      interface = outputName[outputName.find('.')+1:]
      component.addConnection((subcomponent, interface), (stateMachineComponentName, 'output' + str(i)))
    return component


  @staticmethod
  def getInputs(stateMachineName):
    """
    :param stateMachineName: the name of the state machine.
      The port definition file should be this name followed by StateMachine.portFilePostfix
    :return: An ordered dictionary of inputs, as {name:type}
    """
    inputs = OrderedDict()
    # Open port definition file if it exists
    try:
      fin = open(os.path.join(StateMachine.stateMachineDir, stateMachineName + StateMachine.portFilePostfix), 'r')
    except:
      print 'COULD NOT OPEN STATE MACHINE FILE <' + stateMachineName + StateMachine.portFilePostfix + '>'
      return inputs
    # Get to the input definition portion of the file
    nextLine = fin.readline()
    while len(nextLine) > 0 and 'inputs' not in nextLine:
      nextLine = fin.readline()
    # Get to the first variable definition if there is at least one
    while len(nextLine) > 0 and ':' not in nextLine and 'outputs' not in nextLine:
      nextLine = fin.readline()
    # Read all of the inputs (if there are any)
    while len(nextLine) > 0 and ':' in nextLine:
      inputName = nextLine[0:nextLine.find(':')].strip()
      inputType = nextLine[nextLine.find(':')+1:nextLine.find(';')].strip()
      inputs[inputName] = inputType
      nextLine = fin.readline()
    return inputs

  @staticmethod
  def getOutputs(stateMachineName):
    """
    :param stateMachineName: the name of the state machine.
      The port definition file should be this name followed by StateMachine.portFilePostfix
    :return: An ordered dictionary of outputs, as {name:type}
    """
    outputs = OrderedDict()
    # Open port definition file if it exists
    try:
      fin = open(os.path.join(StateMachine.stateMachineDir, stateMachineName + StateMachine.portFilePostfix), 'r')
    except:
      print 'COULD NOT OPEN STATE MACHINE FILE <' + stateMachineName + StateMachine.portFilePostfix + '>'
      return outputs
    # Get to the output definition portion of the file
    nextLine = fin.readline()
    while len(nextLine) > 0 and 'outputs' not in nextLine:
      nextLine = fin.readline()
    # Get to the first variable definition if there is at least one
    while len(nextLine) > 0 and ':' not in nextLine:
      nextLine = fin.readline()
    # Read all of the outputs (if there are any)
    while len(nextLine) > 0 and ':' in nextLine:
      outputName = nextLine[0:nextLine.find(':')].strip()
      outputType = nextLine[nextLine.find(':')+1:nextLine.find(';')].strip()
      outputs[outputName] = outputType
      nextLine = fin.readline()
    return outputs

  @staticmethod
  def getStateInputs(stateMachineName):
    """
    :param stateMachineName: the name of the state machine.
      The port definition file should be this name followed by StateMachine.portFilePostfix
    :return: A list of inputs for each state (each element is a list of inputs for that state index)
    """
    inputs = StateMachine.getInputs(stateMachineName)
    stateInputs = []
    # Open behavior definition file if it exists
    try:
      fin = open(os.path.join(StateMachine.stateMachineDir, stateMachineName + StateMachine.behaviorFilePostfix), 'r')
    except:
      print 'COULD NOT OPEN STATE MACHINE FILE <' + stateMachineName + StateMachine.behaviorFilePostfix + '>'
      return stateInputs
    # Read through all states
    nextLine = fin.readline()
    while len(nextLine) > 0:
      # Get a line defining a state
      while len(nextLine) > 0 and '<' not in nextLine:
        nextLine = fin.readline()
      if len(nextLine) == 0:
        continue
      # Get each input value
      inputValues = []
      for input in inputs:
        inputIndex = nextLine.find(input)
        if nextLine.find(',', inputIndex) >= 0:
          inputValue = nextLine[nextLine.find(':', inputIndex)+1 : nextLine.find(',', inputIndex)].strip()
        else:
          inputValue = nextLine[nextLine.find(':', inputIndex)+1 : nextLine.find('>', inputIndex)].strip()
        inputValues.append(int(inputValue))
      stateInputs.append(inputValues)
      nextLine = fin.readline()
    return stateInputs

  @staticmethod
  def getStateOutputs(stateMachineName):
    """
    :param stateMachineName: the name of the state machine.
      The port definition file should be this name followed by StateMachine.portFilePostfix
    :return: A list of outputs for each state (each element is a list of inputs for that state index)
    """
    outputs = StateMachine.getOutputs(stateMachineName)
    stateOutputs = []
    # Open behavior definition file if it exists
    try:
      fin = open(os.path.join(StateMachine.stateMachineDir, stateMachineName + StateMachine.behaviorFilePostfix), 'r')
    except:
      print 'COULD NOT OPEN STATE MACHINE FILE <' + stateMachineName + StateMachine.behaviorFilePostfix + '>'
      return stateOutputs
    # Read through all states
    nextLine = fin.readline()
    while len(nextLine) > 0:
      # Get a line defining a state
      while len(nextLine) > 0 and '<' not in nextLine:
        nextLine = fin.readline()
      if len(nextLine) == 0:
        continue
      # Get each output value
      outputValues = []
      for output in outputs:
        outputIndex = nextLine.find(output)
        if nextLine.find(',', outputIndex) >= 0:
          outputValue = nextLine[nextLine.find(':', outputIndex)+1 : nextLine.find(',', outputIndex)].strip()
        else:
          outputValue = nextLine[nextLine.find(':', outputIndex)+1 : nextLine.find('>', outputIndex)].strip()
        outputValues.append(int(outputValue))
      stateOutputs.append(outputValues)
      nextLine = fin.readline()
    return stateOutputs

  @staticmethod
  def getStateSuccessors(stateMachineName):
    """
    :param stateMachineName: the name of the state machine.
      The port definition file should be this name followed by StateMachine.portFilePostfix
    :return: A list of successors for each state (each element is a list of inputs for that state index)
    """
    stateSuccessors = []
    # Open behavior definition file if it exists
    try:
      fin = open(os.path.join(StateMachine.stateMachineDir, stateMachineName + StateMachine.behaviorFilePostfix), 'r')
    except:
      print 'COULD NOT OPEN STATE MACHINE FILE <' + stateMachineName + StateMachine.behaviorFilePostfix + '>'
      return stateSuccessors
    # Read through all states
    nextLine = fin.readline()
    while len(nextLine) > 0:
      # Get a line defining successors
      while len(nextLine) > 0 and 'successors' not in nextLine:
        nextLine = fin.readline()
      if len(nextLine) == 0:
        continue
      # Get each successor
      successors = []
      index = nextLine.find(':')
      while index >= 0:
        if nextLine.find(',', index+1) >= 0:
          successorVal = nextLine[index+1 : nextLine.find(',', index+1)].strip()
          index = nextLine.find(',', index+1)
        else:
          successorVal = nextLine[index+1 : ].strip()
          index = -1
        successors.append(int(successorVal))
      stateSuccessors.append(successors)
      nextLine = fin.readline()
    return stateSuccessors

  @staticmethod
  def getBehaviorCode(stateMachineName):
    inputs = StateMachine.getInputs(stateMachineName).keys()
    outputs = StateMachine.getOutputs(stateMachineName).keys()
    numInputs = len(inputs)
    numOutputs = len(outputs)

    code = """
    @@declare
    #include "avr/pgmspace.h"
    int curState = 0;
    const byte stateInputs[/* stateNum */][/* input */ @numInputs] PROGMEM = @stateInputs;
    const byte stateOutputs[/* stateNum */][/* output */ @numOutputs] PROGMEM = @stateOutputs;
    const unsigned int successors[/* curState */][/* nextState */ @numSuccessors] PROGMEM = @stateSuccessors;
    #define getStateInput(i, j) pgm_read_byte(&(stateInputs[i][j]))
    #define getStateOutput(i, j) pgm_read_byte(&(stateOutputs[i][j]))
    #define getSuccessor(i, j) pgm_read_word_near(&(successors[i][j]))

    int curInputs[@numInputs];
    //bool updatedInputs[@numInputs];

    @@insert<void robotSetup()>
    for(int i = 0; i < @numInputs; i++)
    {
      //updatedInputs[i] = false;
      curInputs[i] = 0;
    }
    curState = @initState;

    @@insert<void robotLoop()><@append>
    changeState();
    """

    code += '\n@@insert<void processData(const char* data, int sourceID, int destID)><@prepend>'
    # NOTE: This assumes that the input ports of the stateMachine are named the same as the input propositions in the stateMachineName file
    for i in range(len(inputs)):
      code += '\n'
      code += 'if(destID == @dataInputID<' + inputs[i] + '>)'
      code += '\n{'
      code += '\n int inputNum = ' + str(i) + ';'
      code += """
        curInputs[inputNum] = (int) atof(data);
        /*
        updatedInputs[inputNum] = true;
        bool allUpdated = true;
        for(int i = 0; i < @numInputs; i++)
          allUpdated = allUpdated && updatedInputs[i];
        if(allUpdated)
          changeState();
        */
      }
      """
    code += """
    \n@@method<void changeState()>
    void changeState()
    {
      robotPrintDebug("Changing state to ");
      bool foundState = false;
      for(int s = 0; s < @numSuccessors && getSuccessor(curState, s) >= 0 && !foundState; s++)
      {
        foundState = true;
        for(int in = 0; in < @numInputs && foundState; in++)
          foundState = foundState && (getStateInput(getSuccessor(curState, s), in) == curInputs[in]);
        if(foundState)
        {
          curState = getSuccessor(curState, s);
          robotPrintlnDebug(curState);
          setStateOutputs();
          //for(int i = 0; i < @numInputs; i++)
          //  updatedInputs[i] = false;
        }
      }
      if(!foundState)
      {
        robotPrintlnDebug();
      }
    }
    """

    code += """
    \n@@method<void setStateOutputs()>
    void setStateOutputs()
    {
    """
    code += '\n  int outputIDs[] = {'
    # NOTE: This assumes that the output ports of the stateMachine are named the same as the output propositions in the stateMachineName file
    for interfaceName in outputs:
      code += '@dataOutputID<' + interfaceName + '>'
      code += ', '
    code = code.strip()[0:-1] # get rid of trailing comma
    code += '};'
    code += '\n for(int i = 0; i < ' + str(numOutputs) + '; i++)'
    code += """
      {
        int outputIndex = 0;
        if(outputIDs[i] >= 0)
        {
          for(; outputIndex < NUM_DATA_OUTPUTS && dataOutputIDs[outputIndex] != outputIDs[i]; outputIndex++);
          processData(getData(outputIDs[i]), outputIDs[i], dataMapping[outputIndex], DATA_OUTDEGREE);
        }
      }
    """
    code += '\n}'

    code += '\n@@insert<char* getData(int sourceID, int destID)><@prepend>'
    # NOTE: This assumes that the output ports of the stateMachine are named the same as the output propositions in the stateMachineName file
    for i in range(len(outputs)):
      code += '\n'
      code += 'if(sourceID == @dataOutputID<' + outputs[i] + '>)'
      code += '{'
      code += ' int outputNum = ' + str(i) + ';'
      code += """
        itoa(getStateOutput(curState, outputNum), outputData, 10);
        validGetData = true;
        return outputData;
      }
      """

    # Replace state machine specific code tags used above
    stateInputs = StateMachine.getStateInputs(stateMachineName)
    stateOutputs = StateMachine.getStateOutputs(stateMachineName)
    stateSuccessors = StateMachine.getStateSuccessors(stateMachineName)
    maxNumSuccessors = 0
    for successors in stateSuccessors:
      maxNumSuccessors = max(maxNumSuccessors, len(successors))
    code = code.replace('@stateInputs', StateMachine._arrayToCppStr(stateInputs, -1))
    code = code.replace('@stateOutputs', StateMachine._arrayToCppStr(stateOutputs, -1))
    code = code.replace('@stateSuccessors', StateMachine._arrayToCppStr(stateSuccessors, -1))
    code = code.replace('@numInputs', str(numInputs))
    code = code.replace('@numOutputs', str(numOutputs))
    code = code.replace('@numSuccessors', str(maxNumSuccessors))
    code = code.replace('@initState', str(0))

    return code

  @staticmethod
  def _arrayToCppStr(array, fillerVal="\"\"", length=-1):
    if not isinstance(array, (list, tuple)):
      if isinstance(array, str):
        return array
      return str(array)
    res = '{'
    maxLength = 0
    for item in array:
      if isinstance(item, (list, tuple)) and len(item) > maxLength:
        maxLength = len(item)
    length = length if length >= 0 else len(array)
    for i in range(length):
      if i >= len(array):
        res += str(fillerVal)
      else:
        res += StateMachine._arrayToCppStr(array[i], fillerVal, maxLength)
      if i < length-1:
        res += ', '

    res += '}'
    return res
