/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import matrix.Matrix;
import model.ConvNet;
import model.NeuralNetwork;
import trainer.TrainingMethod;

public class RMSProp
extends TrainingMethod {
    private double decayRate;
    private double smoothEpsilon = 1.0E-8;
    private double gradientClipValue = 5.0;
    private final double regularization = 1.0E-6;

    public RMSProp(double decayRate, double gradientClipValue, double smoothEpsilon) {
        this.decayRate = decayRate;
        this.gradientClipValue = gradientClipValue;
        this.smoothEpsilon = smoothEpsilon;
    }

    public RMSProp(double decayRate, double gradientClipValue) {
        this.decayRate = decayRate;
        this.gradientClipValue = gradientClipValue;
    }

    public RMSProp(double decayRate) {
        this.decayRate = decayRate;
    }

    public RMSProp() {
        this.decayRate = 0.999;
    }

    @Override
    public void updateParameters(NeuralNetwork network, double learningRate, int batchSize) throws Exception {
        for (Matrix m : network.getParameters()) {
            int i = 0;
            while (i < m.w.length) {
                double mdwi = m.dw[i];
                m.stepCache[i] = m.stepCache[i] * this.decayRate + (1.0 - this.decayRate) * mdwi * mdwi;
                if (mdwi > this.gradientClipValue) {
                    mdwi = this.gradientClipValue;
                }
                if (mdwi < -this.gradientClipValue) {
                    mdwi = -this.gradientClipValue;
                }
                int n = i;
                m.w[n] = m.w[n] + (-learningRate * mdwi / Math.sqrt(m.stepCache[i] + this.smoothEpsilon) - 1.0E-6 * m.w[i]);
                m.dw[i] = 0.0;
                ++i;
            }
        }
    }

    @Override
    public void updateParameters(ConvNet network, double learningRate, int batchSize) throws Exception {
        for (Matrix m : network.getParameters()) {
            int i = 0;
            while (i < m.w.length) {
                double mdwi = m.dw[i];
                m.stepCache[i] = m.stepCache[i] * this.decayRate + (1.0 - this.decayRate) * mdwi * mdwi;
                if (mdwi > this.gradientClipValue) {
                    mdwi = this.gradientClipValue;
                }
                if (mdwi < -this.gradientClipValue) {
                    mdwi = -this.gradientClipValue;
                }
                int n = i;
                m.w[n] = m.w[n] + (-learningRate * mdwi / Math.sqrt(m.stepCache[i] + this.smoothEpsilon) - 1.0E-6 * m.w[i]);
                m.dw[i] = 0.0;
                ++i;
            }
        }
    }
}

