/*
 * Decompiled with CFR 0.152.
 */
package datasets;

import datastructs.TensorDataSequence;
import datastructs.TensorDataSet;
import datastructs.TensorDataStep;
import java.util.ArrayList;
import java.util.Random;
import loss.LossSoftmax;
import matrix.Tensor;
import model.ConvNet;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import theGhastModding.console.main.Console;
import util.CifarLoader;

public class CifarDataset
extends TensorDataSet {
    private static final long serialVersionUID = 1L;

    public CifarDataset(CifarLoader loader, int maxImagesPerSequence) {
        this.inputDimension = new TensorDataSet.TensorDimensions(32, 32, 3);
        this.outputDimension = new TensorDataSet.TensorDimensions(1, 10, 1);
        this.lossTraining = new LossSoftmax();
        this.lossReporting = new LossSoftmax();
        this.training = new ArrayList();
        this.training.add(new TensorDataSequence());
        int i = 0;
        while (i < loader.getImageCount()) {
            Tensor output = new Tensor(1, 10, 1);
            output.getMatrixAt((int)0).w[loader.getCategory((int)i)] = 1.0;
            ((TensorDataSequence)this.training.get(this.training.size() - 1)).addDataStep(new TensorDataStep(CifarLoader.asTensor(loader.getImage(i)), output));
            if (i % maxImagesPerSequence == 0 && i != 0) {
                this.training.add(new TensorDataSequence());
            }
            ++i;
        }
    }

    @Override
    public void DisplayReport(ConvNet model, Random rng, Console c) throws Exception {
    }

    @Override
    public Nonlinearity getModelOutputUnitToUse() {
        return new LinearUnit();
    }
}

