import { createSlice } from "@reduxjs/toolkit";
import { Api } from "../../api/api";
import { Utils } from "../../utils/utils";
import {
  setIsWarningModalOpen,
  setModalWarning,
  setSelectedModel,
} from "./toolbar";

const modelsSlice = createSlice({
  name: "models",
  initialState: {
    temporaryModel: null,
    trainingModelTrained: false,
    data: [],
    fetching: false,
    training: false,
    trainingTime: null,
    trainingPatternsNumber: null,
    retrainModel: null,
  },
  reducers: {
    fetchingModels: (state) => {
      state.data = [];
      state.fetching = true;
    },
    getModelsSuccess: (state, action) => {
      state.data = action.payload;
      state.fetching = false;
    },
    trainingStart: (state, action) => {
      state.training = true;
      state.trainingTime = Date.now();
      state.trainingPatternsNumber = action.payload;
    },
    trainingFinish: (state) => {
      state.training = false;
      state.trainingTime = Date.now() - state.trainingTime;
      state.trainingModelTrained = true;
    },
    clearTrainingModels: (state) => {
      state.temporaryModel = null;
      state.trainingModelTrained = false;
      state.retrainModel = null;
    },
    createModelSuccess: (state, action) => {
      state.temporaryModel = action.payload.id;
    },
    setModelToRetrain: (state, action) => {
      state.retrainModel = action.payload;
      state.trainingModelTrained = true;
    },
  },
});

export const {
  fetchingModels,
  getModelsSuccess,
  trainingStart,
  trainingFinish,
  clearTrainingModels,
  createModelSuccess,
  trainingPatternsNumberSet,
  setModelToRetrain,
} = modelsSlice.actions;

export default modelsSlice.reducer;

// Actions

const DEFAULT_MODEL_NAME = "_saffronblue_default_model";

//Obtain list of models for selection
export const getModels =
  ({ keepSelectedModel }) =>
  async (dispatch, getState) => {
    dispatch(fetchingModels());
    const { models } = await Api.getModels();

    const [savedModels, temporaryModels] = splitModels(models);
    //discard any temporary models (unsaved from previous sessions):
    temporaryModels.forEach((model) => dispatch(removeModel(model.id)));
    dispatch(getModelsSuccess(savedModels));

    //Update current selected model
    const newSelectedModel = getSelectedModel(
      keepSelectedModel,
      models,
      getState().toolbar.selectedModel
    );

    dispatch(setSelectedModel(newSelectedModel));
  };

export const removeModel = (id, onError = null, onSuccess = null) => async (dispatch) => {
  try {
    await Api.removeModel(id);
    dispatch(getModels({ keepSelectedModel: true }));
    if (onSuccess) {
      onSuccess();
    }
  } catch (err) {
    if (onError) {
      onSuccess();
    }
    console.log(err);
  }
};

export const editModel = (id, name) => async (dispatch, getState) => {
  const confidence_threshold = getState().detectedPatterns.confidence;
  const payload = { name, confidence_threshold };
  const response = await Api.editModel(id, payload);

  if (response) dispatch(getModels({ keepSelectedModel: true }));
  return response;
};

export const onSaveModel = (name, onError, setSuccess) => async (dispatch, getState) => {
  const modelIdToSave = getState().models.temporaryModel;

  const confidence = getState().detectedPatterns.confidence;

  const data = {
    name,
    confidence_threshold: confidence,
  };

  const modelToRetrain = getState().models.retrainModel;

  try {
    if (modelToRetrain) {

      data.model_id_to_replace_with = modelIdToSave

      await Api.replaceModel(modelToRetrain.id, data);

    } else {
      await Api.editModel(modelIdToSave, data);
    }


    dispatch(clearTrainingModels());
    dispatch(getModels({ keepSelectedModel: false }));

    setSuccess(true);
  } catch (err) {
    onError(err.message);
    setSuccess(false);
  }

};

export const trainModel = (numberOfPatterns) => async (dispatch, getState) => {
  const selectedPatterns = Object.values(getState().selectedPatterns.entities);
  const selectedDetectedPatterns = Object.values(
    getState().selectedDetectedPatterns.entities
  );

  const model = formatPatternsForApi(
    selectedPatterns,
    selectedDetectedPatterns
  );
  dispatch(trainingStart(numberOfPatterns));
  const newModel = await Api.createModel(model);
  await dispatch(processTrainedModel(newModel));
  return true;
};

export const coldTrainModel =
  (numberOfPatterns) => async (dispatch, getState) => {
    const rawPatterns = getState().coldStart.entities;
    const model = formatColdPatternsForApi(rawPatterns);

    dispatch(trainingStart(numberOfPatterns));
    const newModel = await Api.createColdModel(model);

    await dispatch(processTrainedModel(newModel));
    return true;
  };

export const processTrainedModel = (model) => async (dispatch) => {
  dispatch(trainingFinish());
  if (!model) {
    dispatch(setModalWarning(`Training failed`));
    dispatch(setIsWarningModalOpen(true));
    return;
  }
  const newModelData = await Api.getModel(model.id);
  dispatch(createModelSuccess(newModelData));
};

function formatColdPatternsForApi(rawPatterns) {
  const patterns = Object.values(rawPatterns).map((pattern) =>
    //remove baseline point (used for UI only)
    pattern.shape.slice(0, pattern.shape.length - 1)
  );

  return {
    data: patterns,
    name: DEFAULT_MODEL_NAME,
  };
}
const formatPoint = (point) => {
  return {
    datetime: Utils.convertTimestampToDatetime(point.time * 1000),
    price: point.price,
  };
};
function formatPattern(pattern) {
  const shape = pattern.shapes[0]; //ignore trend lines etc.
  return {
    pattern_type: shape.shape_type,
    shapes: [
      {
        shape_type: shape.shape_type,
        points: shape.points.map((point) => formatPoint(point)),
      },
    ],
  };
}

function formatPatternsForApi(selectedPatterns, selectedDetectedPatterns) {
  const model = {
    data: [],
    name: DEFAULT_MODEL_NAME,
  };
  model.data = formatPatternsData(selectedPatterns, selectedDetectedPatterns);
  return model;
}

function formatPatternsData(selectedPatterns, selectedDetectedPatterns) {
  const charts = {};

  const patterns = [...selectedPatterns, ...selectedDetectedPatterns];

  patterns.forEach((pattern) => {
    const chartId = pattern.chartId;
    if (!charts[chartId]) {
      charts[chartId] = {
        chart_id: chartId,
        patterns: [],
      };
    }
    const formattedPattern = formatPattern(pattern);
    charts[chartId].patterns.push(formattedPattern);
  });

  return Object.values(charts);
}

const splitModels = (models) => {
  const savedModels = [];
  const temporaryModels = [];

  models.forEach((model) => {
    if (model.name === DEFAULT_MODEL_NAME) temporaryModels.push(model);
    else savedModels.push(model);
  });
  return [savedModels, temporaryModels];
};
const getSelectedModel = (keepSelectedModel, models, previousModel) => {
  if (!keepSelectedModel) return models[0];
  const newModel = models.find((model) => model.id === previousModel.id);
  return newModel || models[0];
};
