/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class PoolLayer
implements TensorLayer {
    private int stride;
    private int inWidth;
    private int inHeight;
    private int inDepth;
    private int filterWidth;
    private int filterHeight;
    private int outWidth;
    private int outHeight;

    public PoolLayer(int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int stride) {
        this.stride = stride;
        this.inWidth = inWidth;
        this.inHeight = inHeight;
        this.inDepth = inDepth;
        this.filterHeight = filterHeight;
        this.filterWidth = filterWidth;
        this.outWidth = (inWidth - filterWidth) / stride + 1;
        this.outHeight = (inHeight - filterHeight) / stride + 1;
    }

    @Override
    public Tensor forward(final Tensor t, Graph g) throws Exception {
        final Tensor toReturn = new Tensor(this.outWidth, this.outHeight, this.inDepth);
        final Tensor maximums = new Tensor(this.inWidth, this.inHeight, this.inDepth);
        if (t.getDepth() != this.inDepth) {
            throw new Exception("Invalid tensor depth: Is " + Integer.toString(t.getDepth()) + ", expected: " + Integer.toString(this.inDepth));
        }
        if (t.getWidth() != this.inWidth) {
            throw new Exception("Invalid tensor width: Is " + Integer.toString(t.getWidth()) + ", expected: " + Integer.toString(this.inWidth));
        }
        if (t.getHeight() != this.inHeight) {
            throw new Exception("Invalid tensor height: Is " + Integer.toString(t.getHeight()) + ", expected: " + Integer.toString(this.inHeight));
        }
        int j = 0;
        while (j < this.inDepth) {
            int k = 0;
            while (k < this.outHeight) {
                int l = 0;
                while (l < this.outWidth) {
                    double maximum = -1000000.0;
                    int maxM = -1;
                    int maxN = -1;
                    int m = 0;
                    while (m < this.filterHeight) {
                        int n = 0;
                        while (n < this.filterWidth) {
                            double w2 = t.getMatrixAt(j).getW(k * this.stride + m, l * this.stride + n);
                            if (w2 > maximum) {
                                maximum = w2;
                                maxM = k * this.stride + m;
                                maxN = l * this.stride + n;
                            }
                            ++n;
                        }
                        ++m;
                    }
                    toReturn.getMatrixAt(j).setW(k, l, maximum);
                    maximums.getMatrixAt(j).setW(maxM, maxN, 1.0);
                    ++l;
                }
                ++k;
            }
            ++j;
        }
        if (g.applyBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int j = 0;
                    while (j < PoolLayer.this.inDepth) {
                        int k = 0;
                        while (k < PoolLayer.this.outHeight) {
                            int l = 0;
                            while (l < PoolLayer.this.outWidth) {
                                int m = 0;
                                while (m < PoolLayer.this.filterHeight) {
                                    int n = 0;
                                    while (n < PoolLayer.this.filterWidth) {
                                        if (maximums.getMatrixAt(j).getW(k * PoolLayer.this.stride + m, l * PoolLayer.this.stride + n) == 1.0) {
                                            t.getMatrixAt(j).setDW(k * PoolLayer.this.stride + m, l * PoolLayer.this.stride + n, toReturn.getMatrixAt(j).getDW(k, l));
                                        }
                                        ++n;
                                    }
                                    ++m;
                                }
                                ++l;
                            }
                            ++k;
                        }
                        ++j;
                    }
                }
            });
        }
        return toReturn;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public TensorLayer clone() {
        return new PoolLayer(this.inWidth, this.inHeight, this.inDepth, this.filterWidth, this.filterHeight, this.stride);
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.stride);
        fos.writeInt(this.inWidth);
        fos.writeInt(this.inHeight);
        fos.writeInt(this.inDepth);
        fos.writeInt(this.filterWidth);
        fos.writeInt(this.filterHeight);
        fos.writeInt(this.outWidth);
        fos.writeInt(this.outHeight);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.stride = fis.readInt();
        this.inWidth = fis.readInt();
        this.inHeight = fis.readInt();
        this.inDepth = fis.readInt();
        this.filterWidth = fis.readInt();
        this.filterHeight = fis.readInt();
        this.outWidth = fis.readInt();
        this.outHeight = fis.readInt();
    }
}

