/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import matrix.Matrix;
import model.ConvNet;
import model.NeuralNetwork;
import theGhastModding.lstmStuff.gpu.GPUUtils;
import trainer.TrainingMethod;

public class Adam
extends TrainingMethod {
    private double decayRateB1;
    private double decayRateB2;
    private final double epsilon;

    public Adam(double decayRateB1, double decayRateB2, double epsilon) {
        this.decayRateB1 = decayRateB1;
        this.decayRateB2 = decayRateB2;
        this.epsilon = epsilon;
    }

    @Override
    public void updateParameters(NeuralNetwork network, double learningRate, int batchSize, GPUUtils gpu) throws Exception {
        if (gpu != null) {
            for (Matrix m : network.getParameters()) {
                Matrix m2 = gpu.adam(m, this.decayRateB1, this.decayRateB2, network.t, learningRate, this.epsilon);
                m.w = m2.w;
                m.stepCache = m2.stepCache;
                m.dw = m2.dw;
            }
        } else {
            for (Matrix m : network.getParameters()) {
                int i = 0;
                while (i < m.w.length) {
                    double decayB1_fixed = this.decayRateB1 * (1.0 - Math.pow(this.decayRateB1, (double)network.t - 1.0)) / (1.0 - Math.pow(this.decayRateB1, network.t));
                    double decayB2_fixed = this.decayRateB2 * (1.0 - Math.pow(this.decayRateB2, (double)network.t - 1.0)) / (1.0 - Math.pow(this.decayRateB2, network.t));
                    m.stepCache[i] = decayB1_fixed * m.stepCache[i] + (1.0 - decayB1_fixed) * m.dw[i];
                    m.stepCache[i + m.w.length] = decayB2_fixed * m.stepCache[i + m.w.length] + (1.0 - decayB2_fixed) * (m.dw[i] * m.dw[i]);
                    m.w[i] = m.w[i] - learningRate * m.stepCache[i] / Math.sqrt(m.stepCache[i + m.w.length] + this.epsilon);
                    m.dw[i] = 0.0;
                    ++i;
                }
            }
        }
    }

    @Override
    public void updateParameters(ConvNet network, double learningRate, int batchSize, GPUUtils gpu) throws Exception {
        if (gpu != null) {
            for (Matrix m : network.getParameters()) {
                Matrix m2 = gpu.adam(m, this.decayRateB1, this.decayRateB2, network.t, learningRate, this.epsilon);
                m.w = m2.w;
                m.stepCache = m2.stepCache;
                m.dw = m2.dw;
            }
        } else {
            for (Matrix m : network.getParameters()) {
                int i = 0;
                while (i < m.w.length) {
                    double decayB1_fixed = this.decayRateB1 * (1.0 - Math.pow(this.decayRateB1, (double)network.t - 1.0)) / (1.0 - Math.pow(this.decayRateB1, network.t));
                    double decayB2_fixed = this.decayRateB2 * (1.0 - Math.pow(this.decayRateB2, (double)network.t - 1.0)) / (1.0 - Math.pow(this.decayRateB2, network.t));
                    m.stepCache[i] = decayB1_fixed * m.stepCache[i] + (1.0 - decayB1_fixed) * m.dw[i];
                    m.stepCache[i + m.w.length] = decayB2_fixed * m.stepCache[i + m.w.length] + (1.0 - decayB2_fixed) * (m.dw[i] * m.dw[i]);
                    m.w[i] = m.w[i] - learningRate * m.stepCache[i] / Math.sqrt(m.stepCache[i + m.w.length] + this.epsilon);
                    m.dw[i] = 0.0;
                    ++i;
                }
            }
        }
    }
}

