/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import edu.cornell.lassp.houle.RngPack.RanMT;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class ConvDropout
implements TensorLayer {
    private double chance;
    private RanMT rng;

    public ConvDropout(double chance) {
        this.chance = chance;
        this.rng = new RanMT();
    }

    @Override
    public Tensor forward(final Tensor input, Graph g) throws Exception {
        final Tensor dropouts = new Tensor(input.getWidth(), input.getHeight(), input.getDepth());
        int i = 0;
        while (i < input.getWidth()) {
            int j = 0;
            while (j < input.getHeight()) {
                int k = 0;
                while (k < input.getDepth()) {
                    if (this.rng.raw() <= this.chance) {
                        input.getMatrixAt(k).setW(j, i, 0.0);
                        dropouts.setValueAt(i, j, k, 1.0);
                    }
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        if (g.applyBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < input.getWidth()) {
                        int j = 0;
                        while (j < input.getHeight()) {
                            int k = 0;
                            while (k < input.getDepth()) {
                                if (dropouts.getValueAt(i, j, k) == 1.0) {
                                    input.getMatrixAt(k).setDW(j, i, 0.0);
                                }
                                ++k;
                            }
                            ++j;
                        }
                        ++i;
                    }
                }
            });
        }
        return input;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public TensorLayer clone() {
        return new ConvDropout(this.chance);
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeDouble(this.chance);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.chance = fis.readDouble();
    }
}

