/*
 * Decompiled with CFR 0.152.
 */
package model;

import autodiff.Graph;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvLayer;
import model.TensorLayer;
import nonlinearities.LinearUnit;

public class ConvSelfAttention
implements TensorLayer {
    private ConvLayer f;
    private ConvLayer g;
    private ConvLayer h;
    private Matrix gamma;

    public ConvSelfAttention(int inWidth, int inHeight, int inDepth, double initParamsStdDev, Random rng) {
        this(inWidth, inHeight, inDepth, initParamsStdDev, rng, false, false, 1);
    }

    public ConvSelfAttention(int inWidth, int inHeight, int inDepth, double initParamsStdDev, Random rng, boolean gpu, boolean multithreading, int cores) {
        this.f = new ConvLayer(inWidth, inHeight, inDepth, 1, 1, inDepth, 1, 0, new LinearUnit(), initParamsStdDev, rng, gpu, multithreading, cores);
        this.g = new ConvLayer(inWidth, inHeight, inDepth, 1, 1, inDepth, 1, 0, new LinearUnit(), initParamsStdDev, rng, gpu, multithreading, cores);
        this.h = new ConvLayer(inWidth, inHeight, inDepth, 1, 1, inDepth, 1, 0, new LinearUnit(), initParamsStdDev, rng, gpu, multithreading, cores);
        this.gamma = new Matrix(1);
        this.gamma.w[0] = 1.0E-8;
    }

    @Override
    public Tensor forward(Tensor input, Graph graph) throws Exception {
        Tensor f_t = this.f.forward(input, graph);
        Tensor g_t = this.g.forward(input, graph);
        Tensor h_t = this.h.forward(input, graph);
        Matrix f_m = this.hwFlatten(f_t, true, graph);
        Matrix g_m = this.hwFlatten(g_t, false, graph);
        Matrix h_m = this.hwFlatten(h_t, false, graph);
        Matrix mul = this.applySoftmax(graph.mul(g_m, f_m), graph);
        Matrix finalMul = graph.scalMul(mul, this.gamma);
        Matrix o = graph.add(graph.mul(finalMul, h_m), this.hwFlatten(input, false, graph));
        Tensor res = this.hwExpand(o, input.width, input.height, graph);
        return res;
    }

    private Matrix applySoftmax(final Matrix in, Graph g) {
        final Matrix result = new Matrix(in.rows, in.cols);
        int i = 0;
        while (i < in.rows) {
            double sum = 0.0;
            int j = 0;
            while (j < in.cols) {
                sum += Math.exp(in.w[in.cols * i + j]);
                ++j;
            }
            if (sum == 0.0) {
                sum = 1.0E-8;
            }
            j = 0;
            while (j < in.cols) {
                result.w[result.cols * i + j] = Math.exp(in.w[in.cols * i + j]) / sum;
                ++j;
            }
            ++i;
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < in.rows) {
                        double sum = 0.0;
                        int j = 0;
                        while (j < in.cols) {
                            sum += Math.exp(in.w[in.cols * i + j]);
                            ++j;
                        }
                        if (sum == 0.0) {
                            sum = 1.0E-8;
                        }
                        j = 0;
                        while (j < in.cols) {
                            in.dw[in.cols * i + j] = result.dw[result.cols * i + j] * (Math.exp(in.w[in.cols * i + j]) / sum - Math.exp(2.0 * in.w[in.cols * i + j]) / (sum * sum));
                            ++j;
                        }
                        ++i;
                    }
                }
            });
        }
        return result;
    }

    private Tensor hwExpand(final Matrix in, int originalWidth, int originalHeight, Graph g) {
        final Tensor result = new Tensor(originalWidth, originalHeight, in.rows);
        int i = 0;
        while (i < result.depth) {
            System.arraycopy(in.w, i * in.cols, result.matrices[i].w, 0, in.cols);
            ++i;
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    int i = 0;
                    while (i < result.depth) {
                        int startIndx = i * in.cols;
                        Matrix currMatrix = result.matrices[i];
                        int j = 0;
                        while (j < in.cols) {
                            in.dw[startIndx + j] = currMatrix.dw[j];
                            ++j;
                        }
                        ++i;
                    }
                }
            });
        }
        return result;
    }

    private Matrix hwFlatten(final Tensor in, final boolean transpose, Graph g) {
        final Matrix result = new Matrix(transpose ? in.width * in.height : in.depth, transpose ? in.depth : in.width * in.height);
        final int dim = in.width * in.height;
        if (transpose) {
            int i = 0;
            while (i < in.depth) {
                Matrix currMatrix = in.matrices[i];
                int j = 0;
                while (j < dim) {
                    result.w[result.cols * j + i] = currMatrix.w[j];
                    ++j;
                }
                ++i;
            }
        } else {
            int i = 0;
            while (i < in.depth) {
                System.arraycopy(in.matrices[i].w, 0, result.w, i * dim, dim);
                ++i;
            }
        }
        if (g.isApplyingBackprop()) {
            g.addBackprop(new Runnable(){

                @Override
                public void run() {
                    if (transpose) {
                        int i = 0;
                        while (i < in.depth) {
                            Matrix currMatrix = in.matrices[i];
                            int j = 0;
                            while (j < dim) {
                                int n = j;
                                currMatrix.dw[n] = currMatrix.dw[n] + result.dw[result.cols * j + i];
                                ++j;
                            }
                            ++i;
                        }
                    } else {
                        int i = 0;
                        while (i < in.depth) {
                            int startIndx = i * dim;
                            Matrix currInMatrix = in.matrices[i];
                            int j = 0;
                            while (j < in.width * in.height) {
                                int n = j;
                                currInMatrix.dw[n] = currInMatrix.dw[n] + result.dw[startIndx + j];
                                ++j;
                            }
                            ++i;
                        }
                    }
                }
            });
        }
        return result;
    }

    @Override
    public void resetState() {
        this.f.resetState();
        this.g.resetState();
        this.h.resetState();
    }

    @Override
    public List<Matrix> getParameters() {
        ArrayList<Matrix> allParams = new ArrayList<Matrix>();
        allParams.addAll(this.f.getParameters());
        allParams.addAll(this.g.getParameters());
        allParams.addAll(this.h.getParameters());
        allParams.add(this.gamma);
        return allParams;
    }

    @Override
    public void saveState(DataOutputStream fos) throws Exception {
        this.f.saveState(fos);
        this.g.saveState(fos);
        this.h.saveState(fos);
        this.gamma.save(fos);
    }

    @Override
    public void loadState(DataInputStream fis) throws Exception {
        this.f.loadState(fis);
        this.g.loadState(fis);
        this.h.loadState(fis);
        this.gamma.load(fis);
    }

    @Override
    public TensorLayer clone() {
        ConvSelfAttention clone = new ConvSelfAttention(4, 4, 4, 1.0, new Random());
        clone.f = (ConvLayer)this.f.clone();
        clone.g = (ConvLayer)this.g.clone();
        clone.h = (ConvLayer)this.h.clone();
        clone.gamma = this.gamma.clone();
        return clone;
    }
}

