/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import autodiff.GPUGraph;
import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import matrix.Tensor;
import model.NeuralNetwork;
import trainer.Optimizer;
import util.CLUtils;
import util.FileIO;
import util.NNDevice;

public class GANTrainer {
    private Optimizer method;
    private double gamma;
    private Graph trainGraph;
    public double k_s = 0.0;

    public GANTrainer(Optimizer method, double gamma) {
        this.method = method;
        this.gamma = gamma;
        this.trainGraph = new Graph(true);
    }

    public GANTrainer(Optimizer method, double gamma, NNDevice ... dev) throws Exception {
        this.method = method;
        this.gamma = gamma;
        this.trainGraph = this.createGraph(dev, true);
    }

    public double train(NeuralNetwork generator, NeuralNetwork discriminator, double learningRate, int epochs, int batchSize, DataSet data, int reportEveryNthEpoch, String generatorSavefile, String discriminatorSavefile, boolean loadFromSave, boolean autosave, Random rng) throws Exception {
        if (this.trainGraph == null) {
            throw new Exception("Trainer disposed");
        }
        if (loadFromSave) {
            System.out.println("initializing model from saved state...");
            try {
                FileIO.loadNeuralNetwork(generatorSavefile, generator);
                FileIO.loadNeuralNetwork(discriminatorSavefile, discriminator);
            }
            catch (Exception e) {
                System.out.println("Oops. Unable to load from a saved state.");
                System.out.println("WARNING: " + e.getMessage());
                System.out.println("Continuing from freshly initialized model instead.");
            }
        }
        double result = 1.0;
        int epoch = 0;
        while (epoch < epochs) {
            double reportedLossTrain;
            String show = "Trainer: epoch[" + Integer.toString(epoch + 1) + "/" + Integer.toString(epochs) + "]";
            System.out.println(show);
            result = reportedLossTrain = this.passTensors(learningRate, generator, discriminator, batchSize, data.inputDimension, data.training, true, rng);
            if (Double.isNaN(reportedLossTrain) || Double.isInfinite(reportedLossTrain)) {
                throw new Exception("WARNING: invalid value for training loss. Try lowering learning rate.");
            }
            show = "train loss = " + String.format("%.5f", reportedLossTrain);
            System.out.println(show);
            if (epoch % reportEveryNthEpoch == reportEveryNthEpoch - 1) {
                data.DisplayReport(generator, rng);
            }
            if (autosave) {
                System.out.println("Autosaving model...");
                FileIO.saveNeuralNetwork(generatorSavefile, generator);
                FileIO.saveNeuralNetwork(discriminatorSavefile, discriminator);
                System.out.println("Complete!");
            }
            if (reportedLossTrain == 0.0) {
                System.out.println("Optimal loss reached. Trainer done.");
                break;
            }
            ++epoch;
        }
        return result;
    }

    public void dispose() {
        this.trainGraph.cleanUp();
        this.trainGraph = null;
    }

    private synchronized Graph createGraph(NNDevice[] dev, boolean applyTraining) throws Exception {
        if (dev != null && dev[0] != null && dev.length != 0) {
            if (dev.length == 1) {
                return CLUtils.createGraph(dev[0], applyTraining);
            }
            return CLUtils.createGraph(dev[0], dev[1], applyTraining);
        }
        return new Graph(applyTraining);
    }

    /*
     * Unable to fully structure code
     */
    private double passTensors(double learningRate, NeuralNetwork generator, NeuralNetwork discriminator, int batchSize, DataSet.TensorDimensions seedSize, List<DataSequence> sequences, boolean applyTraining, Random rng) throws Exception {
        numLossDisc = 0.0;
        denomLossDisc = 0.0;
        numLossGen = 0.0;
        denomLossGen = 0.0;
        genIn = null;
        discIn2 = null;
        genOut = null;
        discOut = null;
        discOut2 = null;
        System.out.print("|");
        i = 0;
        while (i < 98) {
            System.out.print("-");
            ++i;
        }
        System.out.println("|");
        totalSteps = batchSize;
        prevPercent = 0;
        if (this.trainGraph instanceof GPUGraph) {
            ((GPUGraph)this.trainGraph).forceKeepParameters(true);
        }
        j = 0;
        ** GOTO lbl60
        {
            System.out.print(">");
            ++prevPercent;
            do {
                if ((int)((double)j / (double)totalSteps * 100.0) > prevPercent) continue block1;
                discriminator.resetState();
                generator.resetState();
                genIn = this.betterRandom(seedSize, rng);
                this.trainGraph.setApplyingBackprop(false);
                genOut = generator.forward(genIn, this.trainGraph);
                this.trainGraph.setApplyingBackprop(true);
                discIn2 = this.getRandomTensorStep(sequences, (Random)rng).input.clone();
                i = 0;
                while (i < genOut.depth) {
                    m1 = genOut.matrices[i];
                    m2 = discIn2.matrices[i];
                    l = 0;
                    while (l < m1.w.length) {
                        v0 = l;
                        m1.w[v0] = m1.w[v0] + (2.0 * rng.nextDouble() - 1.0) * 0.1;
                        v1 = l++;
                        m2.w[v1] = m2.w[v1] + (2.0 * rng.nextDouble() - 1.0) * 0.1;
                    }
                    ++i;
                }
                genOut2 = genOut.clone();
                discOut = generator.forward(discriminator.forward(genOut2, this.trainGraph), this.trainGraph);
                discOut2 = generator.forward(discriminator.forward(discIn2, this.trainGraph), this.trainGraph);
                numLossDisc += this.full_d_loss(discIn2, genOut2, discOut2, discOut, this.k_s);
                denomLossDisc += 1.0;
                numLossGen += this.L_loss(discOut, genOut);
                denomLossGen += 1.0;
                if (applyTraining) {
                    this.full_d_loss_backward(discIn2, genOut2, discOut2, discOut, this.k_s);
                    this.L_loss_backward(discOut, genOut, 1.0);
                }
                this.k_s += 0.001 * (this.gamma * this.L_loss(discIn2, discOut2) - this.L_loss(genOut, discOut));
                this.k_s = Math.max(0.0, Math.min(1.0, this.k_s));
                ++j;
lbl60:
                // 2 sources

            } while (j < batchSize);
        }
        if (applyTraining) {
            this.trainGraph.backward();
            this.method.updateParameters(discriminator, learningRate, batchSize);
            this.method.updateParameters(generator, learningRate, batchSize);
        }
        if (this.trainGraph instanceof GPUGraph) {
            ((GPUGraph)this.trainGraph).forceKeepParameters(false);
        }
        i = 0;
        while (i < 100 - prevPercent) {
            System.out.print(">");
            ++i;
        }
        System.out.println();
        return Math.abs((numLossDisc / denomLossDisc + numLossGen / denomLossGen) * 0.5);
    }

    private DataStep getRandomTensorStep(List<DataSequence> s, Random rng) {
        DataSequence seq = s.get(rng.nextInt(s.size()));
        return seq.getDataStep(rng.nextInt(seq.getSequenceLength()));
    }

    private Tensor betterRandom(DataSet.TensorDimensions dim, Random r) {
        Tensor t = new Tensor(dim.getWidth(), dim.getHeight(), dim.getDepth());
        double d = 0.0;
        int i = 0;
        while (i < dim.getDepth()) {
            Matrix m = t.matrices[i];
            int j = 0;
            while (j < m.w.length) {
                d = r.nextGaussian();
                while (d < 0.0 || d > 1.0) {
                    d = r.nextGaussian();
                }
                m.w[j] = d;
                ++j;
            }
            ++i;
        }
        return t;
    }

    public double D_loss(Matrix real_out, Matrix fake_out) throws Exception {
        return 0.5 * Math.pow(real_out.w[0] - 1.0, 2.0) + Math.pow(fake_out.w[0], 2.0);
    }

    public void D_loss_backward(Matrix real_out, Matrix fake_out) throws Exception {
        real_out.dw[0] = real_out.w[0];
        fake_out.dw[0] = 2.0 * fake_out.w[0];
    }

    public double G_loss(Matrix D_out) throws Exception {
        return 0.5 * Math.pow(D_out.w[0] - 1.0, 2.0);
    }

    public void G_loss_backward(Matrix D_out) throws Exception {
        D_out.dw[0] = D_out.w[0] - 1.0;
    }

    private double full_d_loss(Tensor real_original, Tensor fake_original, Tensor real_D, Tensor fake_D, double k) throws Exception {
        return this.L_loss(real_original, real_D) - k * this.L_loss(fake_original, fake_D);
    }

    private void full_d_loss_backward(Tensor real_original, Tensor fake_original, Tensor real_D, Tensor fake_D, double k) throws Exception {
        this.L_loss_backward(real_original, real_D, 1.0);
        this.L_loss_backward(fake_original, fake_D, -k);
    }

    private double L_loss(Tensor original, Tensor d_out) throws Exception {
        if (d_out.depth != original.depth || d_out.height != original.height || d_out.width != original.width) {
            throw new Exception("Invalid dims");
        }
        int n = d_out.depth * d_out.width * d_out.height;
        double res = 0.0;
        int i = 0;
        while (i < d_out.depth) {
            Matrix m1 = original.matrices[i];
            Matrix m2 = d_out.matrices[i];
            int j = 0;
            while (j < m1.w.length) {
                res += Math.abs(m1.w[j] - m2.w[j]);
                ++j;
            }
            ++i;
        }
        return res /= (double)n;
    }

    private void L_loss_backward(Tensor original, Tensor d_out, double k) throws Exception {
        if (d_out.depth != original.depth || d_out.height != original.height || d_out.width != original.width) {
            throw new Exception("Invalid dims");
        }
        int i = 0;
        while (i < d_out.depth) {
            Matrix m1 = original.matrices[i];
            Matrix m2 = d_out.matrices[i];
            int j = 0;
            while (j < m1.dw.length) {
                m2.dw[j] = k * ((m2.w[j] - m1.w[j]) / Math.abs(m2.w[j] - m1.w[j]));
                ++j;
            }
            ++i;
        }
    }
}

