import { type ChartSerieTimeData } from '../useHighchartOptions'

export type GranularityType = 'day' | 'week' | 'month' | 'quarter'
export type ChartMetricApproximation = 'average' | 'sum'

export type GranularityInfo = {
  approximation?: ChartMetricApproximation
  unit: {
    name: 'day' | 'week' | 'month'
    count: number
  }
}

function getStartOfPeriod(
  date: Date | number,
  period: GranularityType,
  count: number,
): number {
  const parsedDate = new Date(date)
  switch (period) {
    case 'week':
      const day = parsedDate.getUTCDay()
      const diff = parsedDate.getUTCDate() - day + (day === 0 ? -6 : 1) // adjust when day is sunday
      parsedDate.setUTCDate(diff + 7 * (count - 1))
      break
    case 'month':
      parsedDate.setUTCDate(1)
      const month = parsedDate.getUTCMonth()
      parsedDate.setUTCMonth(month - (month % count), 1)

      break
    default:
      throw new Error(
        "Invalid period. Choose from 'week', 'month', 'quarter', 'year'.",
      )
  }

  parsedDate.setUTCHours(0, 0, 0, 0)
  return parsedDate.getTime()
}

export function groupDataBasedOnGranularity(
  serieData: ChartSerieTimeData,
  { approximation = 'sum', unit }: GranularityInfo,
  zoom: number,
): ChartSerieTimeData {
  const serieDataForPeriod = serieData.slice(-1 * zoom)
  if (!serieDataForPeriod.length) {
    return serieDataForPeriod
  }

  const startOfPeriodBasedOnZoom = getStartOfPeriod(
    serieDataForPeriod[0][0],
    unit.name,
    unit.count,
  )

  // remove all data that is not in the current period
  const filteredData = serieData.filter(
    ([timestamp]) => timestamp >= startOfPeriodBasedOnZoom,
  )

  // Group the data by the start of the period
  const groupedData = new Map<number, Array<number | null>>()
  // a helper to keep track of null values so we don't count them for average
  const groupedNullData = new Map<number, number>()
  for (const row of filteredData) {
    const startOfPeriod = getStartOfPeriod(row[0], unit.name, unit.count)
    const periodData = groupedData.get(startOfPeriod) ?? []
    let nullCount = groupedNullData.get(startOfPeriod) ?? 0

    periodData.push(row[1])

    if (row[1] === null) {
      nullCount++
    }

    groupedData.set(startOfPeriod, periodData)
    groupedNullData.set(startOfPeriod, nullCount)
  }

  const result: ChartSerieTimeData = []
  for (const [key, values] of groupedData) {
    const sum = values.reduce((prev, curr) => {
      if (prev === null) {
        return 0
      }

      if (curr === null) {
        return prev
      }

      return curr + prev
    }, 0)

    // if we only have null values, we should return null
    if (groupedNullData.get(key) === values.length) {
      result.push([key, null])
    } else {
      const average =
        (sum as number) / (values.length - (groupedNullData.get(key) ?? 0))

      result.push([key, approximation === 'sum' ? sum : average])
    }
  }

  return result
}
