#include "eistype.h" #include <execution> #include <iostream> #include <filesystem> #include <algorithm> #include <system_error> #include <condition_variable> #include <mutex> #include <limits> #include <algorithm> #include "kissinference/kissinference.h" #define PASS_STR "Pass" static std::vector<std::filesystem::path> dirs; std::mutex printmtx; long long passIndex = -1; #ifdef WINDOWS #define OUTSTREAM std::wcout #else #define OUTSTREAM std::cout #endif static std::vector<std::filesystem::path> getFiles(const std::filesystem::path dir, std::string extension = "") { std::vector<std::filesystem::path> out; if(!std::filesystem::is_directory(dir)) { if(extension.empty() || dir.extension() == extension) out.push_back(dir); return out; } for(const std::filesystem::directory_entry& dirent : std::filesystem::directory_iterator{dir}) { if(!extension.empty() && dirent.path().extension() != extension) continue; out.push_back(dirent.path()); } return out; } static void getArrays(const std::vector<DataPoint>& data, float **re, float **im, float **omega) { (*re) = new float[data.size()]; (*im) = new float[data.size()]; (*omega) = new float[data.size()]; for(size_t i = 0; i < data.size(); ++i) { (*re)[i] = data[i].im.real(); (*im)[i] = data[i].im.imag(); (*omega)[i] = data[i].omega; } } static bool checkDir(const std::filesystem::path& outDir) { if(!std::filesystem::is_directory(outDir)) { if(!std::filesystem::create_directory(outDir)) { OUTSTREAM<<outDir<<" dose not exist and can not be created\n"; return false; } } return true; } struct Request { std::filesystem::path path; std::condition_variable cond; std::mutex mutex; }; static size_t topIndex(float* data, size_t size) { float max = std::numeric_limits<float>::min(); size_t maxIndex = 0; for(size_t i = 0; i < size; ++i) { if(data[i] > max) { max = data[i]; maxIndex = i; } } return maxIndex; } static void resultCallback(float* data, struct kiss_network* net, void* userData) { Request* rq = reinterpret_cast<Request*>(userData); size_t index = topIndex(data, net->output_size); if(!dirs.empty()) { if(static_cast<long long>(index) == passIndex) { if(data[index] < 0.7) index = !index; } std::error_code ec; std::filesystem::rename(rq->path, dirs[index]/rq->path.filename(), ec); if(ec) { printmtx.lock(); OUTSTREAM<<"unable to move "<<rq->path<<" to "<<dirs[index]/rq->path.filename(); printmtx.unlock(); } } else { std::vector<std::pair<std::string, float>> results; for(size_t i = 0; i < net->output_size; ++i) results.push_back({net->output_labels[i], data[i]}); std::sort(results.begin(), results.end(), [](std::pair<std::string, float> a, std::pair<std::string, float> b){return a.second > b.second;}); printmtx.lock(); OUTSTREAM<<"\nResult for "<<rq->path<<'\n'; OUTSTREAM<<"Classes:\nNumber\tName\tLikelyhood\n"; for(size_t i = 0; i < net->output_size; ++i) OUTSTREAM<<i<<'\t'<<results[i].first<<'\t'<<results[i].second<<'\n'; printmtx.unlock(); } free(data); rq->cond.notify_all(); } static void pipeline(const std::filesystem::path& path, struct kiss_network* net) { float* re; float* im; float* omega; float* re_filtered; float* im_filtered; bool ret; try { EisSpectra spectra = EisSpectra::loadFromDisk(path); getArrays(spectra.data, &re, &im, &omega); ret = kiss_filter_spectra(re, im, omega, spectra.data.size(), &re_filtered, &im_filtered, net->input_size/2); delete[] re; delete[] im; delete[] omega; if(!ret) { printmtx.lock(); OUTSTREAM<<"Could not filter "<<path<<'\n'; printmtx.unlock(); return; } } catch(const file_error& err) { OUTSTREAM<<"Could not load: "<<err.what()<<'\n'; return; } Request rq; rq.path = path; std::unique_lock<std::mutex> lk(rq.mutex); ret = kiss_async_run_inference_complex(net, re_filtered, im_filtered, &rq); free(re_filtered); free(im_filtered); if(!ret) { printmtx.lock(); OUTSTREAM<<"Could not run inference on "<<path<<": "<<kiss_get_strerror(net)<<'\n'; printmtx.unlock(); } else { rq.cond.wait(lk); } } #ifdef WINDOWS int wmain(int argc, wchar_t** argv) #else int main(int argc, char** argv) #endif { if(argc < 3) { if(argc > 0) OUTSTREAM<<"Usage: "<<argv[0]<<" [NETWORK] [CSV_FILE_OR_DIRECTORY] [OUT_DIRECTORY]"<<'\n'; return 1; } std::vector<std::filesystem::path> files = getFiles(argv[2], ".csv"); if(files.empty()) { OUTSTREAM<<"Could not find any .csv files in "<<argv[2]<<'\n'; return 1; } struct kiss_network* net = kiss_load_network(reinterpret_cast<const char*>(argv[1]), resultCallback, false); if(!net) { OUTSTREAM<<"Could not load network from "<<argv[1]<<'\n'; return 1; } else if(!net->ready) { OUTSTREAM<<"Could not load network: "<<kiss_get_strerror(net)<<'\n'; kiss_free_network(net); return 1; } bool ret = true; if(argc > 3) ret = checkDir(argv[3]); if(!ret) { OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n'; return 1; } OUTSTREAM<<"Network has the following classes:\n"; for(size_t i = 0; i < net->output_size; ++i) { OUTSTREAM<<net->output_labels[i]<<'\n'; if(argc > 3) { dirs.push_back(std::filesystem::path(argv[3])/net->output_labels[i]); ret = checkDir(dirs.back()); OUTSTREAM<<"Could not create output folders at "<<argv[3]<<'\n'; return 1; } if(std::string(net->output_labels[i]) == PASS_STR) passIndex = i; } std::for_each(std::execution::seq, files.begin(), files.end(), [&net](std::filesystem::path& path){pipeline(path, net);}); kiss_free_network(net); return 0; }