/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvLayer;
import model.TensorLayer;

public class DeconvLayer
implements TensorLayer {
    private ConvLayer conv;
    private int sizeMultiplier;

    public DeconvLayer(int sizeMultiplier, int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, int stride, int pad, double initParamsStdDev, Random rng) {
        this.sizeMultiplier = sizeMultiplier;
        this.conv = new ConvLayer(inWidth * sizeMultiplier, inHeight * sizeMultiplier, inDepth, filterWidth, filterHeight, numFiltersPerDepth, stride, pad, initParamsStdDev, rng);
    }

    public DeconvLayer(int sizeMultiplier, int inWidth, int inHeight, int inDepth, int filterWidth, int filterHeight, int numFiltersPerDepth, int stride, int pad, double initParamsStdDev, Random rng, boolean multithreading, int cores) {
        this.sizeMultiplier = sizeMultiplier;
        this.conv = new ConvLayer(inWidth * sizeMultiplier, inHeight * sizeMultiplier, inDepth, filterWidth, filterHeight, numFiltersPerDepth, stride, pad, initParamsStdDev, rng, multithreading, cores);
    }

    @Override
    public Tensor forward(final Tensor input, Graph g) throws Exception {
        final Tensor patchedTensor = new Tensor(input.getWidth() * this.sizeMultiplier, input.getHeight() * this.sizeMultiplier, input.getDepth());
        int i = 0;
        while (i < input.getDepth()) {
            Matrix patchedMatrix = patchedTensor.getMatrixAt(i);
            Matrix toPatch = input.getMatrixAt(i);
            int j = 0;
            while (j < toPatch.rows) {
                int k = 0;
                while (k < toPatch.cols) {
                    patchedMatrix.setW(j * this.sizeMultiplier, k * this.sizeMultiplier, toPatch.getW(j, k));
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        if (g.applyBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < input.getDepth()) {
                        Matrix toReverse = patchedTensor.getMatrixAt(i);
                        Matrix reversedMatrix = input.getMatrixAt(i);
                        int j = 0;
                        while (j < reversedMatrix.rows) {
                            int k = 0;
                            while (k < reversedMatrix.cols) {
                                reversedMatrix.setDW(j, k, toReverse.getDW(j * DeconvLayer.this.sizeMultiplier, k * DeconvLayer.this.sizeMultiplier));
                                ++k;
                            }
                            ++j;
                        }
                        ++i;
                    }
                }
            });
        }
        return this.conv.forward(patchedTensor, g);
    }

    @Override
    public void resetState() {
        this.conv.resetState();
    }

    public int getOutWidth() {
        return this.conv.getOutWidth();
    }

    public int getOutHeight() {
        return this.conv.getOutHeight();
    }

    @Override
    public List<Matrix> getParameters() {
        return this.conv.getParameters();
    }

    public void setCores(int cores) throws Exception {
        this.conv.setCores(cores);
    }

    public boolean isMultithreading() {
        return this.conv.isMultithreading();
    }

    public void setMultithreading(boolean mt) throws Exception {
        this.conv.setMultithreading(mt);
    }

    @Override
    public TensorLayer clone() {
        DeconvLayer clone = new DeconvLayer(this.sizeMultiplier, 1, 1, 1, 1, 1, 1, 1, 0, 1.0, new Random());
        clone.conv = (ConvLayer)this.conv.clone();
        return clone;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        this.conv.saveState(fos);
        fos.writeInt(this.sizeMultiplier);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.conv.loadState(fis);
        this.sizeMultiplier = fis.readInt();
    }
}

