import { Annotation, HtmlLabel } from "@visx/annotation";
import { css } from "@emotion/core";
import { Group } from "@visx/group";
import { scaleLinear } from "@visx/scale";
import { Accessor } from "@visx/shape/lib/types";
import { bisector, max, min } from "d3-array";
import { AnimatePresence } from "framer-motion";
import React, { useEffect, useMemo, useState } from "react";

import { reduceVerticalOverlap } from "lib/reduceOverlap";
import { usePointerPosition } from "../../../hooks";
import { lighten } from "../../../lib/colorManipulator";
import * as M from "../../../materials";
import { O, pipe } from "../../../prelude";
import {
  HOVER_CIRCLE_RADIUS,
  HOVER_CIRCLE_RADIUS_ACTIVE,
} from "../shared/constants";
import HoverCircle from "../shared/HoverCircle";
import HoverLines from "../shared/HoverLines";
import { groupBy } from "../utils";
import Axis from "./components/Axis";
import Band from "./components/Band";
import Legend from "./components/Legend";
import Rule from "./components/Rule";
import Series from "./components/Series";
import Tooltip from "./components/Tooltip";
import { LIGHTEN_COEFFICIENT, PADDING } from "./constants";
import {
  ActivePoint,
  BandItem,
  Orientation,
  RuleItem,
  SeriesItem,
} from "./types";

export interface TrajectoriesChartProps<A> {
  children?: React.ReactNode;
  width: number;
  height: number;
  dataset: Array<A>;
  bands?: BandItem[];
  rules: RuleItem[];
  xAccessor: Accessor<A, number>;
  yAccessor: Accessor<A, number | null>;
  seriesAccessor: Accessor<A, string>;
  lineStyle: (
    series: string | number,
    values?: A[]
  ) => React.CSSProperties & { stroke: string; strokeWidth: number };
  yAxisTitle: string;
  xAxisTitle: string;
  xAxisSubtitle?: string;
  formatTooltipLabel: (d: A) => string;
  formatTooltipValue: (d: A) => string;
  formatSeriesLabel?: (series: string | number, values: A[]) => string;
  lockActivePoint?: A;
  description?: string;
  backgroundSeriesFilter?: (series: string | number) => boolean;
  filterSeriesValues?: (d: A) => boolean;
  isInteractive?: boolean;
  showColorLegend?: boolean;
  legendValues?: Array<string>;
  formatLegendLabel?: (series: string) => React.ReactNode;
  emptyChartMessage?: string;
  flipTooltip?: boolean;
  legendOrientation?: Orientation;
}

function Trajectories<A extends Record<string, unknown>>({
  width,
  height,
  description,
  dataset,
  seriesAccessor,
  xAccessor,
  yAccessor,
  lineStyle,
  formatTooltipLabel,
  formatTooltipValue,
  formatSeriesLabel,
  yAxisTitle,
  xAxisTitle,
  xAxisSubtitle,
  rules,
  lockActivePoint,
  bands = [],
  filterSeriesValues,
  backgroundSeriesFilter = () => false,
  isInteractive = true,
  showColorLegend = false,
  formatLegendLabel,
  legendValues,
  emptyChartMessage = "",
  flipTooltip = false,
  legendOrientation = "vertical",
}: TrajectoriesChartProps<A>): React.ReactElement {
  const [activePoint, setActivePoint] = useState<ActivePoint<A> | null>(null);
  const [pointerPositionRef, pointerPosition] =
    usePointerPosition<SVGRectElement>();

  const boundedWidth = width - PADDING.left - PADDING.right;
  const boundedHeight = height - PADDING.top - PADDING.bottom;
  const seriesList = legendValues ?? [
    ...new Set(dataset.map((d) => seriesAccessor(d))),
  ];

  // SCALES
  const scaleX = useMemo(
    () =>
      scaleLinear({
        domain: [min(dataset, xAccessor) ?? -1, max(dataset, xAccessor) ?? -1],
        range: [0, boundedWidth],
      }),
    [dataset, xAccessor, boundedWidth]
  );

  const scaleY = useMemo(
    () =>
      scaleLinear({
        domain: [0, 1], // We always show percentages from 0% to 100%
        range: [PADDING.top + boundedHeight, PADDING.top],
      }),
    [boundedHeight]
  );

  const bandsToRender =
    activePoint && activePoint.datum
      ? bands.filter((band) =>
          [band.upper, band.lower].includes(seriesAccessor(activePoint.datum))
        )
      : [];

  /**
   * Group data by series for lines and sort for rendering
   **/
  const seriesData: SeriesItem<A>[] = useMemo(() => {
    const groupedData = groupBy(dataset, seriesAccessor) as Array<{
      key: string;
      values: A[];
    }>;

    // Move background series to the beginning of the array so they are drawn first
    let groupedAndSortedData = groupedData.sort((a, b) =>
      backgroundSeriesFilter(a.key) ? -1 : backgroundSeriesFilter(b.key) ? 1 : 0
    );

    // If there's an active point, make sure that respective series is drawn last
    if (activePoint?.datum) {
      groupedAndSortedData = groupedAndSortedData.sort((a, b) =>
        a.key === seriesAccessor(activePoint.datum)
          ? 1
          : b.key === seriesAccessor(activePoint.datum)
            ? -1
            : 0
      );
    }

    return groupedAndSortedData;
  }, [dataset, seriesAccessor, backgroundSeriesFilter, activePoint?.datum]);

  /**
   * Scale used to vertically position tooltips, in order to avoid overlaps.
   **/
  const yTooltipScale = useMemo(() => {
    const labelHeight = 20 + 2;
    const getCurrentDatum = (series: SeriesItem<A>) =>
      series.values.find((d: A) => xAccessor(d) === activePoint?.x);

    const lookup = reduceVerticalOverlap<SeriesItem<A>, string | number>(
      seriesData.filter((series) => {
        if (!activePoint || !activePoint.datum) return false;
        return (
          series.key === seriesAccessor(activePoint.datum) ||
          !backgroundSeriesFilter(series.key)
        );
      }),
      labelHeight,
      (series) => series.key,
      (series) => {
        const currentSeriesDatum = getCurrentDatum(series);
        if (!currentSeriesDatum) return 0;
        return scaleY(yAccessor(currentSeriesDatum) || 0);
      }
    );
    return (key: string | number) => {
      const el = lookup.find((d) => d.id === key);
      return el ? el.top : 0;
    };
  }, [
    seriesData,
    backgroundSeriesFilter,
    seriesAccessor,
    xAccessor,
    yAccessor,
    scaleY,
    activePoint,
  ]);

  useEffect(() => {
    if (lockActivePoint) {
      setActivePoint({ x: xAccessor(lockActivePoint), datum: lockActivePoint });
    }
  }, [lockActivePoint, xAccessor]);

  // Handle on hover interactions
  useEffect(() => {
    // If a point is locked, we don't handle hover events
    if (lockActivePoint || !isInteractive) return;

    pipe(
      pointerPosition,
      O.fold(
        // When user is not hovering chart, there is no value for 'pointerPosition'.
        () => {
          setActivePoint(null);
        },
        // Given hovered coordinates, we get corresponding data values.
        (pos) => {
          const hoveredValue = [scaleX.invert(pos.x), scaleY.invert(pos.y)];
          const hoverRuleX = Math.round(hoveredValue[0]);

          if (hoverRuleX > scaleX.domain()[1]) {
            setActivePoint(null);
          } else {
            const pointsAtX = dataset.filter(
              (d) => xAccessor(d) === hoverRuleX && yAccessor(d) !== null
            );

            const hoverPointIndex = bisector(yAccessor).left(
              pointsAtX.sort(
                (a, b) => (yAccessor(a) ?? 0) - (yAccessor(b) ?? 0)
              ),

              hoveredValue[1]
            );
            const hoverPoint =
              pointsAtX[hoverPointIndex] ?? pointsAtX[pointsAtX.length - 1];

            setActivePoint({ x: hoverRuleX, datum: hoverPoint });
          }
        }
      )
    );
  }, [
    dataset,
    pointerPosition,
    xAccessor,
    scaleX,
    yAccessor,
    scaleY,
    lockActivePoint,
    isInteractive,
  ]);

  return (
    <div
      css={css`
        margin-top: ${M.spacing.base8(8)};
      `}
    >
      <svg width={width} height={height} style={{ overflow: "visible" }}>
        <desc>{description}</desc>
        <Group>
          <Axis<A>
            scaleX={scaleX}
            scaleY={scaleY}
            xAxisTitle={xAxisTitle}
            xAxisSubtitle={xAxisSubtitle}
            yAxisTitle={yAxisTitle}
            boundedHeight={boundedHeight}
            boundedWidth={boundedWidth}
          />
          {bandsToRender.length > 0 && (
            <Group id="bands">
              {bandsToRender.map((band, index) => (
                <Band<A>
                  key={index}
                  band={band}
                  seriesData={seriesData}
                  xAccessor={xAccessor}
                  yAccessor={yAccessor}
                  scaleX={scaleX}
                  scaleY={scaleY}
                />
              ))}
            </Group>
          )}
          {rules && (
            <Group id="rules">
              {rules.map((rule, index) => (
                <Rule
                  key={index}
                  rule={rule}
                  scaleX={scaleX}
                  scaleY={scaleY}
                  activePointX={activePoint?.x}
                />
              ))}
            </Group>
          )}
          <Series
            data={seriesData}
            seriesAccessor={seriesAccessor}
            xAccessor={xAccessor}
            yAccessor={yAccessor}
            scaleX={scaleX}
            scaleY={scaleY}
            lineStyle={lineStyle}
            activePoint={activePoint}
            boundedWidth={boundedWidth}
            backgroundSeriesFilter={backgroundSeriesFilter}
            formatSeriesLabel={formatSeriesLabel}
            filterSeriesValues={filterSeriesValues}
          />
          {/*  Render Hover Elements */}
          {activePoint && activePoint.datum && (
            <Group>
              <HoverLines
                xPos={scaleX(activePoint.x) ?? 0}
                y1={scaleY.range()[0]}
                y2={scaleY.range()[1]}
              />
              {seriesData.map(({ key, values }) => {
                const seriesPointOnLine = values.find(
                  (d) => xAccessor(d) === activePoint.x && yAccessor(d) !== null
                );
                if (!seriesPointOnLine) return null;

                if (
                  bandsToRender.length > 0 &&
                  !bandsToRender
                    .flatMap((d) => [d.lower, d.upper])
                    .includes(key)
                ) {
                  return;
                }

                const [x0, x1] = scaleX.range();
                const isFlipped =
                  flipTooltip &&
                  scaleX(xAccessor(seriesPointOnLine)) > (x1 - x0) / 2;

                const isActivePoint =
                  yAccessor(activePoint.datum) ===
                    yAccessor(seriesPointOnLine) &&
                  seriesAccessor(activePoint.datum) ===
                    seriesAccessor(seriesPointOnLine);

                const isHighlighted =
                  isActivePoint || !backgroundSeriesFilter(key);

                const stroke = isHighlighted
                  ? lineStyle(key, values).stroke
                  : lighten(lineStyle(key, values).stroke, LIGHTEN_COEFFICIENT);

                return (
                  <AnimatePresence key={key}>
                    <Group left={scaleX(xAccessor(seriesPointOnLine))}>
                      <HoverCircle
                        r={
                          isHighlighted
                            ? HOVER_CIRCLE_RADIUS_ACTIVE
                            : HOVER_CIRCLE_RADIUS
                        }
                        fill={isHighlighted ? M.whiteText : stroke}
                        stroke={stroke}
                        cx={0}
                        cy={scaleY(yAccessor(seriesPointOnLine) ?? 0)}
                      />
                      <AnimatePresence>
                        {isHighlighted && (
                          <Annotation
                            y={yTooltipScale(key)}
                            dx={
                              +M.spacing.base8(1.5, "") * (isFlipped ? -1 : 1)
                            }
                          >
                            <HtmlLabel
                              showAnchorLine={false}
                              horizontalAnchor={isFlipped ? "end" : "start"}
                              containerStyle={{ position: "fixed" }}
                            >
                              <Tooltip
                                value={formatTooltipValue(seriesPointOnLine)}
                                label={formatTooltipLabel(seriesPointOnLine)}
                                color={stroke}
                              />
                            </HtmlLabel>
                          </Annotation>
                        )}
                      </AnimatePresence>
                    </Group>
                  </AnimatePresence>
                );
              })}
            </Group>
          )}
        </Group>

        <rect
          ref={pointerPositionRef}
          x={0}
          y={0}
          width={width}
          height={height - PADDING.bottom}
          fill="transparent"
        />
        {dataset.length === 0 && (
          <Annotation x={width / 2} y={height / 2}>
            <HtmlLabel
              showAnchorLine={false}
              horizontalAnchor="middle"
              verticalAnchor="middle"
              containerStyle={{ minWidth: width * 0.8 }}
            >
              <div
                css={css`
                  width: 100%;
                  text-align: center;
                  background: white;
                  border: ${M.divider} solid 1px;
                  padding: ${M.spacing.base8(2)} ${M.spacing.base8(4)};
                  ${M.fontBody2}
                `}
              >
                <p>{emptyChartMessage}</p>
              </div>
            </HtmlLabel>
          </Annotation>
        )}
      </svg>
      {showColorLegend && dataset.length > 0 && (
        <Legend
          series={seriesList}
          lineStyle={lineStyle}
          formatLegendLabel={formatLegendLabel}
          orientation={legendOrientation}
        />
      )}
    </div>
  );
}

export default Trajectories;
