/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import matrix.Matrix;
import matrix.Tensor;
import model.TensorLayer;

public class ConvUpsample
implements TensorLayer {
    private int sizeMultiplier;

    public ConvUpsample(int sizeMultiplier) {
        this.sizeMultiplier = sizeMultiplier;
    }

    @Override
    public Tensor forward(final Tensor input, Graph g) throws Exception {
        final Tensor patchedTensor = new Tensor(input.width * this.sizeMultiplier, input.height * this.sizeMultiplier, input.depth);
        int i = 0;
        while (i < input.depth) {
            Matrix patchedMatrix = patchedTensor.matrices[i];
            Matrix toPatch = input.matrices[i];
            int j = 0;
            while (j < toPatch.rows) {
                int k = 0;
                while (k < toPatch.cols) {
                    int baseIndx1 = j * toPatch.cols + k;
                    int j1 = 0;
                    while (j1 < this.sizeMultiplier) {
                        int baseIndx2 = (j * this.sizeMultiplier + j1) * patchedMatrix.cols + k * this.sizeMultiplier;
                        int k1 = 0;
                        while (k1 < this.sizeMultiplier) {
                            patchedMatrix.w[baseIndx2 + k1] = toPatch.w[baseIndx1];
                            ++k1;
                        }
                        ++j1;
                    }
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < input.depth) {
                        Matrix toReverse = patchedTensor.matrices[i];
                        Matrix reversedMatrix = input.matrices[i];
                        int j = 0;
                        while (j < reversedMatrix.rows) {
                            int k = 0;
                            while (k < reversedMatrix.cols) {
                                int baseIndx1 = j * reversedMatrix.cols + k;
                                int j1 = 0;
                                while (j1 < ConvUpsample.this.sizeMultiplier) {
                                    int baseIndx2 = (j * ConvUpsample.this.sizeMultiplier + j1) * toReverse.cols + k * ConvUpsample.this.sizeMultiplier;
                                    int k1 = 0;
                                    while (k1 < ConvUpsample.this.sizeMultiplier) {
                                        int n = baseIndx1;
                                        reversedMatrix.dw[n] = reversedMatrix.dw[n] + toReverse.dw[baseIndx2 + k1];
                                        ++k1;
                                    }
                                    ++j1;
                                }
                                ++k;
                            }
                            ++j;
                        }
                        ++i;
                    }
                }
            });
        }
        return patchedTensor;
    }

    @Override
    public void resetState() {
    }

    @Override
    public List<Matrix> getParameters() {
        return null;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        fos.writeInt(this.sizeMultiplier);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.sizeMultiplier = fis.readInt();
    }

    @Override
    public TensorLayer clone() {
        return new ConvUpsample(this.sizeMultiplier);
    }
}

