import { DragEndEvent, DragOverEvent, DragStartEvent } from '@dnd-kit/core'
import { arrayMove } from '@dnd-kit/sortable'
import { BoardIssue, BoardStatus } from '@schema'
import { useParams } from '@tanstack/react-router'
import { useCallback, useState } from 'react'
import { unstable_batchedUpdates } from 'react-dom'

import { trpc } from '../utils/trpc'

export const useDnd = ({ initialStatuses }: { initialStatuses: BoardStatus[] }) => {
    const { projectId } = useParams({ strict: false })
    const [activeIssue, setActiveIssue] = useState<BoardIssue | null>(null)

    const context = trpc.useUtils()

    const { mutateAsync } = trpc.issue.move.useMutation({
        onSuccess: async () => {
            await context.project.invalidate()
            await context.issue.getByProjectId.invalidate({ projectId })
        },
    })

    const handleDragStart = useCallback((event: DragStartEvent) => {
        setActiveIssue(event.active.data.current as BoardIssue)
    }, [])

    const [states, setStates] = useState<BoardStatus[]>(initialStatuses)

    const findState = useCallback(
        (issueId: string): BoardStatus | null => {
            for (const state of states) {
                if (state.issues.some((issue) => String(issue.id) === issueId)) {
                    return state
                }
            }
            return null
        },
        [states],
    )

    const handleDragOver = useCallback(
        (event: DragOverEvent) => {
            const { active, over, delta } = event
            const { id } = active
            const overId = over?.id

            const activeContainer = findState(String(id)) || (active.data.current as BoardStatus)
            const overContainer =
                findState(String(overId)) || (over?.data.current?.state as BoardStatus)

            if (!activeContainer || !overContainer || activeContainer === overContainer) {
                return
            }

            const activeIndex = activeContainer.issues.findIndex((item) => item.id === id)
            const overIndex = overContainer.issues.findIndex((item) => item.id === overId)

            const newIndex =
                overIndex >= 0 ? overIndex + (delta.y > 0 ? 1 : 0) : overContainer.issues.length + 1

            unstable_batchedUpdates(() => {
                setStates((prev) =>
                    prev.map((state) => {
                        if (state.id === activeContainer.id) {
                            return {
                                ...state,
                                issues: [
                                    ...state.issues.slice(0, activeIndex),
                                    ...state.issues.slice(activeIndex + 1),
                                ],
                            }
                        }
                        if (state.id === overContainer.id) {
                            return {
                                ...state,
                                issues: [
                                    ...state.issues.slice(0, newIndex),
                                    activeContainer.issues[activeIndex],
                                    ...state.issues.slice(newIndex),
                                ],
                            }
                        }
                        return state
                    }),
                )
            })
        },
        [findState],
    )

    const handleDragEnd = useCallback(
        async ({ active, over }: DragEndEvent) => {
            const activeId = String(active.id)
            const overId = over?.id

            const activeContainer = findState(activeId) || (active.data.current as BoardStatus)
            const overContainer =
                findState(String(overId)) || (over?.data.current?.state as BoardStatus)

            if (!activeContainer || !overContainer || activeContainer !== overContainer) {
                return
            }

            const activeIndex = overContainer.issues.findIndex((task) => task.id === active.id)
            const overIndex = overContainer.issues.findIndex((task) => task.id === over?.id)

            if (activeIndex !== overIndex) {
                unstable_batchedUpdates(() => {
                    setStates((prev) =>
                        prev.map((state) => {
                            if (state.id === overContainer.id) {
                                return {
                                    ...state,
                                    issues: arrayMove(state.issues, activeIndex, overIndex),
                                }
                            }
                            return state
                        }),
                    )
                })
            }

            await mutateAsync({
                issueId: activeId,
                statusId: overContainer.id,
                index: overIndex,
            })

            setActiveIssue(null)
        },
        [findState, mutateAsync],
    )

    return {
        states,
        handleDragStart,
        handleDragOver,
        handleDragEnd,
        activeIssue,
        setStates,
    }
}
