import { ScatterPlotDatum, ScatterPlotLayerProps } from '@nivo/scatterplot';
import { line } from 'd3-shape';
import React, { useState } from 'react';
import { LinearRegressionResult } from './runLinearRegression';

export const getScatterplotRegressionLayer = (
  linearRegressionResults: {
    id: string | number;
    regression: LinearRegressionResult;
  }[],
  getColorFromGroup
) => {
  const ScatterplotRegressionLayer = ({
    xScale,
    yScale,
  }: ScatterPlotLayerProps<ScatterPlotDatum & { size: number }>) => {
    const domainX = xScale.domain();

    const [selectedGroup, setSelectedGroup] = useState<null | string>();

    const allLines = linearRegressionResults.map((group) => {
      const { slope, intercept } = group.regression;

      const lineGenerator = line<number>()
        .x((x) => xScale(x))
        .y((x) => yScale(slope * x + intercept));

      const color = getColorFromGroup(String(group.id));

      return (
        <g key={group.id}>
          <path
            d={lineGenerator(domainX)}
            stroke={color}
            strokeWidth={2}
            fill={'none'}
            onMouseEnter={() => setSelectedGroup(String(group.id))}
            onMouseLeave={() => setSelectedGroup(null)}
            pointerEvents='visible-stroke'
            cursor={'pointer'}
          />
        </g>
      );
    });

    const regressionDetail = (selectedGroup) => {
      if (!selectedGroup) {
        return null;
      }
      const regressionDetails = linearRegressionResults.find(
        (d) => d.id === selectedGroup
      );

      return (
        <g fill={getColorFromGroup(String(selectedGroup))} fontSize={14}>
          <text x={30} y={30}>Linear Regression</text>
          <text x={30} y={50}>
            {`Slope: ${regressionDetails.regression.slope.toFixed(2)}`}
          </text>
          <text x={30} y={70}>
            {`Intercept: ${regressionDetails.regression.intercept.toFixed(2)}`}
          </text>
          <text x={30} y={90}>
            {`R2: ${regressionDetails.regression.r2.toFixed(2)}`}
          </text>
        </g>
      );
    };

    return (
      <>
        {allLines}
        {regressionDetail(selectedGroup)}
      </>
    );
  };

  return ScatterplotRegressionLayer;
};
