/*
 * Decompiled with CFR 0.152.
 */
package theGhastModding.lstmStuff.main;

import autodiff.Graph;
import datastructs.DataSequence;
import datastructs.DataSet;
import datastructs.DataStep;
import java.io.File;
import java.util.ArrayList;
import java.util.Random;
import loss.LossSumOfSquares;
import matrix.Matrix;
import model.FeedForwardLayer;
import model.LstmLayer;
import model.Model;
import model.NeuralNetwork;
import nonlinearities.LinearUnit;
import nonlinearities.Nonlinearity;
import nonlinearities.ReLuUnit;
import theGhastModding.lstmStuff.main.WavFile;
import trainer.NewTrainer;
import trainer.RMSProp;
import util.FileIO;

public class WavAI {
    public static void main(String[] args) {
        try {
            String savefile = "wav.dat";
            String[] audioFiles = new String[]{"1.wav", "2.wav", "3.wav", "4.wav"};
            WavFile[] wavFiles = new WavFile[audioFiles.length];
            int i = 0;
            while (i < wavFiles.length) {
                wavFiles[i] = WavFile.openWavFile(new File(audioFiles[i]));
                ++i;
            }
            System.out.println("Loaded files");
            WavDataset dataset = new WavDataset(500, 500, wavFiles);
            System.out.println("Loaded dataset");
            Random rng = new Random();
            ArrayList<Model> networkLayers = new ArrayList<Model>();
            networkLayers.add(new FeedForwardLayer(dataset.inputDimension, 256, new ReLuUnit(), 0.08, rng));
            networkLayers.add(new LstmLayer(256, 256, 0.08, rng));
            networkLayers.add(new LstmLayer(256, 500, 0.08, rng));
            networkLayers.add(new FeedForwardLayer(500, dataset.outputDimension, new ReLuUnit(), 0.08, rng));
            NeuralNetwork net = new NeuralNetwork(networkLayers);
            if (new File("wav.dat").exists()) {
                FileIO.loadNeuralNetwork("wav.dat", net);
            }
            if (args.length != 0 && args[0].equals("g")) {
                System.out.println("Generating");
                dataset.generate("generated_wav.wav", net, 60, rng);
                System.out.println("Done.");
                System.exit(0);
            }
            NewTrainer trainer = new NewTrainer(new RMSProp(0.99, 3.0, 1.0E-8));
            int iterations = 100;
            int epochs = 5;
            int i2 = 0;
            while (i2 < 100) {
                System.out.println("Iteration " + Integer.toString(i2 + 1) + "/" + Integer.toString(100));
                trainer.train(net, 0.001, 5, dataset, 4, "wav.dat", false, false, rng);
                FileIO.saveNeuralNetwork("wav.dat", net);
                ++i2;
            }
            System.out.println("Done.");
        }
        catch (Exception e) {
            System.err.println("Error: ");
            e.printStackTrace();
            System.exit(1);
        }
    }

    private static class WavDataset
    extends DataSet {
        private int samplesPerInput;

        public WavDataset(int maxSequenceLength, int samplesPerInput, WavFile ... wavs) throws Exception {
            this.samplesPerInput = samplesPerInput;
            this.inputDimension = samplesPerInput;
            this.outputDimension = samplesPerInput;
            this.training = new ArrayList();
            this.lossTraining = new LossSumOfSquares();
            this.lossReporting = new LossSumOfSquares();
            int o = 0;
            while (o < wavs.length) {
                WavFile wav = wavs[o];
                int sequenceNum = (int)(wav.getNumFrames() / (long)samplesPerInput);
                double[] sampleBuffer = new double[samplesPerInput];
                double[] sampleBuffer2 = new double[samplesPerInput];
                ArrayList<DataStep> currentSequence = new ArrayList<DataStep>();
                wav.readFrames(sampleBuffer, samplesPerInput);
                int j = 0;
                while (j < sampleBuffer.length) {
                    sampleBuffer[j] = (sampleBuffer[j] + 1.0) / 2.0;
                    ++j;
                }
                int i = 0;
                while (i < sequenceNum - 1) {
                    wav.readFrames(sampleBuffer2, samplesPerInput);
                    int j2 = 0;
                    while (j2 < sampleBuffer2.length) {
                        sampleBuffer2[j2] = (sampleBuffer2[j2] + 1.0) / 2.0;
                        ++j2;
                    }
                    if (currentSequence.size() >= maxSequenceLength) {
                        this.training.add(new DataSequence(currentSequence));
                        currentSequence = new ArrayList();
                    }
                    double[] a = new double[samplesPerInput];
                    System.arraycopy(sampleBuffer, 0, a, 0, samplesPerInput);
                    double[] b = new double[samplesPerInput];
                    System.arraycopy(sampleBuffer2, 0, b, 0, samplesPerInput);
                    currentSequence.add(new DataStep(a, b));
                    System.arraycopy(sampleBuffer2, 0, sampleBuffer, 0, samplesPerInput);
                    ++i;
                }
                if (currentSequence.size() > maxSequenceLength / 4) {
                    this.training.add(new DataSequence(currentSequence));
                }
                System.out.println("Processed " + Integer.toString(o + 1) + "/" + Integer.toString(wavs.length));
                ++o;
            }
            System.out.println("Num sequences: " + Integer.toString(this.training.size()));
        }

        @Override
        public void DisplayReport(Model model, Random rng) throws Exception {
            System.out.println("a");
        }

        @Override
        public Nonlinearity getModelOutputUnitToUse() {
            return new LinearUnit();
        }

        public void generate(String wavFileName, NeuralNetwork net, int length, Random rng) throws Exception {
            Matrix buffer;
            int randomSeq = rng.nextInt(this.training.size());
            int randomStep = rng.nextInt(((DataSequence)this.training.get((int)randomSeq)).steps.size() - 5);
            WavFile wav = WavFile.newWavFile(new File(wavFileName), 1, length * 8000, 16, 8000L);
            Graph g = new Graph(false);
            net.resetState();
            int i = 0;
            while (i < 5) {
                buffer = ((DataSequence)this.training.get((int)randomSeq)).steps.get((int)(randomStep + i)).input;
                net.forward(buffer, g);
                double[] samples = new double[buffer.w.length];
                int j = 0;
                while (j < samples.length) {
                    samples[j] = buffer.w[j] * 2.0 - 1.0;
                    ++j;
                }
                wav.writeFrames(samples, samples.length);
                ++i;
            }
            buffer = ((DataSequence)this.training.get((int)randomSeq)).steps.get((int)(randomStep + 5)).input;
            i = 0;
            while (i < length * 8000 / this.samplesPerInput) {
                Matrix out = net.forward(buffer, g);
                double[] samples2 = new double[out.w.length];
                int j = 0;
                while (j < samples2.length) {
                    if (out.w[j] == 0.0) {
                        out.w[j] = 0.5;
                    }
                    samples2[j] = out.w[j] * 2.0 - 1.0;
                    ++j;
                }
                wav.writeFrames(samples2, samples2.length);
                buffer = out.clone();
                ++i;
            }
            net.resetState();
            wav.close();
        }
    }
}

