/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class ConvFlatten
implements TensorLayer {
    private int inWidth;
    private int inHeight;
    private int inDepth;
    private int outSize;

    public ConvFlatten(int inWidth, int inHeight, int inDepth) {
        this.outSize = inWidth * inHeight * inDepth;
        this.inWidth = inWidth;
        this.inHeight = inHeight;
        this.inDepth = inDepth;
    }

    @Override
    public Tensor forward(final Tensor t, Graph g) throws Exception {
        if (t.depth != this.inDepth) {
            throw new Exception("Invalid tensor depth: Is " + Integer.toString(t.depth) + ", expected: " + Integer.toString(this.inDepth));
        }
        if (t.width != this.inWidth) {
            throw new Exception("Invalid tensor width: Is " + Integer.toString(t.width) + ", expected: " + Integer.toString(this.inWidth));
        }
        if (t.height != this.inHeight) {
            throw new Exception("Invalid tensor height: Is " + Integer.toString(t.height) + ", expected: " + Integer.toString(this.inHeight));
        }
        Matrix outMatrix = new Matrix(this.outSize);
        int k = 0;
        while (k < this.inDepth) {
            System.arraycopy(t.matrices[k].w, 0, outMatrix.w, k * (this.inWidth * this.inHeight), this.inWidth * this.inHeight);
            ++k;
        }
        final Tensor toReturn = new Tensor(outMatrix);
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    Matrix inMatrix = toReturn.matrices[0];
                    int k = 0;
                    while (k < ConvFlatten.this.inDepth) {
                        Matrix outMat = t.matrices[k];
                        int indx = k * (ConvFlatten.this.inWidth * ConvFlatten.this.inHeight);
                        int l = 0;
                        while (l < ConvFlatten.this.inWidth * ConvFlatten.this.inHeight) {
                            int n = l;
                            outMat.dw[n] = outMat.dw[n] + inMatrix.dw[indx + l];
                            ++l;
                        }
                        ++k;
                    }
                }
            });
        }
        return toReturn;
    }

    public int getOutSize() {
        return this.outSize;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public TensorLayer clone() {
        return new ConvFlatten(this.inWidth, this.inHeight, this.inDepth);
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.inDepth);
        fos.writeInt(this.inHeight);
        fos.writeInt(this.inWidth);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.inDepth = fis.readInt();
        this.inHeight = fis.readInt();
        this.inWidth = fis.readInt();
        this.outSize = this.inWidth * this.inHeight * this.inDepth;
    }
}

