import * as tf from '@tensorflow/tfjs'

import loggit from './Loggit.js'
import { getModelParams } from './helpers.js'

class TensorflowModel{
    constructor(tensorflowModelName, mstruct, nmels, fw){
        this.loadedModel = null;
        
        //  modelFullName  from the tensorflowModelName split by '/' keeping the right side of the split
        this.modelFullName = tensorflowModelName.split('/').pop();          
        this.modelStructure = mstruct;
        
        switch (mstruct) {
            case 'binSums':
            case 'binMeans':
                loggit.debug('     Tensorflow >>>>> binSums or binMeans');
                this.numMelBands = Number(nmels);
                break;
            case 'binStacks':
                loggit.debug('     Tensorflow >>>>> binStacks');
                this.numMelBands = Number(nmels) * fw;
                break;
            case 'binDeltas':
                loggit.debug('     Tensorflow >>>>> binDeltas');
                this.numMelBands = Number(nmels) * 2;
                break;
            default:
                this.numMelBands = Number(nmels);
        }

        this.inputName = null;

        this.loadModel();
    }

    async loadModel(){
        try {

            loggit.ghost('     Tensorflow >>>>> numMelBands:', this.numMelBands, '     tensorflowModelName:', this.modelFullName);
            this.loadedModel = await tf.loadGraphModel(`/tf/${this.modelStructure}/${this.modelFullName}/model.json`);
            this.inputName = this.loadedModel.inputs[0].name;
            loggit.ghostingOff();
            loggit.ghost('     Tensorflow >>>>> Model loaded successfully: ', this.loadedModel.signature, '     inputName:', this.inputName, '     shape:', this.loadedModel.inputs[0].shape);
        } catch (error) {
            console.error('     Tensorflow >>>>> Failed to load model', error);
        }
    }

    // Function to make a prediction (modify this according to your model input)
    async makeVowelPrediction(inputData) {
        if (!this.loadedModel) {
            loggit.ghost('     Tensorflow >>>>> Model not loaded yet');
            return;
        }
    
        if (inputData.length !== this.numMelBands && this.modelStructure !== 'binStacks') {
            loggit.warning('WARNING     Tensorflow >>>>> Input data is not the correct shape', inputData.length, this.numMelBands);
            return null;
        }
    
        // Convert to tensor and make sure inputData is the correct shape
        const inputTensor = tf.tensor2d([inputData], [1, this.numMelBands], 'float32');
        loggit.ghost('     Tensorflow >>>>> Input Data:', inputTensor.dataSync())
        
        let inputs = {};
        inputs[this.inputName] = inputTensor;
        
        // Perform prediction
        const prediction = this.loadedModel.predict(inputs);
        const results = await prediction.dataSync(); // Await the prediction result
        
        // loggit.ghostingOff();
        loggit.ghost(this.modelFullName, ' Prediction >>>>> ', results);
    
        return results; // Optionally return the results
    }
    
    
}

export { TensorflowModel };
