<template>
  <CytomineModal :active="active" @close="$emit('update:active', false)">
    <BLoading :is-full-page="false" :active="loading" class="small" />
    <template v-if="!loading">
      <BSteps
        v-model="activeStep"
        class="flex flex-column h-full w-full align-center justify-center"
      >
        <!-- 1ST STEP -- GENERAL -->
        <BStepItem
          :step="1"
          :label="$t('general')"
          class="flex flex-column justify-center h-full"
          style="width:500px; margin:0 auto;"
        >
          <BField :label="$t('title') + ' *'">
            <IdxInput v-model="title" label="" type="text" name="title" />
          </BField>
          <BField :label="$t('tags')" class="mt-3">
            <DomainTagInput
              v-model="tags"
              :domains="[]"
              :allow-new="true"
              placeholder="tag-placeholder"
            />
          </BField>
          <BField
            v-if="existingTrainings && existingTrainings.length"
            class="mt-3"
            :label="$t('starting-weights')"
          >
            <BSelect v-model="selectedTrainingId" expanded class="w-full">
              <option
                v-for="training in existingTrainings"
                :key="training.pid"
                :value="training.pid"
              >
                {{ training.title }} ({{ training.pid }})
              </option>
            </BSelect>
          </BField>
          <BField
            v-if="
              selectedTrainingDetails &&
                selectedTrainingDetails.checkpoints.length
            "
          >
            <BSelect v-model="selectedWeights" expanded class="w-full">
              <option
                v-for="weights in selectedTrainingDetails.checkpoints"
                :key="weights.file_path"
                :value="weights.file_path"
              >
                Epoch: {{ weights.epoch }} - Step: {{ weights.step }}
              </option>
            </BSelect>
          </BField>
          <p
            v-if="
              selectedTrainingDetails &&
                !selectedTrainingDetails.checkpoints.length
            "
            class="italic text-center"
          >
            {{ $t('no-available-weights') }}
          </p>
          <BField
            v-if="trainingConfigurations && trainingConfigurations.length"
            class="mt-3"
            :label="$t('configuration')"
          >
            <BSelect
              v-model="selectedConfigurationName"
              expanded
              class="w-full"
            >
              <option
                v-for="option in trainingConfigurations"
                :key="option.configuration_name"
                :value="option.configuration_name"
              >
                {{ option.configuration_name }}
              </option>
            </BSelect>
          </BField>
        </BStepItem>

        <!-- 2ND STEP -- PARAMETERS -->
        <BStepItem :step="2" :label="$t('parameters')" class="px-5">
          <AssayParameters
            :assay-params="
              selectedConfiguration && selectedConfiguration.parameters
            "
            :is-valid.sync="paramsValid"
            :id-project="idProject"
            @valuesChanged="(values) => (paramsOutput = values)"
            @isValid="(value) => (paramsValid = value)"
          />
        </BStepItem>

        <!-- 3RD STEP -- TERMS -->
        <BStepItem :step="3" label="Terms" class="pt-4">
          <h2 class="mt-5 text-center">
            {{ $t('training-terms-description') }}
          </h2>
          <p
            class="mt-4 mb-5"
            style="font-style:italic;max-width:700px; text-align:center; margin:0 auto; font-size:0.8em;"
          >
            {{ $t('parent-term-note') }}
          </p>
          <div class="mt-5 pl-5">
            <table style="width:700px; margin:0 auto;">
              <tr v-for="(termGroup, i) in groupedTerms" :key="i">
                <td v-for="term in termGroup" :key="term.id">
                  <BCheckbox
                    v-model="selectedTerms"
                    :value="term.name"
                    :native-value="term.name"
                  >
                    <span
                      class="inline-block p-2"
                      :style="{ background: term.color }"
                    />
                    {{ term.name }}
                  </BCheckbox>
                </td>
              </tr>
            </table>
          </div>
        </BStepItem>

        <!-- 4TH STEP -- ANNOTATIONS -->
        <BStepItem :step="4" label="Annotations" class="pt-4">
          <h2 class="mt-5 pb-5 text-center" style="font-style:italic;">
            {{ $t('training-users-description') }}
          </h2>
          <div class="mt-5 pl-5">
            <table style="width:600px; margin:0 auto;">
              <thead>
                <tr>
                  <td class="pb-3" />
                  <td class="pb-3 weight-6">
                    {{ $t('users') }}
                  </td>
                  <td align="center" class="pb-3 weight-6">
                    # {{ $t('annotations') }}
                  </td>
                </tr>
              </thead>

              <tbody>
                <tr v-for="(userLayer, i) in layersViewModel" :key="i">
                  <td class="pt-1">
                    <BCheckbox
                      v-model="selectedUserIds"
                      :value="userLayer.userId"
                      :native-value="userLayer.userId"
                    />
                  </td>
                  <td class="pt-1">
                    {{ userLayer.fullName }}
                  </td>
                  <td align="center" class="pt-1">
                    {{ userLayer.numAnnotations }}
                  </td>
                </tr>
              </tbody>
            </table>
          </div>
        </BStepItem>

        <!-- 5TH STEP -- TRAINING IMAGES -->
        <BStepItem :step="5" :label="$t('images')" class="step-images pt-4">
          <h2 class="mt-5 pb-5 text-center" style="font-style:italic;">
            {{ $t('training-images-description') }}
          </h2>
          <div class="mt-5 pl-5">
            <table style="width:700px; margin:0 auto;">
              <thead>
                <tr>
                  <td class="pb-3" />
                  <td align="center" class="pb-3 weight-6">
                    {{ $t('train') }}
                  </td>
                  <td align="center" class="pb-3 weight-6">
                    {{ $t('validation') }}
                  </td>
                </tr>
              </thead>
              <tr v-for="image in selectedImageDetails" :key="image.imageId">
                <td class="pt-2">
                  {{ image.imageName }}
                </td>
                <td align="center" class="pt-2">
                  <BRadio
                    v-model="image.type"
                    class="mx-0"
                    :name="'image_' + image.imageId"
                    native-value="train"
                  />
                </td>
                <td align="center" class="pt-2">
                  <BRadio
                    v-model="image.type"
                    class="mx-0"
                    :name="'image_' + image.imageId"
                    native-value="validation"
                  />
                </td>
              </tr>
            </table>
          </div>
        </BStepItem>

        <!-- NAVIGATION BUTTONS -->
        <template #navigation="{previous, next}">
          <div class="flex w-full pr-5 justify-end">
            <span
              v-if="!paramsValid"
              class="mt-1 mr-4 red"
              style="font-size: 0.8rem; display: inline-block"
            >
              {{ $t('fix-validation-first') }}
            </span>
            <IdxBtn
              class="mr-3"
              :disabled="previous.disabled"
              @click="previous.action"
            >
              {{ $t('previous') }}
            </IdxBtn>
            <IdxBtn
              type="submit"
              color="primary"
              :disabled="!isStepValid"
              @click="goNext(next.action)"
            >
              {{
                activeStep === steps.length - 1
                  ? $t('training-launch')
                  : $t('next')
              }}
            </IdxBtn>
          </div>
        </template>
      </BSteps>
    </template>
    <!-- stub to hide footer -->
    <template #footer>
      <span />
    </template>
  </CytomineModal>
</template>
<script>
import { ImageInstance } from 'cytomine-client';
import {
  GetOutputs,
  GetTrainingConfigurations,
  GetTrainingStatus,
  RunTraining,
} from '../../services/trainingFramework.js';
import AssayParameters from '../utils/AssayParameters.vue';
import CytomineModal from '@/components/utils/CytomineModal.vue';
import DomainTagInput from '@/components/utils/DomainTagInput.vue';

const getDefaultState = () => ({
  title: '',
  tags: [],
  loading: true,
  adaptedParams: [],
  existingTrainings: [],
  selectedTrainingId: null,
  selectedTrainingDetails: null,
  startingWeights: [],
  selectedWeights: null,
  paramsValid: false,
  paramsOutput: null,
  selectedConfigurationName: '',
  selectedTerms: [],
  trainingConfigurations: [],
  selectedUserIds: [],
  steps: ['General', 'Parameters', 'Terms', 'Annotations', 'Images'],
  activeStep: 0,
  selectedImageDetails: [],
  layers: [],
});

export default {
  name: 'LaunchTrainingModal',
  components: {
    AssayParameters,
    CytomineModal,
    DomainTagInput,
  },
  props: {
    active: Boolean,
    idProject: {
      type: [String, Number],
      required: true,
    },
    selectedTrainingModel: {
      type: [String, Number],
      default: () => null,
    },
    selectedImages: {
      type: Array,
      default: () => [],
    },
    selectedTraining: {
      type: Object,
      default: null,
    },
    revision: {
      type: Number,
      default: () => 0,
    },
  },
  data() {
    return getDefaultState();
  },
  computed: {
    ontology() {
      return this.$store.state.currentProject.ontology;
    },
    /** @returns {CytoUser} */
    currentUser() {
      return this.$store.state.currentUser.user;
    },
    projectMembers() {
      return this.$store.state.currentProject.members;
    },
    selectedConfiguration() {
      return this.trainingConfigurations.find(
        (configuration) =>
          configuration.configuration_name === this.selectedConfigurationName
      );
    },
    layersViewModel() {
      const layersByUserId = this.layers.reduce((group, layer) => {
        if (group[layer.user] === undefined) {
          group[layer.user] = [];
        }
        group[layer.user].push(layer);
        return group;
      }, {});

      const layerViewModels = [];
      for (const key in layersByUserId) {
        const userLayers = layersByUserId[key];
        const user = this.projectMembers.find(
          (user) => user.id === userLayers[0].user
        );
        layerViewModels.push({
          userId: user?.id || userLayers[0].user,
          fullName: user?.fullName || 'Deleted user',
          numAnnotations: userLayers.reduce(
            (sum, layer) => sum + layer.countAnnotation,
            0
          ),
        });
      }
      return layerViewModels;
    },
    isStepValid() {
      let isValid = false;
      switch (this.activeStep) {
        case 0:
          isValid = this.title != null && this.title !== '';
          break;
        case 1:
          isValid = this.paramsValid === true;
          break;
        case 2:
          isValid = this.selectedTerms.length > 0;
          break;
        case 3:
          isValid = this.selectedUserIds.length > 0;
          break;
        case 4:
          isValid = true;
          break;
      }
      return isValid;
    },
    filteredTerms() {
      // remove parent term - backend is already using this term by default
      return this.ontology.terms.filter(
        (term) => term.name.toLowerCase() !== 'parent'
      );
    },
    groupedTerms() {
      const groupSize = 2;
      const groupedTerms = [];
      let activeGroup = [];
      for (const term of this.filteredTerms) {
        activeGroup.push(term);
        if (activeGroup.length === groupSize) {
          groupedTerms.push(activeGroup);
          activeGroup = [];
        }
      }
      if (activeGroup.length > 0) {
        groupedTerms.push(activeGroup);
      }
      return groupedTerms;
    },
  },
  watch: {
    async active() {
      if (this.active) {
        this.loading = true;

        await this.fetchTrainingConfigurations();
        if (this.trainingConfigurations?.length > 0) {
          this.selectedConfigurationName = this.trainingConfigurations[0].configuration_name;
        }

        // get training configurations
        if (this.selectedTraining) {
          this.selectedTrainingId = this.selectedTraining.pid;
          this.selectedTrainingDetails = await this.getTrainingDetails(
            this.selectedTraining.pid
          );
          if (this.selectedTrainingDetails?.config) {
            this.selectedConfigurationName = this.selectedTrainingDetails.config;
            // fill in default parameters for config params
            const selectedConfigParams = this.trainingConfigurations.find(
              (config) =>
                config.configuration_name === this.selectedConfigurationName
            )?.parameters;
            for (const param of selectedConfigParams) {
              if (param.default) {
                const newDefault = this.selectedTrainingDetails.parameters.find(
                  (p) => p.name === param.name
                );
                param.default.value = newDefault.value;
              }
            }
          }
        }

        // get existing trainings for selecting starting weights
        const trainings = await this.fetchTrainings();
        this.existingTrainings = trainings.filter(
          (training) =>
            (training.valohai_job_status === 'complete' &&
              training.model === this.selectedTrainingModel) ||
            (this.selectedTraining &&
              this.selectedTraining.pid === training.pid)
        );

        // set default training type for selected images
        let selectedImageIds = [];
        if (this.selectedTraining) {
          this.selectedImageDetails = [
            ...this.selectedTraining.train_image_ids.map((image) => ({
              imageId: image.image_id,
              imageName: image.image_name,
              imagePath: image.image_path,
              type: 'train',
            })),
            ...this.selectedTraining.validation_image_ids.map((image) => ({
              imageId: image.image_id,
              imageName: image.image_name,
              imagePath: image.image_path,
              type: 'validation',
            })),
          ];
          selectedImageIds = this.selectedImageDetails.map(
            (image) => image.imageId
          );
        } else {
          selectedImageIds = this.selectedImages.map((image) => image.id);
          this.selectedImageDetails = this.selectedImages.map((image) => ({
            imageId: image.id,
            imageName: image.instanceFilename,
            imagePath: image.fullPath,
            type: 'train', // 'train' or 'validation'
          }));
        }

        // get layers on all selected images
        const layersByImage = await Promise.all(
          selectedImageIds.map(async (imageId) => {
            const imageInstance = await ImageInstance.fetch(imageId);
            return await imageInstance.fetchAnnotationsIndex();
          })
        );

        if (layersByImage?.length > 0) {
          this.layers = layersByImage.flatMap((a) => a);
        }

        // set defaults for pre-selected trainings
        if (this.selectedTraining) {
          this.title = this.selectedTraining.title;
          this.selectedTerms = this.selectedTraining.terms;
          this.selectedUserIds = this.selectedTraining.user_ids;
        }

        this.loading = false;
      }
    },
    async selectedTrainingId(newValue) {
      this.loading = true;
      this.selectedTrainingDetails = await this.getTrainingDetails(newValue);
      this.loading = false;
    },
  },
  methods: {
    goNext(nextStep) {
      if (this.activeStep === this.steps.length - 1) {
        this.launchTraining();
      } else {
        nextStep();
      }
    },
    async fetchTrainingConfigurations() {
      return new Promise(async (resolve, reject) => {
        try {
          this.trainingConfigurations = await GetTrainingConfigurations(
            this.selectedTrainingModel
          );
          resolve();
        } catch (error) {
          this.$notify({
            type: 'error',
            text: 'Error retrieving training configuration.',
          });
          console.log(error);
          reject();
        }
      });
    },
    async fetchTrainings() {
      return new Promise(async (resolve, reject) => {
        try {
          const trainings = await GetOutputs(this.idProject);
          resolve(trainings || []);
        } catch (error) {
          this.$notify({
            type: 'error',
            text: 'Error retrieving trainings.',
          });
          console.log(error);
          reject([]);
        }
      });
    },
    async getTrainingDetails(trainingId) {
      return await GetTrainingStatus(trainingId);
    },
    async launchTraining() {
      if (!this.paramsValid) return;

      try {
        this.loading = true;
        const requestBody = {
          params: JSON.stringify(this.paramsOutput),
          projectId: this.idProject,
          selectedImgIds: this.selectedImgIds,
        };

        const response = await RunTraining(
          this.idProject,
          this.selectedTrainingModel,
          this.selectedConfigurationName,
          this.tags,
          this.currentUser.id,
          this.selectedUserIds,
          this.selectedImageDetails.map((image) => ({
            // eslint-disable-next-line camelcase
            image_id: image.imageId,
            // eslint-disable-next-line camelcase
            image_name: image.imageName,
            image_path: image.imagePath,
            type: image.type,
          })),
          this.selectedTerms,
          this.paramsOutput,
          this.title,
          process.env.NODE_ENV === 'production' ? 'PROD' : 'STAGE',
          this.selectedWeights
        );

        // reset modal state
        for (const key in getDefaultState()) {
          this[key] = getDefaultState()[key];
        }

        this.$notify({
          type: 'success',
          text: this.$t('training-run-start'),
        });

        this.$emit('update:active', false);
      } catch (error) {
        console.error(error);
        this.$notify({
          type: 'error',
          text: this.$t('training-run-error'),
        });
      } finally {
        this.loading = false;
      }
    },
  },
};
</script>

<style scoped>
h4 {
  font-weight: bold;
}

>>> .animation-content {
  max-width: 60% !important;
  width: 60%;
}

>>> .modal-card {
  width: 100%;
  height: 80vh;
}

.image-overview {
  max-height: 4rem;
  max-width: 10rem;
}
.red {
  color: red;
}
</style>
<style>
.b-steps .step-content {
  height: 100%;
  width: 100%;
  flex: 1;
  overflow: auto !important;
}
.b-steps .steps {
  width: 100%;
}
.b-steps .step-item {
  outline: none;
  box-shadow: none;
}
.b-steps .step-content > div {
  outline: none;
  box-shadow: none;
}
.step-images .b-radio.radio .control-label {
  padding-left: 0 !important;
}
label .label {
  margin-bottom: 0px;
}
label .label:not(:last-child) {
  margin-bottom: 0px;
}
</style>
