import cv2
import cv2.aruco as aruco
import numpy as np
import subprocess
import multiprocessing as mp
import threading
from datalink_serial import datalink
import queue
import socket
import rospy
from std_msgs.msg import Int16
from std_msgs.msg import Float32MultiArray

with np.load('/home/nv/Desktop/Aruco_dectction_NX/calibration_data.npz') as data:
    camera_matrix = data['camera_matrix']
    dist_coeffs = data['dist_coeffs']
f_x = camera_matrix[0, 0]  # 获取水平焦距
f_y = camera_matrix[1, 1]  # 获取垂直焦距
cx = camera_matrix[0, 2]   # 图像中心的 x 坐标
cy = camera_matrix[1, 2]   # 图像中心的 y 坐标
# ArUco字典和检测参数
aruco_dict = aruco.Dictionary_get(aruco.DICT_6X6_250)
parameters = aruco.DetectorParameters_create()

W_img = 960
H_img = 540

TCP_IP = '0.0.0.0'
TCP_PORT = 5005

Isdection = False

def command_handler(msg):
    global Isdection
    if msg.data == 1001:
        Isdection = True
    elif msg.data == 1002:
        Isdection = True
    elif msg.data == 1003:
        Isdection = False

def set_pose(x,y,z,yaw):
    message = Float32MultiArray() 
    message.data = [
        0,0,
        x,y,
        0,0,0,0,0,0,0
    ]
    publisher.publish(message)

def gstreamer_pipeline(
    sensor_id=0,
    capture_width=1920,
    capture_height=1080,
    display_width=W_img,
    display_height=H_img,
    framerate=30,
    flip_method=0,
):
    return (
        "nvarguscamerasrc sensor-id=%d ! "
        "video/x-raw(memory:NVMM), width=(int)%d, height=(int)%d, framerate=(fraction)%d/1 ! "
        "nvvidconv flip-method=%d ! "
        "video/x-raw, width=(int)%d, height=(int)%d, format=(string)BGRx ! "
        "videoconvert ! "
        "video/x-raw, format=(string)BGR ! appsink"
        % (
            sensor_id,
            capture_width,
            capture_height,
            framerate,
            flip_method,
            display_width,
            display_height,
        )
    )

class TranstreamThread(threading.Thread):
    def __init__(self, TCP_IP, TCP_PORT):
        super().__init__()
        self.TCP_IP = TCP_IP
        self.TCP_PORT = TCP_PORT
        self.sock = None
        self.conn = None
        self.addr = None
        self.running = True
        self.connected = False
        self.queue = queue.Queue()

    def run(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # 允许端口重用
        self.sock.bind((self.TCP_IP, self.TCP_PORT))
        self.sock.listen(1)
        print(f"服务器已启动，监听 {self.TCP_IP}:{self.TCP_PORT}")

        while self.running:
            print("等待客户端连接...")
            try:
                self.conn, self.addr = self.sock.accept()  # 每次 accept 一个新的客户端
                print(f"客户端已连接：{self.addr}")
                self.connected = True

                # 设置 TCP_NODELAY
                self.conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

                # 处理当前客户端通信
                while self.running and self.connected:
                    try:
                        frame = self.queue.get(timeout=1)  # 等待图像帧
                        self._send_frame(frame)
                    except queue.Empty:
                        continue
            except socket.error as e:
                print(f"accept() 出错：{e}")
                self.connected = False
                continue

    def _send_frame(self, frame):
        try:
            size_data = len(frame).to_bytes(4, 'big')
            self.conn.sendall(size_data)
            self.conn.sendall(frame)
        except (socket.error, BrokenPipeError) as e:
            print(f"客户端断开连接：{e}")
            self.connected = False  # 客户端断开后标记为未连接
            self._close_connection()

    def send_frame(self, frame):
        if self.connected:
            self.queue.put(frame)
            return True
        return False

    def _close_connection(self):
        if self.conn:
            try:
                self.conn.close()
            except:
                pass
            self.conn = None

    def stop(self):
        self.running = False
        self.connected = False
        self._close_connection()
        if self.sock:
            self.sock.close()



def qr_code_detection():
    global Isdection
    aruco_size = 0.05  # ArUco码的实际边长，单位：米
    kp_x = 0.6 #pid参数
    kp_y = 0.6
    kp_alt = 0.6
    kp_yaw = 0.3

    thread2 = TranstreamThread(TCP_IP, TCP_PORT)
    thread2.start()

    video_capture = cv2.VideoCapture(gstreamer_pipeline(flip_method=2), cv2.CAP_GSTREAMER)

    while rospy.is_shutdown() == False:
        if Isdection == False:
            rospy.sleep(0.1)
            continue

        ret, frame = video_capture.read()

        if frame is None:
            print("No frame received")
            continue

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        corners, ids, rejected = aruco.detectMarkers(gray, aruco_dict, parameters=parameters)

        if ids is not None:
            # 估计ArUco码的位姿
            rvecs, tvecs, _ = aruco.estimatePoseSingleMarkers(corners, aruco_size, camera_matrix, dist_coeffs)
            
            for i in range(len(ids)):
                # 获取ArUco码的旋转向量和位移向量
                rvec = rvecs[i][0]
                tvec = tvecs[i][0]

                # 绘制检测到的ArUco码的边框和轴
                aruco.drawDetectedMarkers(frame, corners)
                #aruco.drawAxis(frame, camera_matrix, dist_coeffs, rvec, tvec, 0.1)

                # 将位姿转换为控制指令，例如距离（dz_m）、左右偏移量（dx_m）和高度差（dy_m）
                dz_m = tvec[2]  # 前后距离
                dx_m = tvec[0]  # 左右偏移
                dy_m = -tvec[1]  # 高度偏移

                print("dectct target!!!")

                # print(f"Translation vector: {tvec}")
                # print(f"Rotation vector: {rvec}")
                
                # ArUco码的中心点 (计算其相对于图像中心的偏移)
                corner_points = corners[i][0]
                x1, y1 = corner_points[0]
                x2, y2 = corner_points[2]
                center_x = (x1 + x2) / 2
                center_y = (y1 + y2) / 2

                # 计算相对于图像中心的偏移
                dx_pixel = center_x - cx  # 水平偏移量（像素）
                dy_pixel = center_y - cy  # 垂直偏移量（像素）

                # 根据焦距计算偏航角 (yaw)（左右偏移量转换成弧度）
                d_yaw = np.arctan(dx_pixel / f_x)  # 水平偏移转为弧度

                # 假设我们将无人机的目标距离设定为1.5米
                dx_1 = dz_m - 1.5  # 距离与目标距离之差
                dy_1 = dx_m
                d_alt_1 = dy_m
                #d_yaw = np.arctan(dx_m / f_x)  # 获取yaw角

                set_pose(kp_alt * d_alt_1, kp_y * dy_1, 0, 0)

                label = f'dx: {dx_m:.2f} m, dy: {dy_m:.2f} m, dz: {dz_m:.2f} m, d_yaw:{d_yaw:.2f} rad'
                cv2.putText(frame, label, (10, 30 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        else:
            print("No ArUco marker detected.")

        cv2.imshow('ArUco Code Detection', frame)
        # 推理完图像后，编码为 JPEG 并发送
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 70]
        result, img_encoded = cv2.imencode('.jpg', frame, encode_param)
        data = img_encoded.tobytes()

        thread2.send_frame(data)  # 通过队列异步发送

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cv2.destroyAllWindows()

if __name__ == "__main__":
    rospy.init_node("Aruco_detection_node")
    subscriber = rospy.Subscriber("/fcu_command/command", Int16, command_handler)
    publisher = rospy.Publisher("/mission_follow", Float32MultiArray, queue_size=10)
    qr_code_detection()
