import React, { useState, useCallback, useMemo } from "react";
import { Group } from '@visx/group';
import { scaleLinear, scaleQuantize } from "@visx/scale";
import { localPoint } from '@visx/event';
import { withTooltip, TooltipWithBounds } from '@visx/tooltip';
import { WithTooltipProvidedProps } from '@visx/tooltip/lib/enhancers/withTooltip';
import { bisector, extent } from "d3-array";
import { hierarchy, Pack } from '@visx/hierarchy';
import { AxisBottom, AxisLeft } from '@visx/axis';
import { ToggleTag } from "../primitives/ToggleTag";

export type AreaProps = {
    examResults: ExamResultsType;
    sectionResults: Record<string, SectionResults>;
    width: number;
    height: number;
    margin?: { top: number; right: number; bottom: number; left: number };
};

interface TooltipData {
    x: number;
    y: number;
    density: number;
}

const getScore = (d: TooltipData) => d.x;
const getDistribution = (d: TooltipData) => d.y;
const getPercentile = (d: TooltipData) => d.density;
const bisectData = bisector<TooltipData, number>((d) => d.x).left;

export default withTooltip<AreaProps, TooltipData>(
    ({
         examResults,
         sectionResults,
         width,
         height,
         margin = { top: 70, right: 80, bottom: 40, left: 60 },
         showTooltip,
         hideTooltip,
         tooltipData,
         tooltipTop = 0,
         tooltipLeft = 0,
     }: AreaProps & WithTooltipProvidedProps<TooltipData>) => {
        if (width < 10) return null;

        const sections = Object.keys(sectionResults);
        const abbreviations = [...sections];
        const [activeAbbreviation, setAbbreviation] = useState<string>(abbreviations[0]);
        const [activeGraph, setActiveGraph] = useState(sectionResults[sections[0]]);

        const filteredData = activeGraph.time_spent_all_users
            .map((x, i) => ({ x, y: activeGraph.percentages_all_users[i] }))
            .filter(d => d.x !== null);

        const xValues = filteredData.map(d => d.x);
        const yValues = filteredData.map(d => d.y);

        const actualFlagX = activeGraph.total_time;
        const actualFlagY = activeGraph.result_percentage;

        const combinedData = useMemo(() => {
            const thresholdDistance = 30;
            const dataMap = new Map<string, { x: number; y: number; count: number }>();

            xValues.forEach((x, i) => {
                const y = yValues[i];
                const key = `${x}-${y}`;
                if (dataMap.has(key)) {
                    dataMap.set(key, { ...dataMap.get(key)!, count: dataMap.get(key)!.count + 1 });
                } else {
                    dataMap.set(key, { x, y, count: 1 });
                }
            });

            let flagMerged = false;
            let mergedFlagX = actualFlagX;
            let mergedFlagY = actualFlagY;

            const mergedData: { x: number; y: number; density: number }[] = [];

            dataMap.forEach((value) => {
                const { x, y, count } = value;
                let merged = false;

                for (let i = 0; i < mergedData.length; i++) {
                    const point = mergedData[i];
                    const distance = Math.sqrt((x - point.x) ** 2 + (y - point.y) ** 2);

                    if (distance < thresholdDistance) {
                        const newX = (x * count + point.x * point.density) / (count + point.density);
                        const newY = (y * count + point.y * point.density) / (count + point.density);
                        const newDensity = count + point.density;

                        mergedData[i] = { x: newX, y: newY, density: newDensity };

                        // Check if the flag's original coordinates were merged
                        if ((x === actualFlagX && y === actualFlagY) || (point.x === actualFlagX && point.y === actualFlagY)) {
                            flagMerged = true;
                            mergedFlagX = newX;
                            mergedFlagY = newY;
                        }

                        merged = true;
                        break;
                    }
                }

                if (!merged) {
                    mergedData.push({ x, y, density: count });

                    // Check if the flag's original coordinates were merged
                    if (x === actualFlagX && y === actualFlagY) {
                        flagMerged = true;
                        mergedFlagX = x;
                        mergedFlagY = y;
                    }
                }
            });

            return { data: mergedData, flagX: flagMerged ? mergedFlagX : actualFlagX, flagY: flagMerged ? mergedFlagY : actualFlagY };
        }, [xValues, yValues, actualFlagX, actualFlagY]);

        const xMin = 0;
        const xMax = width - margin.left - margin.right;
        const yMax = height - margin.top - margin.bottom;

        const scoreScale = useMemo(
            () =>
                scaleLinear({
                    domain: [Math.min(...xValues), Math.max(...xValues)],
                    range: [xMax + 10, xMin],
                }),
            [xMax, margin.left, activeGraph]
        );

        const percentileScale = useMemo(
            () =>
                scaleLinear({
                    domain: [0, Math.max(...yValues)],
                    range: [yMax, 0],
                }),
            [margin.top, yMax, activeGraph]
        );

        const handleActiveGraph = (abv: string) => {
            setActiveGraph(sectionResults[abv]);
            setAbbreviation(abv);
        };

        const handleTooltip = useCallback(
            (event: React.TouchEvent<SVGRectElement> | React.MouseEvent<SVGRectElement>) => {
                const { x } = localPoint(event) || { x: 0 };
                const x0 = scoreScale.invert(x);
                const index = bisectData(combinedData.data, x0, 1);
                const d0 = combinedData.data[index - 1];
                const d1 = combinedData.data[index];
                let d = d0;
                if (d1 && getScore(d1)) {
                    d = x0.valueOf() - getScore(d0).valueOf() > getScore(d1).valueOf() - x0.valueOf() ? d1 : d0;
                }
                showTooltip({
                    tooltipData: d,
                    tooltipLeft: x,
                    tooltipTop: percentileScale(getDistribution(d)),
                });
            },
            [showTooltip, percentileScale, scoreScale, combinedData.data]
        );

        const packData = { children: combinedData.data, name: 'root', density: 0 };

        const colorScale = scaleQuantize({
            domain: extent(combinedData.data, (d) => d.density),
            range: ['#7dd3fc'],
        });

        const root = hierarchy(packData)
            .sum((d) => d.density)
            .sort((a, b) => b.value! - a.value!);

        const actualFlagXScaled = scoreScale(combinedData.flagX);
        const actualFlagYScaled = percentileScale(combinedData.flagY);

        const constrainedFlagX = Math.max(margin.left, Math.min(actualFlagXScaled, width - margin.right - 70));
        const constrainedFlagY = Math.max(margin.top, Math.min(actualFlagYScaled, height - margin.bottom - 50));

        const distanceBetweenFlagAndBubble = 50;

        const getAbbreviation = (section: string) => !sections.includes(section) ? section : sectionResults[section].section_abbreviation || section;

        return (
            <div className="flex items-center">
                <svg width={width} height={height}>
                    <defs>
                        {combinedData.data.map((d, i) => (
                            <radialGradient key={i} id={`grad-${i}`} cx="50%" cy="50%" r="50%" fx="50%" fy="50%">
                                <stop offset="0%" style={{ stopColor: "#7dd3fc", stopOpacity: .4 }} />
                                <stop offset="100%" style={{ stopColor: colorScale(d.density), stopOpacity: 1 }} />
                            </radialGradient>
                        ))}
                    </defs>

                    <Group left={margin.left} top={margin.top}>
                        <foreignObject x={-margin.left} y={-margin.top + 10} width={width} height={40}>
                            <div className={"w-full flex justify-center h-full"}>
                                <div className={"flex flex-row gap-x-4 h-full"}>
                                    {abbreviations.map((abv) => (
                                        <div className={"h-full"} key={abv} onClick={() => handleActiveGraph(abv)}>
                                            <ToggleTag
                                                active={abv === activeAbbreviation}>{getAbbreviation(abv)}</ToggleTag>
                                        </div>
                                    ))}
                                </div>
                            </div>
                        </foreignObject>
                        <AxisBottom
                            top={yMax}
                            scale={scoreScale}
                            label="Speed"
                            stroke="#D6D3D1"
                            tickLength={0}
                            labelOffset={10}
                            labelProps={{
                                fill: '#D6D3D1',
                            }}
                            tickLabelProps={() => ({
                                fill: 'transparent',
                            })}
                        />
                        <AxisLeft
                            scale={percentileScale}
                            label="Accuracy"
                            stroke="#D6D3D1"
                            tickLength={0}
                            labelOffset={10}
                            labelProps={{
                                fill: '#D6D3D1',
                            }}
                            tickLabelProps={() => ({
                                fill: 'transparent',
                            })}
                        />
                        <Pack root={root} size={[xMax, yMax]}>
                            {(packData) => {
                                const circles = packData.descendants().slice(1);
                                return (
                                    <Group>
                                        {circles.map((circle, i) => {
                                            const data = circle.data as unknown as {
                                                x: number;
                                                y: number;
                                                density: number
                                            };
                                            return (
                                                <circle
                                                    key={`circle-${i}`}
                                                    r={circle.r * 0.3}
                                                    cx={scoreScale(data.x)}
                                                    cy={percentileScale(data.y)}
                                                    fill={`url(#grad-${i})`}
                                                    stroke={"#0ea5e9"}
                                                    strokeWidth={0}
                                                />
                                            );
                                        })}
                                        <line
                                            x1={actualFlagXScaled}
                                            y1={actualFlagYScaled}
                                            x2={actualFlagXScaled}
                                            y2={actualFlagYScaled + (constrainedFlagY-distanceBetweenFlagAndBubble > actualFlagYScaled ? distanceBetweenFlagAndBubble/2 : -distanceBetweenFlagAndBubble/2)}
                                            stroke="#292524"
                                            strokeWidth={1}
                                        />
                                        <foreignObject x={constrainedFlagX - 70} y={constrainedFlagY -
                                            distanceBetweenFlagAndBubble + (constrainedFlagY-distanceBetweenFlagAndBubble > actualFlagYScaled ?
                                                distanceBetweenFlagAndBubble/2 : -distanceBetweenFlagAndBubble/2)} width={140} height={50}>
                                            <div
                                                className={"w-full flex justify-center bg-backgroundPrimary h-full flex-col gap-y-2.5 font-inter rounded-lg border border-contentPrimary items-center"}>
                                                <p className={"text-contentPrimary text-lg font-medium"}>
                                                    You are here
                                                </p>
                                            </div>
                                        </foreignObject>
                                    </Group>
                                );
                            }}
                        </Pack>
                    </Group>
                </svg>
                {tooltipData && (
                    <div>
                        <TooltipWithBounds
                            key={Math.random()}
                            top={tooltipTop - 12}
                            left={tooltipLeft + 12}
                            className={"bg-backgroundPrimary border border-borderOpaque text-contentPrimary"}
                        >
                            <div className={"flex flex-col gap-1"}>
                                <p>Percentile {`${getPercentile(tooltipData)}`}</p>
                                <p>Time Spent: {`${tooltipData.x}`}</p>
                            </div>
                        </TooltipWithBounds>
                    </div>
                )}
            </div>
        );
    }
);
