/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import com.leff.midi.MidiFile;
import com.leff.midi.MidiTrack;
import com.leff.midi.event.MidiEvent;
import com.leff.midi.event.NoteOff;
import com.leff.midi.event.NoteOn;
import com.leff.midi.event.meta.Tempo;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import loss.LossSumOfSquares;
import matrix.Matrix;
import matrix.Tensor;
import model.ConvExpand;
import model.ConvFlatten;
import model.ConvNonlinLayer;
import model.Dropout;
import model.FeedForwardLayer;
import model.LinearLayer;
import model.NeuralNetwork;
import nonlinearities.ExponentialLinearUnit;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import trainer.Adam;
import trainer.GANTrainer;
import util.FileIO;

public class MIDINet {
    private static NeuralNetwork generator;
    private static NeuralNetwork discriminator;
    private static final int seedLength = 124;
    private static final Graph staticGraph;
    private static final Random staticRandom;
    private static final String saveBasePath = "D:/models/midi_net";

    static {
        MidiTrack.deltaFix = true;
        MidiEvent.statusByteFix = true;
        staticGraph = new Graph(false);
        staticRandom = new Random();
    }

    public static void main(String[] args) {
        try {
            MidiDataset dataset = new MidiDataset(new File("D:\\midiz\\"));
            MIDINet.loadNets();
            GANTrainer trainer = new GANTrainer(new Adam(0.5, 0.5555, 1.0E-9), 0.64);
            int iters = 1024;
            double lr = 1.0E-4;
            int i = 0;
            while (i < MIDINet.generator.t / 2000) {
                lr *= 0.95;
                ++i;
            }
            int lastT = MIDINet.generator.t / 2000;
            int i2 = 0;
            while (i2 < iters) {
                System.out.println("Iteration " + Integer.toString(i2 + 1) + " of " + Integer.toString(1024));
                trainer.train(generator, discriminator, lr, 2, 16, dataset, Integer.MAX_VALUE, null, null, false, false, staticRandom);
                if (i2 % 5 == 0) {
                    dataset.DisplayReport(generator, staticRandom);
                    MIDINet.saveNets();
                }
                if (MIDINet.generator.t - lastT >= 2000) {
                    lr *= 0.95;
                    lastT = MIDINet.generator.t;
                }
                ++i2;
            }
            trainer.dispose();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private static void loadNets() throws Exception {
        generator = new NeuralNetwork();
        generator.addLayer(new FeedForwardLayer(124, 2048, new ExponentialLinearUnit(1.0), Math.sqrt(0.008064516129032258), staticRandom));
        generator.addLayer(new LinearLayer(2048, 256, Math.sqrt(9.765625E-4), staticRandom));
        generator.addLayer(new ConvNonlinLayer(new ExponentialLinearUnit(1.0)));
        generator.addLayer(new FeedForwardLayer(256, 147456, new LinearUnit(), Math.sqrt(6.544502617801048E-4), staticRandom));
        generator.addLayer(new ConvExpand(147456, 96, 96, 16));
        discriminator = new NeuralNetwork();
        discriminator.addLayer(new ConvFlatten(96, 96, 16));
        discriminator.addLayer(new FeedForwardLayer(147456, 256, new ExponentialLinearUnit(1.0), Math.sqrt(6.781684027777777E-6), staticRandom));
        discriminator.addLayer(new Dropout(0.25));
        discriminator.addLayer(new FeedForwardLayer(256, 2048, new ExponentialLinearUnit(1.0), Math.sqrt(4.8828125E-4), staticRandom));
        discriminator.addLayer(new Dropout(0.125));
        discriminator.addLayer(new FeedForwardLayer(2048, 124, new LinearUnit(), Math.sqrt(9.765625E-4), staticRandom));
        File generatorFile = new File("D:/models/midi_net_generator.dat");
        File discriminatorFile = new File("D:/models/midi_net_discriminator.dat");
        if (generatorFile.exists() && discriminatorFile.exists()) {
            FileIO.loadNeuralNetwork(generatorFile.getPath(), generator);
            FileIO.loadNeuralNetwork(discriminatorFile.getPath(), discriminator);
        } else {
            System.err.println("WAIT!");
        }
        long totalParams = 0L;
        for (Matrix m : generator.getParameters()) {
            totalParams += (long)m.w.length;
        }
        for (Matrix m : discriminator.getParameters()) {
            totalParams += (long)m.w.length;
        }
        System.out.println("Loaded. " + Long.toString(totalParams) + " total parameters.");
    }

    private static void saveNets() throws Exception {
        FileIO.saveNeuralNetwork("D:/models/midi_net_generator.dat", generator);
        FileIO.saveNeuralNetwork("D:/models/midi_net_discriminator.dat", discriminator);
        System.out.println("Saved.");
    }

    private static MidiFile generateExample() throws Exception {
        Tensor in = Tensor.rand(1, 124, 1, 1.0, staticRandom);
        generator.resetState();
        Tensor out = generator.forward(in, staticGraph);
        generator.resetState();
        return MIDINet.tensorsToMIDI(out);
    }

    public static Tensor[] MIDItoTensors(File midiFile) throws Exception {
        MidiFile midi = new MidiFile(midiFile);
        List<MidiEvent> allEvents = MIDINet.getMidiEvents(midi);
        Tempo t = MIDINet.findFirstTempo(allEvents);
        double tps = (double)t.getBpm() / 60.0 * (double)midi.getResolution();
        double ticksPerSlice = tps * 3.0;
        int matrixNum = (int)Math.floor((double)midi.getLengthInTicks() / ticksPerSlice);
        int tensorNum = matrixNum / 16;
        int eventNum = 0;
        Tensor[] toReturn = new Tensor[tensorNum];
        int i = 0;
        while (i < tensorNum) {
            toReturn[i] = new Tensor(96, 96, 16);
            ++i;
        }
        int[] noteOnCntrs = new int[96];
        int i2 = 0;
        while (i2 < tensorNum) {
            Tensor res = toReturn[i2];
            int j = 0;
            while (j < 16) {
                MidiEvent event;
                double sliceStartTick = (double)(i2 * 16 + j) * ticksPerSlice;
                while (!((double)(event = allEvents.get(eventNum)).getTick() >= sliceStartTick + ticksPerSlice)) {
                    block18: {
                        double existingPixel;
                        Matrix resM;
                        int indx;
                        int noteVal;
                        int pixelThingy;
                        block17: {
                            if (event instanceof NoteOn) {
                                pixelThingy = (int)(((double)event.getTick() - sliceStartTick) / ticksPerSlice * 96.0);
                                if (pixelThingy < 0) {
                                    pixelThingy = 0;
                                }
                                if (pixelThingy > 95) {
                                    pixelThingy = 95;
                                }
                                if ((noteVal = ((NoteOn)event).getNoteValue() - 16) > 0 && noteVal < 96 && ((NoteOn)event).getVelocity() >= 75) {
                                    indx = j;
                                    resM = res.matrices[indx];
                                    while ((existingPixel = resM.getW(noteVal, pixelThingy)) == -1.0) {
                                        if (++pixelThingy != 96) continue;
                                        if (indx != 15) {
                                            resM = res.matrices[++indx];
                                            pixelThingy = 0;
                                            continue;
                                        }
                                        break block17;
                                    }
                                    resM.setW(noteVal, pixelThingy, 1.0);
                                    int n = noteVal;
                                    noteOnCntrs[n] = noteOnCntrs[n] + 1;
                                }
                            }
                        }
                        if (event instanceof NoteOff) {
                            pixelThingy = (int)(((double)event.getTick() - sliceStartTick) / ticksPerSlice * 96.0);
                            if (pixelThingy < 0) {
                                pixelThingy = 0;
                            }
                            if (pixelThingy > 95) {
                                pixelThingy = 95;
                            }
                            if ((noteVal = ((NoteOff)event).getNoteValue() - 16) > 0 && noteVal < 96) {
                                indx = j;
                                resM = res.matrices[indx];
                                while ((existingPixel = resM.getW(noteVal, pixelThingy)) == 1.0) {
                                    if (--pixelThingy != -1) continue;
                                    if (indx != 0) {
                                        resM = res.matrices[--indx];
                                        pixelThingy = 95;
                                        continue;
                                    }
                                    break block18;
                                }
                                if (noteOnCntrs[noteVal] > 0) {
                                    resM.setW(noteVal, pixelThingy, -1.0);
                                    int n = noteVal;
                                    noteOnCntrs[n] = noteOnCntrs[n] - 1;
                                }
                            }
                        }
                    }
                    if (++eventNum < allEvents.size()) continue;
                }
                if (eventNum >= allEvents.size()) break;
                ++j;
            }
            if (eventNum >= allEvents.size()) break;
            ++i2;
        }
        return toReturn;
    }

    public static MidiFile tensorsToMIDI(Tensor ... t) throws Exception {
        double threshold = 0.25;
        MidiFile res = new MidiFile(480);
        MidiTrack tempoTrack = new MidiTrack();
        tempoTrack.insertEvent(new Tempo());
        res.addTrack(tempoTrack);
        MidiTrack noteTrack = new MidiTrack();
        double currTickPos = 0.0;
        double tps = 960.0;
        double ticksPerSlice = tps * 3.0;
        double ticksPerColumn = ticksPerSlice / 96.0;
        int i2 = 0;
        while (i2 < t.length) {
            int i = 0;
            while (i < 16) {
                Matrix m = t[i2].matrices[i];
                int j = 0;
                while (j < 96) {
                    int k = 0;
                    while (k < 96) {
                        double pixel = m.getW(k, j);
                        if (pixel < -0.75) {
                            NoteOff noff = new NoteOff((long)currTickPos, 0, k + 16, 0);
                            noteTrack.insertEvent(noff);
                        } else if (pixel > 0.25) {
                            NoteOn non = new NoteOn((long)currTickPos, 0, k + 16, 100);
                            noteTrack.insertEvent(non);
                        }
                        ++k;
                    }
                    currTickPos += ticksPerColumn;
                    ++j;
                }
                ++i;
            }
            ++i2;
        }
        res.addTrack(noteTrack);
        return res;
    }

    private static Tempo findFirstTempo(List<MidiEvent> events) {
        for (MidiEvent ev : events) {
            if (!(ev instanceof Tempo) || ev.getTick() != 0L) continue;
            return (Tempo)ev;
        }
        return new Tempo();
    }

    private static List<MidiEvent> getMidiEvents(MidiFile midi) throws Exception {
        ArrayList<MidiEvent> toReturn = new ArrayList<MidiEvent>();
        int i = 0;
        while (i < midi.getTrackCount()) {
            Iterator<MidiEvent> it = midi.getTrack(i).getEvents().iterator();
            while (it.hasNext()) {
                toReturn.add(it.next());
            }
            ++i;
        }
        return MIDINet.sortEvents(toReturn);
    }

    private static List<MidiEvent> sortEvents(List<MidiEvent> list) {
        int i = 1;
        while (i < list.size()) {
            MidiEvent event = list.get(i);
            int j = i;
            while (j > 0 && list.get(j - 1).getTick() > event.getTick()) {
                list.set(j, list.get(j - 1));
                --j;
            }
            list.set(j, event);
            ++i;
        }
        return list;
    }

    private static class MidiDataset
    extends DataSet {
        /*
         * Unable to fully structure code
         */
        public MidiDataset(File midiFileFolder) throws Exception {
            super();
            this.inputDimension = new DataSet.TensorDimensions(1, 124, 1);
            this.outputDimension = new DataSet.TensorDimensions(96, 96, 16);
            this.training = new ArrayList<E>();
            System.out.println("Loading MIDI dataset...");
            System.out.print("|");
            i = 0;
            while (i < 98) {
                System.out.print("-");
                ++i;
            }
            System.out.println("|");
            prevPercent = 0;
            fs = midiFileFolder.listFiles();
            i = 0;
            ** GOTO lbl35
            {
                System.out.print(">");
                ++prevPercent;
                do {
                    if ((int)((double)i / (double)fs.length * 100.0) > prevPercent) continue block1;
                    ds = new DataSequence();
                    tensors = MIDINet.MIDItoTensors(fs[i]);
                    if (tensors.length != 0) {
                        var10_10 = tensors;
                        var9_9 = tensors.length;
                        var8_8 = 0;
                        while (var8_8 < var9_9) {
                            t = var10_10[var8_8];
                            ds.addDataStep(new DataStep(t, null));
                            ++var8_8;
                        }
                        this.training.add(ds);
                    }
                    ++i;
lbl35:
                    // 2 sources

                } while (i < fs.length);
            }
            System.out.println();
            this.lossReporting = this.lossTraining = new LossSumOfSquares();
            System.out.println("Done!");
        }

        @Override
        public void DisplayReport(NeuralNetwork model, Random rng) throws Exception {
            try {
                MIDINet.generateExample().writeToFile(new File("midi_net_output_" + Long.toString(System.currentTimeMillis()) + ".mid"));
            }
            catch (Exception e) {
                System.err.println("Error generating test output: ");
                e.printStackTrace();
                return;
            }
            System.err.println("a");
        }

        @Override
        public Nonlinearity getModelOutputUnitToUse() {
            return new LinearUnit();
        }
    }
}

