// Copyright 2021-2024 Luminary Cloud, Inc. All Rights Reserved.
/*
  A stacked bar chart, based visx.BarStack, with a categorical x-axis
  and a numeric y-axis.
*/

import React, { useEffect, useRef, useState } from 'react';

import { AxisBottom, AxisRight } from '@visx/axis';
import { localPoint } from '@visx/event';
import { GridRows } from '@visx/grid';
import { scaleBand, scaleLinear, scaleOrdinal } from '@visx/scale';
import { BarStack } from '@visx/shape';
import { SeriesPoint } from '@visx/shape/lib/types';
import { defaultStyles, useTooltip, useTooltipInPortal } from '@visx/tooltip';

import { colors } from '../lib/designSystem';

// A data point representing a single stack of bars at one X value.
// Must contain exactly the following keys:
// "label": string
//     The label for the X axis at this data point.
// For each category in props.categories:
// {category.name}: number
//     The Y value for this category at this X value
export interface Datum {
  [category: string]: number | string;
}

type TooltipData = {
  bar: SeriesPoint<Datum>;
  key: string;
  height: number;
  width: number;
  x: number;
  y: number;
  color: string;
};

const tooltipStyles = {
  ...defaultStyles,
  minWidth: 60,
  backgroundColor: colors.surfaceMedium1,
  color: colors.highEmphasisText,
};

export interface Category {
  name: string;
  color: string;
}

/* A chart for calories consumed by meal might have props like the following:
 * categories: [
 *   { name: 'breakfast', color: '#92d143' },
 *   { name: 'lunch', color: '#6367d5' },
 *   { name: 'dinner', color: '#ff8d4d' },
 *   { name: 'snack', color: '#26a2ff' },
 * ],
 * data: [
 *   { label: 'June 15', breakfast: 1000, lunch: 630, dinner: 840, snack: 325 },
 *   { label: 'June 16', breakfast: 0, lunch: 900, dinner: 540, snack: 750 },
 *   ...
 * ],
 * unit: 'calories',
 * ...
 * The resulting chart would have four bars for June 15 and three for June 16, with
 * the 'breakfast' bar at the bottom of the stack and 'snack' bar at the top.
 * A bar's height is determined by the number of calories specified in the datum,
 * while its color is determined as specified in the corresponding category.
 */

// Props for the BarStackChart
interface BarStackChartProps {
  // The list of categories.  Each X value will have a set of stacked bars,
  // one per category.  This determines the color, order and names of these stacked bars.
  categories: Category[];
  // The list of data.  Each datum represents a single X value, and contains both
  // the label for that X value and the value for each category at that X value.
  data: Datum[];
  // Height of the chart, in pixels.
  height: number;
  // Width of the chart, in pixels.
  width: number;
  // Color of the axes.
  axisColor: string;
  // Background color for the chart.
  bgColor: string;
  // Color for the axis labels.
  labelColor: string;
  // Unit of the y-axis.
  unit: string;
}

const xLabel = (datum: Datum) => datum.label as string;

const xScaleFn = (arg: any) => scaleBand<string>(arg);
const yScaleFn = (arg: any) => scaleLinear<number>(arg);
const colorScaleFn = (arg: any) => scaleOrdinal<string, string>(arg);

// This is the state of the chart, as reproducible from the inputs of the chart.
interface ChartState {
  keys: string[];
  xScale: ReturnType<typeof xScaleFn>;
  yScale: ReturnType<typeof yScaleFn>;
  colorScale: ReturnType<typeof colorScaleFn>;
  xMax: number;
  yMax: number;
}

// Given the chart props, determine the chart state.
function computeChartState(props: BarStackChartProps): ChartState {
  const keys = props.categories.map((cat) => cat.name);

  const yTotals = props.data.map(
    (datum) => keys.reduce((total, key) => total + (datum[key] as number), 0),
  );

  const xScale = scaleBand<string>({
    domain: props.data.map(xLabel),
    padding: 0.2,
  });

  const yScale = scaleLinear<number>({
    domain: [0, Math.max(...yTotals)],
    nice: true,
  });

  const colorScale = scaleOrdinal<string, string>({
    domain: keys,
    range: props.categories.map((cat) => cat.color),
  });

  const xMax = props.width - 32;
  const yMax = props.height - 28;

  xScale.rangeRound([0, xMax]);
  yScale.range([yMax, 0]);

  return {
    keys,
    xScale,
    yScale,
    colorScale,
    xMax,
    yMax,
  };
}

const BarStackChart = (props: BarStackChartProps) => {
  const {
    tooltipOpen,
    tooltipLeft,
    tooltipTop,
    tooltipData,
    hideTooltip,
    showTooltip,
  } = useTooltip<TooltipData>();

  const { containerRef, TooltipInPortal } = useTooltipInPortal({
    scroll: true,
  });

  const tooltipTimeout = useRef<number>(0);

  const [chartState, setChartState] = useState<ChartState | null>(null);

  // Tooltip causes a lot of re-renders, so minimize the repeated work.
  useEffect(() => {
    setChartState(computeChartState(props));
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [props.categories, props.data, props.height, props.width]);

  if (props.width < 10 || !chartState) {
    return null;
  }

  return (
    <div style={{ position: 'relative', display: 'flex' }}>
      <svg height={props.height} ref={containerRef} width={props.width}>
        <rect fill={props.bgColor} height={props.height} rx={3} width={props.width} x={0} y={0} />
        <GridRows
          left={-8}
          numTicks={4}
          scale={chartState.yScale}
          stroke={props.axisColor}
          top={0}
          width={chartState.xMax}
        />
        <BarStack<Datum, string>
          color={chartState.colorScale}
          data={props.data}
          keys={chartState.keys}
          x={xLabel}
          xScale={chartState.xScale}
          yScale={chartState.yScale}>
          {(barStacks) => barStacks.map((barStack) => barStack.bars.map((bar) => (
            <rect
              fill={bar.color}
              height={bar.height}
              key={`bar-stack-${barStack.index}-${bar.index}`}
              onMouseLeave={() => {
                tooltipTimeout.current = window.setTimeout(() => {
                  hideTooltip();
                }, 300);
              }}
              onMouseMove={(event) => {
                if (tooltipTimeout.current) {
                  clearTimeout(tooltipTimeout.current);
                }
                const eventSvgCoords = localPoint(event);
                const left = bar.x + bar.width / 2;
                showTooltip({
                  tooltipData: bar,
                  tooltipTop: eventSvgCoords?.y,
                  tooltipLeft: left,
                });
              }}
              width={bar.width}
              x={bar.x}
              y={bar.y}
            />
          )))}
        </BarStack>
        <AxisBottom
          scale={chartState.xScale}
          stroke={props.axisColor}
          tickLabelProps={() => ({
            fill: props.labelColor,
            fontSize: 10,
            textAnchor: 'middle',
          })}
          tickStroke={props.axisColor}
          top={chartState.yMax}
        />
        <AxisRight
          hideAxisLine
          hideTicks
          left={chartState.xMax}
          numTicks={4}
          scale={chartState.yScale}
          stroke={props.axisColor}
          tickLabelProps={() => ({
            fill: props.labelColor,
            fontSize: 10,
            textAnchor: 'middle',
          })}
          tickStroke={props.axisColor}
        />
      </svg>
      {tooltipOpen && tooltipData && (
        <TooltipInPortal left={tooltipLeft} style={tooltipStyles} top={tooltipTop}>
          <div style={{ color: chartState.colorScale(tooltipData.key) }}>
            <strong>{tooltipData.key}</strong>
          </div>
          <div>{`${tooltipData.bar.data[tooltipData.key]} ${props.unit}`}</div>
          <div>
            <small>{tooltipData.bar.data.label}</small>
          </div>
        </TooltipInPortal>
      )}
    </div>
  );
};

export default BarStackChart;
