/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import matrix.Matrix;
import model.ConvNet;
import model.NeuralNetwork;
import trainer.TrainingMethod;

public class Adam
extends TrainingMethod {
    private double decayRateB1;
    private double decayRateB2;
    private final double epsilon;

    public Adam(double decayRateB1, double decayRateB2, double epsilon) {
        this.decayRateB1 = decayRateB1;
        this.decayRateB2 = decayRateB2;
        this.epsilon = epsilon;
    }

    public Adam() {
        this.decayRateB1 = 0.9;
        this.decayRateB2 = 0.999;
        this.epsilon = 1.0E-8;
    }

    @Override
    public void updateParameters(NeuralNetwork network, double learningRate, int batchSize) throws Exception {
        double decayB1Pow = Math.pow(this.decayRateB1, (double)network.t - 1.0);
        double decayB2Pow = Math.pow(this.decayRateB2, (double)network.t - 1.0);
        double decayB1_fixed = this.decayRateB1 * (1.0 - decayB1Pow) / (1.0 - decayB1Pow * this.decayRateB1);
        double decayB2_fixed = this.decayRateB2 * (1.0 - decayB2Pow) / (1.0 - decayB2Pow * this.decayRateB2);
        for (Matrix m : network.getParameters()) {
            int i = 0;
            while (i < m.w.length) {
                m.stepCache[i] = decayB1_fixed * m.stepCache[i] + (1.0 - decayB1_fixed) * m.dw[i];
                m.stepCache[i + m.w.length] = decayB2_fixed * m.stepCache[i + m.w.length] + (1.0 - decayB2_fixed) * (m.dw[i] * m.dw[i]);
                m.w[i] = m.w[i] - learningRate * m.stepCache[i] / Math.sqrt(m.stepCache[i + m.w.length] + this.epsilon);
                m.dw[i] = 0.0;
                ++i;
            }
        }
    }

    @Override
    public void updateParameters(ConvNet network, double learningRate, int batchSize) throws Exception {
        double decayB1Pow = Math.pow(this.decayRateB1, (double)network.t - 1.0);
        double decayB2Pow = Math.pow(this.decayRateB2, (double)network.t - 1.0);
        double decayB1_fixed = this.decayRateB1 * (1.0 - decayB1Pow) / (1.0 - decayB1Pow * this.decayRateB1);
        double decayB2_fixed = this.decayRateB2 * (1.0 - decayB2Pow) / (1.0 - decayB2Pow * this.decayRateB2);
        for (Matrix m : network.getParameters()) {
            int i = 0;
            while (i < m.w.length) {
                m.stepCache[i] = decayB1_fixed * m.stepCache[i] + (1.0 - decayB1_fixed) * m.dw[i];
                m.stepCache[i + m.w.length] = decayB2_fixed * m.stepCache[i + m.w.length] + (1.0 - decayB2_fixed) * (m.dw[i] * m.dw[i]);
                m.w[i] = m.w[i] - learningRate * m.stepCache[i] / Math.sqrt(m.stepCache[i + m.w.length] + this.epsilon);
                m.dw[i] = 0.0;
                ++i;
            }
        }
    }
}

