/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;
import theGhastModding.utils.math.ByteConverters;

public class ConvFlatten
implements TensorLayer {
    private int inWidth;
    private int inHeight;
    private int inDepth;
    private int outSize;

    public ConvFlatten(int inWidth, int inHeight, int inDepth) {
        this.outSize = inWidth * inHeight * inDepth;
        this.inWidth = inWidth;
        this.inHeight = inHeight;
        this.inDepth = inDepth;
    }

    @Override
    public Tensor forward(final Tensor t, Graph g) throws Exception {
        if (t.getDepth() != this.inDepth) {
            throw new Exception("Invalid tensor depth: Is " + Integer.toString(t.getDepth()) + ", expected: " + Integer.toString(this.inDepth));
        }
        if (t.getWidth() != this.inWidth) {
            throw new Exception("Invalid tensor width: Is " + Integer.toString(t.getWidth()) + ", expected: " + Integer.toString(this.inWidth));
        }
        if (t.getHeight() != this.inHeight) {
            throw new Exception("Invalid tensor height: Is " + Integer.toString(t.getHeight()) + ", expected: " + Integer.toString(this.inHeight));
        }
        final Tensor toReturn = new Tensor(1, this.outSize, 1);
        Matrix outMatrix = new Matrix(this.outSize);
        int loc = 0;
        int k = 0;
        while (k < this.inDepth) {
            int i = 0;
            while (i < this.inWidth) {
                int j = 0;
                while (j < this.inHeight) {
                    outMatrix.setW(loc, 0, t.getMatrixAt(k).getW(j, i));
                    ++loc;
                    ++j;
                }
                ++i;
            }
            ++k;
        }
        toReturn.setMatrixAt(0, outMatrix);
        if (g.applyBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    Matrix inMatrix = toReturn.getMatrixAt(0);
                    int loc = 0;
                    int k = 0;
                    while (k < ConvFlatten.this.inDepth) {
                        int i = 0;
                        while (i < ConvFlatten.this.inWidth) {
                            int j = 0;
                            while (j < ConvFlatten.this.inHeight) {
                                t.getMatrixAt(k).setDW(j, i, inMatrix.getDW(loc, 0));
                                ++loc;
                                ++j;
                            }
                            ++i;
                        }
                        ++k;
                    }
                }
            });
        }
        return toReturn;
    }

    public int getOutSize() {
        return this.outSize;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public void saveState(FileOutputStream fos) throws Exception {
        fos.write(ByteConverters.intToBytes(this.inDepth));
        fos.write(ByteConverters.intToBytes(this.inHeight));
        fos.write(ByteConverters.intToBytes(this.inWidth));
    }

    @Override
    public void loadState(FileInputStream fis) throws Exception {
        byte[] intBuffer = new byte[4];
        fis.read(intBuffer);
        this.inDepth = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.inHeight = ByteConverters.bytesToInt(intBuffer);
        fis.read(intBuffer);
        this.inWidth = ByteConverters.bytesToInt(intBuffer);
        this.outSize = this.inWidth * this.inHeight * this.inDepth;
    }
}

