import { useEffect, MutableRefObject } from "react";
import * as d3 from "d3";
import { CompensationModel } from "@modules/compensation/model";
import { getAmountVestedInYear } from "@modules/compensation/api";

interface GraphHookOptions {
  data: CompensationModel;
  d3Container: MutableRefObject<SVGSVGElement | null>;
}

function totalComp(data: GraphHookOptions["data"], date: Date): number {
  let tc =
    data.baseSalary + (data.baseSalary * data.targetBonusPercentage) / 100;
  data.equityGrants.forEach((grant) => {
    tc += getAmountVestedInYear(data.sharePrice, grant, date.getFullYear());
  });

  return tc;
}

const margin = { top: 10, right: 30, bottom: 30, left: 60 };

export const useGraphHook = ({ data, d3Container }: GraphHookOptions) => {
  // Called once for setup
  useEffect(() => {
    const svg = d3.select(d3Container.current);

    svg
      .append("g")
      .attr("transform", `translate(${margin.left}, ${margin.top})`)
      .classed("chartGroup", true);
  }, [d3Container]);

  // Called every time data changes
  useEffect(() => {
    if (data && d3Container.current) {
      const currentYear = new Date().getFullYear();
      const dataset = [0, 1, 2, 3].map((i) => {
        const date = new Date(currentYear + i, 1, 1);
        return {
          date,
          totalComp: totalComp(data, date),
        };
      });

      const yearExtent = d3.extent(dataset, (d) => d.date.getFullYear());
      if (!yearExtent[0] || !yearExtent[1]) {
        throw new Error("Invalid date extent");
      }

      const svg = d3.select(d3Container.current);
      const chartGroup = d3.select(".chartGroup");

      const width = d3Container.current.clientWidth;
      const height = d3Container.current.clientHeight;

      // Add X axis --> it is a date format
      const xScale = d3
        .scaleLinear()
        .domain(yearExtent)
        .range([0, width - margin.left - margin.right]);

      // Add Y axis
      const yScale = d3
        .scaleLinear()
        .domain([
          0,
          d3.max(dataset, (d) => {
            // Adjust the graph so the top is slightly higher than the max TC
            return Math.ceil(d.totalComp / 20000 + 1) * 20000;
          }) as number,
        ])
        .range([height - margin.top - margin.bottom, 0]);

      const update = chartGroup
        .selectAll<SVGCircleElement, number>("circle")
        .data(dataset);

      update
        .enter()
        .append("circle")
        .attr("cx", (d) => xScale(d.date.getFullYear()))
        .attr("cy", (d) => yScale(d.totalComp))
        .attr("r", 5)
        .attr("fill", "steelblue");

      update
        .attr("cx", (d) => xScale(d.date.getFullYear()))
        .attr("cy", (d) => yScale(d.totalComp))
        .attr("r", 5)
        .attr("fill", "steelblue");

      update.exit().remove();

      const line = d3
        .line<{ date: Date; totalComp: number }>()
        .x((d) => xScale(d.date.getFullYear()))
        .y((d) => yScale(d.totalComp));

      const updateLine = chartGroup.selectAll(".line").data([dataset]);
      updateLine
        .enter()
        .append("path")
        .attr("class", "line")
        .attr("fill", "none")
        .attr("stroke", "steelblue")
        .attr("stroke-width", 1.5)
        .attr("d", line);
      updateLine.attr("d", line);
      updateLine.exit().remove();

      svg.select(".y-axis").remove();
      chartGroup
        .append("g")
        .attr(
          "transform",
          `translate(0,${height - margin.top - margin.bottom})`
        )
        .call(
          d3.axisBottom(xScale).ticks(dataset.length).tickFormat(d3.format("d"))
        )
        .classed("x-axis", true);
      svg.select(".y-axis").remove();
      chartGroup.append("g").call(d3.axisLeft(yScale)).classed("y-axis", true);
    }
  }, [data, d3Container]);
};
