#include "eistype.h"
#include <execution>
#include <iostream>
#include <filesystem>
#include <algorithm>
#include <system_error>
#include <condition_variable>
#include <mutex>
#include <limits>
#include <algorithm>

#include "kissinference/kissinference.h"

#define PASS_STR "Pass"

static std::vector<std::filesystem::path> dirs;
std::mutex printmtx;
long long passIndex = -1;

#ifdef WINDOWS
#define OUTSTREAM std::wcout
#else
#define OUTSTREAM std::cout
#endif

static std::vector<std::filesystem::path> getFiles(const std::filesystem::path dir, std::string extension = "")
{
	std::vector<std::filesystem::path> out;
	if(!std::filesystem::is_directory(dir))
	{
		if(extension.empty() || dir.extension() == extension)
			out.push_back(dir);
		return out;
	}

	for(const std::filesystem::directory_entry& dirent : std::filesystem::directory_iterator{dir})
	{
		if(!extension.empty() && dirent.path().extension() != extension)
			continue;
		out.push_back(dirent.path());
	}
	return out;
}

static void getArrays(const std::vector<DataPoint>& data, float **re, float **im, float **omega)
{
	(*re) = new float[data.size()];
	(*im) = new float[data.size()];
	(*omega) = new float[data.size()];

	for(size_t i = 0; i < data.size(); ++i)
	{
		(*re)[i] = data[i].im.real();
		(*im)[i] = data[i].im.imag();
		(*omega)[i] = data[i].omega;
	}
}

static bool checkDir(const std::filesystem::path& outDir)
{
	if(!std::filesystem::is_directory(outDir))
	{
		if(!std::filesystem::create_directory(outDir))
		{
			OUTSTREAM<<outDir<<" dose not exist and can not be created\n";
			return false;
		}
	}
	return true;
}

struct Request
{
	std::filesystem::path path;
	std::condition_variable cond;
	std::mutex mutex;
};

static size_t topIndex(float* data, size_t size)
{
	float max = std::numeric_limits<float>::min();
	size_t maxIndex = 0;
	for(size_t i = 0; i < size; ++i)
	{
		if(data[i] > max)
		{
			max = data[i];
			maxIndex = i;
		}
	}
	return maxIndex;
}

static void resultCallback(float* data, struct kiss_network* net, void* userData)
{
	Request* rq = reinterpret_cast<Request*>(userData);
	size_t index = topIndex(data, net->output_size);

	if(!dirs.empty())
	{
		if(static_cast<long long>(index) == passIndex)
		{
			if(data[index] < 0.7)
				index = !index;
		}
		std::error_code ec;
		std::filesystem::rename(rq->path, dirs[index]/rq->path.filename(), ec);
		if(ec)
		{
			printmtx.lock();
			OUTSTREAM<<"unable to move "<<rq->path<<" to "<<dirs[index]/rq->path.filename();
			printmtx.unlock();
		}
	}
	else
	{
		std::vector<std::pair<std::string, float>> results;
		for(size_t i = 0; i < net->output_size; ++i)
			results.push_back({net->output_labels[i], data[i]});
		std::sort(results.begin(), results.end(), [](std::pair<std::string, float> a, std::pair<std::string, float> b){return a.second > b.second;});
		printmtx.lock();
		OUTSTREAM<<"\nResult for "<<rq->path<<'\n';
		OUTSTREAM<<"Classes:\nNumber\tName\tLikelyhood\n";
		for(size_t i = 0; i < net->output_size; ++i)
			OUTSTREAM<<i<<'\t'<<results[i].first<<'\t'<<results[i].second<<'\n';
		printmtx.unlock();
	}

	free(data);

	rq->cond.notify_all();
}

static void pipeline(const std::filesystem::path& path, struct kiss_network* net)
{
	float* re;
	float* im;
	float* omega;
	float* re_filtered;
	float* im_filtered;
	bool ret;

	try
	{
		EisSpectra spectra = EisSpectra::loadFromDisk(path);

		getArrays(spectra.data, &re, &im, &omega);
		ret = kiss_filter_spectra(re, im, omega, spectra.data.size(), &re_filtered, &im_filtered, net->input_size/2);
		delete[] re;
		delete[] im;
		delete[] omega;

		if(!ret)
		{
			printmtx.lock();
			OUTSTREAM<<"Could not filter "<<path<<'\n';
			printmtx.unlock();
			return;
		}
	}
	catch(const file_error& err)
	{
		OUTSTREAM<<"Could not load: "<<err.what()<<'\n';
		return;
	}

	Request rq;
	rq.path = path;

	std::unique_lock<std::mutex> lk(rq.mutex);
	ret = kiss_async_run_inference_complex(net, re_filtered, im_filtered, &rq);
	free(re_filtered);
	free(im_filtered);
	if(!ret)
	{
		printmtx.lock();
		OUTSTREAM<<"Could not run inference on "<<path<<": "<<kiss_get_strerror(net)<<'\n';
		printmtx.unlock();
	}
	else
	{
		rq.cond.wait(lk);
	}
}

#ifdef WINDOWS
int wmain(int argc, wchar_t** argv)
#else
int main(int argc, char** argv)
#endif
{
	if(argc < 3)
	{
		if(argc > 0)
			OUTSTREAM<<"Usage: "<<argv[0]<<" [NETWORK] [CSV_FILE_OR_DIRECTORY] [OUT_DIRECTORY]"<<'\n';
		return 1;
	}

	std::vector<std::filesystem::path> files = getFiles(argv[2], ".csv");
	if(files.empty())
	{
		OUTSTREAM<<"Could not find any .csv files in "<<argv[2]<<'\n';
		return 1;
	}

	struct kiss_network* net = kiss_load_network(reinterpret_cast<const char*>(argv[1]), resultCallback, false);
	if(!net)
	{
		OUTSTREAM<<"Could not load network from "<<argv[1]<<'\n';
		return 1;
	}
	else if(!net->ready)
	{
		OUTSTREAM<<"Could not load network: "<<kiss_get_strerror(net)<<'\n';
		kiss_free_network(net);
		return 1;
	}


	bool ret = true;
	if(argc > 3)
		ret = checkDir(argv[3]);
	if(!ret)
	{
		OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n';
		return 1;
	}

	OUTSTREAM<<"Network has the following classes:\n";
	for(size_t i = 0; i < net->output_size; ++i)
	{
		OUTSTREAM<<net->output_labels[i]<<'\n';
		if(argc > 3)
		{
			dirs.push_back(std::filesystem::path(argv[3])/net->output_labels[i]);
			ret = checkDir(dirs.back());
			OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n';
			return 1;
		}
		if(std::string(net->output_labels[i]) == PASS_STR)
			passIndex = i;
	}

	std::for_each(std::execution::seq, files.begin(), files.end(), [&net](std::filesystem::path& path){pipeline(path, net);});

	kiss_free_network(net);

	return 0;
}