import re
import fitz  # PyMuPDF
import openai
import os
import time
import base64
from typing import List, Dict, Any, Optional

# Import OCR libraries
import pytesseract
from PIL import Image
import io
import tiktoken

# Import practice area templates
try:
    from practice_area_templates import (
        get_analysis_prompt,
        get_summary_prompt,
        detect_document_type
    )
except ImportError:
    # Fallback if templates not available
    def get_analysis_prompt(practice_area, analysis_type):
        return None
    def get_summary_prompt(practice_area, summary_type):
        return None
    def detect_document_type(text, filename=""):
        return ('general', 'general_analysis')

class DocumentSummarizationProcessor:
    def __init__(self, api_key, model="gpt-4o", max_tokens=7000):
        self.api_key = api_key
        self.model = model
        self.max_tokens = max_tokens
        self.encoding = tiktoken.encoding_for_model("gpt-4")
        
        # Set more reasonable chunk limits based on model capabilities
        if model == "gpt-4o":
            self.max_chunk_tokens = 5000  # More conservative for large documents
        elif model == "gpt-3.5-turbo":
            self.max_chunk_tokens = 2500  # Conservative limit for GPT-3.5
        else:
            self.max_chunk_tokens = 2500  # Default conservative

    def extract_text_from_pdf(self, file_path):
        """Extract text from PDF file with OCR fallback for scanned documents"""
        try:
            doc = fitz.open(file_path)
            text = ""
            total_pages = len(doc)
            
            print(f"Processing PDF with {total_pages} pages")

            # For very large PDFs, sample pages more aggressively
            if total_pages > 200:
                print(f"⚠️ Very large PDF detected ({total_pages} pages). Using aggressive sampling.")
                # For Supreme Court decisions, focus on key sections
                if total_pages > 250:  # Likely a Supreme Court decision like Dobbs
                    print("📜 Detected Supreme Court decision format - focusing on key sections")
                    # Process: Syllabus (first 5-10), Main opinion (20-40), Conclusion (last 10)
                    pages_to_process = list(range(min(10, total_pages)))  # Syllabus
                    pages_to_process.extend(range(20, min(40, total_pages)))  # Main opinion start
                    pages_to_process.extend(range(max(0, total_pages - 10), total_pages))  # Conclusion
                    print(f"Supreme Court mode: Processing {len(pages_to_process)} key pages")
                else:
                    # Regular large document sampling
                    pages_to_process = list(range(min(30, total_pages)))  # First 30
                    pages_to_process.extend(range(max(0, total_pages - 15), total_pages))  # Last 15
                    # Sample every 15th page from middle
                    pages_to_process.extend(range(30, total_pages - 15, 15))
                pages_to_process = sorted(set(pages_to_process))
                print(f"Processing {len(pages_to_process)} key pages out of {total_pages}")
            else:
                pages_to_process = range(total_pages)

            # First try normal text extraction
            for page_num in pages_to_process:
                page = doc[page_num]
                page_text = page.get_text()

                # If page has text content, use it
                if page_text.strip():
                    text += page_text
                # If page has no text (likely an image/scan), use OCR
                else:
                    print(f"Using OCR for page {page_num+1} (no text found in PDF)")
                    try:
                        # Get the page as a pixmap/image
                        pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))  # Scale up for better OCR
                        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)

                        # Try OpenAI Vision first, fallback to Tesseract
                        try:
                            page_text = self.extract_text_with_openai_vision(img)
                            print(f"Used OpenAI Vision for page {page_num+1}")
                        except Exception as vision_error:
                            print(f"OpenAI Vision failed for page {page_num+1}, falling back to Tesseract: {vision_error}")
                            page_text = pytesseract.image_to_string(img)
                        
                        text += page_text + "\n\n"
                    except Exception as ocr_error:
                        print(f"OCR error on page {page_num+1}: {ocr_error}")
                        text += f"[OCR FAILED FOR PAGE {page_num+1}]\n\n"

            return text
        except Exception as e:
            print(f"Error extracting text from PDF: {e}")
            return None

    def extract_text_from_docx(self, file_path):
        """Extract text from Word documents with fallback options"""
        # Check if this is a legacy .doc file (not .docx)
        if file_path.lower().endswith('.doc') and not file_path.lower().endswith('.docx'):
            try:
                import subprocess
                print("Detected legacy .doc file - using antiword")
                result = subprocess.run(['antiword', file_path],
                                      capture_output=True,
                                      text=True,
                                      timeout=60)
                if result.returncode == 0 and result.stdout.strip():
                    print(f"Successfully extracted text using antiword: {len(result.stdout)} characters")
                    return result.stdout
                else:
                    print(f"antiword failed with return code {result.returncode}")
                    if result.stderr:
                        print(f"antiword stderr: {result.stderr}")
            except Exception as e:
                print(f"antiword extraction failed: {e}")

        # For .docx files, try mammoth first (best formatting preservation)
        try:
            import mammoth
            with open(file_path, "rb") as docx_file:
                result = mammoth.extract_raw_text(docx_file)
                if result.value and len(result.value.strip()) > 0:
                    print(f"Successfully extracted text using mammoth: {len(result.value)} characters")
                    return result.value
        except Exception as e:
            print(f"Mammoth failed: {e}")
        
        try:
            # Fallback to docx2txt
            import docx2txt
            text = docx2txt.process(file_path)
            if text and len(text.strip()) > 0:
                print(f"Successfully extracted text using docx2txt: {len(text)} characters")
                return text
        except Exception as e:
            print(f"docx2txt failed: {e}")
        
        try:
            # Final fallback to python-docx
            import docx
            doc = docx.Document(file_path)
            text = ""
            
            # Extract paragraph text
            for paragraph in doc.paragraphs:
                text += paragraph.text + "\n"
            
            # Extract table text
            for table in doc.tables:
                for row in table.rows:
                    for cell in row.cells:
                        text += cell.text + " "
                    text += "\n"
            
            if text.strip():
                print(f"Successfully extracted text using python-docx: {len(text)} characters")
                return text
            else:
                print("python-docx extracted empty text")
                return None
        except Exception as e:
            print(f"python-docx failed: {e}")
            return None

    def image_to_base64(self, image):
        """Convert PIL Image to base64 string for OpenAI Vision API"""
        try:
            # Convert to RGB if not already (for PNG with transparency, etc.)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Save to bytes buffer
            buffer = io.BytesIO()
            image.save(buffer, format='JPEG', quality=95)
            image_bytes = buffer.getvalue()
            
            # Encode to base64
            base64_string = base64.b64encode(image_bytes).decode('utf-8')
            return base64_string
        except Exception as e:
            print(f"Error converting image to base64: {e}")
            return None

    def extract_text_with_openai_vision(self, image):
        """Extract text from image using OpenAI Vision API"""
        try:
            # Convert image to base64
            base64_image = self.image_to_base64(image)
            if not base64_image:
                raise Exception("Failed to convert image to base64")

            client = openai.OpenAI(api_key=self.api_key, timeout=180.0)
            
            response = client.chat.completions.create(
                model="gpt-4o",  # GPT-4o supports vision
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": """Extract all text from this image. Focus on:
1. Maintaining original formatting and structure
2. Preserving line breaks and paragraph structure  
3. Including all text, even if handwritten or in tables
4. If this appears to be a legal document, preserve section numbers, dates, and formal language exactly
5. If no text is visible, respond with: [NO TEXT FOUND IN IMAGE]

Return only the extracted text, no additional commentary."""
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{base64_image}",
                                    "detail": "high"  # Use high detail for better text extraction
                                }
                            }
                        ]
                    }
                ],
                max_tokens=4000,  # Generous token limit for text extraction
                temperature=0.1   # Low temperature for consistent text extraction
            )
            
            extracted_text = response.choices[0].message.content
            
            # Check if no text was found
            if "[NO TEXT FOUND IN IMAGE]" in extracted_text:
                print("OpenAI Vision found no text in image")
                return ""
            
            return extracted_text
            
        except Exception as e:
            print(f"OpenAI Vision API error: {e}")
            raise e  # Re-raise to allow fallback to Tesseract

    def extract_text_from_image(self, file_path):
        """Extract text from image files using OpenAI Vision API with Tesseract fallback"""
        try:
            img = Image.open(file_path)
            
            # Try OpenAI Vision first
            try:
                print("Attempting text extraction with OpenAI Vision API...")
                text = self.extract_text_with_openai_vision(img)
                if text and text.strip():
                    print(f"OpenAI Vision extracted {len(text)} characters")
                    return text
                else:
                    print("OpenAI Vision returned empty text, trying Tesseract...")
                    
            except Exception as vision_error:
                print(f"OpenAI Vision failed: {vision_error}")
                print("Falling back to Tesseract OCR...")
            
            # Fallback to Tesseract OCR
            text = pytesseract.image_to_string(img)
            print(f"Tesseract OCR extracted {len(text)} characters")
            return text
            
        except Exception as e:
            print(f"Error extracting text from image: {e}")
            return None

    def extract_text_from_file(self, file_path):
        """Extract text from PDF, TXT, Word, or image files"""
        lower_path = file_path.lower()
        if lower_path.endswith('.pdf'):
            return self.extract_text_from_pdf(file_path)
        elif lower_path.endswith('.txt'):
            with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
                return f.read()
        elif lower_path.endswith(('.docx', '.doc')):
            return self.extract_text_from_docx(file_path)
        elif lower_path.endswith(('.jpg', '.jpeg', '.png', '.tiff', '.tif', '.bmp')):
            return self.extract_text_from_image(file_path)
        else:
            print(f"Unsupported file format: {file_path}")
            return None

    def tokenize(self, text: str) -> List[int]:
        """Tokenize text using tiktoken"""
        return self.encoding.encode(text)

    def chunk_on_delimiter(self, input_string: str, max_tokens: int, delimiter: str) -> List[str]:
        """Split text into chunks based on delimiter and token limit"""
        chunks = input_string.split(delimiter)
        combined_chunks, _, dropped_chunk_count = self.combine_chunks_with_no_minimum(
            chunks, max_tokens, chunk_delimiter=delimiter, add_ellipsis_for_overflow=True
        )
        if dropped_chunk_count > 0:
            print(f"Warning: Dropped {dropped_chunk_count} chunks due to length")
        return combined_chunks

    def combine_chunks_with_no_minimum(self, chunks: List[str], max_tokens: int,
                                     chunk_delimiter="\n\n", header: str = "",
                                     add_ellipsis_for_overflow=False):
        """Combine text chunks without minimum size constraint"""
        dropped_chunk_count = 0
        output = []
        output_len = 0
        
        header_tokens = len(self.tokenize(header)) if header else 0
        delimiter_tokens = len(self.tokenize(chunk_delimiter))
        candidate = header if header else ""

        for chunk in chunks:
            chunk_with_header = header + chunk if header else chunk
            chunk_tokens = len(self.tokenize(chunk_with_header))
            
            # FIXED: Use the instance's max_chunk_tokens instead of passed max_tokens
            # This prevents valid documents from being dropped
            if chunk_tokens > self.max_chunk_tokens:
                print(f"WARNING: Chunk over token limit ({chunk_tokens} tokens > {self.max_chunk_tokens})")
                
                # Try to split large chunks instead of dropping them
                if len(chunk) > 1000:  # Only try to split if chunk is reasonably large
                    print(f"Attempting to split large chunk...")
                    sub_chunks = self.split_large_chunk(chunk)
                    for sub_chunk in sub_chunks:
                        sub_chunk_tokens = len(self.tokenize(sub_chunk))
                        if sub_chunk_tokens <= self.max_chunk_tokens:
                            # Process this sub-chunk normally
                            test_candidate = candidate + chunk_delimiter + sub_chunk if candidate else sub_chunk
                            candidate_len = len(self.tokenize(test_candidate))
                            
                            if candidate_len > max_tokens:
                                if candidate:
                                    output.append(candidate)
                                    output_len += len(self.tokenize(candidate))
                                candidate = header + sub_chunk if header else sub_chunk
                            else:
                                candidate = test_candidate
                        else:
                            print(f"Sub-chunk still too large ({sub_chunk_tokens} tokens), dropping")
                            dropped_chunk_count += 1
                else:
                    dropped_chunk_count += 1
                continue

            # Normal processing for chunks within limits
            test_candidate = candidate + chunk_delimiter + chunk if candidate else chunk
            candidate_len = len(self.tokenize(test_candidate))
            
            if candidate_len > max_tokens:
                if candidate:  # Only add non-empty candidates
                    output.append(candidate)
                    output_len += len(self.tokenize(candidate))
                candidate = chunk_with_header
            else:
                candidate = test_candidate

        if candidate:
            output.append(candidate)
            output_len += len(self.tokenize(candidate))

        return output, output_len, dropped_chunk_count

    def split_large_chunk(self, chunk: str) -> List[str]:
        """Split a large chunk into smaller pieces"""
        # Try different delimiters in order of preference
        delimiters = ['\n\n', '\n', '. ', ' ']
        
        for delimiter in delimiters:
            if delimiter in chunk:
                parts = chunk.split(delimiter)
                if len(parts) > 1:
                    # Recombine with delimiter and return
                    result = []
                    for i, part in enumerate(parts):
                        if i < len(parts) - 1:  # Add delimiter back except for last part
                            result.append(part + delimiter)
                        else:
                            result.append(part)
                    return [p for p in result if p.strip()]  # Filter empty parts
        
        # If no good delimiter found, split by character count
        chunk_size = len(chunk) // 3  # Split into 3 parts
        return [chunk[i:i+chunk_size] for i in range(0, len(chunk), chunk_size)]

    def get_summary_prompt(self, summary_type: str, summary_length: str, additional_instructions: str = "",
                          source_type: str = "document", practice_area: str = None, analysis_type: str = None) -> str:
        """Generate appropriate prompt based on summary type, length, source type, and practice area"""

        # Check for practice-specific prompts first
        if practice_area and practice_area != 'general':
            # For analysis jobs, use analysis prompts
            if analysis_type:
                practice_prompt = get_analysis_prompt(practice_area, analysis_type)
                if practice_prompt:
                    if additional_instructions:
                        return f"{practice_prompt}\n\nAdditional instructions: {additional_instructions}"
                    return practice_prompt
            # For summary jobs, use summary prompts
            else:
                practice_prompt = get_summary_prompt(practice_area, summary_type)
                if practice_prompt:
                    if additional_instructions:
                        return f"{practice_prompt}\n\nAdditional instructions: {additional_instructions}"
                    return practice_prompt

        # Fallback to original prompt generation
        # Determine appropriate language based on source type
        if source_type == "image":
            content_reference = "image"
            content_verb = "shows" if summary_type == "standard" else "contains"
        else:
            content_reference = "document"
            content_verb = "provides"

        # Base instruction based on summary type
        if summary_type == "executive":
            base_instruction = f"Create an executive summary focusing on key decisions, recommendations, and high-level insights from this {content_reference} that would be most relevant to decision-makers and stakeholders."
        elif summary_type == "bullet_points":
            base_instruction = f"Create a summary in bullet point format, organizing the main points from this {content_reference} clearly and concisely."
        elif summary_type == "key_facts":
            base_instruction = f"Extract and summarize the key facts, figures, dates, names, and important details from this {content_reference}."
        else:  # standard
            base_instruction = f"Create a comprehensive summary that captures the main themes, important details, and overall content of this {content_reference}."

        # Add natural language guidance
        natural_language_hint = f"Begin your summary naturally (e.g., 'The {content_reference} {content_verb}...' or 'This {content_reference} {content_verb}...')."

        # Length specification
        if summary_length == "short":
            length_instruction = "Keep the summary concise (100-200 words)."
        elif summary_length == "long":
            length_instruction = "Provide a detailed summary (400-600 words)."
        else:  # medium
            length_instruction = "Create a balanced summary (200-400 words)."

        # Combine instructions
        prompt = f"{base_instruction} {length_instruction} {natural_language_hint}"

        if additional_instructions:
            prompt += f" Additional requirements: {additional_instructions}"

        return prompt

    def detect_practice_area(self, text: str, filename: str = "") -> tuple:
        """Detect practice area and document type from content"""
        return detect_document_type(text, filename)

    def create_summary(self, text: str, summary_type: str = "standard",
                      summary_length: str = "medium", source_type: str = "document",
                      practice_area: str = None, analysis_type: str = None) -> Dict[str, Any]:
        """
        Create a summary from text with better error handling and token management
        
        Args:
            text: The text to summarize
            summary_type: Type of summary (standard, executive, bullet_points, key_facts)  
            summary_length: Length of summary (short, medium, long)
            
        Returns:
            Dictionary with success status and summary data
        """
        try:
            if not text or not text.strip():
                return {
                    'success': False,
                    'error': 'No text provided for summarization',
                    'word_count': 0,
                    'compression_ratio': 0
                }

            print(f"Creating summary from {len(text)} characters")

            # Get document length in tokens
            document_length = len(self.tokenize(text))
            print(f"Document length: {document_length} tokens")

            # Check if document is within our processing limits
            if document_length <= self.max_chunk_tokens:
                print(f"Document fits in single chunk ({document_length} <= {self.max_chunk_tokens} tokens) - processing directly")
                
                # Get summary prompt with source type
                system_prompt = self.get_summary_prompt(summary_type, summary_length, source_type=source_type)
                
                client = openai.OpenAI(api_key=self.api_key, timeout=180.0)
                
                messages = [
                    {"role": "system", "content": f"You are an expert document summarizer. {system_prompt}"},
                    {"role": "user", "content": f"Please summarize the following text:\n\n{text}"}
                ]

                response = client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    max_tokens=self.max_tokens,
                    temperature=0.3
                )

                summary = response.choices[0].message.content
                
                # Calculate metrics
                original_words = len(text.split())
                summary_words = len(summary.split())
                compression_ratio = round((1 - summary_words / original_words) * 100, 1) if original_words > 0 else 0
                
                return {
                    'success': True,
                    'summary': summary,
                    'word_count': summary_words,
                    'compression_ratio': compression_ratio,
                    'original_length': document_length,
                    'chunks_processed': 1
                }
            else:
                # Document is too large - need to chunk it
                print(f"Document too large ({document_length} tokens) - using chunking approach")
                
                # Calculate reasonable chunk size (leave room for system prompt)
                target_chunk_size = self.max_chunk_tokens - 500  # Reserve tokens for prompts
                
                # Split text into chunks
                text_chunks = self.chunk_on_delimiter(text, target_chunk_size, "\n\n")
                
                if not text_chunks:
                    return {
                        'success': False,
                        'error': 'Failed to create processable chunks from document',
                        'word_count': 0,
                        'compression_ratio': 0
                    }

                print(f"Split document into {len(text_chunks)} chunks")

                # Get summary prompt with source type
                system_prompt = self.get_summary_prompt(summary_type, summary_length, source_type=source_type)

                # Process chunks
                chunk_summaries = []
                client = openai.OpenAI(api_key=self.api_key, timeout=180.0)

                for i, chunk in enumerate(text_chunks):
                    try:
                        print(f"Processing chunk {i+1}/{len(text_chunks)}")

                        messages = [
                            {"role": "system", "content": f"You are an expert document summarizer. {system_prompt}"},
                            {"role": "user", "content": f"Please summarize this section of the document:\n\n{chunk}"}
                        ]

                        response = client.chat.completions.create(
                            model=self.model,
                            messages=messages,
                            max_tokens=self.max_tokens,
                            temperature=0.3
                        )

                        chunk_summary = response.choices[0].message.content
                        chunk_summaries.append(chunk_summary)

                        # Small delay to avoid rate limits
                        time.sleep(0.5)

                    except Exception as e:
                        print(f"Error processing chunk {i+1}: {e}")
                        chunk_summaries.append(f"[Error processing section {i+1}]")

                # Create final summary from chunk summaries
                if len(chunk_summaries) == 1:
                    final_summary = chunk_summaries[0]
                else:
                    combined_summaries = "\n\n".join(chunk_summaries)

                    messages = [
                        {"role": "system", "content": f"You are an expert document summarizer. Create a cohesive final summary from the following section summaries. {system_prompt}"},
                        {"role": "user", "content": f"Please create a unified summary from these section summaries:\n\n{combined_summaries}"}
                    ]

                    try:
                        response = client.chat.completions.create(
                            model=self.model,
                            messages=messages,
                            max_tokens=self.max_tokens,
                            temperature=0.3
                        )

                        final_summary = response.choices[0].message.content
                    except Exception as e:
                        print(f"Error creating final summary: {e}")
                        final_summary = "\n\n".join(chunk_summaries)

                # Calculate metrics
                original_words = len(text.split())
                summary_words = len(final_summary.split())
                compression_ratio = round((1 - summary_words / original_words) * 100, 1) if original_words > 0 else 0

                return {
                    'success': True,
                    'summary': final_summary,
                    'word_count': summary_words,
                    'compression_ratio': compression_ratio,
                    'original_length': document_length,
                    'chunks_processed': len(text_chunks)
                }

        except Exception as e:
            import traceback
            print(f"Error in create_summary: {e}")
            print(traceback.format_exc())
            return {
                'success': False,
                'error': f"Error creating summary: {str(e)}",
                'word_count': 0,
                'compression_ratio': 0
            }

    def summarize_document(self, file_path: str, summary_type: str = "standard",
                          summary_length: str = "medium", additional_instructions: str = "",
                          detail: float = 0.5) -> Dict[str, Any]:
        """
        Summarize a document with specified type and length

        Args:
            file_path: Path to the document file
            summary_type: Type of summary (standard, executive, bullet_points, key_facts)
            summary_length: Length of summary (short, medium, long)
            additional_instructions: Custom instructions for the summary
            detail: Level of detail (0.0 to 1.0)

        Returns:
            Dictionary containing the summary and metadata
        """
        try:
            # Extract text from document
            text = self.extract_text_from_file(file_path)
            if not text:
                return {
                    "summary": "Failed to extract text from document",
                    "metadata": {
                        "error": "Text extraction failed",
                        "chunks_processed": 0,
                        "total_tokens": 0
                    }
                }

            # Determine source type based on file extension
            lower_path = file_path.lower()
            if lower_path.endswith(('.jpg', '.jpeg', '.png', '.tiff', '.tif', '.bmp')):
                source_type = "image"
            else:
                source_type = "document"

            print(f"Extracted {len(text)} characters from {source_type}")
            
            # Check for very large documents (like Dobbs decision)
            if len(text) > 500000:  # More than 500k characters
                print(f"⚠️ Very large document detected ({len(text)} chars). Using aggressive chunking.")
                # For very large documents, force more conservative settings
                self.max_chunk_tokens = min(self.max_chunk_tokens, 3000)
                self.max_tokens = min(self.max_tokens, 3000)
                
                # Add special instruction for large legal documents
                if "supreme court" in text[:5000].lower() or "opinion of the court" in text[:5000].lower():
                    if additional_instructions:
                        additional_instructions += " Focus on the majority opinion, key holdings, and legal reasoning."
                    else:
                        additional_instructions = "Focus on the majority opinion, key holdings, and legal reasoning."

            # Use the new create_summary method with source type
            result = self.create_summary(text, summary_type, summary_length, source_type)
            
            if result['success']:
                return {
                    "summary": result['summary'],
                    "section_summaries": [result['summary']],  # For compatibility
                    "metadata": {
                        "summary_type": summary_type,
                        "summary_length": summary_length,
                        "chunks_processed": result.get('chunks_processed', 1),
                        "total_tokens": result.get('original_length', 0),
                        "detail_level": detail,
                        "word_count": result.get('word_count', 0),
                        "compression_ratio": result.get('compression_ratio', 0)
                    }
                }
            else:
                return {
                    "summary": f"Error: {result['error']}",
                    "metadata": {
                        "error": result['error'],
                        "chunks_processed": 0,
                        "total_tokens": 0
                    }
                }

        except Exception as e:
            import traceback
            print(f"Error in summarize_document: {e}")
            print(traceback.format_exc())
            return {
                "summary": f"Error processing document: {str(e)}",
                "metadata": {
                    "error": str(e),
                    "chunks_processed": 0,
                    "total_tokens": 0
                }
            }
