#!/usr/bin/env python3
"""
Add CSRF tokens to all POST forms in templates
"""

import os
import re
import glob

def add_csrf_to_form(content):
    """Add CSRF token to forms that don't have it"""
    
    # Pattern to find forms with POST method
    form_pattern = r'(<form[^>]*method=["\']POST["\'][^>]*>)'
    
    # Check if CSRF token already exists nearby
    def replace_form(match):
        form_tag = match.group(1)
        # Get the position of this form in the content
        start_pos = match.start()
        end_pos = match.end()
        
        # Check next 200 characters for existing csrf_token
        next_content = content[end_pos:end_pos+200]
        if 'csrf_token' in next_content:
            # Already has CSRF token
            return form_tag
        
        # Add CSRF token after the form tag
        csrf_html = '\n                <input type="hidden" name="csrf_token" value="{{ csrf_token() }}">'
        return form_tag + csrf_html
    
    # Replace all forms
    updated_content = re.sub(form_pattern, replace_form, content, flags=re.IGNORECASE)
    return updated_content

def process_template_file(filepath):
    """Process a single template file"""
    print(f"Processing {filepath}...")
    
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Skip if already has csrf_token
        if 'csrf_token()' in content and 'method="POST"' in content:
            forms_count = content.count('method="POST"') + content.count("method='POST'")
            csrf_count = content.count('csrf_token()')
            
            if csrf_count >= forms_count:
                print(f"  ✓ Already has CSRF tokens")
                return False
        
        # Add CSRF tokens
        updated_content = add_csrf_to_form(content)
        
        if updated_content != content:
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(updated_content)
            print(f"  ✅ Added CSRF token")
            return True
        else:
            print(f"  ✓ No changes needed")
            return False
            
    except Exception as e:
        print(f"  ❌ Error: {e}")
        return False

def main():
    """Main function"""
    templates_dir = '/var/www/lawbot/templates'
    
    # Find all HTML template files
    template_files = []
    for root, dirs, files in os.walk(templates_dir):
        for file in files:
            if file.endswith('.html'):
                template_files.append(os.path.join(root, file))
    
    print(f"Found {len(template_files)} template files")
    print("=" * 50)
    
    updated_count = 0
    for filepath in sorted(template_files):
        if process_template_file(filepath):
            updated_count += 1
    
    print("=" * 50)
    print(f"✅ Updated {updated_count} files with CSRF tokens")
    
    if updated_count > 0:
        print("\n⚠️  Remember to reload the application:")
        print("   sudo kill -HUP $(pgrep -f gunicorn)")

if __name__ == "__main__":
    main()