from __future__ import annotations

import os
import logging
import traceback
from typing import List, Tuple

from django.core.cache import cache
from django.core.management.base import BaseCommand, CommandParser
from django.db import transaction
from django.utils import timezone as dj_tz
from django.conf import settings
from django.contrib.auth import get_user_model

from bazi.models import Person as BaziPerson
from iching.utils.bazi_relations import (
    find_sanhe_groups_for_owner,
    find_sanxing_groups_for_owner,
    evaluate_pair,
)
from bazi.models_group import GroupRelation


def _cache_keys(user_id: int):
    prefix = f"bazi:grp:{user_id}:"
    return {
        "state": prefix + "state",
        "started_at": prefix + "started_at",
        "updated_at": prefix + "updated_at",
        "error_at": prefix + "error_at",
        "last_retry_at": prefix + "last_retry_at",
        "results": prefix + "results",
        "count": prefix + "count",
    }


class Command(BaseCommand):
    help = "Recalculate pairwise and 3-party BaZi relations for a user or all users, store results in DB."
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._setup_logging()
    
    def _setup_logging(self):
        """Setup logging for the command"""
        # Create logs directory if it doesn't exist
        log_dir = "logs"
        if not os.path.exists(log_dir):
            os.makedirs(log_dir, exist_ok=True)
        
        # Setup logger
        self.logger = logging.getLogger(f"recalc_bazi_relations_{os.getpid()}")
        self.logger.setLevel(logging.DEBUG)
        
        # Prevent duplicate handlers
        if not self.logger.handlers:
            # File handler for detailed logging
            log_file = os.path.join(log_dir, "recalc_bazi_relations.log")
            file_handler = logging.FileHandler(log_file, encoding='utf-8')
            file_handler.setLevel(logging.DEBUG)
            
            # Console handler - only add if not in test environment
            # This prevents verbose output during testing
            if not self._is_test_environment():
                console_handler = logging.StreamHandler()
                console_handler.setLevel(logging.INFO)
                
                # Formatter
                formatter = logging.Formatter(
                    '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
                )
                file_handler.setFormatter(formatter)
                console_handler.setFormatter(formatter)
                
                self.logger.addHandler(console_handler)
            
            # Always add file handler
            file_handler.setLevel(logging.DEBUG)
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            file_handler.setFormatter(formatter)
            self.logger.addHandler(file_handler)
        
        # Only log startup message if not in test environment
        if not self._is_test_environment():
            self.logger.info(f"=== Starting BaZi Relations Recalculation (PID: {os.getpid()}) ===")
    
    def _is_test_environment(self):
        """Check if we're running in a test environment"""
        import sys
        return any('test' in arg.lower() for arg in sys.argv)

    def add_arguments(self, parser: CommandParser) -> None:
        group = parser.add_mutually_exclusive_group(required=True)
        group.add_argument("--user", type=int, help="Owner user id")
        group.add_argument("--all-users", action="store_true", help="Recalculate for all users")
        parser.add_argument("--force", action="store_true", help="Force recalculation regardless of state")

    def handle(self, *args, **options):
        all_users = options.get("all_users", False)
        force = options.get('force', False)
        
        # Only log if not in test environment
        if not self._is_test_environment():
            self.logger.info(f"Command started - all_users: {all_users}, force: {force}")
        
        try:
            if all_users:
                self._handle_all_users(force)
            else:
                user_id = options["user"]
                if not self._is_test_environment():
                    self.logger.info(f"Processing single user: {user_id}")
                self._handle_single_user(user_id, force)
        except Exception as e:
            self.logger.error(f"Fatal error in command: {str(e)}")
            self.logger.error(f"Traceback: {traceback.format_exc()}")
            raise
        finally:
            if not self._is_test_environment():
                self.logger.info("=== Command completed ===")
    
    def _handle_all_users(self, force: bool):
        """Handle recalculation for all users"""
        from django.contrib.auth import get_user_model
        User = get_user_model()
        
        self.logger.info(f"Starting bulk recalculation for all users (force: {force})")
        
        # Get all users who have BaZi records
        users_with_bazi = User.objects.filter(
            person__owner=True,
            person__bazi_result__isnull=False
        ).distinct()
        
        total_users = users_with_bazi.count()
        self.logger.info(f"Found {total_users} users with BaZi records")
        self.stdout.write(f"Found {total_users} users with BaZi records")
        
        for i, user in enumerate(users_with_bazi, 1):
            self.stdout.write(f"Processing user {i}/{total_users}: {user.id} ({user.email or user.username})")
            self.logger.info(f"Processing user {i}/{total_users}: {user.id} ({user.email or user.username})")
            start_time = dj_tz.now()
            try:
                self._handle_single_user(user.id, force)
                end_time = dj_tz.now()
                duration = end_time - start_time
                self.stdout.write(f"✓ Completed user {user.id} in {duration.total_seconds():.2f}s")
                self.logger.info(f"✓ Completed user {user.id} in {duration.total_seconds():.2f}s")
            except Exception as e:
                end_time = dj_tz.now()
                duration = end_time - start_time
                error_msg = f"✗ Failed user {user.id} after {duration.total_seconds():.2f}s: {str(e)}"
                self.stdout.write(error_msg)
                self.logger.error(error_msg)
                self.logger.error(f"Error details for user {user.id}: {traceback.format_exc()}")
                
                # Mark user as error state
                try:
                    user.group_relations_state = 'error'
                    user.group_relations_error_at = dj_tz.now()
                    user.save(update_fields=['group_relations_state', 'group_relations_error_at'])
                    success_msg = f"  → Marked user {user.id} as error state"
                    self.stdout.write(success_msg)
                    self.logger.info(success_msg)
                except Exception as save_error:
                    error_msg = f"  → Failed to mark error state: {str(save_error)}"
                    self.stdout.write(error_msg)
                    self.logger.error(error_msg)
                    self.logger.error(f"Save error details: {traceback.format_exc()}")
                # Continue with next user instead of stopping
                continue
        
        self.stdout.write(f"Completed processing {total_users} users")
        self.logger.info(f"Completed processing {total_users} users")
        self.stdout.write("Use --force flag if you need to recalculate for users who are not currently processing")
    
    def _handle_single_user(self, user_id: int, force: bool):
        """Handle recalculation for a single user"""
        self.logger.info(f"Starting single user processing: user_id={user_id}, force={force}")
        keys = _cache_keys(user_id)
        
        # Check if we're running in a background thread
        import threading
        is_main_thread = threading.current_thread() is threading.main_thread()
        
        if is_main_thread:
            # Use signal-based timeout for main thread
            import signal
            
            def timeout_handler(signum, frame):
                timeout_msg = f"Processing timeout for user {user_id}"
                self.logger.error(timeout_msg)
                raise TimeoutError(timeout_msg)
            
            # Set 5-minute timeout
            if not self._is_test_environment():
                self.logger.info(f"Setting 5-minute timeout for user {user_id} (signal-based)")
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(300)  # 5 minutes
            
            try:
                self._process_single_user(user_id, force)
            except TimeoutError:
                self.logger.error(f"Timeout occurred for user {user_id}")
                raise
            except Exception as e:
                self.logger.error(f"Error processing user {user_id}: {str(e)}")
                self.logger.error(f"Traceback: {traceback.format_exc()}")
                raise
            finally:
                # Clear timeout
                signal.alarm(0)
                if not self._is_test_environment():
                    self.logger.info(f"Cleared timeout for user {user_id}")
        else:
            # Use thread-based timeout for background threads
            if not self._is_test_environment():
                self.logger.info(f"Running in background thread, using thread-based timeout for user {user_id}")
            
            import threading
            import time
            
            # Create a timeout event
            timeout_event = threading.Event()
            
            def timeout_worker():
                time.sleep(300)  # 5 minutes
                timeout_event.set()
            
            # Start timeout thread
            timeout_thread = threading.Thread(target=timeout_worker, daemon=True)
            timeout_thread.start()
            
            try:
                # Process with timeout checking
                self._process_single_user_with_timeout(user_id, force, timeout_event)
            except TimeoutError:
                self.logger.error(f"Timeout occurred for user {user_id}")
                raise
            except Exception as e:
                self.logger.error(f"Error processing user {user_id}: {str(e)}")
                self.logger.error(f"Traceback: {traceback.format_exc()}")
                raise
            finally:
                # Signal timeout thread to stop
                timeout_event.set()
                if not self._is_test_environment():
                    self.logger.info(f"Cleared timeout for user {user_id}")
    
    def _process_single_user_with_timeout(self, user_id: int, force: bool, timeout_event: threading.Event):
        """Process a single user with timeout checking for background threads"""
        # Check timeout periodically during processing
        if timeout_event.is_set():
            raise TimeoutError(f"Processing timeout for user {user_id}")
        
        # For background threads, we need to modify the processing to check timeouts
        # We'll call the original method but add timeout checks at key points
        
        # Check timeout before starting
        if timeout_event.is_set():
            raise TimeoutError(f"Processing timeout for user {user_id}")
        
        # Call the original processing method
        self._process_single_user(user_id, force)
        
        # Final timeout check
        if timeout_event.is_set():
            raise TimeoutError(f"Processing timeout for user {user_id}")
    
    def _process_single_user(self, user_id: int, force: bool):
        """Internal method to process a single user"""
        if not self._is_test_environment():
            self.logger.info(f"Processing user {user_id} (force: {force})")
        keys = _cache_keys(user_id)

        # Guard: only complete if currently processing (unless forced)
        from django.contrib.auth import get_user_model
        User = get_user_model()
        try:
            usr = User.objects.get(id=user_id)
            state = usr.group_relations_state or 'idle'
            if not self._is_test_environment():
                self.logger.info(f"User {user_id} current state: {state}")
            
            if state != 'processing' and not force:
                if force:
                    # Force mode: set state to processing and continue
                    usr.group_relations_state = 'processing'
                    usr.group_relations_started_at = dj_tz.now()
                    usr.save(update_fields=['group_relations_state', 'group_relations_started_at'])
                    force_msg = f"Forcing recalculation for user {user_id} (state was: {state})"
                    if not self._is_test_environment():
                        self.stdout.write(force_msg)
                        self.logger.info(force_msg)
                else:
                    if not self._is_test_environment():
                        self.logger.info(f"Skipping user {user_id} - not processing and not forced")
                    return
                
        except User.DoesNotExist:
            if not self._is_test_environment():
                self.logger.warning(f"User {user_id} not found in database")
            return

        # Load owner and persons
        self.logger.info(f"Loading owner BaZi record for user {user_id}")
        owner = (
            BaziPerson.objects.filter(created_by_id=user_id, owner=True)
            .order_by("-created_at")
            .first()
        )
        
        if owner:
            self.logger.info(f"Found owner BaZi record: {owner.id}")
        else:
            self.logger.warning(f"No owner BaZi record found for user {user_id}")
            
        if owner is None or not owner.bazi_result or "day" not in owner.bazi_result:
            # No data → empty results
            self.logger.info(f"User {user_id} has no valid BaZi data, clearing relations")
            with transaction.atomic():
                # Clear all group rows
                deleted_count = GroupRelation.objects.filter(owner_user_id=user_id).delete()[0]
                self.logger.info(f"Deleted {deleted_count} group relations for user {user_id}")
                
                # Mark DB state completed (regardless of current state when forced)
                User = get_user_model()
                try:
                    usr = User.objects.select_for_update().get(id=user_id)
                    if force or usr.group_relations_state == 'processing':
                        usr.group_relations_state = 'completed'
                        usr.group_relations_updated_at = dj_tz.now()
                        usr.save(update_fields=['group_relations_state', 'group_relations_updated_at'])
                        self.logger.info(f"Marked user {user_id} as completed (no data)")
                except User.DoesNotExist:
                    self.logger.warning(f"User {user_id} not found when marking completed")
            return

        owner_day = owner.bazi_result.get("day", {})
        owner_g = owner_day.get("god")
        owner_e = owner_day.get("earth")
        
        if owner_g is None or owner_e is None:
            # Clear all group rows and pairwise data
            with transaction.atomic():
                GroupRelation.objects.filter(owner_user_id=user_id).delete()
                BaziPerson.objects.filter(created_by_id=user_id).exclude(id=owner.id).update(
                    relation_good=None, relation_bad=None, 
                    relation_good_count=0, relation_bad_count=0,
                    relation_updated_at=dj_tz.now()
                )
                
                # Mark DB state completed (regardless of current state when forced)
                User = get_user_model()
                try:
                    usr = User.objects.select_for_update().get(id=user_id)
                    if force or usr.group_relations_state == 'processing':
                        usr.group_relations_state = 'completed'
                        usr.group_relations_updated_at = dj_tz.now()
                        usr.save(update_fields=['group_relations_state', 'group_relations_updated_at'])
                except User.DoesNotExist:
                    pass
            return

        # Calculate pairwise relations for all persons
        self.logger.info(f"Starting pairwise relation calculations for user {user_id}")
        persons_qs = BaziPerson.objects.filter(created_by_id=user_id).exclude(id=owner.id)
        persons: List[Tuple[int, int, int]] = []  # (id, god, earth)
        
        total_persons = persons_qs.count()
        self.logger.info(f"Found {total_persons} persons to process for user {user_id}")
        if total_persons > 0:
            self.stdout.write(f"  Processing {total_persons} persons for user {user_id}...")
        
        processed_count = 0
        skipped_count = 0
        for i, p in enumerate(persons_qs, 1):
            try:
                if p.bazi_result and p.bazi_result.get("day"):
                    day = p.bazi_result["day"]
                    g = day.get("god")
                    e = day.get("earth")
                    if g is not None and e is not None:
                        persons.append((p.id, g, e))
                        
                        # Calculate pairwise relation for this person
                        result = evaluate_pair(owner_g, owner_e, g, e)
                        good_reasons = result.get("good", [])
                        bad_reasons = result.get("bad", [])
                        
                        # Update person's relation fields
                        # Refresh from DB to avoid race condition with API updates
                        p.refresh_from_db()
                        p.relation_good = good_reasons if good_reasons else None
                        p.relation_bad = bad_reasons if bad_reasons else None
                        p.relation_good_count = len(good_reasons)
                        p.relation_bad_count = len(bad_reasons)
                        p.relation_updated_at = dj_tz.now()
                        p.save(update_fields=[
                            'relation_good', 'relation_bad', 
                            'relation_good_count', 'relation_bad_count', 
                            'relation_updated_at'
                        ])
                        
                        processed_count += 1
                        if total_persons > 10 and i % 10 == 0:  # Progress indicator for large datasets
                            progress_msg = f"    Processed {i}/{total_persons} persons..."
                            self.stdout.write(progress_msg)
                            self.logger.info(progress_msg)
                    else:
                        self.logger.warning(f"Person {p.id} missing god/earth data: g={g}, e={e}")
                        skipped_count += 1
                else:
                    self.logger.warning(f"Person {p.id} missing bazi_result or day data")
                    skipped_count += 1
            except Exception as e:
                error_msg = f"    ⚠️  Skipped person {p.id}: {str(e)}"
                self.stdout.write(error_msg)
                self.logger.error(f"Error processing person {p.id}: {str(e)}")
                self.logger.error(f"Person {p.id} error traceback: {traceback.format_exc()}")
                skipped_count += 1
                continue
        
        self.logger.info(f"Person processing completed: {processed_count} processed, {skipped_count} skipped")

        # Convert to format needed for group calculations
        group_persons: List[Tuple[int, int]] = [(pid, e) for pid, g, e in persons]

        # Compute groups
        results: List[dict] = []
        
        # Find 三合 (sanhe) groups
        for p1_id, p2_id in find_sanhe_groups_for_owner(owner_e, group_persons):
            # Find branches for by refs
            e1 = next(e for pid, e in group_persons if pid == p1_id)
            e2 = next(e for pid, e in group_persons if pid == p2_id)
            results.append({
                "t": 0,
                "p1": p1_id,
                "p2": p2_id,
                "by": [["o", "d", "e", owner_e], ["p1", "d", "e", e1], ["p2", "d", "e", e2]],
            })
        
        # Find 三刑 (sanxing) groups
        for p1_id, p2_id in find_sanxing_groups_for_owner(owner_e, group_persons):
            e1 = next(e for pid, e in group_persons if pid == p1_id)
            e2 = next(e for pid, e in group_persons if pid == p2_id)
            results.append({
                "t": 1,
                "p1": p1_id,
                "p2": p2_id,
                "by": [["o", "d", "e", owner_e], ["p1", "d", "e", e1], ["p2", "d", "e", e2]],
            })

        # Persist to DB (replace existing rows)
        self.logger.info(f"Persisting group relations to database for user {user_id}")
        with transaction.atomic():
            deleted_count = GroupRelation.objects.filter(owner_user_id=user_id).delete()[0]
            self.logger.info(f"Deleted {deleted_count} existing group relations for user {user_id}")
            
            bulk = []
            for item in results:
                rel_type = 'sanhe' if item['t'] == 0 else 'sanxing'
                bulk.append(GroupRelation(
                    owner_user_id=user_id,
                    person1_id=item['p1'],
                    person2_id=item['p2'],
                    relation_type=rel_type,
                    by=item.get('by'),
                ))
            if bulk:
                GroupRelation.objects.bulk_create(bulk)
                self.logger.info(f"Created {len(bulk)} new group relations for user {user_id}")
            else:
                self.logger.info(f"No group relations found for user {user_id}")
                
            # Mark DB state completed (regardless of current state when forced)
            User = get_user_model()
            try:
                usr = User.objects.select_for_update().get(id=user_id)
                if force or usr.group_relations_state == 'processing':
                    usr.group_relations_state = 'completed'
                    usr.group_relations_updated_at = dj_tz.now()
                    usr.save(update_fields=['group_relations_state', 'group_relations_updated_at'])
                    self.logger.info(f"Successfully marked user {user_id} as completed")
            except User.DoesNotExist:
                self.logger.warning(f"User {user_id} not found when marking completed")

        # Conditional completion to avoid stale overwrite
        # Cache marker (optional backward-compat); not used by API anymore
        if (cache.get(keys["state"]) or "idle") == "processing":
            cache.set(keys["state"], "completed")
            self.logger.info(f"Updated cache state to completed for user {user_id}")
        
        self.logger.info(f"User {user_id} processing completed successfully")


