/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class Dropout
implements TensorLayer {
    private double chance;
    private Random rng;

    public Dropout(double chance) {
        this.chance = chance;
        this.rng = new Random();
    }

    @Override
    public Tensor forward(final Tensor input, Graph g) throws Exception {
        final Tensor dropouts = new Tensor(input.width, input.height, input.depth);
        int i = 0;
        while (i < input.width) {
            int j = 0;
            while (j < input.height) {
                int k = 0;
                while (k < input.depth) {
                    if (this.rng.nextDouble() <= this.chance) {
                        input.matrices[k].setW(j, i, 0.0);
                        dropouts.setValueAt(i, j, k, 1.0);
                    }
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < input.width) {
                        int j = 0;
                        while (j < input.height) {
                            int k = 0;
                            while (k < input.depth) {
                                if (dropouts.getValueAt(i, j, k) == 1.0) {
                                    input.matrices[k].setDW(j, i, 0.0);
                                }
                                ++k;
                            }
                            ++j;
                        }
                        ++i;
                    }
                }
            });
        }
        return input;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public TensorLayer clone() {
        return new Dropout(this.chance);
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeDouble(this.chance);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.chance = fis.readDouble();
    }
}

