from django.core.management.base import BaseCommand
from django.core.cache import cache
from django.utils import timezone
from django.conf import settings

from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import time
import random
import string
import json
import os
from urllib.parse import urlencode
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError

from rest_framework_simplejwt.tokens import RefreshToken

from bazi.models import Person
from django.contrib.auth import get_user_model


class Command(BaseCommand):
    help = "Stress test /api/bazi/good-days/ with and without caching. Requires server running."

    def add_arguments(self, parser):
        parser.add_argument('--host', default='http://127.0.0.1:8000', help='Base host, e.g. http://127.0.0.1:8000')
        parser.add_argument('--duration', type=int, default=60, help='Duration in seconds per phase')
        parser.add_argument('--concurrency', type=int, default=50, help='Concurrent workers')
        parser.add_argument('--person-id', type=int, default=None, help='Existing person_id to use (must belong to created test user if you set --reuse-user)')
        parser.add_argument('--only', choices=['nocache', 'cache', 'both'], default='both', help='Which phases to run')
        parser.add_argument('--reuse-user', action='store_true', help='Reuse or create a persistent test user if exists')

    def handle(self, *args, **options):
        host = options['host'].rstrip('/')
        duration = int(options['duration'])
        concurrency = int(options['concurrency'])
        only = options['only']
        reuse_user = bool(options['reuse_user'])

        # Prepare a test user and person
        User = get_user_model()
        phone = '19999999999'
        email = 'stress+gooddays@example.com'
        password = 'GoodDays!123'

        user, created = User.objects.get_or_create(phone=phone, defaults={'email': email, 'first_name': 'Stress', 'last_name': 'Tester'})
        if created:
            user.set_password(password)
            user.save()

        person = None
        if options['person_id']:
            try:
                person = Person.objects.get(id=options['person_id'])
            except Person.DoesNotExist:
                raise SystemExit(f"Person id {options['person_id']} not found")
        else:
            # Ensure we have an owner=True person for this user
            person = Person.objects.filter(created_by=user, owner=True).order_by('-created_at').first()
            if not person:
                # Minimal deterministic BaZi so API always works
                person = Person(
                    name='Stress Tester',
                    birth_date=timezone.now().date(),
                    created_by=user,
                    owner=True,
                    bazi_result={'year': {'god': 6, 'earth': 8}, 'month': {'god': 4, 'earth': 5}, 'day': {'god': 2, 'earth': 6}, 'hour': {'god': 0, 'earth': 1}},
                )
                person._skip_bazi_calculation = True
                person.save()

        # Build JWT token for Authorization header
        access = str(RefreshToken.for_user(user).access_token)
        headers = {'Authorization': f'Bearer {access}', 'Accept': 'application/json'}
        base_url = f"{host}/api/bazi/good-days/"

        def build_url(month_offset: int | None):
            # Default: same month for cache phase
            if month_offset is None:
                return f"{base_url}?{urlencode({'person_id': person.id})}"
            # No-cache: vary month range to avoid cache hits (different yyyymm)
            now = timezone.now()
            year = now.year + (now.month - 1 + month_offset) // 12
            month = ((now.month - 1 + month_offset) % 12) + 1
            start = f"{year:04d}-{month:02d}-01"
            # End: pick day 28 to be safe across months
            end = f"{year:04d}-{month:02d}-28"
            q = {'person_id': person.id, 'start': start, 'end': end}
            return f"{base_url}?{urlencode(q)}"

        # Results container
        def run_phase(label: str, mutate_cache: bool):
            self.stdout.write(self.style.NOTICE(f"\nRunning phase: {label} (duration={duration}s, concurrency={concurrency})"))
            # Optional warmup for cache phase
            if not mutate_cache:
                # Stable cache month warmup (same month)
                try:
                    req = Request(build_url(None), headers=headers)
                    urlopen(req, timeout=15).read()
                except Exception:
                    pass

            stop_at = time.time() + duration
            total = 0
            errors = 0
            latencies = []
            status_counts = {}
            lock = threading.Lock()

            def worker_thread(idx: int):
                nonlocal total, errors
                local_total = 0
                while time.time() < stop_at:
                    # For no-cache phase, vary the month range deterministically per loop
                    target_url = build_url((idx + local_total) % 120) if mutate_cache else build_url(None)
                    start_t = time.time()
                    try:
                        req = Request(target_url, headers=headers)
                        resp = urlopen(req, timeout=20)
                        elapsed = (time.time() - start_t) * 1000.0
                        with lock:
                            total += 1
                            local_total += 1
                            latencies.append(elapsed)
                            code = getattr(resp, 'status', 200)
                            status_counts[code] = status_counts.get(code, 0) + 1
                            if code != 200:
                                errors += 1
                    except (HTTPError, URLError, Exception):
                        with lock:
                            total += 1
                            local_total += 1
                            errors += 1
                            status_counts['ex'] = status_counts.get('ex', 0) + 1
                return local_total

            with ThreadPoolExecutor(max_workers=concurrency) as ex:
                futures = [ex.submit(worker_thread, i) for i in range(concurrency)]
                for _ in as_completed(futures):
                    pass

            # Summaries
            dur = duration
            rps = total / dur if dur > 0 else 0.0
            latencies_sorted = sorted(latencies)
            p50 = latencies_sorted[int(0.50 * len(latencies_sorted))] if latencies_sorted else 0
            p95 = latencies_sorted[int(0.95 * len(latencies_sorted))] if latencies_sorted else 0
            p99 = latencies_sorted[int(0.99 * len(latencies_sorted))] if latencies_sorted else 0

            summary = {
                'phase': label,
                'total_requests': total,
                'errors': errors,
                'rps': round(rps, 2),
                'avg_ms': round(sum(latencies) / len(latencies), 2) if latencies else 0,
                'p50_ms': round(p50, 2),
                'p95_ms': round(p95, 2),
                'p99_ms': round(p99, 2),
                'status_counts': status_counts,
            }
            self.stdout.write(self.style.SUCCESS(json.dumps(summary, ensure_ascii=False)))
            self._write_log(summary)

        if only in ('nocache', 'both'):
            run_phase('no-cache', mutate_cache=True)

        if only in ('cache', 'both'):
            run_phase('with-cache', mutate_cache=False)

        self.stdout.write(self.style.SUCCESS("Stress test completed."))

    def _write_log(self, summary: dict):
        # Persist to setup/stress_tests directory
        base_dir = os.path.join(settings.BASE_DIR, 'iching', 'setup', 'stress_tests')
        os.makedirs(base_dir, exist_ok=True)
        ts = timezone.now().strftime('%Y%m%d_%H%M%S')
        filename = os.path.join(base_dir, f'stress_good_days_{ts}.log')
        try:
            with open(filename, 'a', encoding='utf-8') as f:
                f.write(json.dumps(summary, ensure_ascii=False) + "\n")
        except Exception:
            pass

