/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import java.util.List;
import matrix.Matrix;
import trainer.Optimizer;

public class Adam
extends Optimizer {
    private double decayRateB1;
    private double decayRateB2;
    private final double epsilon;
    private double regularization = 2.5E-4;

    public Adam(double decayRateB1, double decayRateB2, double epsilon) {
        this.decayRateB1 = decayRateB1;
        this.decayRateB2 = decayRateB2;
        this.epsilon = epsilon;
    }

    public Adam() {
        this.decayRateB1 = 0.9;
        this.decayRateB2 = 0.999;
        this.epsilon = 1.0E-8;
    }

    @Override
    public void updateParameters(List<Matrix> parameters, double learningRate, int timestep, int batchSize) throws Exception {
        double decayB1Pow = 1.0 - Math.pow(this.decayRateB1, timestep);
        double decayB2Pow = 1.0 - Math.pow(this.decayRateB2, timestep);
        double a = 0.0;
        for (Matrix m : parameters) {
            int i = 0;
            while (i < m.w.length) {
                m.stepCache[i] = this.decayRateB1 * m.stepCache[i] + (1.0 - this.decayRateB1) / (double)batchSize * m.dw[i];
                a = m.dw[i] / (double)batchSize;
                m.stepCache[i + m.w.length] = this.decayRateB2 * m.stepCache[i + m.w.length] + (1.0 - this.decayRateB2) * (a * a);
                m.w[i] = m.w[i] - m.stepCache[i] / decayB1Pow * learningRate / (Math.sqrt(m.stepCache[i + m.w.length] / decayB2Pow) + this.epsilon) - this.regularization * m.w[i];
                m.dw[i] = 0.0;
                ++i;
            }
        }
    }
}

