// ScatterPlot.tsx

import { Margin } from '@nivo/core';
import {
  ResponsiveScatterPlot,
  ScatterPlotCustomSvgLayer,
  ScatterPlotDatum,
  ScatterPlotRawSerie,
} from '@nivo/scatterplot';
import { ThemeContext } from 'styled-components';
import { Theme } from '@nivo/core';
import React, { useContext, useState } from 'react';
import { getDatavizTheme } from '@plotting/single-plot-view/plot-panel/plot.themes';
import {
  DEFAULT_LEGEND_CONFIG,
  DEFAULT_TITLE_SIZE,
  DEFAULT_X_AXIS_STYLE,
  DEFAULT_Y_AXIS_STYLE,
} from '@dataviz/constants';
import { getScatterplotTooltip } from './getScatterplotTooltip';
import { scaleOrdinal, scaleSqrt } from 'd3-scale';
import { extent } from 'd3-array';
import { runLinearRegression } from './runLinearRegression';
import { COLORS } from '@utils/scales/color/ColorSchemes';
import { getScatterplotNode } from './getScatterplotNode';
import { KeyColor, AnyScaleSpec } from '@plotting/single-plot-view/plot.types';
import { getScatterplotRegressionLayer } from './getScatterplotRegressionLayer';

const DEFAULT_MARGIN = {
  top: 60,
  right: 160,
  bottom: 90,
  left: 116,
};

export const DEFAULT_CIRCLE_SIZE = 6;

type ScatterPlotProps = {
  data: ScatterPlotRawSerie<
    ScatterPlotDatum & { size: number; color?: string }
  >[];
  title?: string;
  titleSize?: number;
  xAxisName?: string;
  xAxisScale?: AnyScaleSpec;
  yAxisName?: string;
  labelSize?: number;
  yAxisScale?: AnyScaleSpec;
  datavizTheme?: Theme;
  margin?: Margin;
  circleSize?: number | [number, number];
  isLinearRegressionEnabled?: boolean;
  isLegendEnabled?: boolean;
  colorConfig?: KeyColor[];
  axisLabelFontSize?: number;
  tickLabelFontSize?: number;
};

// Centralized filtering function
const filterDataPoints = (points) =>
  points.filter(
    (item) =>
      item.x != null &&
      item.y != null &&
      item.x !== "" &&
      item.y !== ""
  );

export const ScatterPlot = ({
  data,
  title,
  titleSize,
  xAxisName = 'X',
  xAxisScale,
  yAxisName = 'Y',
  yAxisScale,
  datavizTheme,
  margin = DEFAULT_MARGIN,
  circleSize,
  isLinearRegressionEnabled,
  isLegendEnabled,
  colorConfig,
  axisLabelFontSize,
  tickLabelFontSize,
}: ScatterPlotProps) => {
  const { palette } = useContext(ThemeContext);
  const finalDatavizTheme = datavizTheme ?? getDatavizTheme({axisLabelFontSize, tickLabelFontSize}, palette);
  const [width, setWidth] = useState(800);
  const [height, setHeight] = useState(400);

  // Apply centralized filtering to all series
  data = data.map((series) => ({
    ...series,
    data: filterDataPoints(series.data),
  }));

  const plotTitleStyle = {
    fontFamily: finalDatavizTheme.fontFamily,
    fontSize: titleSize || DEFAULT_TITLE_SIZE,
    fill: finalDatavizTheme.textColor,
    textAnchor: 'middle',
  } as const;

  const plotTitle = ({ innerWidth, innerHeight }) => {
    setWidth(innerWidth);
    setHeight(innerHeight);
    return (
      <text x={innerWidth / 2} y={-margin.top / 2} style={plotTitleStyle}>
        {title}
      </text>
    );
  };

  const hasMultipleSeries = data.length > 1;

  // SIZES
  const allSizes = data.flatMap((category) =>
    category.data.map((point) => point.size)
  );
  const sizeDomain = extent(allSizes);
  const sizeRange =
    typeof circleSize === 'undefined'
      ? [DEFAULT_CIRCLE_SIZE, DEFAULT_CIRCLE_SIZE]
      : typeof circleSize === 'number'
      ? [circleSize, circleSize]
      : circleSize;
  const sizeScale = scaleSqrt().domain(sizeDomain).range(sizeRange);

  // COLORS
  const groups = [...new Set(data.map((d) => d.id))];
  const defaultColors = COLORS.find((col) => col.id === 'aseda');
  const defaultColorScale = scaleOrdinal<string>()
    .domain(groups.map((g) => String(g)))
    .range(defaultColors.scheme);
  const getColorFromGroup = (name: string) => {
    if (colorConfig?.length > 0 && name && name !== 'undefined') {
      const config = colorConfig.find((config) => config.id===name);
      if (config && config.color) {        
        return config.color;
      }
    }
    // Fallback to default color
    return (groups && !groups.includes("undefined"))
    ? defaultColorScale(name)
    : (colorConfig && colorConfig.length > 0 ? colorConfig[0].color : defaultColors.scheme[0]);
  };

  // LINEAR REGRESSION
  let linearRegressionLayer: ScatterPlotCustomSvgLayer<
    ScatterPlotDatum & { size: number }
  > = () => {
    return null;
  };

  if (isLinearRegressionEnabled) {
    const linearRegressionResults = data.map((grp) => {
      return {
        id: grp.id,
        regression: runLinearRegression(grp.data),
      };
    });

    linearRegressionLayer = getScatterplotRegressionLayer(
      linearRegressionResults,
      getColorFromGroup
    );
  }

  const determineScales = (): { xAxisScale: AnyScaleSpec; yAxisScale: AnyScaleSpec } => {
    const firstDataPoint = data?.[0]?.data?.[0];

    if (!firstDataPoint) {
      throw new Error('Data is empty or invalid');
    }

    const determineScaleType = (axisType: AnyScaleSpec, value: unknown): AnyScaleSpec => {
      if (axisType?.type === 'log' && value!=null && typeof value === 'number') {
        return { type: 'log', base: 10, min: 'auto', max: 'auto' };
      } else if (value!=null && typeof value === 'number') {
        return { type: 'linear', min: 'auto', max: 'auto' };
      } else if (value!=null && typeof value === 'string') {
        return { type: 'point' };
      } else if (value!=null && value instanceof Date) {
        return { type: 'time', format: 'native', min: 'auto', max: 'auto'  };
      } else {
        return { type: 'point' };
      }
    };

    return {
      xAxisScale: determineScaleType(xAxisScale, firstDataPoint.x),
      yAxisScale: determineScaleType(yAxisScale, firstDataPoint.y),
    };
  };

  const scales = determineScales();

  const isValidForScale = (value: unknown, scaleType?: string) => {
    if (scaleType === 'log' && value!==null && typeof value === 'number') {
      return value > 0;
    }
    return true; // Always valid for non-logarithmic scales
  };
  
  const filteredData = data.map((series) => ({
    ...series,
    data: [
      ...series.data.filter(
        (point) =>
          isValidForScale(point.y, yAxisScale?.type) &&
          isValidForScale(point.x, xAxisScale?.type)
      ),
    ],
  }));

  return (
    <ResponsiveScatterPlot<ScatterPlotDatum & { size: number }>
      data={filteredData}
      margin={margin || DEFAULT_MARGIN}
      theme={finalDatavizTheme}
      layers={[
        plotTitle,
        'grid',
        'axes',
        'nodes',
        'markers',
        linearRegressionLayer,
      ]}
      xScale={scales.xAxisScale}
      yScale={scales.yAxisScale}
      axisBottom={{ ...DEFAULT_X_AXIS_STYLE, legend: xAxisName }}
      axisLeft={{ ...DEFAULT_Y_AXIS_STYLE, legend: yAxisName }}
      enableGridX={true}
      enableGridY={true}
      animate={false}
      tooltip={getScatterplotTooltip({ innerWidth: width, innerHeight: height })}
      legends={
        hasMultipleSeries && isLegendEnabled
          ? [DEFAULT_LEGEND_CONFIG]
          : undefined
      }
      nodeComponent={getScatterplotNode}
      nodeSize={(node) => {
        return node.data.size
          ? sizeScale(node.data.size)
          : typeof circleSize === 'number'
          ? circleSize
          : DEFAULT_CIRCLE_SIZE;
      }}
      colors={(group) => {
        return getColorFromGroup(String(group.serieId));
      }}
      useMesh={false}
    />
  );
};
