/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.util.List;
import java.util.Random;
import loss.Loss;
import matrix.Tensor;
import model.NeuralNetwork;
import trainer.Optimizer;
import util.CLUtils;
import util.FileIO;
import util.NNDevice;

public class Trainer {
    private Optimizer method;
    private Graph trainGraph;

    public Trainer(Optimizer method) {
        this.method = method;
        this.trainGraph = new Graph(true);
    }

    public Trainer(Optimizer method, NNDevice ... dev) throws Exception {
        this.method = method;
        this.trainGraph = this.createGraph(dev, true);
    }

    public double train(NeuralNetwork model, double learningRate, int epochs, DataSet data, int reportEveryNthEpoch, String savefile, boolean loadFromSave, boolean autosave, Random rng) throws Exception {
        if (this.trainGraph == null) {
            throw new Exception("Trainer disposed");
        }
        if (loadFromSave) {
            System.out.println("initializing model from saved state...");
            try {
                FileIO.loadNeuralNetwork(savefile, model);
            }
            catch (Exception e) {
                System.out.println("Oops. Unable to load from a saved state.");
                System.out.println("WARNING: " + e.getMessage());
                System.out.println("Continuing from freshly initialized model instead.");
            }
        }
        double result = 1.0;
        int epoch = 0;
        while (epoch < epochs) {
            String show = "Trainer: epoch[" + Integer.toString(epoch + 1) + "/" + Integer.toString(epochs) + "]";
            System.out.println(show);
            double reportedLossTrain = this.passTensors(learningRate, model, data.training, true, data.lossTraining, data.lossReporting);
            System.gc();
            result = reportedLossTrain;
            if (Double.isNaN(reportedLossTrain) || Double.isInfinite(reportedLossTrain)) {
                throw new Exception("WARNING: invalid value for training loss. Try lowering learning rate.");
            }
            double reportedLossValidation = 0.0;
            double reportedLossTesting = 0.0;
            if (data.validation != null) {
                result = reportedLossValidation = this.passTensors(learningRate, model, data.validation, false, data.lossTraining, data.lossReporting);
            }
            if (data.testing != null) {
                result = reportedLossTesting = this.passTensors(learningRate, model, data.testing, false, data.lossTraining, data.lossReporting);
            }
            show = "train loss = " + String.format("%.5f", reportedLossTrain);
            if (data.validation != null) {
                show = String.valueOf(show) + "valid loss = " + String.format("%.5f", reportedLossValidation);
            }
            if (data.testing != null) {
                show = String.valueOf(show) + "test loss  = " + String.format("%.5f", reportedLossTesting);
            }
            System.out.println(show);
            if (epoch % reportEveryNthEpoch == reportEveryNthEpoch - 1) {
                data.DisplayReport(model, rng);
            }
            if (autosave) {
                System.out.println("Autosaving model...");
                FileIO.saveNeuralNetwork(savefile, model);
                System.out.println("Complete!");
            }
            if (reportedLossTrain == 0.0 && reportedLossValidation == 0.0) {
                System.out.println("Optimal loss reached. Trainer done.");
                break;
            }
            ++epoch;
        }
        return result;
    }

    private synchronized Graph createGraph(NNDevice[] dev, boolean applyTraining) throws Exception {
        if (dev != null && dev[0] != null && dev.length != 0) {
            if (dev.length == 1) {
                return CLUtils.createGraph(dev[0], applyTraining);
            }
            return CLUtils.createGraph(dev[0], dev[1], applyTraining);
        }
        return new Graph(applyTraining);
    }

    public void dispose() {
        this.trainGraph.cleanUp();
        this.trainGraph = null;
    }

    private double passTensors(double learningRate, NeuralNetwork model, List<DataSequence> sequences, boolean applyTraining, Loss lossTraining, Loss lossReporting) throws Exception {
        if (!applyTraining) {
            this.trainGraph.setApplyingBackprop(false);
        }
        double numerLoss = 0.0;
        double denomLoss = 0.0;
        int aaa = 0;
        for (DataSequence seq : sequences) {
            System.out.println("Trainer: Passing sequence " + Integer.toString(aaa + 1) + "/" + Integer.toString(sequences.size()));
            ++aaa;
            model.resetState();
            long startTime = System.currentTimeMillis();
            double counter = 0.0;
            int h = 0;
            while (h < seq.getSequenceLength()) {
                DataStep step = seq.getDataStep(h);
                if (System.currentTimeMillis() - startTime >= 10000L) {
                    System.out.println(String.valueOf(counter / (double)seq.getSequenceLength() * 100.0) + "% done with forward pass");
                    startTime = System.currentTimeMillis();
                    System.gc();
                }
                counter += 1.0;
                Tensor out = model.forward(step.input, this.trainGraph);
                if (step.targetOutput != null) {
                    int x = 0;
                    while (x < out.depth) {
                        double loss = lossReporting.measure(out.matrices[x], step.targetOutput.matrices[x]);
                        if (Double.isNaN(loss) || Double.isInfinite(loss)) {
                            return loss;
                        }
                        numerLoss += loss;
                        if (applyTraining) {
                            lossTraining.backward(out.matrices[x], step.targetOutput.matrices[x]);
                        }
                        ++x;
                    }
                    denomLoss += 1.0;
                }
                ++h;
            }
            if (applyTraining) {
                System.out.println("Backpropagating and updating Model parameters...");
                this.trainGraph.backward();
                this.method.updateParameters(model, learningRate, seq.getSequenceLength());
                System.out.println("Pass Complete!");
            }
            System.gc();
        }
        if (!applyTraining) {
            this.trainGraph.setApplyingBackprop(true);
        }
        model.resetState();
        return numerLoss / denomLoss;
    }
}

