#!/usr/bin/env python

import numpy as np
import ezdxf 
import random
from math import sqrt
import cv2
import os
from pathfinding.core.diagonal_movement import DiagonalMovement
from pathfinding.core.grid import Grid
from pathfinding.finder.a_star import AStarFinder
import copy
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import cm


class auto_rounter:
    def __init__(self,path,dxf_file):
        self.dwg=ezdxf.readfile(path+dxf_file)
        self.msp=self.dwg.modelspace()
        self.dxf_name='silhouette_ele.dxf'
        self.fail_flag=False #return True when path not found
        
        
        self.create_layer('Cut',5)
        self.create_layer('Circuit',1)
        self.create_layer('Label',3)
        self.create_layer('Fold',4)
        self.create_layer('Pin_temp',6)

        self.layer_rearrange() ## reagrrange cut and fold lines to corresponding layers
        self.remove_wheels() ##remove wheel drawings for this design, no need to call for other designs
        self.find_pin()

        isolength=2.6 #mm
        # self.draw_on_pin(isolength)

        self.matrix_shape=(270,210)
        self.read_dxf_as_matrix()
        


        # self.connections_list=np.array([[0,0],[0,0]]).reshape(1,2,2)
        # connection_amount=4
        # for i in range(connection_amount):
        #     s=random.randint(0,len(self.center_arr)-1)
        #     e=random.randint(0,len(self.center_arr)-1)
        #     connections=np.array([[self.center_arr[s],self.center_arr[e]]]).reshape(1,2,2)
        #     self.connections_list=np.append(self.connections_list,connections,axis=0)
        
        # self.connections_list=np.rint(self.connections_list)[1:]

        ############for testing and demo purpose #############
        self.connections_list=self.test_connections()
        
        

                
        # self.matrix_temp=copy.deepcopy(self.matrix)
        # self.img=plt.imshow(self.matrix_temp,interpolation='nearest',cmap=cm.Spectral)
        # ani=animation.FuncAnimation(self.fig,self.find_multi_path,interval=10)
        # plt.show()

        self.find_multi_path(1)
        self.dwg.saveas(self.dxf_name)



    def create_layer(self,layer_name,color):
        if not layer_name in self.dwg.layers:
            self.dwg.layers.new(name=layer_name,dxfattribs={'color':color})

    def layer_rearrange(self):
        # put fold lines to the new layer 'Fold'
        # put cut lines to the new layer 'Cut
        for e in self.msp.query('LINE'):
            if e.dxf.color!=5:
                e.dxf.layer='Fold'
            else:
                e.dxf.layer='Cut'

    def find_pin(self):
        tolerance=0.05 #mm
        pincutsize=1 #mm
        pin_edge_arr=np.array([[0,0],[0,0]]).reshape(1,2,2)
        self.center_arr=np.array([[0,0]])

        for e in self.msp.query('LINE[layer=="Cut"]'):
            length= sqrt((e.dxf.start[0]-e.dxf.end[0])**2+(e.dxf.start[1]-e.dxf.end[1])**2)
            if length > pincutsize-tolerance and length < pincutsize + tolerance:
                e.dxf.layer='Pin_temp'
                if e.dxf.start[1]==e.dxf.end[1]: ##this line is horizontal
                    pin_edge=np.array([e.dxf.start,e.dxf.end])[:,:2]
                    pin_edge_arr=np.concatenate((pin_edge_arr,pin_edge.reshape(1,2,2)),axis=0)
        pin_edge_arr=pin_edge_arr[1:]
        # print(pin_edge_arr)
        
        for i in range(len(pin_edge_arr)):
            for e in np.delete(pin_edge_arr,i,axis=0):
                if pin_edge_arr[i][0,1]-e[0,1]==1.0:
                    center_x=pin_edge_arr[i][0,0]-0.5
                    center_y=pin_edge_arr[i][0,1]-0.5
                    center=np.array([center_x,center_y]).reshape(1,2)
                    self.center_arr=np.append(self.center_arr,center,axis=0)

        self.center_arr=np.unique(self.center_arr,axis=0)
        self.center_arr=self.center_arr[1:]
        # print(self.center_arr)
    def remove_wheels(self):
        #no need to call this function for other design
        for e in self.msp.query('Arc LINE[layer=="Cut"]'):
            if e.dxf.start[0]>=179:
                self.msp.delete_entity(e)
    

    def draw_on_pin(self,fulllength): #isolation
        iso=self.dwg.blocks.new(name='ISO_BLK')
        isolength=fulllength/2
        trace_w=0.8
        iso.add_line((-isolength,isolength),(isolength,isolength),dxfattribs={'linetype':'DASHDOT'})
        iso.add_line((-isolength,-isolength),(isolength,-isolength),dxfattribs={'linetype':'DASHDOT'})
        iso.add_line((-isolength,-isolength),(-isolength,isolength),dxfattribs={'linetype':'DASHDOT'})
        iso.add_line((isolength,-isolength),(isolength,isolength),dxfattribs={'linetype':'DASHDOT'})

        iso.add_line((-isolength,-trace_w/2),(-(isolength+5),-trace_w/2),dxfattribs={'linetype':'DASHDOT'})
        iso.add_line((-isolength,trace_w/2),(-(isolength+5),trace_w/2),dxfattribs={'linetype':'DASHDOT'})            
        for center_point in self.center_arr:    
            self.msp.add_blockref('ISO_BLK',center_point,dxfattribs={'layer':'Circuit'})

    
    def read_dxf_as_matrix(self): 
        """unit: mm 
            accuracy issue expected atm (mm round up)
        """
        self.matrix=np.ones(self.matrix_shape)

        for line in self.msp.query('LINE[layer!="Fold" & layer!="Label" & layer!="Pin_temp"]'):
            start=np.rint(line.dxf.start)
            end=np.rint(line.dxf.end)
            i=int(start[0])
            j=int(start[1])
            #can draw horizontal or vertical lines only
            if j==end[1]: #this line is a horizontal line
                self.matrix[j,i]=0
                while i!=end[0]:
                    if i<end[0]:
                        i+=1
                    else:
                        i-=1
                    self.matrix[j,i]=0
            elif i==end[0]: #this line is a vertical line
                self.matrix[j,i]=0
                while j!=end[1]:
                    if j<end[1]:
                        j+=1
                    else:
                        j-=1
                    self.matrix[j,i]=0
        
    def find_a_path(self,matrixws,start_point,end_point):
        """
        find a path between two points on img=matrix
        start_point and end_point shape: (2,)
        """
        #DEBUG:
        # if not np.array_equal(matrixws,self.matrix):
        #     print('Map updated')

        grid = Grid(matrix=matrixws)
        start=grid.node(int(start_point[0]),int(start_point[1]))
        end=grid.node(int(end_point[0]),int(end_point[1]))
        finder = AStarFinder(diagonal_movement=4)
        path, runs = finder.find_path(start, end, grid)

        return path
    def draw_a_path(self,path,matrix,dxf=False):
        """ draw a path on img=matrix, or on dxf file too"""
        #draw on matrix:
        for point in path:
            matrix[point[1],point[0]]=0
        #draw on dxf:
        if dxf:
            
            for i in range(len(path)-1):
                self.msp.add_line(path[i],path[i+1],dxfattribs={
                    'layer':'Circuit',
                    'linetype':'DASHDOT'})
    
    def get_cost(self,path):
        if len(path)==0: #if no path found
            print('one path not found')
            self.fail_flag=True
            cost=1000
        else: cost=len(path)

        return cost        


    def find_multi_path(self,i):
        """connection list shape : NX2X2 """
        E=20
        print("==========Auto rounting start===========")
        
        for episode in range(E):
            print 'episode:',episode+1,'(/',E,')==========='
            self.matrix_temp=copy.deepcopy(self.matrix)
            self.Q=0
            self.fail_flag=False

            random_ix=random.randint(0,len(self.connections_list)-1)
            init_s=self.connections_list[random_ix]
            current_solving=np.array([init_s]).reshape(1,2,2)
            con_list_temp=np.delete(self.connections_list,random_ix,axis=0)
            
            cur_path=self.find_a_path(self.matrix_temp,init_s[0],init_s[1])
            self.draw_a_path(cur_path,self.matrix_temp)
            # self.img.set_array(self.matrix_temp)
            
            cost=self.get_cost(cur_path)
            self.Q=self.Q+cost
            
            curpath_temp=[cur_path]
            while len(con_list_temp)!=0:
                random_ix=random.randint(0,len(con_list_temp)-1)
                next_conn=con_list_temp[random_ix]
                current_solving=np.append(current_solving,next_conn.reshape(1,2,2),axis=0)
                con_list_temp=np.delete(con_list_temp,random_ix,axis=0)
                
                curpath=self.find_a_path(self.matrix_temp,next_conn[0],next_conn[1])
                self.draw_a_path(curpath,self.matrix_temp)
                curpath_temp.append(curpath)
                # self.img.set_array(self.matrix_temp)
               
                cost=self.get_cost(curpath)
                self.Q=self.Q+cost

            if episode==0:
                self.Q_buff=copy.deepcopy(self.Q)
            if self.Q<=self.Q_buff:
                self.Q_buff=copy.deepcopy(self.Q)
                self.final_solving=current_solving
                self.final_path=curpath_temp
                self.final_fail=self.fail_flag
                
            print 'Current cost:',self.Q,'Best cost',self.Q_buff
            episode+=1 
        if not self.final_fail:
            for i in range(len(self.final_path)):
                self.draw_a_path(self.final_path[i],self.matrix,True)
        else:
            print 'One or more path cannot be solved'
        
    def test_connections(self):
        y=97
        x=90
        arti_pin_array=np.empty((1,2))
        center_pin_conn=np.empty((1,2))
        pin_conn=np.empty((1,2,2))

        for i in range(10):
            arti_pin=(x+i,y)
            arti_pin_array=np.append(arti_pin_array,[arti_pin],axis=0)
        arti_pin_array=arti_pin_array[1:]
        
        for i in range(len(self.center_arr)):
            if self.center_arr[i][1]<60:
                center_pin_conn=np.append(center_pin_conn,self.center_arr[i].reshape(1,2),axis=0)
            if len(center_pin_conn)>10: break
        center_pin_conn=center_pin_conn[1:]
        
        for i in range(10):
            pin_conn_temp=np.array([arti_pin_array[i],center_pin_conn[i]]).reshape(1,2,2)
            pin_conn=np.append(pin_conn,pin_conn_temp,axis=0)
        pin_conn=pin_conn[1:]

        return pin_conn

# def main():
path='/home/jingyan/Documents/summer_intern_lemur/roco_electrical/'
dxf_file='graph-silhouette.dxf'
router=auto_rounter(path,dxf_file)
# animation.FuncAnimation(edit.fig,plt.imshow(edit.matrix_temp))
# plt.show()


# if __name__ == '__main__':
#     main()