import React from 'react';
import { curveCatmullRom, line } from 'd3-shape';
import { compile } from 'mathjs';
import {
  ScatterPlotCustomSvgLayer,
  ScatterPlotDatum,
  ScatterPlotLayerProps,
} from '@nivo/scatterplot/dist/types/types';
import { ScaleOrdinal } from 'd3-scale';
import { DrcCurve } from './dose-response-curve.types';

const NUMBER_OF_INTERVALS = 100;

const getModeledResponseFromDose = (min, max, IC50, dose, slope) => {
  const ic50curve = compile('min + (max - min) / (1 + (IC50 / dose) ^ slope)');
  return ic50curve.evaluate({ min, max, IC50, dose, slope });
};

const getLogIntervals = (steps: number, domainX: [number, number]) => {
  const powMin = Math.log(domainX[0]) / Math.log(10);
  const powMax = Math.log(domainX[1]) / Math.log(10);
  const step = (powMax - powMin) / steps;
  const intervals: number[] = [];
  for (let i = 0; i <= steps; i++) {
    intervals.push(Math.pow(10, powMin + i * step));
  }
  intervals.push(Math.pow(10, powMax));
  return intervals;
};

export const getCurveLayer = (
  curveData: DrcCurve[],
  colorScale: ScaleOrdinal<string, string, never>
): ScatterPlotCustomSvgLayer<ScatterPlotDatum> => {
  const CurveLayer = (props: ScatterPlotLayerProps<ScatterPlotDatum>) => {
    const { innerWidth, xScale, yScale } = props;

    const domainX = xScale.domain();
    const intervalDots = getLogIntervals(
      innerWidth / NUMBER_OF_INTERVALS,
      domainX
    );

    return (
      <>
        {curveData.map((curve) => {
          const { min, max, IC50, slope } = curve;

          const curveGenerator = line<number>()
            .x((dose) => xScale(dose))
            .y((dose) => {
              const modeledResponse = getModeledResponseFromDose(
                min,
                max,
                IC50,
                dose,
                slope
              );
              return yScale(modeledResponse);
            })
            .curve(curveCatmullRom);

          return (
            <g key={curve.id} shapeRendering='geometricPrecision'>
              <path
                d={curveGenerator(intervalDots)}
                stroke={colorScale(curve.id)}
                strokeWidth={2}
                fill={'none'}
                shapeRendering='optimizeQuality'
              />
            </g>
          );
        })}
      </>
    );
  };

  return CurveLayer;
};


// const DoseResponseCurve = (
//   layer: LayerProps,
//   curve: DrcCurve,
//   color: string,
//   theme?: PlotTheme
// ) => {
//   const innerWidth = layer.innerWidth;
//   const xScale = layer.xScale.copy().interpolate(interpolateNumber); // override Nivo/d3 scale rounding
//   const yScale = layer.yScale.copy().interpolate(interpolateNumber); // override Nivo/d3 scale rounding
//   const { switches } = useContext(PlotContext);
//   const { showCurveSampling } = switches || {};
//   const ic50curve = compile('min + (max - min) / (1 + (IC50 / dose) ^ slope)');
//   const scope = { ...curve } as Dictionary;
//   const yMin = yScale.domain()[0];

//   const expression = (dose) => {
//     scope.dose = dose;
//     return ic50curve.evaluate(scope);
//   };

//   const curveGenerator = line()
//     .x((d) => xScale(d))
//     .y((d) => yScale(expression(d)))
//     .curve(curveNatural);
//   const lineGenerator = line<Coordinate>()
//     .x((d) => xScale(d.x))
//     .y((d) => yScale(d.y));
//   const domainY = yScale.domain();
//   const ic50 = [
//     { x: curve.IC50, y: domainY[0] },
//     { x: curve.IC50, y: domainY[1] },
//   ] as Coordinate[];

//   const getLogIntervals = (steps) => {
//     const domainX = xScale.domain();
//     const powMin = Math.log(domainX[0]) / Math.log(10);
//     const powMax = Math.log(domainX[1]) / Math.log(10);
//     const step = (powMax - powMin) / steps;
//     const intervals = [];
//     for (let i = 0; i <= steps; i++) {
//       intervals.push(Math.pow(10, powMin + i * step));
//     }
//     intervals.push(Math.pow(10, powMax));
//     return intervals;
//   };

//   const intervalDots = getLogIntervals(innerWidth / 16);
//   const pixelDots = showCurveSampling ? getLogIntervals(innerWidth / 4) : [];
//   const mapCoords = (intervals) =>
//     intervals.map((interval) => {
//       return [xScale(interval), yScale(expression(interval))];
//     });

//   const IC50 = curve.IC50;

//   const marker = (
//     <StyledGroup>
//       <g transform={`translate(${xScale(IC50)},${yScale(yMin) + 4})`}>
//         <text textAnchor='middle' style={{ fontSize: `${8}px` }} fill={color}>
//           ▲
//         </text>
//       </g>
//       <g transform={`translate(${xScale(IC50)},${yScale(yMin) + 14})`}>
//         <text
//           textAnchor='middle'
//           fill={color}
//           style={{
//             fontSize: theme?.fontSize,
//             fontFamily: theme?.fontFamily,
//           }}
//         >
//           {IC50}
//         </text>
//       </g>
//     </StyledGroup>
//   );

//   return (
//     <g key={curve.id} shapeRendering='geometricPrecision'>
//       <path
//         d={curveGenerator(intervalDots)}
//         stroke={color}
//         strokeWidth={
//           theme?.lines?.strokeWidth !== undefined
//             ? theme?.lines?.strokeWidth
//             : 2
//         }
//         fill={'none'}
//         shapeRendering='optimizeQuality'
//       />
//       <path
//         d={lineGenerator(ic50)}
//         stroke={color}
//         strokeWidth={
//           theme?.lines?.secondaryStrokeWidth !== undefined
//             ? theme?.lines?.secondaryStrokeWidth
//             : 2
//         }
//         fill={'none'}
//       />
//       {theme?.fontSize > 0 && marker}
//       {showCurveSampling &&
//         mapCoords(pixelDots).map((coord, index) => (
//           <g key={index} transform={`translate(${coord[0]},${coord[1]})`}>
//             <rect x={-0.5} y={-0.5} width={1.0} height={1.0} fill={'orange'} />
//           </g>
//         ))}
//       {showCurveSampling &&
//         mapCoords(intervalDots).map((coord, index) => (
//           <g key={index} transform={`translate(${coord[0]},${coord[1]})`}>
//             <rect x={-0.5} y={-0.5} width={1.0} height={1.0} fill={'black'} />
//           </g>
//         ))}
//     </g>
//   );
// };

// const StyledGroup = styled.g`
//   font-family: ${(props) => props.theme?.fonts?.main};
//   font-size: 11px;
//   font-weight: 100;
// `;
