/*
 * Decompiled with CFR 0.152.
 */
package autodiff;

import autodiff.GPUGraph;
import java.util.ArrayList;
import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import matrix.Matrix;
import matrix.Tensor;
import nonlinearities.Nonlinearity;

public class MultiGPUGraph
extends GPUGraph {
    private GPUGraph g1;
    private GPUGraph g2;
    private double perf1;
    private ThreadPoolExecutor pool;

    public MultiGPUGraph(GPUGraph g1, GPUGraph g2, boolean applyBackprop) throws Exception {
        this(g1, 0.5, g2, 0.5, applyBackprop);
    }

    public MultiGPUGraph(GPUGraph g1, double perfMeasure1, GPUGraph g2, double perfMeasure2, boolean applyBackprop) throws Exception {
        super(g1.command_queue, g1.context, g1.device, applyBackprop);
        this.g1 = g1;
        this.g2 = g2;
        this.perf1 = perfMeasure1;
        if (applyBackprop) {
            g1.applyBackprop = true;
            g2.applyBackprop = true;
            g1.backprop = new ArrayList();
            g2.backprop = new ArrayList();
        } else {
            g1.applyBackprop = false;
            g2.applyBackprop = false;
            g1.backprop = null;
            g2.backprop = null;
        }
        this.pool = (ThreadPoolExecutor)Executors.newCachedThreadPool();
    }

    @Override
    public synchronized Tensor fullConv(boolean doBackprop, int pad, int stride, int numFiltersPerDepth, Tensor t, Matrix bias, Matrix[] filters, Nonlinearity nonlin, UUID parameterUUID) throws Exception {
        if (filters.length != t.depth * numFiltersPerDepth) {
            throw new Exception("Filter count does not match input depth times out depth");
        }
        Tensor out = new Tensor((t.width + pad * 2 - filters[0].cols) / stride + 1, (t.height + pad * 2 - filters[0].rows) / stride + 1, numFiltersPerDepth);
        int filterCnt1 = (int)Math.round((double)numFiltersPerDepth * this.perf1) * t.depth;
        int filterCnt2 = filters.length - filterCnt1;
        Matrix[] filters1 = new Matrix[filterCnt1];
        Matrix[] filters2 = new Matrix[filterCnt2];
        System.arraycopy(filters, 0, filters1, 0, filterCnt1);
        System.arraycopy(filters, filterCnt1, filters2, 0, filterCnt2);
        ConvThread ct1 = new ConvThread(doBackprop && this.applyBackprop, pad, stride, (int)Math.round((double)numFiltersPerDepth * this.perf1), t, bias, filters1, parameterUUID, nonlin);
        ct1.g = this.g1;
        ConvThread ct2 = new ConvThread(doBackprop && this.applyBackprop, pad, stride, numFiltersPerDepth - (int)Math.round((double)numFiltersPerDepth * this.perf1), t, bias, filters2, parameterUUID, nonlin);
        ct2.g = this.g2;
        Future<?> f1 = this.pool.submit(ct1);
        Future<?> f2 = this.pool.submit(ct2);
        while (f1.get() != null) {
        }
        while (f2.get() != null) {
        }
        if (ct1.e != null) {
            throw ct1.e;
        }
        if (ct2.e != null) {
            throw ct2.e;
        }
        System.arraycopy(ct1.res.matrices, 0, out.matrices, 0, ct1.res.depth);
        System.arraycopy(ct2.res.matrices, 0, out.matrices, ct1.res.depth, ct2.res.depth);
        if (doBackprop && this.applyBackprop) {
            if (ct1.res.depth == 0) {
                throw new Exception("Graph didn't create backprop thread");
            }
            final Runnable bp1 = (Runnable)this.g1.backprop.remove(this.g1.backprop.size() - 1);
            if (ct2.res.depth == 0) {
                throw new Exception("Graph didn't create backprop thread");
            }
            final Runnable bp2 = (Runnable)this.g2.backprop.remove(this.g2.backprop.size() - 1);
            super.addBackprop(new Runnable(){

                @Override
                public synchronized void run() {
                    try {
                        Future<?> f3 = MultiGPUGraph.this.pool.submit(bp1);
                        Future<?> f4 = MultiGPUGraph.this.pool.submit(bp2);
                        while (f3.get() != null) {
                        }
                        while (f4.get() != null) {
                        }
                    }
                    catch (Exception e) {
                        MultiGPUGraph.this.setBackpropException(e);
                        return;
                    }
                }
            });
        }
        return out;
    }

    @Override
    public void cleanUp() {
        super.cleanUp();
        this.g1.cleanUp();
        this.g2.cleanUp();
        this.pool.shutdown();
    }

    @Override
    public void setApplyingBackprop(boolean applyBackprop) {
        super.setApplyingBackprop(applyBackprop);
        this.g1.setApplyingBackprop(applyBackprop);
        this.g2.setApplyingBackprop(applyBackprop);
    }

    @Override
    public void resetBackprop() {
        super.resetBackprop();
        Future<?> f1 = this.pool.submit(new Runnable(){

            @Override
            public void run() {
                MultiGPUGraph.this.g1.resetBackprop();
            }
        });
        Future<?> f2 = this.pool.submit(new Runnable(){

            @Override
            public void run() {
                MultiGPUGraph.this.g2.resetBackprop();
            }
        });
        try {
            while (f1.get() != null) {
            }
            while (f2.get() != null) {
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            this.g1.resetBackprop();
            this.g2.resetBackprop();
        }
    }

    @Override
    public void forceKeepParameters(final boolean force) {
        super.forceKeepParameters(force);
        if (!force) {
            Future<?> f1 = this.pool.submit(new Runnable(){

                @Override
                public void run() {
                    MultiGPUGraph.this.g1.forceKeepParameters(force);
                }
            });
            Future<?> f2 = this.pool.submit(new Runnable(){

                @Override
                public void run() {
                    MultiGPUGraph.this.g2.forceKeepParameters(force);
                }
            });
            try {
                while (f1.get() != null) {
                }
                while (f2.get() != null) {
                }
            }
            catch (Exception e) {
                e.printStackTrace();
                this.g1.forceKeepParameters(force);
                this.g2.forceKeepParameters(force);
            }
        } else {
            this.g1.forceKeepParameters(force);
            this.g2.forceKeepParameters(force);
        }
    }

    @Override
    public void backward() throws Exception {
        Exception ee = null;
        try {
            super.backward();
        }
        catch (Exception e) {
            ee = e;
        }
        this.resetBackprop();
        if (ee != null) {
            throw ee;
        }
    }

    private class ConvThread
    implements Runnable {
        public Tensor t;
        public Tensor res;
        public boolean doBackprop;
        public GPUGraph g;
        public int pad;
        public int stride;
        public int numFiltersPerDepth;
        public Matrix bias;
        public Matrix[] filters;
        public UUID parameterUUID;
        public Nonlinearity nonlin;
        public Exception e;

        public ConvThread(boolean doBackprop, int pad, int stride, int numFiltersPerDepth, Tensor t, Matrix bias, Matrix[] filters, UUID parameterUUID, Nonlinearity nonlin) {
            this.doBackprop = doBackprop;
            this.pad = pad;
            this.stride = stride;
            this.numFiltersPerDepth = numFiltersPerDepth;
            this.t = t;
            this.bias = bias;
            this.filters = filters;
            this.parameterUUID = parameterUUID;
            this.nonlin = nonlin;
        }

        @Override
        public void run() {
            this.e = null;
            if (this.numFiltersPerDepth == 0) {
                this.res = new Tensor(0, 0, 0);
                return;
            }
            try {
                this.res = this.g.fullConv(this.doBackprop, this.pad, this.stride, this.numFiltersPerDepth, this.t, this.bias, this.filters, this.nonlin, this.parameterUUID);
            }
            catch (Exception e1) {
                this.e = e1;
            }
        }
    }
}

