/*
 * Decompiled with CFR 0.152.
 */
package loss;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataStep;
import java.util.ArrayList;
import java.util.List;
import loss.Loss;
import matrix.Matrix;
import model.Model;
import util.Util;

public class LossSoftmax
implements Loss {
    @Override
    public void backward(Matrix logprobs, Matrix targetOutput) throws Exception {
        int targetIndex = LossSoftmax.getTargetIndex(targetOutput);
        Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, 1.0);
        int i = 0;
        while (i < probs.w.length) {
            logprobs.dw[i] = probs.w[i];
            ++i;
        }
        int n = targetIndex;
        logprobs.dw[n] = logprobs.dw[n] - 1.0;
    }

    @Override
    public double measure(Matrix logprobs, Matrix targetOutput) throws Exception {
        int targetIndex = LossSoftmax.getTargetIndex(targetOutput);
        Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, 1.0);
        double cost = -Math.log(probs.w[targetIndex]);
        return cost;
    }

    public static double calculateMedianPerplexity(Model model, List<DataSequence> sequences) throws Exception {
        double temperature = 1.0;
        ArrayList<Double> ppls = new ArrayList<Double>();
        for (DataSequence seq : sequences) {
            double n = 0.0;
            double neglog2ppl = 0.0;
            Graph g = new Graph(false);
            model.resetState();
            for (DataStep step : seq.steps) {
                Matrix logprobs = model.forward(step.input, g);
                Matrix probs = LossSoftmax.getSoftmaxProbs(logprobs, temperature);
                int targetIndex = LossSoftmax.getTargetIndex(step.targetOutput);
                double probOfCorrect = probs.w[targetIndex];
                double log2prob = Math.log(probOfCorrect) / Math.log(2.0);
                neglog2ppl += -log2prob;
                n += 1.0;
            }
            double ppl = Math.pow(2.0, neglog2ppl / ((n -= 1.0) - 1.0));
            ppls.add(ppl);
        }
        return Util.median(ppls);
    }

    public static Matrix getSoftmaxProbs(Matrix logprobs, double temperature) throws Exception {
        Matrix probs = new Matrix(logprobs.w.length);
        if (temperature != 1.0) {
            int i = 0;
            while (i < logprobs.w.length) {
                int n = i++;
                logprobs.w[n] = logprobs.w[n] / temperature;
            }
        }
        double maxval = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (i < logprobs.w.length) {
            if (logprobs.w[i] > maxval) {
                maxval = logprobs.w[i];
            }
            ++i;
        }
        double sum = 0.0;
        int i2 = 0;
        while (i2 < logprobs.w.length) {
            probs.w[i2] = Math.exp(logprobs.w[i2] - maxval);
            sum += probs.w[i2];
            ++i2;
        }
        i2 = 0;
        while (i2 < probs.w.length) {
            int n = i2++;
            probs.w[n] = probs.w[n] / sum;
        }
        return probs;
    }

    private static int getTargetIndex(Matrix targetOutput) throws Exception {
        int i = 0;
        while (i < targetOutput.w.length) {
            if (targetOutput.w[i] == 1.0) {
                return i;
            }
            ++i;
        }
        throw new Exception("no target index selected");
    }
}

