"""
Management command to test BaZi AI analysis.
"""
import logging
from django.core.management.base import BaseCommand, CommandError
from django.utils import timezone
from bazi.models import Person
from ai.utils.bazi import analyze_bazi
from ai.services.groq import GroqService
from ai.utils.bazi import prepare_bazi_prompt
from ai.services.factory import LLMServiceFactory
from django.conf import settings

logger = logging.getLogger(__name__)

class Command(BaseCommand):
    help = 'Test BaZi AI analysis on a specific person or the most recent record'

    def add_arguments(self, parser):
        parser.add_argument('--id', type=int, help='Person ID to analyze')
        parser.add_argument('--model', type=str, help='Specific model key to use from settings.GROQ_MODELS or settings.OPENAI_MODELS')
        parser.add_argument('--provider', type=str, choices=['groq', 'openai'], help='LLM provider to use (groq or openai)')
        parser.add_argument('--list-models', action='store_true', help='List available models from all providers')
        parser.add_argument('--list-providers', action='store_true', help='List available providers')
        parser.add_argument('--output', type=str, help='File to save the analysis output to')
        parser.add_argument('--print-prompt', action='store_true', help='Print the prompt that will be sent to the LLM')

    def handle(self, *args, **options):
        if options['list_models']:
            self._list_models()
            return
            
        if options['list_providers']:
            self._list_providers()
            return

        try:
            # Find person to analyze
            person = self._get_person(options['id'])
            
            if not person:
                self.stdout.write(self.style.ERROR('No person record found to analyze.'))
                return
                
            # Calculate BaZi data
            bazi_result = person.calculate_bazi()
            if not bazi_result:
                self.stdout.write(self.style.ERROR(f'Could not calculate BaZi data for person (ID: {person.id}).'))
                return
                
            # Temporarily set the bazi_result for prompt preparation
            person.bazi_result = bazi_result
                
            self.stdout.write(self.style.SUCCESS(f'Analyzing BaZi for {person.name} (ID: {person.id})...'))
            
            # Prepare the prompt
            prompt = prepare_bazi_prompt(person)
            
            # Print prompt if requested
            if options['print_prompt']:
                self.stdout.write('\n' + '=' * 80)
                self.stdout.write('BaZi Analysis Prompt:')
                self.stdout.write('=' * 80 + '\n')
                self.stdout.write(prompt)
                self.stdout.write('\n' + '=' * 80)
                return
            
            # Perform analysis
            model_key = options['model']
            provider = options['provider']
            
            if provider:
                self.stdout.write(f'Using provider: {provider}')
            if model_key:
                self.stdout.write(f'Using model: {model_key}')
            
            start_time = timezone.now()
            analysis = analyze_bazi(person, model_key, provider)
            end_time = timezone.now()
            
            # Display analysis time
            duration = (end_time - start_time).total_seconds()
            self.stdout.write(f'Analysis completed in {duration:.2f} seconds.')
            
            # Save to file if requested
            if options['output']:
                with open(options['output'], 'w', encoding='utf-8') as f:
                    f.write(analysis['bazi_analysis'])
                self.stdout.write(f'Analysis saved to {options["output"]}')
            
            # Display analysis
            self.stdout.write('\n' + '=' * 80)
            self.stdout.write('BaZi Analysis Result:')
            self.stdout.write('=' * 80 + '\n')
            self.stdout.write(analysis['bazi_analysis'])
            self.stdout.write('\n' + '=' * 80)
            
            # Display model used
            self.stdout.write(f'Model used: {person.analysis_model}')
            self.stdout.write(f'Analysis timestamp: {person.analysis_timestamp}')
            
        except Exception as e:
            logger.exception("Error in test_bazi_analysis command")
            self.stdout.write(self.style.ERROR(f'Error: {str(e)}'))

    def _get_person(self, person_id=None):
        """Get a person record by ID or the most recent one."""
        if person_id:
            try:
                return Person.objects.get(id=person_id)
            except Person.DoesNotExist:
                self.stdout.write(self.style.ERROR(f'Person with ID {person_id} does not exist.'))
                return None
        else:
            # Get the most recent person with BaZi data
            return Person.objects.filter(bazi_result__isnull=False).order_by('-created_at').first()
            
    def _list_models(self):
        """List available models from all providers."""
        try:
            providers = LLMServiceFactory.get_available_providers()
            
            self.stdout.write(self.style.SUCCESS('Available Models:'))
            for provider, models in providers.items():
                self.stdout.write(f'\n{provider.upper()} Models:')
                for key, model_id in models.items():
                    self.stdout.write(f'  {key}: {model_id}')
                
            # Display default provider
            default_provider = getattr(settings, 'DEFAULT_LLM_PROVIDER', 'groq')
            self.stdout.write(f'\nCurrent default provider: {default_provider}')
            
        except Exception as e:
            self.stdout.write(self.style.ERROR(f'Error listing models: {str(e)}'))
            
    def _list_providers(self):
        """List available LLM providers."""
        try:
            providers = LLMServiceFactory.get_available_providers()
            
            self.stdout.write(self.style.SUCCESS('Available LLM Providers:'))
            for provider in providers.keys():
                self.stdout.write(f'  {provider}')
                
            # Display default provider
            default_provider = getattr(settings, 'DEFAULT_LLM_PROVIDER', 'groq')
            self.stdout.write(f'\nCurrent default provider: {default_provider}')
            
        except Exception as e:
            self.stdout.write(self.style.ERROR(f'Error listing providers: {str(e)}'))


