import Meyda from "meyda";
import { kmeans } from "ml-kmeans";
import { KMeansResult } from "ml-kmeans/lib/KMeansResult";
import { agnes } from "ml-hclust";
import { decibelToVolume, volumeToDecibel } from "./calcVolume";
import { downloadJson } from "./downloadJson";
import { isDebug } from "../pages/calibrationV2";

// 【デバッグ用】trueでjsonファイルをダウンロード
const NEED_DOWNLOAD_JSON = false;

function calcThresholdMinus1Decibel(threshold: number): number {
  const thresholdDecibel = volumeToDecibel(threshold);
  const thresholdDecibelMinus1 = thresholdDecibel - 1;
  return decibelToVolume(thresholdDecibelMinus1);
}

class HClustUtils {
  static analyze(mfccResult: number[][]): number[][] {
    // 階層的クラスタリングを実行
    const tree = agnes(mfccResult);
    const clusters = tree.children;

    // 各クラスタの中心点を計算
    const initialCentroids = clusters.map((cluster) => {
      const clusterPoints = cluster.indices().map((index) => mfccResult[index]);
      const center = Array(mfccResult[0].length).fill(0);
      clusterPoints.forEach((point) => {
        point.forEach((value, i) => {
          center[i] += value;
        });
      });
      return center.map((value) => value / clusterPoints.length);
    });

    if (isDebug) {
      console.log("initialCentroids", initialCentroids);
      if (NEED_DOWNLOAD_JSON) {
        downloadJson("initialCentroids", initialCentroids);
      }
    }

    return initialCentroids;
  }
}

class KMeansUtils {
  static reorderClusters(kmeansResult: KMeansResult): number[] {
    // 必要に応じてクラスタラベルの修正
    // セントロイドを最初の要素に基づいてソートし、元のインデックスを保持
    const sortedCentroids = kmeansResult.centroids
      .map((centroid, index) => ({ index, centroid }))
      .sort((a, b) => a.centroid[0] - b.centroid[0]);

    // 新しいインデックスのマッピングを作成
    const indexMapping: Record<number, number> = sortedCentroids.reduce(
      (acc, curr, newIndex) => {
        acc[curr.index] = newIndex;
        return acc;
      },
      {} as Record<number, number>
    );

    // kmeans_clustersを更新
    const updatedClusters = kmeansResult.clusters.map(
      (cluster) => indexMapping[cluster]
    );

    const kmeansReorderClusters = updatedClusters.map((cluster) =>
      cluster === 0 ? 0 : 1
    );

    return kmeansReorderClusters;
  }

  static analyze(
    mfccResult: number[][],
    initialCentroids: number[][]
  ): number[] {
    const kmeansResult = kmeans(mfccResult, initialCentroids.length, {
      initialization: initialCentroids,
    });
    const kmeansReorderClusters = this.reorderClusters(kmeansResult);

    if (isDebug) {
      console.log("kmeansResult", kmeansResult);
      console.log("kmeansReorderClusters", kmeansReorderClusters);

      if (NEED_DOWNLOAD_JSON) {
        downloadJson("kmeansCentroids", kmeansResult.centroids);
        downloadJson("kmeansReorderClusters", kmeansReorderClusters);
      }
    }

    return kmeansReorderClusters;
  }
}

class MfccUtils {
  // 標準化（Standardization）
  static standardize(features: number[][]): number[][] {
    // 特徴量を絶対値変換
    const absFeatures = features.map((feature) =>
      feature.map((value) => Math.abs(value))
    );
    const means = Array(absFeatures[0].length).fill(0);
    const stdDevs = Array(absFeatures[0].length).fill(0);

    // 平均と標準偏差の計算
    for (const feature of absFeatures) {
      for (let i = 0; i < feature.length; i++) {
        means[i] += feature[i];
      }
    }
    for (let i = 0; i < means.length; i++) {
      means[i] /= absFeatures.length;
    }
    for (const feature of absFeatures) {
      for (let i = 0; i < feature.length; i++) {
        stdDevs[i] += (feature[i] - means[i]) ** 2;
      }
    }
    for (let i = 0; i < stdDevs.length; i++) {
      stdDevs[i] = Math.sqrt(stdDevs[i] / absFeatures.length);
    }

    // 標準化
    return absFeatures.map((feature) =>
      feature.map((value, index) => (value - means[index]) / stdDevs[index])
    );
  }

  // Min-Maxスケーリング
  static minMaxScale(features: number[][]): number[][] {
    const mins = Array(features[0].length).fill(Infinity);
    const maxs = Array(features[0].length).fill(-Infinity);

    // 最小値と最大値の計算
    for (const feature of features) {
      for (let i = 0; i < feature.length; i++) {
        mins[i] = Math.min(mins[i], feature[i]);
        maxs[i] = Math.max(maxs[i], feature[i]);
      }
    }

    // スケーリング
    return features.map((feature) =>
      feature.map(
        (value, index) => (value - mins[index]) / (maxs[index] - mins[index])
      )
    );
  }

  static analyze(
    recordedData: Float32Array,
    sr: number,
    detectLengthMilliSec: number
  ): {
    normalizedMfccResult: number[][];
    actualMilliSecond: number;
    maxVolumeListByMfcc: number[];
  } {
    // NOTE: Meydaを使ったMFCCではbufferSizeが2のべき乗でないと解析できない
    const bufferSize = Math.pow(
      2,
      Math.round(Math.log((sr * detectLengthMilliSec) / 1000) / Math.log(2))
    );
    // NOTE: bufferSizeは最低512必要（それよりも少ないとMFCCの結果がnullになる）
    if (bufferSize < 512) {
      console.error("bufferSize length must be more than 512.");
    }

    // 実際の解析秒数を計算
    const actualMilliSecond = (bufferSize / sr) * 1000;
    // 検証用のmaxVolumeList
    const maxVolumeListByMfcc: number[] = [];
    const mfccResults: number[][] = [];

    for (let i = 0; i < Math.floor(recordedData.length / bufferSize) - 1; ++i) {
      const end = (i + 1) * bufferSize;
      let frame = recordedData.slice(i * bufferSize, end);
      // フレームサイズが bufferSize より小さい場合、ゼロで埋める
      if (frame.length < bufferSize) {
        frame = new Float32Array(bufferSize);
        frame.set(recordedData.slice(i, end));
      }
      // Meydaを使用してMFCCを抽出
      const mfcc = Meyda.extract("mfcc", frame);
      if (mfcc) {
        mfccResults.push(mfcc as number[]);
      }
      const absFrame = frame.map((v) => Math.abs(v));
      const maxValue = absFrame.reduce(function (a, b) {
        return Math.max(a, b);
      }, -Infinity);
      maxVolumeListByMfcc.push(maxValue);
    }

    const standardizedMfccResult = this.standardize(mfccResults);
    const normalizedMfccResult = this.minMaxScale(standardizedMfccResult);

    if (isDebug) {
      console.log("maxVolumeListWithMfcc", maxVolumeListByMfcc);
      console.log("mfccResults", mfccResults);
      console.log("normalizedMfccResult", normalizedMfccResult);
      if (NEED_DOWNLOAD_JSON) {
        downloadJson("maxVolumeListWithMfcc", maxVolumeListByMfcc);
        downloadJson("mfccResults", mfccResults);
        downloadJson("normalizedMfccResult", normalizedMfccResult);
      }
    }

    return { normalizedMfccResult, actualMilliSecond, maxVolumeListByMfcc };
  }
}

function calcThreshold(
  maxVolumeList: number[],
  detectLengthMilliSec: number,
  actualMilliSecond: number,
  start: number,
  end: number
): number {
  const startFrame =
    Math.floor((start * actualMilliSecond) / detectLengthMilliSec) + 1;
  const endFrame =
    Math.floor((end * actualMilliSecond) / detectLengthMilliSec) + 1;

  // 発話範囲で発話検知されるようにaudioRecorderの発話条件に合った値を設定
  const maxVolumeListStartSpeech = maxVolumeList.slice(
    startFrame,
    startFrame + 17 + 1 // audioRecorderの発話条件(17フレーム連続)よりも厳しく
  );
  const thresholdForDetectVolume = calcThresholdMinus1Decibel(
    Math.min(...maxVolumeListStartSpeech)
  );

  const findThresholdStartIndex = (
    array: number[],
    threshold: number,
    consecutiveCount: number
  ): number => {
    for (let i = 0; i <= array.length - consecutiveCount; i++) {
      if (
        array
          .slice(i, i + consecutiveCount)
          .every((value) => value <= threshold)
      ) {
        return i;
      }
    }
    return -1;
  };

  const adjustThresholdForUtterance = (
    array: number[],
    initialThreshold: number,
    consecutiveCount: number
  ): number => {
    // 初期の閾値またはそれより小さい値のみを考慮
    const filteredArray = array
      .filter((value) => value <= initialThreshold)
      .sort((a, b) => a - b);

    for (const threshold of filteredArray) {
      const startIndex = findThresholdStartIndex(
        array,
        threshold,
        consecutiveCount
      );
      if (startIndex !== -1) {
        const maxInThreshold = Math.max(
          ...array.slice(startIndex, startIndex + consecutiveCount)
        );
        return calcThresholdMinus1Decibel(maxInThreshold);
      }
    }

    return initialThreshold; // 適切な閾値が見つからない場合は初期閾値を返す
  };

  // 発話範囲から発話が途切れない最小の音量を算出
  const maxVolumeListDuringSpeech = maxVolumeList.slice(startFrame, endFrame);
  const thresholdForUtterance = adjustThresholdForUtterance(
    maxVolumeListDuringSpeech,
    thresholdForDetectVolume,
    30
  );

  if (isDebug) {
    console.log("thresholdForDetectVolume", thresholdForDetectVolume);
    console.log("thresholdForUtterance", thresholdForUtterance);
  }

  return thresholdForUtterance;
}

// 1が指定された回数連続しているかを確認するヘルパー関数
function isConsecutiveOnes(
  indexes: number[],
  startIndex: number,
  count: number
): boolean {
  for (let i = 0; i < count; i++) {
    if (indexes[startIndex + i] !== 1) {
      return false;
    }
  }
  return true;
}

function findSpeechSegment(indexes: number[]): {
  start: number;
  end: number;
} {
  let start = -1;
  let end = -1;
  const indexesLength = indexes.length;
  const consecutiveCountForStart = 5; // 発話開始の連続する1の数
  const consecutiveCountForEnd = 5; // 発話終了の連続する1の数

  // 発話開始点の探索
  for (let i = 0; i <= indexesLength - consecutiveCountForStart; i++) {
    if (isConsecutiveOnes(indexes, i, consecutiveCountForStart)) {
      start = i;
      break;
    }
  }

  // 発話終了点の探索
  for (let i = indexesLength - consecutiveCountForEnd; i >= 0; i--) {
    if (isConsecutiveOnes(indexes, i, consecutiveCountForEnd)) {
      end = i + consecutiveCountForEnd - 1; // 連続する最後の1のインデックス
      break;
    }
  }

  return { start, end };
}

export function autoVad(
  recordedData: Float32Array,
  sr: number,
  maxVolumeListState: number[],
  detectLengthMilliSec = 10 // MFCC解析するデータ長(mSec)
): {
  thresholdByAutoVad: number;
  maxVolumeListWithVadResult: { value: number; class: number }[];
  voiceDetectionSegmentByAutoVad: number[];
} {
  // MFCCを実行して正規化したデータを取得
  const { normalizedMfccResult, actualMilliSecond, maxVolumeListByMfcc } =
    MfccUtils.analyze(recordedData, sr, detectLengthMilliSec);
  // 階層的クラスタリングを実行してK-Meansの初期センターを取得
  const initialCentroids = HClustUtils.analyze(normalizedMfccResult);

  // MFCCの結果と初期センターを用いてK-Meansによる発話解析を実行
  const detectSpeechResultList = KMeansUtils.analyze(
    normalizedMfccResult,
    initialCentroids
  );

  const { start, end } = findSpeechSegment(detectSpeechResultList);
  // FIXME: 10はmaxVolumeListStateのフレーム長(mSec)なので、共通パラメータ化すべき
  const startFrame = Math.floor((start * actualMilliSecond) / 10);
  const endFrame = Math.floor((end * actualMilliSecond) / 10);

  const threshold = calcThreshold(
    maxVolumeListState,
    detectLengthMilliSec,
    actualMilliSecond,
    start,
    end
  );

  if (isDebug) {
    console.log("maxVolumeList", maxVolumeListState);
    console.log("actualMilliSecond", actualMilliSecond);
    console.log("startFrame", startFrame);
    console.log("endFrame", endFrame);
    if (NEED_DOWNLOAD_JSON) {
      const vadResult = {
        startFrame: startFrame,
        endFrame: endFrame,
        threshold: threshold,
        maxVolumeList: maxVolumeListState,
      };
      downloadJson("vadResult", vadResult);
    }
  }

  return {
    thresholdByAutoVad:
      threshold === Infinity || threshold === -Infinity ? 0 : threshold,
    maxVolumeListWithVadResult: maxVolumeListByMfcc.map((value, index) => ({
      value,
      class: detectSpeechResultList[index],
    })),
    voiceDetectionSegmentByAutoVad: [start, end],
  };
}
