Select Git revision
RWTHVRToolkit.Build.cs
-
David Gilbert authoredDavid Gilbert authored
classify.cpp 5.52 KiB
#include "eistype.h"
#include <execution>
#include <iostream>
#include <filesystem>
#include <algorithm>
#include <system_error>
#include <condition_variable>
#include <mutex>
#include <limits>
#include <algorithm>
#include "kissinference/kissinference.h"
#define PASS_STR "Pass"
static std::vector<std::filesystem::path> dirs;
std::mutex printmtx;
long long passIndex = -1;
#ifdef WINDOWS
#define OUTSTREAM std::wcout
#else
#define OUTSTREAM std::cout
#endif
static std::vector<std::filesystem::path> getFiles(const std::filesystem::path dir, std::string extension = "")
{
std::vector<std::filesystem::path> out;
if(!std::filesystem::is_directory(dir))
{
if(extension.empty() || dir.extension() == extension)
out.push_back(dir);
return out;
}
for(const std::filesystem::directory_entry& dirent : std::filesystem::directory_iterator{dir})
{
if(!extension.empty() && dirent.path().extension() != extension)
continue;
out.push_back(dirent.path());
}
return out;
}
static void getArrays(const std::vector<DataPoint>& data, float **re, float **im, float **omega)
{
(*re) = new float[data.size()];
(*im) = new float[data.size()];
(*omega) = new float[data.size()];
for(size_t i = 0; i < data.size(); ++i)
{
(*re)[i] = data[i].im.real();
(*im)[i] = data[i].im.imag();
(*omega)[i] = data[i].omega;
}
}
static bool checkDir(const std::filesystem::path& outDir)
{
if(!std::filesystem::is_directory(outDir))
{
if(!std::filesystem::create_directory(outDir))
{
OUTSTREAM<<outDir<<" dose not exist and can not be created\n";
return false;
}
}
return true;
}
struct Request
{
std::filesystem::path path;
std::condition_variable cond;
std::mutex mutex;
};
static size_t topIndex(float* data, size_t size)
{
float max = std::numeric_limits<float>::min();
size_t maxIndex = 0;
for(size_t i = 0; i < size; ++i)
{
if(data[i] > max)
{
max = data[i];
maxIndex = i;
}
}
return maxIndex;
}
static void resultCallback(float* data, struct kiss_network* net, void* userData)
{
Request* rq = reinterpret_cast<Request*>(userData);
size_t index = topIndex(data, net->output_size);
if(!dirs.empty())
{
if(static_cast<long long>(index) == passIndex)
{
if(data[index] < 0.7)
index = !index;
}
std::error_code ec;
std::filesystem::rename(rq->path, dirs[index]/rq->path.filename(), ec);
if(ec)
{
printmtx.lock();
OUTSTREAM<<"unable to move "<<rq->path<<" to "<<dirs[index]/rq->path.filename();
printmtx.unlock();
}
}
else
{
std::vector<std::pair<std::string, float>> results;
for(size_t i = 0; i < net->output_size; ++i)
results.push_back({net->output_labels[i], data[i]});
std::sort(results.begin(), results.end(), [](std::pair<std::string, float> a, std::pair<std::string, float> b){return a.second > b.second;});
printmtx.lock();
OUTSTREAM<<"\nResult for "<<rq->path<<'\n';
OUTSTREAM<<"Classes:\nNumber\tName\tLikelyhood\n";
for(size_t i = 0; i < net->output_size; ++i)
OUTSTREAM<<i<<'\t'<<results[i].first<<'\t'<<results[i].second<<'\n';
printmtx.unlock();
}
free(data);
rq->cond.notify_all();
}
static void pipeline(const std::filesystem::path& path, struct kiss_network* net)
{
float* re;
float* im;
float* omega;
float* re_filtered;
float* im_filtered;
bool ret;
try
{
EisSpectra spectra = EisSpectra::loadFromDisk(path);
getArrays(spectra.data, &re, &im, &omega);
ret = kiss_filter_spectra(re, im, omega, spectra.data.size(), &re_filtered, &im_filtered, net->input_size/2);
delete[] re;
delete[] im;
delete[] omega;
if(!ret)
{
printmtx.lock();
OUTSTREAM<<"Could not filter "<<path<<'\n';
printmtx.unlock();
return;
}
}
catch(const file_error& err)
{
OUTSTREAM<<"Could not load: "<<err.what()<<'\n';
return;
}
Request rq;
rq.path = path;
std::unique_lock<std::mutex> lk(rq.mutex);
ret = kiss_async_run_inference_complex(net, re_filtered, im_filtered, &rq);
free(re_filtered);
free(im_filtered);
if(!ret)
{
printmtx.lock();
OUTSTREAM<<"Could not run inference on "<<path<<": "<<kiss_get_strerror(net)<<'\n';
printmtx.unlock();
}
else
{
rq.cond.wait(lk);
}
}
#ifdef WINDOWS
int wmain(int argc, wchar_t** argv)
#else
int main(int argc, char** argv)
#endif
{
if(argc < 3)
{
if(argc > 0)
OUTSTREAM<<"Usage: "<<argv[0]<<" [NETWORK] [CSV_FILE_OR_DIRECTORY] [OUT_DIRECTORY]"<<'\n';
return 1;
}
std::vector<std::filesystem::path> files = getFiles(argv[2], ".csv");
if(files.empty())
{
OUTSTREAM<<"Could not find any .csv files in "<<argv[2]<<'\n';
return 1;
}
struct kiss_network* net = kiss_load_network(reinterpret_cast<const char*>(argv[1]), resultCallback, false);
if(!net)
{
OUTSTREAM<<"Could not load network from "<<argv[1]<<'\n';
return 1;
}
else if(!net->ready)
{
OUTSTREAM<<"Could not load network: "<<kiss_get_strerror(net)<<'\n';
kiss_free_network(net);
return 1;
}
bool ret = true;
if(argc > 3)
ret = checkDir(argv[3]);
if(!ret)
{
OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n';
return 1;
}
OUTSTREAM<<"Network has the following classes:\n";
for(size_t i = 0; i < net->output_size; ++i)
{
OUTSTREAM<<net->output_labels[i]<<'\n';
if(argc > 3)
{
dirs.push_back(std::filesystem::path(argv[3])/net->output_labels[i]);
ret = checkDir(dirs.back());
OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n';
return 1;
}
if(std::string(net->output_labels[i]) == PASS_STR)
passIndex = i;
}
std::for_each(std::execution::seq, files.begin(), files.end(), [&net](std::filesystem::path& path){pipeline(path, net);});
kiss_free_network(net);
return 0;
}