/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import edu.cornell.lassp.houle.RngPack.RanMT;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import matrix.Matrix;
import model.Model;

public class Dropout
implements Model {
    private double chance;
    private RanMT rng;

    public Dropout(double chance) {
        this.chance = chance;
        this.rng = new RanMT();
    }

    @Override
    public Matrix forward(final Matrix input, Graph g) throws Exception {
        final Matrix dropouts = new Matrix(input.rows, input.cols);
        int i = 0;
        while (i < input.w.length) {
            if (this.rng.raw() <= this.chance) {
                input.w[i] = 0.0;
                dropouts.w[i] = 1.0;
            }
            ++i;
        }
        if (g.applyBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < input.dw.length) {
                        if (dropouts.w[i] == 1.0) {
                            input.dw[i] = 0.0;
                        }
                        ++i;
                    }
                }
            });
        }
        return input;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public Model clone() {
        return new Dropout(this.chance);
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeDouble(this.chance);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.chance = fis.readDouble();
    }
}

