"""
Security Utility Functions for EPO-LAW
Password validation, input sanitization, and security helpers
"""
import re
import os
import logging
from datetime import datetime
from flask import current_app, request
from werkzeug.utils import secure_filename

# Configure logging
logger = logging.getLogger(__name__)


def validate_password(password):
    """
    Validate password against security policy

    Returns:
        tuple: (is_valid, error_message)
    """
    config = current_app.config

    min_length = config.get('MIN_PASSWORD_LENGTH', 12)
    require_uppercase = config.get('REQUIRE_UPPERCASE', True)
    require_lowercase = config.get('REQUIRE_LOWERCASE', True)
    require_digits = config.get('REQUIRE_DIGITS', True)
    require_special = config.get('REQUIRE_SPECIAL_CHARS', True)

    errors = []

    # Check length
    if len(password) < min_length:
        errors.append(f"Password must be at least {min_length} characters long")

    # Check for uppercase
    if require_uppercase and not re.search(r'[A-Z]', password):
        errors.append("Password must contain at least one uppercase letter")

    # Check for lowercase
    if require_lowercase and not re.search(r'[a-z]', password):
        errors.append("Password must contain at least one lowercase letter")

    # Check for digits
    if require_digits and not re.search(r'\d', password):
        errors.append("Password must contain at least one digit")

    # Check for special characters
    if require_special and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
        errors.append("Password must contain at least one special character (!@#$%^&*...)")

    # Check for common weak passwords
    weak_passwords = [
        'password', '12345678', 'qwerty', 'admin', 'letmein',
        'welcome', 'monkey', 'dragon', 'master', 'password123'
    ]
    if password.lower() in weak_passwords:
        errors.append("This password is too common and not allowed")

    if errors:
        return False, "; ".join(errors)

    return True, None


def validate_email(email):
    """
    Validate email address format

    Returns:
        tuple: (is_valid, error_message)
    """
    if not email:
        return False, "Email is required"

    # RFC 5322 simplified pattern
    pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'

    if not re.match(pattern, email):
        return False, "Invalid email format"

    # Additional checks
    if len(email) > 254:  # RFC 5321
        return False, "Email address is too long"

    local, domain = email.rsplit('@', 1)
    if len(local) > 64:  # RFC 5321
        return False, "Email local part is too long"

    return True, None


def validate_username(username):
    """
    Validate username format

    Returns:
        tuple: (is_valid, error_message)
    """
    if not username:
        return False, "Username is required"

    if len(username) < 3:
        return False, "Username must be at least 3 characters long"

    if len(username) > 50:
        return False, "Username must be less than 50 characters"

    # Alphanumeric, underscore, hyphen only
    if not re.match(r'^[a-zA-Z0-9_-]+$', username):
        return False, "Username can only contain letters, numbers, underscores, and hyphens"

    return True, None


def sanitize_filename(filename):
    """
    Sanitize uploaded filename to prevent path traversal and other attacks

    Returns:
        str: Sanitized filename or None if invalid
    """
    if not filename:
        return None

    # Use werkzeug's secure_filename
    filename = secure_filename(filename)

    # Additional checks
    if not filename or filename == '':
        return None

    # Remove any remaining potentially dangerous characters
    filename = re.sub(r'[^\w\s\-\.]', '', filename)

    # Limit length
    if len(filename) > 255:
        # Keep extension, truncate name
        name, ext = os.path.splitext(filename)
        filename = name[:250] + ext

    return filename


def allowed_file(filename):
    """
    Check if file extension is allowed

    Returns:
        bool: True if allowed, False otherwise
    """
    if not filename or '.' not in filename:
        return False

    ext = filename.rsplit('.', 1)[1].lower()
    allowed = current_app.config.get('ALLOWED_EXTENSIONS', set())

    return ext in allowed


def get_file_extension(filename):
    """Get file extension safely"""
    if not filename or '.' not in filename:
        return None
    return filename.rsplit('.', 1)[1].lower()


def validate_file_upload(file, allowed_extensions=None):
    """
    Comprehensive file upload validation

    Returns:
        tuple: (is_valid, error_message, sanitized_filename)
    """
    if not file:
        return False, "No file provided", None

    if file.filename == '':
        return False, "No file selected", None

    # Sanitize filename
    safe_filename = sanitize_filename(file.filename)
    if not safe_filename:
        return False, "Invalid filename", None

    # Check allowed extensions
    if allowed_extensions is None:
        allowed_extensions = current_app.config.get('ALLOWED_EXTENSIONS', set())

    ext = get_file_extension(safe_filename)
    if not ext or ext not in allowed_extensions:
        return False, f"File type not allowed. Allowed types: {', '.join(allowed_extensions)}", None

    # Check file size
    max_size = current_app.config.get('MAX_CONTENT_LENGTH', 16 * 1024 * 1024)
    file.seek(0, os.SEEK_END)
    size = file.tell()
    file.seek(0)  # Reset to beginning

    if size > max_size:
        max_mb = max_size / (1024 * 1024)
        return False, f"File too large. Maximum size: {max_mb}MB", None

    return True, None, safe_filename


def log_security_event(event_type, description, user_id=None, severity='INFO'):
    """
    Log security-related events

    Args:
        event_type: Type of security event (login_failed, unauthorized_access, etc.)
        description: Detailed description
        user_id: User ID if applicable
        severity: Log severity (INFO, WARNING, ERROR, CRITICAL)
    """
    ip_address = request.remote_addr if request else 'unknown'
    user_agent = request.headers.get('User-Agent', 'unknown') if request else 'unknown'

    log_entry = {
        'timestamp': datetime.utcnow().isoformat(),
        'event_type': event_type,
        'description': description,
        'user_id': user_id,
        'ip_address': ip_address,
        'user_agent': user_agent,
        'severity': severity
    }

    # Log to security log
    security_logger = logging.getLogger('security')

    if severity == 'CRITICAL':
        security_logger.critical(f"SECURITY: {event_type} - {description} - User: {user_id} - IP: {ip_address}")
    elif severity == 'ERROR':
        security_logger.error(f"SECURITY: {event_type} - {description} - User: {user_id} - IP: {ip_address}")
    elif severity == 'WARNING':
        security_logger.warning(f"SECURITY: {event_type} - {description} - User: {user_id} - IP: {ip_address}")
    else:
        security_logger.info(f"SECURITY: {event_type} - {description} - User: {user_id} - IP: {ip_address}")

    return log_entry


def sanitize_input(text, max_length=None, allow_html=False):
    """
    Sanitize text input to prevent XSS and injection attacks

    Args:
        text: Input text to sanitize
        max_length: Maximum allowed length
        allow_html: Whether to allow HTML (will be escaped if False)

    Returns:
        str: Sanitized text
    """
    if not text:
        return ''

    # Convert to string
    text = str(text)

    # Truncate if needed
    if max_length and len(text) > max_length:
        text = text[:max_length]

    # Remove null bytes
    text = text.replace('\x00', '')

    # If HTML not allowed, escape it
    if not allow_html:
        # Basic HTML escaping
        text = text.replace('&', '&amp;')
        text = text.replace('<', '&lt;')
        text = text.replace('>', '&gt;')
        text = text.replace('"', '&quot;')
        text = text.replace("'", '&#x27;')

    return text


def is_safe_redirect_url(target):
    """
    Check if redirect URL is safe (prevents open redirect vulnerabilities)

    Returns:
        bool: True if safe, False otherwise
    """
    if not target:
        return False

    # Only allow relative URLs (no protocol)
    if '://' in target:
        return False

    # Don't allow // (protocol-relative URLs)
    if target.startswith('//'):
        return False

    # Don't allow backslashes (Windows-style paths)
    if '\\' in target:
        return False

    return True
