import json
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async
from django.db.models import Q
from django.contrib.auth import get_user_model

from .models import Conversation, Message, SupportThread, SupportMessage


class ConversationConsumer(AsyncWebsocketConsumer):
    """Booking chat: /ws/chat/<conversation_id>/"""

    async def connect(self):
        self.conversation_id = self.scope["url_route"]["kwargs"]["conversation_id"]
        self.group_name = f"conv_{self.conversation_id}"

        user = self.scope.get("user")
        if user is None or not user.is_authenticated:
            await self.close(code=4401)
            return

        allowed = await self._user_allowed(user.id, self.conversation_id)
        if not allowed:
            await self.close(code=4403)
            return

        await self.channel_layer.group_add(self.group_name, self.channel_name)
        await self.accept()

    async def disconnect(self, close_code):
        try:
            await self.channel_layer.group_discard(self.group_name, self.channel_name)
        except Exception:
            pass

    async def receive(self, text_data=None, bytes_data=None):
        try:
            data = json.loads(text_data or "{}")
        except Exception:
            return

        event_type = (data.get("type") or "message").strip().lower()
        if event_type != "message":
            return

        text = (data.get("text") or "").strip()
        if not text:
            return

        user = self.scope["user"]
        msg = await self._save_message(self.conversation_id, user.id, text)

        await self.channel_layer.group_send(
            self.group_name,
            {
                "type": "chat.message",
                "message": {
                    "id": msg["id"],
                    "conversation_id": int(self.conversation_id),
                    "sender_id": user.id,
                    "sender_name": msg["sender_name"],
                    "text": msg["text"],
                    "created_at": msg["created_at"],
                },
            },
        )

    async def chat_message(self, event):
        await self.send(text_data=json.dumps(event["message"]))

    @database_sync_to_async
    def _user_allowed(self, user_id, conversation_id):
        return Conversation.objects.filter(id=conversation_id).filter(
            Q(student_id=user_id) | Q(tutor_id=user_id)
        ).exists()

    @database_sync_to_async
    def _save_message(self, conversation_id, user_id, text):
        User = get_user_model()
        conv = Conversation.objects.get(id=conversation_id)
        user = User.objects.get(id=user_id)
        m = Message.objects.create(conversation=conv, sender=user, text=text)
        return {
            "id": m.id,
            "text": m.text,
            "created_at": m.created_at.isoformat(),
            "sender_name": (user.get_full_name() or user.username or "User"),
        }


class SupportConsumer(AsyncWebsocketConsumer):
    """ET-chat support inbox: /ws/support/<thread_id>/"""

    async def connect(self):
        self.thread_id = self.scope["url_route"]["kwargs"]["thread_id"]
        self.group_name = f"support_{self.thread_id}"

        user = self.scope.get("user")
        session = self.scope.get("session")

        allowed = await self._support_allowed(user, session, self.thread_id)
        if not allowed:
            await self.close(code=4403)
            return

        await self.channel_layer.group_add(self.group_name, self.channel_name)
        await self.accept()

    async def disconnect(self, close_code):
        try:
            await self.channel_layer.group_discard(self.group_name, self.channel_name)
        except Exception:
            pass

    async def receive(self, text_data=None, bytes_data=None):
        try:
            data = json.loads(text_data or "{}")
        except Exception:
            return

        event_type = (data.get("type") or "message").strip().lower()

        # Typing indicator (no DB write)
        if event_type == "typing":
            user = self.scope.get("user")
            is_typing = bool(data.get("is_typing"))
            sender_label = "guest"
            if user and getattr(user, "is_authenticated", False):
                sender_label = "staff" if getattr(user, "is_staff", False) else "user"

            await self.channel_layer.group_send(
                self.group_name,
                {
                    "type": "support.typing",
                    "message": {
                        "type": "typing",
                        "thread_id": int(self.thread_id),
                        "sender_label": sender_label,
                        "is_typing": is_typing,
                    },
                },
            )
            return

        if event_type != "message":
            return

        text = (data.get("text") or "").strip()
        client_id = (data.get("client_id") or "").strip()
        if not text:
            return

        user = self.scope.get("user")
        session = self.scope.get("session")

        msg = await self._save_support_message(self.thread_id, user, session, text)

        # Echo to all clients; include client_id to let sender mark ✓ delivered
        msg_payload = dict(msg)
        if client_id:
            msg_payload["client_id"] = client_id

        await self.channel_layer.group_send(
            self.group_name,
            {"type": "support.message", "message": msg_payload},
        )

    async def support_typing(self, event):
        await self.send(text_data=json.dumps(event["message"]))

    async def support_message(self, event):
        await self.send(text_data=json.dumps(event["message"]))

    @database_sync_to_async
    def _support_allowed(self, user, session, thread_id):
        try:
            t = SupportThread.objects.get(id=thread_id)
        except SupportThread.DoesNotExist:
            return False

        if user and getattr(user, "is_authenticated", False):
            if getattr(user, "is_staff", False):
                return True
            return t.user_id == user.id

        # guest via session key
        if session is not None:
            sk = getattr(session, "session_key", "") or ""
            return bool(sk and t.session_key and t.session_key == sk)

        return False

    @database_sync_to_async
    def _save_support_message(self, thread_id, user, session, text):
        t = SupportThread.objects.get(id=thread_id)

        sender = None
        sender_label = "guest"

        if user and getattr(user, "is_authenticated", False):
            sender = user
            sender_label = "staff" if getattr(user, "is_staff", False) else "user"
        else:
            # ensure session key is saved on thread
            if session is not None:
                if not session.session_key:
                    session.save()
                if not t.session_key:
                    t.session_key = session.session_key or ""
                    t.save(update_fields=["session_key", "updated_at"])

        m = SupportMessage.objects.create(
            thread=t, sender=sender, sender_label=sender_label, text=text
        )

        # touch thread
        try:
            t.save(update_fields=["updated_at"])
        except Exception:
            pass

        sender_name = "Guest"
        if sender_label == "staff":
            sender_name = (sender.get_full_name() or sender.username or "Support") if sender else "Support"
        elif sender_label == "user":
            sender_name = (sender.get_full_name() or sender.username or "User") if sender else "User"

        return {
            "id": m.id,
            "thread_id": t.id,
            "sender_label": sender_label,
            "sender_name": sender_name,
            "text": m.text,
            "created_at": m.created_at.isoformat(),
        }
