/*
 * Decompiled with CFR 0.152.
 */
package trainer;

import matrix.Matrix;
import model.ConvNet;
import model.NeuralNetwork;
import trainer.TrainingMethod;

public class BasicSGD
extends TrainingMethod {
    private double momentum;

    public BasicSGD() {
        this.momentum = 0.0;
    }

    public BasicSGD(double momentum) {
        this.momentum = momentum;
    }

    @Override
    public void updateParameters(NeuralNetwork network, double learningRate, int batchSize) throws Exception {
        for (Matrix m : network.getParameters()) {
            int i = 0;
            while (i < m.w.length) {
                m.stepCache[i] = m.stepCache[i] * this.momentum + m.dw[i];
                int n = i;
                m.w[n] = m.w[n] - m.stepCache[i];
                m.dw[i] = 0.0;
                ++i;
            }
        }
    }

    @Override
    public void updateParameters(ConvNet network, double learningRate, int batchSize) throws Exception {
        for (Matrix m : network.getParameters()) {
            int i = 0;
            while (i < m.w.length) {
                m.stepCache[i] = m.stepCache[i] * this.momentum + m.dw[i];
                int n = i;
                m.w[n] = m.w[n] - m.stepCache[i];
                m.dw[i] = 0.0;
                ++i;
            }
        }
    }
}

