/*
 * Decompiled with CFR 0.152.
 */
package util;

import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import matrix.Tensor;
import util.FileChannelInputStream;

public class CifarLoader {
    private File folder;
    private BufferedImage[] images;
    private int[] categories;
    public static final int AIRPLANE = 0;
    public static final int AUTOMOBILE = 1;
    public static final int BIRD = 2;
    public static final int CAT = 3;
    public static final int DEER = 4;
    public static final int DOG = 5;
    public static final int FROG = 6;
    public static final int HORSE = 7;
    public static final int SHIP = 8;
    public static final int TRUCK = 9;

    public CifarLoader(File folder) {
        this.folder = folder;
    }

    public boolean load(int batch, int imageCountToLoad) throws Exception {
        if (imageCountToLoad > 10000) {
            imageCountToLoad = 10000;
        }
        if (!this.folder.exists()) {
            throw new FileNotFoundException(this.folder.getPath());
        }
        if (!this.folder.isDirectory()) {
            throw new Exception("Not a directory");
        }
        File batchFile = batch <= 5 ? new File(String.valueOf(this.folder.getPath()) + "/data_batch_" + Integer.toString(batch) + ".bin") : new File(String.valueOf(this.folder.getPath()) + "/test_batch.bin");
        this.images = new BufferedImage[imageCountToLoad];
        this.categories = new int[imageCountToLoad];
        FileInputStream fis = new FileInputStream(batchFile);
        FileChannelInputStream in = new FileChannelInputStream(fis.getChannel());
        int i = 0;
        while (i < imageCountToLoad) {
            BufferedImage newImage = new BufferedImage(32, 32, 1);
            this.categories[i] = in.read();
            int[] reds = new int[1024];
            int[] greens = new int[1024];
            int[] blues = new int[1024];
            int j = 0;
            while (j < 1024) {
                reds[j] = in.read();
                ++j;
            }
            j = 0;
            while (j < 1024) {
                greens[j] = in.read();
                ++j;
            }
            j = 0;
            while (j < 1024) {
                blues[j] = in.read();
                ++j;
            }
            j = 0;
            while (j < 1024) {
                newImage.setRGB(j % 32, j / 32, new Color(reds[j], greens[j], blues[j]).getRGB());
                ++j;
            }
            this.images[i] = newImage;
            ++i;
        }
        in.close();
        fis.close();
        return true;
    }

    public BufferedImage getImage(int index) {
        return this.images[index];
    }

    public int getCategory(int index) {
        return this.categories[index];
    }

    public int getCategoryForImage(BufferedImage image) {
        int i = 0;
        while (i < this.getImageCount()) {
            if (this.getImage(i).equals(image)) {
                return i;
            }
            ++i;
        }
        return -1;
    }

    public int getImageCount() {
        return this.images.length;
    }

    public static Tensor asTensor(BufferedImage image) {
        Tensor toReturn = new Tensor(image.getWidth(), image.getHeight(), 3);
        int i = 0;
        while (i < image.getWidth()) {
            int j = 0;
            while (j < image.getHeight()) {
                int argb = image.getRGB(i, j);
                int r = argb >> 16 & 0xFF;
                int g = argb >> 8 & 0xFF;
                int b = argb >> 0 & 0xFF;
                toReturn.matrices[0].setW(j, i, (double)r / 255.0);
                toReturn.matrices[1].setW(j, i, (double)g / 255.0);
                toReturn.matrices[2].setW(j, i, (double)b / 255.0);
                ++j;
            }
            ++i;
        }
        return toReturn;
    }

    public static BufferedImage asImage(Tensor t) {
        BufferedImage toReturn = new BufferedImage(t.width, t.height, t.depth >= 4 ? 2 : 1);
        int j = 0;
        while (j < toReturn.getWidth()) {
            int k = 0;
            while (k < toReturn.getHeight()) {
                int r = 0;
                int g = 0;
                int b = 0;
                int i = 0;
                while (i < t.depth) {
                    if (i >= 3) break;
                    switch (i) {
                        case 0: {
                            r = (int)(t.matrices[i].getW(k, j) * 255.0);
                            break;
                        }
                        case 1: {
                            g = (int)(t.matrices[i].getW(k, j) * 255.0);
                            break;
                        }
                        case 2: {
                            b = (int)(t.matrices[i].getW(k, j) * 255.0);
                        }
                    }
                    ++i;
                }
                if (r > 255) {
                    r = 255;
                }
                if (r < 0) {
                    r = 0;
                }
                if (g > 255) {
                    g = 255;
                }
                if (g < 0) {
                    g = 0;
                }
                if (b > 255) {
                    b = 255;
                }
                if (b < 0) {
                    b = 0;
                }
                int rgb = r;
                rgb = (rgb << 8) + g;
                rgb = (rgb << 8) + b;
                toReturn.setRGB(j, k, rgb);
                ++k;
            }
            ++j;
        }
        return toReturn;
    }

    public String getCategoryString(int category) {
        String str = "";
        switch (category) {
            case 0: {
                str = "Airplane";
                break;
            }
            case 1: {
                str = "Automobile";
                break;
            }
            case 2: {
                str = "Bird";
                break;
            }
            case 3: {
                str = "Cat";
                break;
            }
            case 4: {
                str = "Deer";
                break;
            }
            case 5: {
                str = "Dog";
                break;
            }
            case 6: {
                str = "Frog";
                break;
            }
            case 7: {
                str = "Horse";
                break;
            }
            case 8: {
                str = "Ship";
                break;
            }
            case 9: {
                str = " Truck";
                break;
            }
            default: {
                str = "Unknown";
            }
        }
        return str;
    }
}

