/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import datastructs.TensorDataSequence;
import datastructs.TensorDataSet;
import datastructs.TensorDataStep;
import java.util.List;
import java.util.Random;
import loss.Loss;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvNet;
import model.NeuralNetwork;
import trainer.TrainingMethod;
import util.FileIO;

public class NewTrainer {
    private TrainingMethod method;

    public NewTrainer(TrainingMethod method) {
        this.method = method;
    }

    public double train(NeuralNetwork model, double learningRate, int epochs, DataSet data, int reportEveryNthEpoch, String savefile, boolean loadFromSave, boolean autosave, Random rng) throws Exception {
        if (loadFromSave) {
            System.out.println("initializing model from saved state...");
            try {
                FileIO.loadNeuralNetwork(savefile, model);
            }
            catch (Exception e) {
                System.out.println("Oops. Unable to load from a saved state.");
                System.out.println("WARNING: " + e.getMessage());
                System.out.println("Continuing from freshly initialized model instead.");
            }
        }
        double result = 1.0;
        int epoch = 0;
        while (epoch < epochs) {
            String show = "Trainer: epoch[" + Integer.toString(epoch + 1) + "/" + Integer.toString(epochs) + "]";
            System.out.println(show);
            long startTime = System.currentTimeMillis();
            double reportedLossTrain = this.pass(learningRate, model, data.training, true, data.lossTraining, data.lossReporting);
            System.err.println(System.currentTimeMillis() - startTime);
            System.gc();
            result = reportedLossTrain;
            if (Double.isNaN(reportedLossTrain) || Double.isInfinite(reportedLossTrain)) {
                throw new Exception("WARNING: invalid value for training loss. Try lowering learning rate.");
            }
            double reportedLossValidation = 0.0;
            double reportedLossTesting = 0.0;
            if (data.validation != null) {
                result = reportedLossValidation = this.pass(learningRate, model, data.validation, false, data.lossTraining, data.lossReporting);
            }
            if (data.testing != null) {
                result = reportedLossTesting = this.pass(learningRate, model, data.testing, false, data.lossTraining, data.lossReporting);
            }
            show = String.valueOf(show) + "train loss = " + String.format("%.5f", reportedLossTrain);
            if (data.validation != null) {
                show = String.valueOf(show) + "valid loss = " + String.format("%.5f", reportedLossValidation);
            }
            if (data.testing != null) {
                show = String.valueOf(show) + "test loss  = " + String.format("%.5f", reportedLossTesting);
            }
            System.out.println(show);
            if (epoch % reportEveryNthEpoch == reportEveryNthEpoch - 1) {
                data.DisplayReport(model, rng);
            }
            if (autosave) {
                System.out.println("Autosaving model...");
                FileIO.saveNeuralNetwork(savefile, model);
                System.out.println("Complete!");
            }
            if (reportedLossTrain == 0.0 && reportedLossValidation == 0.0) {
                System.out.println("Optimal loss reached. Trainer done.");
                break;
            }
            ++epoch;
        }
        return result;
    }

    public double trainConvNet(ConvNet model, double learningRate, int epochs, TensorDataSet data, int reportEveryNthEpoch, String savefile, boolean loadFromSave, boolean autosave, Random rng) throws Exception {
        if (loadFromSave) {
            System.out.println("initializing model from saved state...");
            try {
                FileIO.loadNeuralNetwork(savefile, model);
            }
            catch (Exception e) {
                System.out.println("Oops. Unable to load from a saved state.");
                System.out.println("WARNING: " + e.getMessage());
                System.out.println("Continuing from freshly initialized model instead.");
            }
        }
        double result = 1.0;
        int epoch = 0;
        while (epoch < epochs) {
            String show = "Trainer: epoch[" + Integer.toString(epoch + 1) + "/" + Integer.toString(epochs) + "]";
            System.out.println(show);
            long startTime = System.currentTimeMillis();
            double reportedLossTrain = this.passTensors(learningRate, model, data.training, true, data.lossTraining, data.lossReporting);
            System.err.println(System.currentTimeMillis() - startTime);
            System.gc();
            result = reportedLossTrain;
            if (Double.isNaN(reportedLossTrain) || Double.isInfinite(reportedLossTrain)) {
                throw new Exception("WARNING: invalid value for training loss. Try lowering learning rate.");
            }
            double reportedLossValidation = 0.0;
            double reportedLossTesting = 0.0;
            if (data.validation != null) {
                result = reportedLossValidation = this.passTensors(learningRate, model, data.validation, false, data.lossTraining, data.lossReporting);
            }
            if (data.testing != null) {
                result = reportedLossTesting = this.passTensors(learningRate, model, data.testing, false, data.lossTraining, data.lossReporting);
            }
            show = String.valueOf(show) + "train loss = " + String.format("%.5f", reportedLossTrain);
            if (data.validation != null) {
                show = String.valueOf(show) + "valid loss = " + String.format("%.5f", reportedLossValidation);
            }
            if (data.testing != null) {
                show = String.valueOf(show) + "test loss  = " + String.format("%.5f", reportedLossTesting);
            }
            System.out.println(show);
            if (epoch % reportEveryNthEpoch == reportEveryNthEpoch - 1) {
                data.DisplayReport(model, rng);
            }
            if (autosave) {
                System.out.println("Autosaving model...");
                FileIO.saveNeuralNetwork(savefile, model);
                System.out.println("Complete!");
            }
            if (reportedLossTrain == 0.0 && reportedLossValidation == 0.0) {
                System.out.println("Optimal loss reached. Trainer done.");
                break;
            }
            ++epoch;
        }
        return result;
    }

    public double pass(double learningRate, NeuralNetwork model, List<DataSequence> sequences, boolean applyTraining, Loss lossTraining, Loss lossReporting) throws Exception {
        ++model.t;
        long totalTime = 0L;
        long cntr = 0L;
        double numerLoss = 0.0;
        double denomLoss = 0.0;
        int aaa = 0;
        boolean testForPassTime = true;
        boolean printDelay = false;
        long printDelayTime = System.currentTimeMillis();
        for (DataSequence seq : sequences) {
            if (!printDelay || System.currentTimeMillis() - printDelayTime >= 1000L) {
                System.out.println("Trainer: Passing sequence " + Integer.toString(aaa + 1) + "/" + Integer.toString(sequences.size()));
                printDelayTime = System.currentTimeMillis();
            }
            ++aaa;
            model.resetState();
            Graph g = new Graph(applyTraining);
            long startTime = System.currentTimeMillis();
            double counter = 0.0;
            for (DataStep step : seq.steps) {
                if (System.currentTimeMillis() - startTime >= 10000L) {
                    System.out.println(String.valueOf(counter / (double)seq.steps.size() * 100.0) + "% done with forward pass");
                    startTime = System.currentTimeMillis();
                    System.gc();
                }
                counter += 1.0;
                Matrix output = model.forward(step.input, g);
                if (step.targetOutput == null) continue;
                double loss = lossReporting.measure(output, step.targetOutput);
                if (Double.isNaN(loss) || Double.isInfinite(loss)) {
                    return loss;
                }
                numerLoss += loss;
                denomLoss += 1.0;
                if (!applyTraining) continue;
                lossTraining.backward(output, step.targetOutput);
            }
            if (applyTraining) {
                if (!printDelay) {
                    System.out.println("Backpropagating and updating Model parameters...");
                }
                g.backward();
                long tim = System.nanoTime();
                this.method.updateParameters(model, learningRate, seq.steps.size());
                totalTime += System.nanoTime() - tim;
                ++cntr;
                g.cleanUp();
                System.gc();
                if (!printDelay) {
                    System.out.println("Epoch Complete!");
                }
            }
            if (!testForPassTime) continue;
            testForPassTime = false;
            if (System.currentTimeMillis() - startTime >= 2048L) continue;
            printDelay = true;
        }
        System.out.println("Average pass time in MS is " + Double.toString((double)totalTime / (double)cntr / 1000.0));
        return numerLoss / denomLoss;
    }

    public double passTensors(double learningRate, ConvNet model, List<TensorDataSequence> sequences, boolean applyTraining, Loss lossTraining, Loss lossReporting) throws Exception {
        ++model.t;
        double numerLoss = 0.0;
        double denomLoss = 0.0;
        int aaa = 0;
        long startTime2 = System.currentTimeMillis();
        for (TensorDataSequence seq : sequences) {
            if (System.currentTimeMillis() - startTime2 >= 1000L) {
                startTime2 = System.currentTimeMillis();
            }
            System.out.println("Trainer: Passing sequence " + Integer.toString(aaa + 1) + "/" + Integer.toString(sequences.size()));
            ++aaa;
            model.resetState();
            Graph g = new Graph(applyTraining);
            long startTime = System.currentTimeMillis();
            double counter = 0.0;
            int h = 0;
            while (h < seq.getSequenceLength()) {
                TensorDataStep step = seq.getDataStep(h);
                if (System.currentTimeMillis() - startTime >= 10000L) {
                    System.out.println(String.valueOf(counter / (double)seq.getSequenceLength() * 100.0) + "% done with forward pass");
                    startTime = System.currentTimeMillis();
                    System.gc();
                }
                counter += 1.0;
                Tensor out = model.forward(step.input, g);
                if (step.targetOutput != null) {
                    int x = 0;
                    while (x < out.getDepth()) {
                        double loss = lossReporting.measure(out.getMatrixAt(x), step.targetOutput.getMatrixAt(x));
                        if (Double.isNaN(loss) || Double.isInfinite(loss)) {
                            return loss;
                        }
                        numerLoss += loss;
                        if (applyTraining) {
                            lossTraining.backward(out.getMatrixAt(x), step.targetOutput.getMatrixAt(x));
                        }
                        ++x;
                    }
                    denomLoss += 1.0;
                }
                ++h;
            }
            if (!applyTraining) continue;
            System.out.println("Backpropagating and updating Model parameters...");
            g.backward();
            this.method.updateParameters(model, learningRate, seq.getSequenceLength());
            g.cleanUp();
            System.gc();
            System.out.println("Epoch Complete!");
        }
        return numerLoss / denomLoss;
    }
}

