#!/usr/bin/env python3
"""
EcoOS Wayland Display Agent (eco-vdagent)

A Wayland-native replacement for spice-vdagent that uses swaymsg/wlr-output-management
instead of xrandr to configure displays.

Listens on the SPICE virtio-serial port for VD_AGENT_MONITORS_CONFIG messages
and applies the configuration to Sway outputs.
"""

import os
import sys
import struct
import subprocess
import json
import time
import signal
import logging
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - eco-vdagent - %(levelname)s - %(message)s'
)
log = logging.getLogger('eco-vdagent')

# SPICE VDAgent Protocol Constants
VD_AGENT_PROTOCOL = 1

# Message types
VD_AGENT_MOUSE_STATE = 1
VD_AGENT_MONITORS_CONFIG = 2
VD_AGENT_REPLY = 3
VD_AGENT_CLIPBOARD = 4
VD_AGENT_DISPLAY_CONFIG = 5
VD_AGENT_ANNOUNCE_CAPABILITIES = 6
VD_AGENT_CLIPBOARD_GRAB = 7
VD_AGENT_CLIPBOARD_REQUEST = 8
VD_AGENT_CLIPBOARD_RELEASE = 9
VD_AGENT_FILE_XFER_START = 10
VD_AGENT_FILE_XFER_STATUS = 11
VD_AGENT_FILE_XFER_DATA = 12
VD_AGENT_CLIENT_DISCONNECTED = 13
VD_AGENT_MAX_CLIPBOARD = 14
VD_AGENT_AUDIO_VOLUME_SYNC = 15
VD_AGENT_GRAPHICS_DEVICE_INFO = 16

# Reply error codes
VD_AGENT_SUCCESS = 1
VD_AGENT_ERROR = 2

# Capability bits
VD_AGENT_CAP_MOUSE_STATE = 0
VD_AGENT_CAP_MONITORS_CONFIG = 1
VD_AGENT_CAP_REPLY = 2
VD_AGENT_CAP_CLIPBOARD = 3
VD_AGENT_CAP_DISPLAY_CONFIG = 4
VD_AGENT_CAP_CLIPBOARD_BY_DEMAND = 5
VD_AGENT_CAP_CLIPBOARD_SELECTION = 6
VD_AGENT_CAP_SPARSE_MONITORS_CONFIG = 7
VD_AGENT_CAP_GUEST_LINEEND_LF = 8
VD_AGENT_CAP_GUEST_LINEEND_CRLF = 9
VD_AGENT_CAP_MAX_CLIPBOARD = 10
VD_AGENT_CAP_AUDIO_VOLUME_SYNC = 11
VD_AGENT_CAP_MONITORS_CONFIG_POSITION = 12
VD_AGENT_CAP_FILE_XFER_DISABLED = 13
VD_AGENT_CAP_FILE_XFER_DETAILED_ERRORS = 14
VD_AGENT_CAP_GRAPHICS_DEVICE_INFO = 15
VD_AGENT_CAP_CLIPBOARD_NO_RELEASE_ON_REGRAB = 16
VD_AGENT_CAP_CLIPBOARD_GRAB_SERIAL = 17

# Virtio serial port path
VIRTIO_PORT = '/dev/virtio-ports/com.redhat.spice.0'

# VDI Chunk header: port(4) + size(4) = 8 bytes
VDI_CHUNK_HEADER_SIZE = 8
VDI_CHUNK_HEADER_FMT = '<II'  # port, size

# VDI Port constants
VDP_CLIENT_PORT = 1
VDP_SERVER_PORT = 2

# VDAgentMessage header: protocol(4) + type(4) + opaque(8) + size(4) = 20 bytes
VDAGENT_MSG_HEADER_SIZE = 20
VDAGENT_MSG_HEADER_FMT = '<IIQI'  # little-endian: uint32, uint32, uint64, uint32

# VDAgentMonitorsConfig header: num_of_monitors(4) + flags(4) = 8 bytes
MONITORS_CONFIG_HEADER_SIZE = 8
MONITORS_CONFIG_HEADER_FMT = '<II'

# VDAgentMonConfig: height(4) + width(4) + depth(4) + x(4) + y(4) = 20 bytes
MON_CONFIG_SIZE = 20
MON_CONFIG_FMT = '<IIIii'  # height, width, depth, x, y (x,y are signed)


class EcoVDAgent:
    def __init__(self):
        self.port_fd = None
        self.running = True
        self.sway_socket = None

    def find_sway_socket(self):
        """Find the Sway IPC socket"""
        # Check environment first
        if 'SWAYSOCK' in os.environ:
            return os.environ['SWAYSOCK']

        # Search common locations
        runtime_dir = os.environ.get('XDG_RUNTIME_DIR', '/run/user/1000')

        # Try to find sway socket - check fixed path first, then glob patterns
        import glob

        # Check for fixed socket path first (set by eco-daemon)
        fixed_socket = f'{runtime_dir}/sway-ipc.sock'
        if os.path.exists(fixed_socket):
            return fixed_socket

        # Fall back to glob patterns for standard Sway socket naming
        for pattern in [f'{runtime_dir}/sway-ipc.*.sock', '/run/user/*/sway-ipc.*.sock']:
            sockets = glob.glob(pattern)
            if sockets:
                return sockets[0]

        return None

    def run_swaymsg(self, *args):
        """Run swaymsg command"""
        cmd = ['swaymsg']
        if self.sway_socket:
            cmd.extend(['-s', self.sway_socket])
        cmd.extend(args)

        try:
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
            if result.returncode != 0:
                log.warning(f"swaymsg failed: {result.stderr}")
            return result.returncode == 0, result.stdout
        except Exception as e:
            log.error(f"Failed to run swaymsg: {e}")
            return False, ""

    def get_outputs(self):
        """Get current Sway outputs"""
        success, output = self.run_swaymsg('-t', 'get_outputs', '-r')
        if success:
            try:
                return json.loads(output)
            except json.JSONDecodeError:
                pass
        return []

    def configure_output(self, name, width, height, x, y, enable=True):
        """Configure a Sway output"""
        if enable:
            # Try to enable and position the output
            # First, try setting mode
            mode_cmd = f'output {name} mode {width}x{height} position {x} {y} enable'
            success, _ = self.run_swaymsg(mode_cmd)
            if not success:
                # Try without explicit mode (use preferred)
                pos_cmd = f'output {name} position {x} {y} enable'
                success, _ = self.run_swaymsg(pos_cmd)
            return success
        else:
            return self.run_swaymsg(f'output {name} disable')[0]

    def apply_monitors_config(self, monitors):
        """Apply monitor configuration to Sway outputs"""
        log.info(f"Applying configuration for {len(monitors)} monitors")

        # Get current outputs
        outputs = self.get_outputs()
        output_names = [o.get('name') for o in outputs]
        log.info(f"Available outputs: {output_names}")

        # Sort monitors by x position to match with outputs
        monitors_sorted = sorted(enumerate(monitors), key=lambda m: m[1]['x'])

        # Match monitors to outputs
        for i, (mon_idx, mon) in enumerate(monitors_sorted):
            if i < len(output_names):
                name = output_names[i]
                log.info(f"Configuring {name}: {mon['width']}x{mon['height']} at ({mon['x']}, {mon['y']})")
                self.configure_output(
                    name,
                    mon['width'],
                    mon['height'],
                    mon['x'],
                    mon['y'],
                    enable=True
                )
            else:
                log.warning(f"No output available for monitor {mon_idx}")

        # Disable extra outputs
        for i in range(len(monitors), len(output_names)):
            name = output_names[i]
            log.info(f"Disabling unused output: {name}")
            self.configure_output(name, 0, 0, 0, 0, enable=False)

    def parse_monitors_config(self, data):
        """Parse VD_AGENT_MONITORS_CONFIG message"""
        if len(data) < MONITORS_CONFIG_HEADER_SIZE:
            log.error("Monitors config data too short")
            return None

        num_monitors, flags = struct.unpack(MONITORS_CONFIG_HEADER_FMT, data[:MONITORS_CONFIG_HEADER_SIZE])
        log.info(f"Monitors config: {num_monitors} monitors, flags={flags}")

        monitors = []
        offset = MONITORS_CONFIG_HEADER_SIZE

        for i in range(num_monitors):
            if offset + MON_CONFIG_SIZE > len(data):
                log.error(f"Truncated monitor config at index {i}")
                break

            height, width, depth, x, y = struct.unpack(
                MON_CONFIG_FMT,
                data[offset:offset + MON_CONFIG_SIZE]
            )

            monitors.append({
                'width': width,
                'height': height,
                'depth': depth,
                'x': x,
                'y': y
            })
            log.info(f"  Monitor {i}: {width}x{height}+{x}+{y} depth={depth}")
            offset += MON_CONFIG_SIZE

        return monitors

    def send_reply(self, msg_type, error_code):
        """Send VD_AGENT_REPLY message"""
        # Reply data: type(4) + error(4) = 8 bytes
        reply_data = struct.pack('<II', msg_type, error_code)

        if self.send_message(VD_AGENT_REPLY, reply_data):
            log.debug(f"Sent reply for type {msg_type}: {'success' if error_code == VD_AGENT_SUCCESS else 'error'}")
        else:
            log.error(f"Failed to send reply for type {msg_type}")

    def send_message(self, msg_type, data):
        """Send a VDAgent message with proper chunk header"""
        if not self.port_fd:
            return False

        # Build VDAgentMessage header
        msg_header = struct.pack(
            VDAGENT_MSG_HEADER_FMT,
            VD_AGENT_PROTOCOL,
            msg_type,
            0,  # opaque
            len(data)
        )

        # Full message = header + data
        full_msg = msg_header + data

        # Build VDI chunk header (port=SERVER, size=message size)
        chunk_header = struct.pack(
            VDI_CHUNK_HEADER_FMT,
            VDP_SERVER_PORT,
            len(full_msg)
        )

        # Retry writes with EAGAIN handling (non-blocking fd)
        message = chunk_header + full_msg
        retries = 10
        while retries > 0:
            try:
                os.write(self.port_fd, message)
                return True
            except OSError as e:
                if e.errno == 11:  # EAGAIN - resource temporarily unavailable
                    retries -= 1
                    time.sleep(0.1)
                    continue
                log.error(f"Failed to send message type {msg_type}: {e}")
                return False
        log.error(f"Failed to send message type {msg_type}: EAGAIN after retries")
        return False

    def announce_capabilities(self):
        """Send VD_AGENT_ANNOUNCE_CAPABILITIES to register with SPICE server"""
        # Build capability bits - we support monitors config
        caps = 0
        caps |= (1 << VD_AGENT_CAP_MONITORS_CONFIG)
        caps |= (1 << VD_AGENT_CAP_REPLY)
        caps |= (1 << VD_AGENT_CAP_SPARSE_MONITORS_CONFIG)
        caps |= (1 << VD_AGENT_CAP_MONITORS_CONFIG_POSITION)

        # VDAgentAnnounceCapabilities: request(4) + caps(4) = 8 bytes
        # request=1 means we want the server to send us its capabilities
        announce_data = struct.pack('<II', 1, caps)

        if self.send_message(VD_AGENT_ANNOUNCE_CAPABILITIES, announce_data):
            log.info("Announced capabilities to SPICE server")
        else:
            log.error("Failed to announce capabilities")

    def handle_message(self, msg_type, data):
        """Handle a VDAgent message"""
        if msg_type == VD_AGENT_MONITORS_CONFIG:
            log.info("Received VD_AGENT_MONITORS_CONFIG")
            monitors = self.parse_monitors_config(data)
            if monitors:
                self.apply_monitors_config(monitors)
                self.send_reply(VD_AGENT_MONITORS_CONFIG, VD_AGENT_SUCCESS)
            else:
                self.send_reply(VD_AGENT_MONITORS_CONFIG, VD_AGENT_ERROR)

        elif msg_type == VD_AGENT_ANNOUNCE_CAPABILITIES:
            log.info("Received VD_AGENT_ANNOUNCE_CAPABILITIES")
            # We could respond with our capabilities here
            # For now, just acknowledge

        elif msg_type == VD_AGENT_DISPLAY_CONFIG:
            log.info("Received VD_AGENT_DISPLAY_CONFIG")
            # Display config for disabling client display changes

        elif msg_type == VD_AGENT_CLIENT_DISCONNECTED:
            log.info("Client disconnected")

        else:
            log.debug(f"Unhandled message type: {msg_type}")

    def read_message(self):
        """Read a single VDAgent message from the port (with chunk header)"""
        # Read VDI chunk header first
        try:
            chunk_header_data = os.read(self.port_fd, VDI_CHUNK_HEADER_SIZE)
        except OSError as e:
            if e.errno == 11:  # EAGAIN
                return None
            raise

        if len(chunk_header_data) < VDI_CHUNK_HEADER_SIZE:
            if len(chunk_header_data) == 0:
                return None
            log.warning(f"Short chunk header read: {len(chunk_header_data)} bytes")
            return None

        port, chunk_size = struct.unpack(VDI_CHUNK_HEADER_FMT, chunk_header_data)
        log.debug(f"Chunk header: port={port}, size={chunk_size}")

        if chunk_size < VDAGENT_MSG_HEADER_SIZE:
            log.warning(f"Chunk size too small: {chunk_size}")
            return None

        # Read VDAgent message header
        try:
            header_data = os.read(self.port_fd, VDAGENT_MSG_HEADER_SIZE)
        except OSError as e:
            if e.errno == 11:  # EAGAIN
                return None
            raise

        if len(header_data) < VDAGENT_MSG_HEADER_SIZE:
            log.warning(f"Short message header read: {len(header_data)} bytes")
            return None

        protocol, msg_type, opaque, size = struct.unpack(VDAGENT_MSG_HEADER_FMT, header_data)

        if protocol != VD_AGENT_PROTOCOL:
            log.warning(f"Unknown protocol: {protocol}")
            return None

        # Read message data
        data = b''
        while len(data) < size:
            try:
                chunk = os.read(self.port_fd, size - len(data))
                if not chunk:
                    break
                data += chunk
            except OSError as e:
                if e.errno == 11:  # EAGAIN
                    time.sleep(0.01)
                    continue
                raise

        return msg_type, data

    def signal_handler(self, signum, frame):
        """Handle shutdown signals"""
        log.info(f"Received signal {signum}, shutting down...")
        self.running = False

    def run(self):
        """Main loop"""
        # Set up signal handlers
        signal.signal(signal.SIGTERM, self.signal_handler)
        signal.signal(signal.SIGINT, self.signal_handler)

        # Find Sway socket
        self.sway_socket = self.find_sway_socket()
        if self.sway_socket:
            log.info(f"Using Sway socket: {self.sway_socket}")
        else:
            log.warning("No Sway socket found, will retry...")

        # Wait for virtio port
        log.info(f"Waiting for virtio port: {VIRTIO_PORT}")
        while self.running and not Path(VIRTIO_PORT).exists():
            time.sleep(1)

        if not self.running:
            return

        log.info("Opening virtio port...")
        try:
            self.port_fd = os.open(VIRTIO_PORT, os.O_RDWR | os.O_NONBLOCK)
        except OSError as e:
            log.error(f"Failed to open virtio port: {e}")
            return

        log.info("eco-vdagent started, announcing capabilities...")

        # Announce our capabilities to the SPICE server
        self.announce_capabilities()

        log.info("Listening for SPICE agent messages...")

        # Main loop
        while self.running:
            try:
                # Try to find Sway socket if not found yet
                if not self.sway_socket:
                    self.sway_socket = self.find_sway_socket()

                result = self.read_message()
                if result:
                    msg_type, data = result
                    self.handle_message(msg_type, data)
                else:
                    time.sleep(0.1)
            except Exception as e:
                log.error(f"Error in main loop: {e}")
                time.sleep(1)

        if self.port_fd:
            os.close(self.port_fd)

        log.info("eco-vdagent stopped")


if __name__ == '__main__':
    agent = EcoVDAgent()
    agent.run()
