Skip to content
Snippets Groups Projects
Select Git revision
  • 812ed426a5341ba1e7a2cd20ef1b848a791688ed
  • 5.4 default protected
  • 5.5
  • dev/5.5
  • dev/5.4
  • dev/5.3_downgrade
  • feature/experimenttime_hack
  • 5.3 protected
  • _IntenSelect5.3
  • IntenSelect5.3
  • 4.27 protected
  • 4.26 protected
  • 5.0 protected
  • 4.22 protected
  • 4.21 protected
  • UE5.4-2024.1
  • UE5.4-2024.1-rc1
  • UE5.3-2023.1-rc3
  • UE5.3-2023.1-rc2
  • UE5.3-2023.1-rc
20 results

RWTHVRToolkit.Build.cs

Blame
  • classify.cpp 5.52 KiB
    #include "eistype.h"
    #include <execution>
    #include <iostream>
    #include <filesystem>
    #include <algorithm>
    #include <system_error>
    #include <condition_variable>
    #include <mutex>
    #include <limits>
    #include <algorithm>
    
    #include "kissinference/kissinference.h"
    
    #define PASS_STR "Pass"
    
    static std::vector<std::filesystem::path> dirs;
    std::mutex printmtx;
    long long passIndex = -1;
    
    #ifdef WINDOWS
    #define OUTSTREAM std::wcout
    #else
    #define OUTSTREAM std::cout
    #endif
    
    static std::vector<std::filesystem::path> getFiles(const std::filesystem::path dir, std::string extension = "")
    {
    	std::vector<std::filesystem::path> out;
    	if(!std::filesystem::is_directory(dir))
    	{
    		if(extension.empty() || dir.extension() == extension)
    			out.push_back(dir);
    		return out;
    	}
    
    	for(const std::filesystem::directory_entry& dirent : std::filesystem::directory_iterator{dir})
    	{
    		if(!extension.empty() && dirent.path().extension() != extension)
    			continue;
    		out.push_back(dirent.path());
    	}
    	return out;
    }
    
    static void getArrays(const std::vector<DataPoint>& data, float **re, float **im, float **omega)
    {
    	(*re) = new float[data.size()];
    	(*im) = new float[data.size()];
    	(*omega) = new float[data.size()];
    
    	for(size_t i = 0; i < data.size(); ++i)
    	{
    		(*re)[i] = data[i].im.real();
    		(*im)[i] = data[i].im.imag();
    		(*omega)[i] = data[i].omega;
    	}
    }
    
    static bool checkDir(const std::filesystem::path& outDir)
    {
    	if(!std::filesystem::is_directory(outDir))
    	{
    		if(!std::filesystem::create_directory(outDir))
    		{
    			OUTSTREAM<<outDir<<" dose not exist and can not be created\n";
    			return false;
    		}
    	}
    	return true;
    }
    
    struct Request
    {
    	std::filesystem::path path;
    	std::condition_variable cond;
    	std::mutex mutex;
    };
    
    static size_t topIndex(float* data, size_t size)
    {
    	float max = std::numeric_limits<float>::min();
    	size_t maxIndex = 0;
    	for(size_t i = 0; i < size; ++i)
    	{
    		if(data[i] > max)
    		{
    			max = data[i];
    			maxIndex = i;
    		}
    	}
    	return maxIndex;
    }
    
    static void resultCallback(float* data, struct kiss_network* net, void* userData)
    {
    	Request* rq = reinterpret_cast<Request*>(userData);
    	size_t index = topIndex(data, net->output_size);
    
    	if(!dirs.empty())
    	{
    		if(static_cast<long long>(index) == passIndex)
    		{
    			if(data[index] < 0.7)
    				index = !index;
    		}
    		std::error_code ec;
    		std::filesystem::rename(rq->path, dirs[index]/rq->path.filename(), ec);
    		if(ec)
    		{
    			printmtx.lock();
    			OUTSTREAM<<"unable to move "<<rq->path<<" to "<<dirs[index]/rq->path.filename();
    			printmtx.unlock();
    		}
    	}
    	else
    	{
    		std::vector<std::pair<std::string, float>> results;
    		for(size_t i = 0; i < net->output_size; ++i)
    			results.push_back({net->output_labels[i], data[i]});
    		std::sort(results.begin(), results.end(), [](std::pair<std::string, float> a, std::pair<std::string, float> b){return a.second > b.second;});
    		printmtx.lock();
    		OUTSTREAM<<"\nResult for "<<rq->path<<'\n';
    		OUTSTREAM<<"Classes:\nNumber\tName\tLikelyhood\n";
    		for(size_t i = 0; i < net->output_size; ++i)
    			OUTSTREAM<<i<<'\t'<<results[i].first<<'\t'<<results[i].second<<'\n';
    		printmtx.unlock();
    	}
    
    	free(data);
    
    	rq->cond.notify_all();
    }
    
    static void pipeline(const std::filesystem::path& path, struct kiss_network* net)
    {
    	float* re;
    	float* im;
    	float* omega;
    	float* re_filtered;
    	float* im_filtered;
    	bool ret;
    
    	try
    	{
    		EisSpectra spectra = EisSpectra::loadFromDisk(path);
    
    		getArrays(spectra.data, &re, &im, &omega);
    		ret = kiss_filter_spectra(re, im, omega, spectra.data.size(), &re_filtered, &im_filtered, net->input_size/2);
    		delete[] re;
    		delete[] im;
    		delete[] omega;
    
    		if(!ret)
    		{
    			printmtx.lock();
    			OUTSTREAM<<"Could not filter "<<path<<'\n';
    			printmtx.unlock();
    			return;
    		}
    	}
    	catch(const file_error& err)
    	{
    		OUTSTREAM<<"Could not load: "<<err.what()<<'\n';
    		return;
    	}
    
    	Request rq;
    	rq.path = path;
    
    	std::unique_lock<std::mutex> lk(rq.mutex);
    	ret = kiss_async_run_inference_complex(net, re_filtered, im_filtered, &rq);
    	free(re_filtered);
    	free(im_filtered);
    	if(!ret)
    	{
    		printmtx.lock();
    		OUTSTREAM<<"Could not run inference on "<<path<<": "<<kiss_get_strerror(net)<<'\n';
    		printmtx.unlock();
    	}
    	else
    	{
    		rq.cond.wait(lk);
    	}
    }
    
    #ifdef WINDOWS
    int wmain(int argc, wchar_t** argv)
    #else
    int main(int argc, char** argv)
    #endif
    {
    	if(argc < 3)
    	{
    		if(argc > 0)
    			OUTSTREAM<<"Usage: "<<argv[0]<<" [NETWORK] [CSV_FILE_OR_DIRECTORY] [OUT_DIRECTORY]"<<'\n';
    		return 1;
    	}
    
    	std::vector<std::filesystem::path> files = getFiles(argv[2], ".csv");
    	if(files.empty())
    	{
    		OUTSTREAM<<"Could not find any .csv files in "<<argv[2]<<'\n';
    		return 1;
    	}
    
    	struct kiss_network* net = kiss_load_network(reinterpret_cast<const char*>(argv[1]), resultCallback, false);
    	if(!net)
    	{
    		OUTSTREAM<<"Could not load network from "<<argv[1]<<'\n';
    		return 1;
    	}
    	else if(!net->ready)
    	{
    		OUTSTREAM<<"Could not load network: "<<kiss_get_strerror(net)<<'\n';
    		kiss_free_network(net);
    		return 1;
    	}
    
    
    	bool ret = true;
    	if(argc > 3)
    		ret = checkDir(argv[3]);
    	if(!ret)
    	{
    		OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n';
    		return 1;
    	}
    
    	OUTSTREAM<<"Network has the following classes:\n";
    	for(size_t i = 0; i < net->output_size; ++i)
    	{
    		OUTSTREAM<<net->output_labels[i]<<'\n';
    		if(argc > 3)
    		{
    			dirs.push_back(std::filesystem::path(argv[3])/net->output_labels[i]);
    			ret = checkDir(dirs.back());
    			OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n';
    			return 1;
    		}
    		if(std::string(net->output_labels[i]) == PASS_STR)
    			passIndex = i;
    	}
    
    	std::for_each(std::execution::seq, files.begin(), files.end(), [&net](std::filesystem::path& path){pipeline(path, net);});
    
    	kiss_free_network(net);
    
    	return 0;
    }