from flask_sqlalchemy import SQLAlchemy
from datetime import datetime, timedelta
from werkzeug.security import generate_password_hash, check_password_hash
from sqlalchemy import Numeric, Index
import uuid
import json

db = SQLAlchemy()

# Association table for case team members
case_team = db.Table('case_team',
    db.Column('case_id', db.Integer, db.ForeignKey('cases.id'), primary_key=True),
    db.Column('user_id', db.Integer, db.ForeignKey('users.id'), primary_key=True),
    db.Column('role', db.String(50)),
    db.Column('added_at', db.DateTime, default=datetime.utcnow)
)

class Company(db.Model):
    __tablename__ = 'companies'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(150), unique=True, nullable=False, index=True)
    description = db.Column(db.Text)
    subscription_tier = db.Column(db.String(50), default='free', index=True)  # free, premium, enterprise
    active = db.Column(db.Boolean, default=True, index=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    # Relationships
    users = db.relationship('User', backref='company', lazy=True)
    cases = db.relationship('Case', backref='company', lazy=True)

class User(db.Model):
    __tablename__ = 'users'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(100), unique=True, nullable=False, index=True)
    email = db.Column(db.String(150), unique=True, nullable=False, index=True)
    password_hash = db.Column(db.String(256), nullable=False)
    first_name = db.Column(db.String(100))
    last_name = db.Column(db.String(100))
    phone_number = db.Column(db.String(20))
    role = db.Column(db.String(50), default='user', index=True)  # user, company_admin, admin
    active = db.Column(db.Boolean, default=True, index=True)
    email_verified = db.Column(db.Boolean, default=False, index=True)
    email_verification_token = db.Column(db.String(100), unique=True)
    email_verification_sent_at = db.Column(db.DateTime)
    last_login = db.Column(db.DateTime)
    failed_login_attempts = db.Column(db.Integer, default=0)
    locked_until = db.Column(db.DateTime, nullable=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)

    # Foreign Keys
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), nullable=True, index=True)

    # Stripe Integration
    stripe_customer_id = db.Column(db.String(100), index=True)

    # Relationships with cascade delete for dependent records
    analysis_jobs = db.relationship('AnalysisJob', backref='user', lazy=True, cascade='all, delete-orphan')
    summarization_jobs = db.relationship('SummarizationJob', backref='user', lazy=True, cascade='all, delete-orphan')
    sessions = db.relationship('Session', backref='user', lazy=True, cascade='all, delete-orphan')
    activities = db.relationship('ActivityLog', backref='user', lazy=True, cascade='all, delete-orphan')
    invitations_created = db.relationship('Invitation', backref='creator', foreign_keys='Invitation.created_by', cascade='all, delete-orphan')
    led_cases = db.relationship('Case', foreign_keys='Case.lead_attorney_id', backref='lead_attorney')
    created_cases = db.relationship('Case', foreign_keys='Case.created_by_id', backref='created_by')
    case_assignments = db.relationship('Case', secondary=case_team, backref='team_members')
    uploaded_documents = db.relationship('CaseDocument', backref='uploaded_by', cascade='all, delete-orphan')
    case_notes = db.relationship('CaseNote', backref='user', cascade='all, delete-orphan')
    case_activities = db.relationship('CaseActivity', backref='user', cascade='all, delete-orphan')
    subscriptions = db.relationship('Subscription', backref='user', lazy=True, cascade='all, delete-orphan')
    password_resets = db.relationship('PasswordReset', backref='user', lazy=True, cascade='all, delete-orphan')

    def set_password(self, password):
        self.password_hash = generate_password_hash(password)

    def check_password(self, password):
        return check_password_hash(self.password_hash, password)

    def is_admin(self):
        return self.role == 'admin'

    def is_company_admin(self):
        return self.role in ['company_admin', 'admin']

    def can_view_case(self, case):
        """Check if user can view this case"""
        # Admins can see everything
        if self.is_admin():
            return True

        # Must be in the same company
        if self.company_id != case.company_id:
            return False

        # Company admins can see all company cases
        if self.is_company_admin():
            return True

        # Case lead attorney
        if case.lead_attorney_id == self.id:
            return True

        # Team members
        if self in case.team_members:
            return True

        # For non-confidential cases, all company users can view
        if not case.is_confidential:
            return True

        return False

    def can_delete_case(self, case):
        """Check if user can delete a specific case"""
        if self.is_admin():
            return True
        
        if self.is_company_admin() and case.company_id == self.company_id:
            return True
        
        if case.created_by_id == self.id or case.lead_attorney_id == self.id:
            return True
        
        return False

    def can_edit_case(self, case):
        """Check if user can modify this case"""
        # Only lead attorney and admins
        return (self.is_admin() or
                (self.is_company_admin() and self.company_id == case.company_id) or
                case.lead_attorney_id == self.id)

    # Subscription Methods
    def get_subscription(self):
        """Get user's current active subscription"""
        return Subscription.query.filter_by(user_id=self.id, status='active').first()

    def has_active_subscription(self):
        """Check if user has an active subscription"""
        subscription = self.get_subscription()
        return subscription and subscription.status == 'active'

    def get_plan_name(self):
        """Get user's current plan name"""
        subscription = self.get_subscription()
        return subscription.plan.name if subscription else 'free'

    def can_access_feature(self, feature):
        """Check if user can access a specific feature"""
        subscription = self.get_subscription()
        if not subscription:
            # Default free plan features
            free_features = {
                'basic_analysis': True,
                'case_management': False,
                'arguments_generation': False,
                'comparison_tools': False,
                'monthly_analyses': 5
            }
            return free_features.get(feature, False)
        
        return subscription.is_feature_enabled(feature)

    def get_monthly_analyses_remaining(self):
        """Get remaining analyses for the month"""
        subscription = self.get_subscription()
        if not subscription:
            return 5  # Free plan default

        # Get the monthly limit from plan features
        limit = subscription.plan.features.get('monthly_analyses', -1)

        # If limit is -1 (unlimited), return infinity
        if limit == -1:
            return float('inf')

        # Otherwise calculate remaining based on usage
        return max(0, limit - subscription.monthly_analyses_used)

    def get_full_name(self):
        """Get user's full name"""
        if self.first_name and self.last_name:
            return f"{self.first_name} {self.last_name}"
        elif self.first_name:
            return self.first_name
        elif self.last_name:
            return self.last_name
        else:
            return self.username

class Session(db.Model):
    __tablename__ = 'sessions'

    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    session_token = db.Column(db.String(256), unique=True, nullable=False, index=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    expires_at = db.Column(db.DateTime, nullable=False, index=True)
    ip_address = db.Column(db.String(45))
    user_agent = db.Column(db.String(256))

    @property
    def is_expired(self):
        return datetime.utcnow() > self.expires_at

    @staticmethod
    def create_for_user(user_id, days_valid=1):
        session_token = str(uuid.uuid4())
        expires_at = datetime.utcnow() + timedelta(days=days_valid)
        return Session(
            user_id=user_id,
            session_token=session_token,
            expires_at=expires_at
        )

class AnalysisJob(db.Model):
    __tablename__ = 'analysis_jobs'

    id = db.Column(db.Integer, primary_key=True)
    job_uuid = db.Column(db.String(36), unique=True, nullable=False, index=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), index=True)
    case_id = db.Column(db.Integer, db.ForeignKey('cases.id'), nullable=True, index=True)
    filename = db.Column(db.String(256), nullable=False)
    original_filename = db.Column(db.String(256), index=True)
    file_path = db.Column(db.String(512))
    perspective = db.Column(db.String(50), nullable=False, index=True)  # prosecutor, defense, neutral
    additional_instructions = db.Column(db.Text, nullable=True)
    practice_area = db.Column(db.String(50), index=True)  # business, family, criminal, civil, general
    analysis_type = db.Column(db.String(100))  # contract_review, custody_evaluation, asset_division, etc.
    status = db.Column(db.String(50), default='pending', index=True)  # pending, processing, completed, failed
    error_message = db.Column(db.Text)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    completed_at = db.Column(db.DateTime, index=True)

    # Relationships - FIXED
    results = db.relationship('AnalysisResult', backref='job', uselist=False, cascade='all, delete-orphan')
    case_documents = db.relationship('CaseDocument', 
                                   foreign_keys='CaseDocument.analysis_job_id',
                                   primaryjoin='AnalysisJob.job_uuid == CaseDocument.analysis_job_id',
                                   backref='analysis_job')

    def get_display_info(self):
        """Get formatted info for dashboard display"""
        return {
            'job_id': self.job_uuid,
            'filename': self.original_filename,
            'status': self.status,
            'perspective': self.perspective.title(),
            'created_at': self.created_at,
            'completed_at': self.completed_at,
            'case': self.case.case_name if self.case else None,
            'user': self.user.get_full_name() if self.user else 'Unknown',
            'has_results': self.status == 'completed' and self.results is not None
        }

    def is_image_analysis(self):
        """Check if this is an image analysis job"""
        if not self.original_filename:
            return False
        image_extensions = {'jpg', 'jpeg', 'png', 'tif', 'tiff', 'bmp', 'gif', 'webp'}
        return '.' in self.original_filename and self.original_filename.rsplit('.', 1)[1].lower() in image_extensions

    def is_document_analysis(self):
        """Check if this is a document analysis job"""
        return not self.is_image_analysis()

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_analysis_user_status', 'user_id', 'status'),
        Index('idx_analysis_company_status', 'company_id', 'status'),
        Index('idx_analysis_case_status', 'case_id', 'status'),
        Index('idx_analysis_created_status', 'created_at', 'status'),
    )

class AnalysisResult(db.Model):
    __tablename__ = 'analysis_results'

    id = db.Column(db.Integer, primary_key=True)
    job_id = db.Column(db.Integer, db.ForeignKey('analysis_jobs.id'), nullable=False, index=True)
    final_analysis = db.Column(db.Text)
    section_analyses = db.Column(db.JSON)
    citations = db.Column(db.JSON)
    arguments = db.Column(db.Text)  # For storing generated arguments
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)

    @property
    def section_analyses_json(self):
        if isinstance(self.section_analyses, str):
            return json.loads(self.section_analyses)
        return self.section_analyses or []

    @property
    def citations_json(self):
        if isinstance(self.citations, str):
            return json.loads(self.citations)
        return self.citations or []

    def get_word_count(self):
        """Get approximate word count of the analysis"""
        if not self.final_analysis:
            return 0
        return len(self.final_analysis.split())

    def get_section_count(self):
        """Get number of sections analyzed"""
        sections = self.section_analyses_json
        return len(sections) if sections else 0

class ActivityLog(db.Model):
    __tablename__ = 'activity_logs'

    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), index=True)
    activity_type = db.Column(db.String(50), nullable=False, index=True)
    description = db.Column(db.Text)
    entity_type = db.Column(db.String(50), index=True)  # user, analysis_job, company, etc.
    entity_id = db.Column(db.Integer)
    ip_address = db.Column(db.String(45))
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_activity_user_created', 'user_id', 'created_at'),
        Index('idx_activity_company_created', 'company_id', 'created_at'),
        Index('idx_activity_type_created', 'activity_type', 'created_at'),
    )

# SUBSCRIPTION MODELS
class SubscriptionPlan(db.Model):
    __tablename__ = 'subscription_plans'
    
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(50), nullable=False, unique=True, index=True)  # 'free', 'pro', 'enterprise'
    display_name = db.Column(db.String(100), nullable=False)
    price = db.Column(Numeric(10, 2), nullable=False)  # Monthly price
    stripe_price_id = db.Column(db.String(100), index=True)  # Stripe price ID
    features = db.Column(db.JSON)  # Feature limits and permissions
    is_active = db.Column(db.Boolean, default=True, index=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    
    # Relationships
    subscriptions = db.relationship('Subscription', backref='plan', lazy=True)

    def get_feature(self, feature_name, default=None):
        """Get a specific feature value"""
        if not self.features:
            return default
        return self.features.get(feature_name, default)

    def get_monthly_analysis_limit(self):
        """Get monthly analysis limit for this plan"""
        return self.get_feature('monthly_analyses', 5)

class Subscription(db.Model):
    __tablename__ = 'subscriptions'
    
    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), index=True)
    plan_id = db.Column(db.Integer, db.ForeignKey('subscription_plans.id'), nullable=False, index=True)
    
    # Stripe information
    stripe_customer_id = db.Column(db.String(100), index=True)
    stripe_subscription_id = db.Column(db.String(100), index=True)
    
    # Subscription status
    status = db.Column(db.String(20), default='active', index=True)  # active, canceled, past_due, incomplete
    current_period_start = db.Column(db.DateTime, index=True)
    current_period_end = db.Column(db.DateTime, index=True)
    cancel_at_period_end = db.Column(db.Boolean, default=False)
    
    # Usage tracking
    monthly_analyses_used = db.Column(db.Integer, default=0, index=True)
    monthly_reset_date = db.Column(db.DateTime, index=True)
    
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
    
    def reset_monthly_usage(self):
        """Reset monthly usage counters"""
        self.monthly_analyses_used = 0
        self.monthly_reset_date = datetime.utcnow() + timedelta(days=30)
        db.session.commit()
    
    def can_analyze_document(self):
        """Check if user can analyze another document"""
        if self.plan.name == 'free':
            # Check monthly limit
            if self.monthly_analyses_used < self.plan.get_monthly_analysis_limit():
                return True
            # Check if user has unused single analysis purchases
            # Avoid circular import by using query directly
            unused_purchase = db.session.query(SingleAnalysisPurchase).filter_by(
                user_id=self.user_id,
                status='completed'
            ).first()
            return unused_purchase is not None
        return True  # Pro and Enterprise have unlimited
    
    def increment_usage(self):
        """Increment analysis usage counter"""
        self.monthly_analyses_used += 1
        db.session.commit()
    
    def is_feature_enabled(self, feature):
        """Check if a specific feature is enabled for this subscription"""
        return self.plan.get_feature(feature, False)

    def get_usage_percentage(self):
        """Get usage as percentage for free plans"""
        if self.plan.name != 'free':
            return 0  # Unlimited plans
        
        limit = self.plan.get_monthly_analysis_limit()
        if limit == 0:
            return 100
        
        return min(100, (self.monthly_analyses_used / limit) * 100)

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_subscription_user_status', 'user_id', 'status'),
        Index('idx_subscription_plan_status', 'plan_id', 'status'),
    )

class PaymentHistory(db.Model):
    __tablename__ = 'payment_history'
    
    id = db.Column(db.Integer, primary_key=True)
    subscription_id = db.Column(db.Integer, db.ForeignKey('subscriptions.id'), nullable=False, index=True)
    stripe_payment_intent_id = db.Column(db.String(100), index=True)
    amount = db.Column(Numeric(10, 2), nullable=False)
    currency = db.Column(db.String(3), default='usd')
    status = db.Column(db.String(20), index=True)  # succeeded, failed, pending
    description = db.Column(db.String(255))
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    
    # Relationships
    subscription = db.relationship('Subscription', backref='payments')

class SingleAnalysisPurchase(db.Model):
    """Track individual analysis purchases for users on free plans"""
    __tablename__ = 'single_analysis_purchases'

    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), index=True)
    stripe_payment_intent_id = db.Column(db.String(100), unique=True, index=True)
    stripe_session_id = db.Column(db.String(100), unique=True, index=True)
    amount = db.Column(Numeric(10, 2), nullable=False)  # Should be 15.00
    currency = db.Column(db.String(3), default='usd')
    status = db.Column(db.String(20), default='pending', index=True)  # pending, completed, failed, used
    analysis_job_id = db.Column(db.String(255), index=True)  # Link to the analysis job when used
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    used_at = db.Column(db.DateTime)  # When the analysis was performed

    # Relationships
    user = db.relationship('User', backref='single_analysis_purchases')
    company = db.relationship('Company', backref='single_analysis_purchases')

    def mark_as_used(self, job_id):
        """Mark this purchase as used for a specific analysis job"""
        self.status = 'used'
        self.analysis_job_id = job_id
        self.used_at = datetime.utcnow()
        db.session.commit()

    @staticmethod
    def has_unused_purchase(user_id):
        """Check if user has an unused single analysis purchase"""
        return SingleAnalysisPurchase.query.filter_by(
            user_id=user_id,
            status='completed'
        ).first() is not None

class Invitation(db.Model):
    __tablename__ = 'invitations'

    id = db.Column(db.Integer, primary_key=True)
    email = db.Column(db.String(150), nullable=False, index=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), nullable=False, index=True)
    role = db.Column(db.String(50), default='user')
    invitation_token = db.Column(db.String(256), unique=True, nullable=False, index=True)
    created_by = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    expires_at = db.Column(db.DateTime, index=True)
    accepted = db.Column(db.Boolean, default=False, index=True)
    accepted_at = db.Column(db.DateTime)

    # Relationships
    company = db.relationship('Company', backref='invitations')

    @property
    def is_expired(self):
        return datetime.utcnow() > self.expires_at

    @staticmethod
    def create_invitation(email, company_id, role='user', created_by=None, days_valid=7):
        invitation_token = str(uuid.uuid4())
        expires_at = datetime.utcnow() + timedelta(days=days_valid)
        return Invitation(
            email=email,
            company_id=company_id,
            role=role,
            invitation_token=invitation_token,
            created_by=created_by,
            expires_at=expires_at
        )

class PasswordReset(db.Model):
    __tablename__ = 'password_resets'

    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    token = db.Column(db.String(256), unique=True, nullable=False, index=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    expires_at = db.Column(db.DateTime, nullable=False, index=True)
    used = db.Column(db.Boolean, default=False, index=True)
    used_at = db.Column(db.DateTime)

    @property
    def is_expired(self):
        return datetime.utcnow() > self.expires_at or self.used

    @staticmethod
    def create_token(user_id, hours_valid=24):
        token = str(uuid.uuid4())
        expires_at = datetime.utcnow() + timedelta(hours=hours_valid)
        return PasswordReset(
            user_id=user_id,
            token=token,
            expires_at=expires_at
        )

# Case Management Models
class Case(db.Model):
    __tablename__ = 'cases'

    id = db.Column(db.Integer, primary_key=True)
    case_number = db.Column(db.String(100), unique=True, nullable=False, index=True)
    case_name = db.Column(db.String(200), nullable=False, index=True)
    case_type = db.Column(db.String(50), index=True)  # litigation, contract, criminal, etc.
    practice_area = db.Column(db.String(50), index=True)  # business, family, criminal, civil, general
    client_name = db.Column(db.String(200), index=True)
    opposing_party = db.Column(db.String(200))
    jurisdiction = db.Column(db.String(100))
    status = db.Column(db.String(50), default='active', index=True)  # active, closed, archived
    description = db.Column(db.Text)

    # Company-level ownership
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), nullable=False, index=True)

    # Case lead/primary attorney
    lead_attorney_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)

    # Who created the case
    created_by_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)

    # Timestamps
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, index=True)
    date_opened = db.Column(db.Date, index=True)
    date_closed = db.Column(db.Date, index=True)

    # Access control
    is_confidential = db.Column(db.Boolean, default=False, index=True)

    # Process Service - Simple tracking
    process_server_name = db.Column(db.String(200))
    service_date = db.Column(db.DateTime)
    person_served = db.Column(db.String(200))
    documents_served = db.Column(db.Text)  # Comma-separated or JSON list
    service_notes = db.Column(db.Text)
    proof_of_service_path = db.Column(db.String(500))

    # Relationships - defined with string references to avoid circular imports
    analyses = db.relationship('AnalysisJob', backref='case', lazy='dynamic')
    documents = db.relationship('CaseDocument', backref='case', lazy='dynamic', cascade='all, delete-orphan')
    notes = db.relationship('CaseNote', backref='case', lazy='dynamic', cascade='all, delete-orphan')
    activities = db.relationship('CaseActivity', backref='case', lazy='dynamic', cascade='all, delete-orphan')

    def get_team_member_count(self):
        """Get count of team members assigned to this case"""
        return len(self.team_members)

    def get_document_count(self):
        """Get count of documents in this case"""
        return self.documents.count()

    def get_analysis_count(self):
        """Get count of analyses for this case"""
        return self.analyses.count()

    def can_user_access(self, user):
        """
        Check if a user can access this case

        Access is granted if user is:
        - Admin (sees all cases)
        - Company admin (sees all company cases)
        - Creator of the case
        - Lead attorney on the case
        - Team member on the case
        - Same company AND case is not confidential
        """
        if user.is_admin():
            return True

        if user.is_company_admin() and self.company_id == user.company_id:
            return True

        # Check if user is creator or lead attorney
        if self.created_by_id == user.id or self.lead_attorney_id == user.id:
            return True

        # Check if user is a team member
        if user in self.team_members:
            return True

        # For non-confidential cases, all company users can view
        if not self.is_confidential and self.company_id == user.company_id:
            return True

        return False

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_case_company_status', 'company_id', 'status'),
        Index('idx_case_lead_status', 'lead_attorney_id', 'status'),
        Index('idx_case_created_status', 'created_at', 'status'),
    )

class CaseDocument(db.Model):
    __tablename__ = 'case_documents'

    id = db.Column(db.Integer, primary_key=True)
    case_id = db.Column(db.Integer, db.ForeignKey('cases.id'), nullable=False, index=True)
    filename = db.Column(db.String(200), nullable=False)
    original_filename = db.Column(db.String(200), index=True)
    file_path = db.Column(db.String(500))
    document_type = db.Column(db.String(50), index=True)  # contract, pleading, discovery, etc.
    description = db.Column(db.Text)

    # Who uploaded it
    uploaded_by_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)
    uploaded_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)

    # Link to analysis if it exists - FIXED with proper indexing
    analysis_job_id = db.Column(db.String(36), db.ForeignKey('analysis_jobs.job_uuid'), index=True)

    # Link to summarization if it exists - FIXED with proper indexing
    summarization_job_id = db.Column(db.String(36), db.ForeignKey('summarization_jobs.job_uuid'), nullable=True, index=True)

    # Full-text content extracted from document for searching
    content_text = db.Column(db.Text)

    def has_analysis(self):
        """Check if this document has been analyzed"""
        return self.analysis_job_id is not None

    def has_summarization(self):
        """Check if this document has been summarized"""
        return self.summarization_job_id is not None

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_casedoc_case_uploaded', 'case_id', 'uploaded_at'),
        Index('idx_casedoc_uploader_uploaded', 'uploaded_by_id', 'uploaded_at'),
    )

class CaseNote(db.Model):
    __tablename__ = 'case_notes'

    id = db.Column(db.Integer, primary_key=True)
    case_id = db.Column(db.Integer, db.ForeignKey('cases.id'), nullable=False, index=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    note_text = db.Column(db.String(200), nullable=False)
    note_type = db.Column(db.String(50), index=True)  # general, strategy, client_comm, research
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_casenote_case_created', 'case_id', 'created_at'),
        Index('idx_casenote_user_created', 'user_id', 'created_at'),
    )

class CaseActivity(db.Model):
    __tablename__ = 'case_activities'

    id = db.Column(db.Integer, primary_key=True)
    case_id = db.Column(db.Integer, db.ForeignKey('cases.id'), nullable=False, index=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    activity_type = db.Column(db.String(50), index=True)  # document_added, analysis_complete, note_added, team_member_added
    description = db.Column(db.Text)
    entity_type = db.Column(db.String(50), index=True)  # document, analysis, note, etc.
    entity_id = db.Column(db.Integer)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_caseactivity_case_created', 'case_id', 'created_at'),
        Index('idx_caseactivity_user_created', 'user_id', 'created_at'),
        Index('idx_caseactivity_type_created', 'activity_type', 'created_at'),
    )

# Summarization Models - FIXED VERSION
class SummarizationJob(db.Model):
    __tablename__ = 'summarization_jobs'

    id = db.Column(db.Integer, primary_key=True)
    job_uuid = db.Column(db.String(36), unique=True, nullable=False, index=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), index=True)
    case_id = db.Column(db.Integer, db.ForeignKey('cases.id'), nullable=True, index=True)
    filename = db.Column(db.String(256), nullable=False)
    original_filename = db.Column(db.String(256), index=True)
    file_path = db.Column(db.String(512))
    practice_area = db.Column(db.String(50), index=True)  # business, family, criminal, civil, general
    summary_type = db.Column(db.String(50), nullable=False, index=True)  # standard, executive, bullet_points, key_facts
    summary_length = db.Column(db.String(50), nullable=False, index=True)  # short, medium, long
    additional_instructions = db.Column(db.Text)
    status = db.Column(db.String(50), default='pending', index=True)  # pending, processing, completed, failed
    error_message = db.Column(db.Text)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    completed_at = db.Column(db.DateTime, index=True)

    # Relationships - FIXED
    company = db.relationship('Company', backref='summarization_jobs')
    case = db.relationship('Case', backref='summarization_jobs', foreign_keys=[case_id])
    results = db.relationship('SummarizationResult', backref='job', uselist=False, cascade='all, delete-orphan')
    case_documents = db.relationship('CaseDocument', 
                                   foreign_keys='CaseDocument.summarization_job_id',
                                   primaryjoin='SummarizationJob.job_uuid == CaseDocument.summarization_job_id',
                                   backref='summarization_job')

    def get_display_info(self):
        """Get formatted info for dashboard display"""
        return {
            'job_id': self.job_uuid,
            'filename': self.original_filename,
            'status': self.status,
            'summary_type': self.summary_type.replace('_', ' ').title(),
            'summary_length': self.summary_length.title(),
            'created_at': self.created_at,
            'completed_at': self.completed_at,
            'case': self.case.case_name if self.case else None,
            'user': self.user.get_full_name() if self.user else 'Unknown',
            'has_results': self.status == 'completed' and self.results is not None
        }

    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_summary_user_status', 'user_id', 'status'),
        Index('idx_summary_company_status', 'company_id', 'status'),
        Index('idx_summary_case_status', 'case_id', 'status'),
        Index('idx_summary_created_status', 'created_at', 'status'),
        Index('idx_summary_type_status', 'summary_type', 'status'),
    )

class SummarizationResult(db.Model):
    __tablename__ = 'summarization_results'

    id = db.Column(db.Integer, primary_key=True)
    job_id = db.Column(db.Integer, db.ForeignKey('summarization_jobs.id'), nullable=False, index=True)
    summary = db.Column(db.Text, nullable=False)
    section_summaries = db.Column(db.JSON)  # List of section summaries
    processing_info = db.Column(db.JSON)  # Processing metadata (tokens, chunks, etc.) - RENAMED from 'metadata'
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)

    @property
    def section_summaries_json(self):
        if isinstance(self.section_summaries, str):
            return json.loads(self.section_summaries)
        return self.section_summaries or []

    @property
    def metadata_json(self):
        """Keep this property name for template compatibility"""
        if isinstance(self.processing_info, str):
            return json.loads(self.processing_info)
        return self.processing_info or {}

# Law Library Models
class LawLibraryDocument(db.Model):
    __tablename__ = 'law_library_documents'
    
    id = db.Column(db.Integer, primary_key=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), nullable=False, index=True)
    
    # Basic document info
    title = db.Column(db.String(300), nullable=False, index=True)
    document_type = db.Column(db.String(100), index=True)  # brief, motion, contract, memo, ruling, etc.
    description = db.Column(db.Text)
    
    # File information
    filename = db.Column(db.String(300), nullable=False)
    original_filename = db.Column(db.String(300))
    file_path = db.Column(db.String(500))
    file_size = db.Column(db.Integer)  # in bytes
    
    # Metadata for searching and filtering
    case_name = db.Column(db.String(300), index=True)
    case_number = db.Column(db.String(100), index=True)
    court = db.Column(db.String(200), index=True)
    judge = db.Column(db.String(200), index=True)
    date_filed = db.Column(db.Date, index=True)
    practice_area = db.Column(db.String(100), index=True)  # criminal, civil, corporate, family, etc.
    outcome = db.Column(db.String(50))  # won, lost, settled, pending
    opposing_party = db.Column(db.String(300))
    opposing_counsel = db.Column(db.String(300))
    
    # Keywords and tags for searching
    keywords = db.Column(db.Text)  # Comma-separated keywords
    tags = db.Column(db.JSON)  # Array of tags
    
    # Full text content for searching (extracted from document)
    # Using MEDIUMTEXT to support large documents like Supreme Court decisions (up to 16MB)
    content_text = db.Column(db.Text)  # NOTE: In database this is MEDIUMTEXT for larger documents
    
    # Tracking
    uploaded_by_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)
    uploaded_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    last_accessed_at = db.Column(db.DateTime)
    access_count = db.Column(db.Integer, default=0)
    
    # Relationships
    company = db.relationship('Company', backref='law_library_documents')
    uploaded_by = db.relationship('User', backref='uploaded_library_documents')
    
    # Composite indexes for common queries
    __table_args__ = (
        Index('idx_lawlib_company_type', 'company_id', 'document_type'),
        Index('idx_lawlib_company_practice', 'company_id', 'practice_area'),
        Index('idx_lawlib_company_court', 'company_id', 'court'),
        Index('idx_lawlib_company_date', 'company_id', 'date_filed'),
    )
    
    def increment_access(self):
        """Increment access count and update last accessed time"""
        self.access_count += 1
        self.last_accessed_at = datetime.utcnow()
        db.session.commit()

    @property
    def section_summaries_json(self):
        if isinstance(self.section_summaries, str):
            return json.loads(self.section_summaries)
        return self.section_summaries or []

    @property
    def metadata_json(self):
        """Keep this property name for template compatibility"""
        if isinstance(self.processing_info, str):
            return json.loads(self.processing_info)
        return self.processing_info or {}

    def get_word_count(self):
        """Get approximate word count of the summary"""
        if not self.summary:
            return 0
        return len(self.summary.split())

    def get_section_count(self):
        """Get number of sections summarized"""
        sections = self.section_summaries_json
        return len(sections) if sections else 0

class FailedLoginAttempt(db.Model):
    """Track failed login attempts for IP blocking"""
    __tablename__ = 'failed_login_attempts'

    id = db.Column(db.Integer, primary_key=True)
    ip_address = db.Column(db.String(45), nullable=False, index=True)  # Support IPv6
    email = db.Column(db.String(150), index=True)
    attempt_time = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    user_agent = db.Column(db.String(500))

    # Composite index for IP and time queries
    __table_args__ = (
        Index('idx_failed_login_ip_time', 'ip_address', 'attempt_time'),
    )

class ProcessService(db.Model):
    """Track service of process - multiple services per case"""
    __tablename__ = 'process_services'

    id = db.Column(db.Integer, primary_key=True)
    case_id = db.Column(db.Integer, db.ForeignKey('cases.id'), nullable=False, index=True)

    # Service details
    process_server_name = db.Column(db.String(200), nullable=False)
    service_date = db.Column(db.DateTime, nullable=False, index=True)
    person_served = db.Column(db.String(200))
    documents_served = db.Column(db.Text)
    service_notes = db.Column(db.Text)
    proof_of_service_path = db.Column(db.String(500))

    # Tracking
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    created_by_id = db.Column(db.Integer, db.ForeignKey('users.id'), index=True)

    # Relationships
    case = db.relationship('Case', backref=db.backref('process_services', lazy='dynamic', cascade='all, delete-orphan'))
    created_by = db.relationship('User', foreign_keys=[created_by_id])

    # Composite index
    __table_args__ = (
        Index('idx_process_service_case_date', 'case_id', 'service_date'),
    )

# AI Assistant Models
class AssistantConversation(db.Model):
    """Track AI assistant conversation sessions"""
    __tablename__ = 'assistant_conversations'

    id = db.Column(db.Integer, primary_key=True)
    conversation_uuid = db.Column(db.String(36), unique=True, nullable=False, index=True)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False, index=True)
    company_id = db.Column(db.Integer, db.ForeignKey('companies.id'), nullable=False, index=True)
    title = db.Column(db.String(200))  # Auto-generated from first message
    is_active = db.Column(db.Boolean, default=True, index=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    # Relationships
    user = db.relationship('User', backref='assistant_conversations')
    company = db.relationship('Company', backref='assistant_conversations')
    messages = db.relationship('AssistantMessage', backref='conversation', lazy='dynamic', cascade='all, delete-orphan')

    # Composite indexes
    __table_args__ = (
        Index('idx_assistant_conv_user_active', 'user_id', 'is_active'),
        Index('idx_assistant_conv_company_active', 'company_id', 'is_active'),
    )

class AssistantMessage(db.Model):
    """Individual messages in assistant conversations"""
    __tablename__ = 'assistant_messages'

    id = db.Column(db.Integer, primary_key=True)
    conversation_id = db.Column(db.Integer, db.ForeignKey('assistant_conversations.id'), nullable=False, index=True)
    message_type = db.Column(db.String(20), nullable=False, index=True)  # 'user' or 'assistant'
    message_text = db.Column(db.Text, nullable=False)

    # For tracking what the assistant did
    action_type = db.Column(db.String(50), index=True)  # summarize_case, analyze_document, search, etc.
    entity_type = db.Column(db.String(50))  # case, document, analysis_job, etc.
    entity_id = db.Column(db.String(100))  # ID of the entity acted upon
    msg_metadata = db.Column(db.JSON)  # Additional data (e.g., which documents were included)

    created_at = db.Column(db.DateTime, default=datetime.utcnow, index=True)

    # Composite indexes
    __table_args__ = (
        Index('idx_assistant_msg_conv_created', 'conversation_id', 'created_at'),
        Index('idx_assistant_msg_type_created', 'message_type', 'created_at'),
    )
