import { StimChannelConfiguration } from '@egzotech/exo-session/features/electrostim';
import { logger } from 'helpers/logger';
import { EMSExerciseDefinition } from 'libs/exo-session-manager/core';

import { BaseChartTimelineGenerator } from './BaseTimelineChartGenerator';
import { BaseTimeLinePoint, ChannelIndex, ChannelMap } from './types';

const TIME_REDUCTOR = 1000 * 1000; // convert us to seconds
type ElectrostimTimeLinePoint = BaseTimeLinePoint & {
  phaseIndex: number;
};

type PreparedElectrostimData = {
  channelIndex: ChannelIndex;
  programTime: number;
  phases: StimChannelConfiguration[];
};

type PreparedElectrostimProgram = ChannelMap<PreparedElectrostimData>;

export class ElectrostimChartTimelineGenerator extends BaseChartTimelineGenerator<ElectrostimTimeLinePoint> {
  constructor() {
    super();
  }

  dispose() {
    super.dispose();
  }

  setProgram(exerciseDefinition: EMSExerciseDefinition, channelMapping: Record<number, number>) {
    const preparedProgram = this.prepareElectrostimProgram(exerciseDefinition, channelMapping);

    logger.info('ElectrostimChartTimelineGenerator.setProgram', 'preparedProgram', preparedProgram);
    this._timelines = this.analyzeElectrostimProgram(preparedProgram);
  }

  private prepareElectrostimProgram(exerciseDefinition: EMSExerciseDefinition, channelMapping: Record<number, number>) {
    const phases: PreparedElectrostimProgram = new ChannelMap();
    const emsProgram = exerciseDefinition.ems.program;

    const programTime = emsProgram.programTime;

    for (const _channelIndex of Object.keys(channelMapping)) {
      const channelIndex = parseInt(_channelIndex) as ChannelIndex;
      const sourceChannel = channelMapping[channelIndex];

      const defaultChannelValues = emsProgram.defaultChannelValues.find(v => v.channelIndex == sourceChannel);
      if (!defaultChannelValues) {
        continue;
      }
      const updatedChannelValues = { ...defaultChannelValues, channelIndex };

      for (const [_phaseIdx, phase] of Object.entries(emsProgram.phases)) {
        const phaseIdx = Number(_phaseIdx);

        const stimChannel = {
          ...updatedChannelValues,
          ...phase.channels[sourceChannel],
        };

        if (!phases.channelExists(channelIndex)) {
          phases.set(channelIndex, {
            programTime,
            channelIndex,
            phases: [],
          });
        }
        phases.get(channelIndex).phases[phaseIdx] = stimChannel;
      }
    }
    return phases;
  }

  private analyzeElectrostimProgram(program: PreparedElectrostimProgram) {
    const timelines: ChannelMap<ElectrostimTimeLinePoint[]> = new ChannelMap<ElectrostimTimeLinePoint[]>();

    for (const channelData of program.values()) {
      const programTime = channelData.programTime / TIME_REDUCTOR;
      let channelTimeOffset = 0;
      do {
        const channelIndex = channelData.channelIndex as ChannelIndex;
        if (!timelines.channelExists(channelIndex)) {
          timelines.set(channelIndex, []);
        }
        for (const [phaseIndex, phase] of Object.entries(channelData.phases)) {
          const prevPoint = timelines.get(channelIndex)?.at(-1);
          if (prevPoint) {
            channelTimeOffset = channelTimeOffset = prevPoint.time;
          }
          const result = this.electrostimConvertPhaseToPoints(phase, Number(phaseIndex), channelTimeOffset);

          channelTimeOffset = result.at(-1)?.time ?? channelTimeOffset;

          timelines.get(channelIndex)?.push(...result);
        }
      } while (!programTime || channelTimeOffset < programTime);
    }
    return timelines;
  }

  private electrostimConvertPhaseToPoints(
    stimChannel: StimChannelConfiguration,
    phaseIndex: number,
    timeOffset = 0,
  ): ElectrostimTimeLinePoint[] {
    const result: ElectrostimTimeLinePoint[] = [];

    const dataSet = {
      delayTime: stimChannel.delay / TIME_REDUCTOR,
      riseTime: stimChannel.riseTime / TIME_REDUCTOR,
      plateauTime: (stimChannel.plateauTime ?? 0) / TIME_REDUCTOR,
      fallTime: stimChannel.fallTime / TIME_REDUCTOR,
      pauseTime: ((stimChannel.pauseTime ?? 0) - stimChannel.delay) / TIME_REDUCTOR,
    };
    const dataSetKeys = Object.keys(dataSet) as (keyof typeof dataSet)[];

    let runTime = stimChannel.runTime / TIME_REDUCTOR;
    const calculatedTime = Object.values(dataSet).reduce((acc, val) => acc + val, 0);

    if (!calculatedTime) {
      return [
        {
          time: timeOffset,
          value: 1,
          phaseIndex,
        },
        {
          time: runTime + timeOffset,
          value: 1,
          phaseIndex,
        },
      ];
    }
    runTime = runTime ?? calculatedTime;

    let value = 0;
    let index = 0;
    let currentPhaseTime = 0;

    result.push({
      time: currentPhaseTime + timeOffset,
      value: 0,
      phaseIndex,
    });

    while (currentPhaseTime < runTime) {
      const key = dataSetKeys[index];
      currentPhaseTime = currentPhaseTime + dataSet[key] > runTime ? runTime : currentPhaseTime + dataSet[key];
      switch (key) {
        case 'riseTime':
        case 'plateauTime':
          value = 1;
          break;
        case 'fallTime':
        case 'pauseTime':
        case 'delayTime':
          value = 0;
      }
      result.push({
        time: currentPhaseTime + timeOffset,
        value,
        phaseIndex,
      });

      index++;
      if (index >= dataSetKeys.length) {
        index = 0;
      }
    }
    return result;
  }
}
