/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class ConvExpand
implements TensorLayer {
    private int inLength;
    private int outWidth;
    private int outHeight;
    private int outDepth;

    public ConvExpand(int inLength, int outWidth, int outHeight, int outDepth) {
        this.inLength = inLength;
        this.outWidth = outWidth;
        this.outHeight = outHeight;
        this.outDepth = outDepth;
    }

    @Override
    public Tensor forward(final Tensor input, Graph g) throws Exception {
        if (input.depth != 1) {
            throw new Exception("Invalid tensor depth: Is " + Integer.toString(input.depth) + ", expected: 1");
        }
        if (input.width != 1) {
            throw new Exception("Invalid tensor width: Is " + Integer.toString(input.width) + ", expected: 1");
        }
        if (input.height != this.inLength) {
            throw new Exception("Invalid tensor height: Is " + Integer.toString(input.height) + ", expected: " + Integer.toString(this.inLength));
        }
        final Tensor toReturn = new Tensor(this.outWidth, this.outHeight, this.outDepth);
        int pos = 0;
        Matrix inMatrix = input.matrices[0];
        int i = 0;
        while (i < this.outDepth) {
            Matrix outm = toReturn.matrices[i];
            int j = 0;
            while (j < this.outWidth) {
                int loopEnd = pos + this.outHeight >= this.inLength ? this.outHeight + (this.inLength - (pos + this.outHeight)) - 1 : this.outHeight;
                int k = 0;
                while (k < loopEnd) {
                    outm.w[k * outm.cols + j] = inMatrix.w[pos + k];
                    ++k;
                }
                if ((pos += this.outHeight) >= this.inLength) break;
                ++j;
            }
            if (pos >= this.inLength) break;
            ++i;
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    Matrix outMatrix = input.matrices[0];
                    int pos = 0;
                    int i = 0;
                    while (i < ConvExpand.this.outDepth) {
                        Matrix outm = toReturn.matrices[i];
                        int j = 0;
                        while (j < ConvExpand.this.outWidth) {
                            int loopEnd = pos + ConvExpand.this.outHeight >= ConvExpand.this.inLength ? ConvExpand.this.outHeight + (ConvExpand.this.inLength - (pos + ConvExpand.this.outHeight)) - 1 : ConvExpand.this.outHeight;
                            int k = 0;
                            while (k < loopEnd) {
                                outMatrix.dw[pos + k] = outm.dw[k * outm.cols + j];
                                ++k;
                            }
                            if ((pos += ConvExpand.this.outHeight) >= ConvExpand.this.inLength) {
                                return;
                            }
                            ++j;
                        }
                        ++i;
                    }
                }
            });
        }
        return toReturn;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.inLength);
        fos.writeInt(this.outHeight);
        fos.writeInt(this.outWidth);
        fos.writeInt(this.outDepth);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.inLength = fis.readInt();
        this.outHeight = fis.readInt();
        this.outWidth = fis.readInt();
        this.outDepth = fis.readInt();
    }

    @Override
    public TensorLayer clone() {
        return new ConvExpand(this.inLength, this.outWidth, this.outHeight, this.outDepth);
    }
}

