import { Buffer } from 'buffer';
import { 
  TlsRecordType,
  TlsHandshakeType,
  TlsExtensionType,
  TlsUtils
} from '../utils/tls-utils.js';

/**
 * Interface for logging functions used by the parser
 */
export type LoggerFunction = (message: string) => void;

/**
 * Result of a session resumption check
 */
export interface SessionResumptionResult {
  isResumption: boolean;
  hasSNI: boolean;
}

/**
 * Information about parsed TLS extensions
 */
export interface ExtensionInfo {
  type: number;
  length: number;
  data: Buffer;
}

/**
 * Result of a ClientHello parse operation
 */
export interface ClientHelloParseResult {
  isValid: boolean;
  version?: [number, number];
  random?: Buffer;
  sessionId?: Buffer;
  hasSessionId: boolean;
  cipherSuites?: Buffer;
  compressionMethods?: Buffer;
  extensions: ExtensionInfo[];
  serverNameList?: string[];
  hasSessionTicket: boolean;
  hasPsk: boolean;
  hasEarlyData: boolean;
  error?: string;
}

/**
 * Fragment tracking information
 */
export interface FragmentTrackingInfo {
  buffer: Buffer;
  timestamp: number;
  connectionId: string;
}

/**
 * Class for parsing TLS ClientHello messages
 */
export class ClientHelloParser {
  // Buffer for handling fragmented ClientHello messages
  private static fragmentedBuffers: Map<string, FragmentTrackingInfo> = new Map();
  private static fragmentTimeout: number = 1000; // ms to wait for fragments before cleanup

  /**
   * Clean up expired fragments
   */
  private static cleanupExpiredFragments(): void {
    const now = Date.now();
    for (const [connectionId, info] of this.fragmentedBuffers.entries()) {
      if (now - info.timestamp > this.fragmentTimeout) {
        this.fragmentedBuffers.delete(connectionId);
      }
    }
  }

  /**
   * Handles potential fragmented ClientHello messages by buffering and reassembling
   * TLS record fragments that might span multiple TCP packets.
   *
   * @param buffer The current buffer fragment
   * @param connectionId Unique identifier for the connection
   * @param logger Optional logging function
   * @returns A complete buffer if reassembly is successful, or undefined if more fragments are needed
   */
  public static handleFragmentedClientHello(
    buffer: Buffer,
    connectionId: string,
    logger?: LoggerFunction
  ): Buffer | undefined {
    const log = logger || (() => {});
    
    // Periodically clean up expired fragments
    this.cleanupExpiredFragments();

    // Check if we've seen this connection before
    if (!this.fragmentedBuffers.has(connectionId)) {
      // New connection, start with this buffer
      this.fragmentedBuffers.set(connectionId, {
        buffer,
        timestamp: Date.now(),
        connectionId
      });

      // Evaluate if this buffer already contains a complete ClientHello
      try {
        if (buffer.length >= 5) {
          // Get the record length from TLS header
          const recordLength = (buffer[3] << 8) + buffer[4] + 5; // +5 for the TLS record header itself
          log(`Initial buffer size: ${buffer.length}, expected record length: ${recordLength}`);

          // Check if this buffer already contains a complete TLS record
          if (buffer.length >= recordLength) {
            log(`Initial buffer contains complete ClientHello, length: ${buffer.length}`);
            return buffer;
          }
        } else {
          log(
            `Initial buffer too small (${buffer.length} bytes), needs at least 5 bytes for TLS header`
          );
        }
      } catch (e) {
        log(`Error checking initial buffer completeness: ${e}`);
      }

      log(`Started buffering connection ${connectionId}, initial size: ${buffer.length}`);
      return undefined; // Need more fragments
    } else {
      // Existing connection, append this buffer
      const existingInfo = this.fragmentedBuffers.get(connectionId)!;
      const newBuffer = Buffer.concat([existingInfo.buffer, buffer]);
      
      // Update the buffer and timestamp
      this.fragmentedBuffers.set(connectionId, {
        ...existingInfo,
        buffer: newBuffer,
        timestamp: Date.now()
      });

      log(`Appended to buffer for ${connectionId}, new size: ${newBuffer.length}`);

      // Check if we now have a complete ClientHello
      try {
        if (newBuffer.length >= 5) {
          // Get the record length from TLS header
          const recordLength = (newBuffer[3] << 8) + newBuffer[4] + 5; // +5 for the TLS record header itself
          log(
            `Reassembled buffer size: ${newBuffer.length}, expected record length: ${recordLength}`
          );

          // Check if we have a complete TLS record now
          if (newBuffer.length >= recordLength) {
            log(
              `Assembled complete ClientHello, length: ${newBuffer.length}, needed: ${recordLength}`
            );

            // Extract the complete TLS record (might be followed by more data)
            const completeRecord = newBuffer.slice(0, recordLength);

            // Check if this record is indeed a ClientHello (type 1) at position 5
            if (
              completeRecord.length > 5 &&
              completeRecord[5] === TlsHandshakeType.CLIENT_HELLO
            ) {
              log(`Verified record is a ClientHello handshake message`);

              // Complete message received, remove from tracking
              this.fragmentedBuffers.delete(connectionId);
              return completeRecord;
            } else {
              log(`Record is complete but not a ClientHello handshake, continuing to buffer`);
              // This might be another TLS record type preceding the ClientHello

              // Try checking for a ClientHello starting at the end of this record
              if (newBuffer.length > recordLength + 5) {
                const nextRecordType = newBuffer[recordLength];
                log(
                  `Next record type: ${nextRecordType} (looking for ${TlsRecordType.HANDSHAKE})`
                );

                if (nextRecordType === TlsRecordType.HANDSHAKE) {
                  const handshakeType = newBuffer[recordLength + 5];
                  log(
                    `Next handshake type: ${handshakeType} (looking for ${TlsHandshakeType.CLIENT_HELLO})`
                  );

                  if (handshakeType === TlsHandshakeType.CLIENT_HELLO) {
                    // Found a ClientHello in the next record, return the entire buffer
                    log(`Found ClientHello in subsequent record, returning full buffer`);
                    this.fragmentedBuffers.delete(connectionId);
                    return newBuffer;
                  }
                }
              }
            }
          }
        }
      } catch (e) {
        log(`Error checking reassembled buffer completeness: ${e}`);
      }

      return undefined; // Still need more fragments
    }
  }

  /**
   * Parses a TLS ClientHello message and extracts all components
   * 
   * @param buffer The buffer containing the ClientHello message
   * @param logger Optional logging function
   * @returns Parsed ClientHello or undefined if parsing failed
   */
  public static parseClientHello(
    buffer: Buffer,
    logger?: LoggerFunction
  ): ClientHelloParseResult {
    const log = logger || (() => {});
    const result: ClientHelloParseResult = {
      isValid: false,
      hasSessionId: false,
      extensions: [],
      hasSessionTicket: false,
      hasPsk: false,
      hasEarlyData: false
    };

    try {
      // Check basic validity
      if (buffer.length < 5) {
        result.error = 'Buffer too small for TLS record header';
        return result;
      }

      // Check record type (must be HANDSHAKE)
      if (buffer[0] !== TlsRecordType.HANDSHAKE) {
        result.error = `Not a TLS handshake record: ${buffer[0]}`;
        return result;
      }

      // Get TLS version from record header
      const majorVersion = buffer[1];
      const minorVersion = buffer[2];
      result.version = [majorVersion, minorVersion];
      log(`TLS record version: ${majorVersion}.${minorVersion}`);

      // Parse record length (bytes 3-4, big-endian)
      const recordLength = (buffer[3] << 8) + buffer[4];
      log(`Record length: ${recordLength}`);

      // Validate record length against buffer size
      if (buffer.length < recordLength + 5) {
        result.error = 'Buffer smaller than expected record length';
        return result;
      }

      // Start of handshake message in the buffer
      let pos = 5;

      // Check handshake type (must be CLIENT_HELLO)
      if (buffer[pos] !== TlsHandshakeType.CLIENT_HELLO) {
        result.error = `Not a ClientHello message: ${buffer[pos]}`;
        return result;
      }

      // Skip handshake type (1 byte)
      pos += 1;

      // Parse handshake length (3 bytes, big-endian)
      const handshakeLength = (buffer[pos] << 16) + (buffer[pos + 1] << 8) + buffer[pos + 2];
      log(`Handshake length: ${handshakeLength}`);

      // Skip handshake length (3 bytes)
      pos += 3;

      // Check client version (2 bytes)
      const clientMajorVersion = buffer[pos];
      const clientMinorVersion = buffer[pos + 1];
      log(`Client version: ${clientMajorVersion}.${clientMinorVersion}`);

      // Skip client version (2 bytes)
      pos += 2;

      // Extract client random (32 bytes)
      if (pos + 32 > buffer.length) {
        result.error = 'Buffer too small for client random';
        return result;
      }
      
      result.random = buffer.slice(pos, pos + 32);
      log(`Client random: ${result.random.toString('hex')}`);

      // Skip client random (32 bytes)
      pos += 32;

      // Parse session ID
      if (pos + 1 > buffer.length) {
        result.error = 'Buffer too small for session ID length';
        return result;
      }

      const sessionIdLength = buffer[pos];
      log(`Session ID length: ${sessionIdLength}`);
      pos += 1;
      
      result.hasSessionId = sessionIdLength > 0;
      
      if (sessionIdLength > 0) {
        if (pos + sessionIdLength > buffer.length) {
          result.error = 'Buffer too small for session ID';
          return result;
        }
        
        result.sessionId = buffer.slice(pos, pos + sessionIdLength);
        log(`Session ID: ${result.sessionId.toString('hex')}`);
      }

      // Skip session ID
      pos += sessionIdLength;

      // Check if we have enough bytes left for cipher suites
      if (pos + 2 > buffer.length) {
        result.error = 'Buffer too small for cipher suites length';
        return result;
      }

      // Parse cipher suites length (2 bytes, big-endian)
      const cipherSuitesLength = (buffer[pos] << 8) + buffer[pos + 1];
      log(`Cipher suites length: ${cipherSuitesLength}`);
      pos += 2;
      
      // Extract cipher suites
      if (pos + cipherSuitesLength > buffer.length) {
        result.error = 'Buffer too small for cipher suites';
        return result;
      }
      
      result.cipherSuites = buffer.slice(pos, pos + cipherSuitesLength);

      // Skip cipher suites
      pos += cipherSuitesLength;

      // Check if we have enough bytes left for compression methods
      if (pos + 1 > buffer.length) {
        result.error = 'Buffer too small for compression methods length';
        return result;
      }

      // Parse compression methods length (1 byte)
      const compressionMethodsLength = buffer[pos];
      log(`Compression methods length: ${compressionMethodsLength}`);
      pos += 1;
      
      // Extract compression methods
      if (pos + compressionMethodsLength > buffer.length) {
        result.error = 'Buffer too small for compression methods';
        return result;
      }
      
      result.compressionMethods = buffer.slice(pos, pos + compressionMethodsLength);

      // Skip compression methods
      pos += compressionMethodsLength;

      // Check if we have enough bytes for extensions length
      if (pos + 2 > buffer.length) {
        // No extensions present - this is valid for older TLS versions
        result.isValid = true;
        return result;
      }

      // Parse extensions length (2 bytes, big-endian)
      const extensionsLength = (buffer[pos] << 8) + buffer[pos + 1];
      log(`Extensions length: ${extensionsLength}`);
      pos += 2;

      // Extensions end position
      const extensionsEnd = pos + extensionsLength;

      // Check if extensions length is valid
      if (extensionsEnd > buffer.length) {
        result.error = 'Extensions length exceeds buffer size';
        return result;
      }

      // Iterate through extensions
      const serverNames: string[] = [];
      
      while (pos + 4 <= extensionsEnd) {
        // Parse extension type (2 bytes, big-endian)
        const extensionType = (buffer[pos] << 8) + buffer[pos + 1];
        log(`Extension type: 0x${extensionType.toString(16).padStart(4, '0')}`);
        pos += 2;

        // Parse extension length (2 bytes, big-endian)
        const extensionLength = (buffer[pos] << 8) + buffer[pos + 1];
        log(`Extension length: ${extensionLength}`);
        pos += 2;
        
        // Extract extension data
        if (pos + extensionLength > extensionsEnd) {
          result.error = `Extension ${extensionType} data exceeds bounds`;
          return result;
        }
        
        const extensionData = buffer.slice(pos, pos + extensionLength);
        
        // Record all extensions
        result.extensions.push({
          type: extensionType,
          length: extensionLength,
          data: extensionData
        });
        
        // Track specific extension types
        if (extensionType === TlsExtensionType.SERVER_NAME) {
          // Server Name Indication (SNI)
          this.parseServerNameExtension(extensionData, serverNames, logger);
        } else if (extensionType === TlsExtensionType.SESSION_TICKET) {
          // Session ticket
          result.hasSessionTicket = true;
        } else if (extensionType === TlsExtensionType.PRE_SHARED_KEY) {
          // TLS 1.3 PSK
          result.hasPsk = true;
        } else if (extensionType === TlsExtensionType.EARLY_DATA) {
          // TLS 1.3 Early Data (0-RTT)
          result.hasEarlyData = true;
        }
        
        // Move to next extension
        pos += extensionLength;
      }
      
      // Store any server names found
      if (serverNames.length > 0) {
        result.serverNameList = serverNames;
      }
      
      // Mark as valid if we get here
      result.isValid = true;
      return result;
    } catch (error) {
      const errorMessage = error instanceof Error ? error.message : String(error);
      log(`Error parsing ClientHello: ${errorMessage}`);
      result.error = errorMessage;
      return result;
    }
  }
  
  /**
   * Parses the server name extension data and extracts hostnames
   * 
   * @param data Extension data buffer
   * @param serverNames Array to populate with found server names
   * @param logger Optional logging function
   * @returns true if parsing succeeded
   */
  private static parseServerNameExtension(
    data: Buffer,
    serverNames: string[],
    logger?: LoggerFunction
  ): boolean {
    const log = logger || (() => {});
    
    try {
      // Need at least 2 bytes for server name list length
      if (data.length < 2) {
        log('SNI extension too small for server name list length');
        return false;
      }
      
      // Parse server name list length (2 bytes)
      const listLength = (data[0] << 8) + data[1];
      
      // Skip to first name entry
      let pos = 2;
      
      // End of list
      const listEnd = pos + listLength;
      
      // Validate length
      if (listEnd > data.length) {
        log('SNI server name list exceeds extension data');
        return false;
      }
      
      // Process all name entries
      while (pos + 3 <= listEnd) {
        // Name type (1 byte)
        const nameType = data[pos];
        pos += 1;
        
        // For hostname, type must be 0
        if (nameType !== 0) {
          // Skip this entry
          if (pos + 2 <= listEnd) {
            const nameLength = (data[pos] << 8) + data[pos + 1];
            pos += 2 + nameLength;
            continue;
          } else {
            log('Malformed SNI entry');
            return false;
          }
        }
        
        // Parse hostname length (2 bytes)
        if (pos + 2 > listEnd) {
          log('SNI extension truncated');
          return false;
        }
        
        const nameLength = (data[pos] << 8) + data[pos + 1];
        pos += 2;
        
        // Extract hostname
        if (pos + nameLength > listEnd) {
          log('SNI hostname truncated');
          return false;
        }
        
        // Extract the hostname as UTF-8
        try {
          const hostname = data.slice(pos, pos + nameLength).toString('utf8');
          log(`Found SNI hostname: ${hostname}`);
          serverNames.push(hostname);
        } catch (err) {
          log(`Error extracting hostname: ${err}`);
        }
        
        // Move to next entry
        pos += nameLength;
      }
      
      return serverNames.length > 0;
    } catch (error) {
      log(`Error parsing SNI extension: ${error}`);
      return false;
    }
  }
  
  /**
   * Determines if a ClientHello contains session resumption indicators
   * 
   * @param buffer The ClientHello buffer
   * @param logger Optional logging function
   * @returns Session resumption result
   */
  public static hasSessionResumption(
    buffer: Buffer,
    logger?: LoggerFunction
  ): SessionResumptionResult {
    const log = logger || (() => {});
    
    if (!TlsUtils.isClientHello(buffer)) {
      return { isResumption: false, hasSNI: false };
    }
    
    const parseResult = this.parseClientHello(buffer, logger);
    if (!parseResult.isValid) {
      log(`ClientHello parse failed: ${parseResult.error}`);
      return { isResumption: false, hasSNI: false };
    }
    
    // Check resumption indicators
    const hasSessionId = parseResult.hasSessionId;
    const hasSessionTicket = parseResult.hasSessionTicket;
    const hasPsk = parseResult.hasPsk;
    const hasEarlyData = parseResult.hasEarlyData;
    
    // Check for SNI
    const hasSNI = !!parseResult.serverNameList && parseResult.serverNameList.length > 0;
    
    // Consider it a resumption if any resumption mechanism is present
    const isResumption = hasSessionTicket || hasPsk || hasEarlyData || 
                         (hasSessionId && !hasPsk); // Legacy resumption
    
    // Log details
    if (isResumption) {
      log(
        'Session resumption detected: ' +
        (hasSessionTicket ? 'session ticket, ' : '') +
        (hasPsk ? 'PSK, ' : '') +
        (hasEarlyData ? 'early data, ' : '') +
        (hasSessionId ? 'session ID' : '') +
        (hasSNI ? ', with SNI' : ', without SNI')
      );
    }
    
    return { isResumption, hasSNI };
  }
  
  /**
   * Checks if a ClientHello appears to be from a tab reactivation 
   * 
   * @param buffer The ClientHello buffer 
   * @param logger Optional logging function
   * @returns true if it appears to be a tab reactivation
   */
  public static isTabReactivationHandshake(
    buffer: Buffer,
    logger?: LoggerFunction
  ): boolean {
    const log = logger || (() => {});
    
    if (!TlsUtils.isClientHello(buffer)) {
      return false;
    }
    
    // Parse the ClientHello
    const parseResult = this.parseClientHello(buffer, logger);
    if (!parseResult.isValid) {
      return false;
    }
    
    // Tab reactivation pattern: session identifier + (ticket or PSK) but no SNI
    const hasSessionId = parseResult.hasSessionId;
    const hasSessionTicket = parseResult.hasSessionTicket;
    const hasPsk = parseResult.hasPsk;
    const hasSNI = !!parseResult.serverNameList && parseResult.serverNameList.length > 0;
    
    if ((hasSessionId && (hasSessionTicket || hasPsk)) && !hasSNI) {
      log('Detected tab reactivation pattern: session resumption without SNI');
      return true;
    }
    
    return false;
  }
}