diff --git a/eisgenerator/eistype.h b/eisgenerator/eistype.h index c6e6a74029e1b8787d96b4040186fb9783ffb82c..0f6df5a46fdcd26f4fd1c009b51b2229e2c3f796 100644 --- a/eisgenerator/eistype.h +++ b/eisgenerator/eistype.h @@ -149,6 +149,8 @@ public: EisSpectra(const std::vector<DataPoint>& dataIn, const std::string& modelIn, const std::string& headerIn, std::vector<double> labelsIn = std::vector<double>(), std::vector<std::string> labelNamesIn = std::vector<std::string>()); + EisSpectra(const std::vector<DataPoint>& dataIn, const std::string& modelIn, const std::string& headerIn, + std::vector<float> labelsIn, std::vector<std::string> labelNamesIn = std::vector<std::string>()); EisSpectra(const std::vector<DataPoint>& dataIn, const std::string& modelIn, const std::string& headerIn, std::vector<size_t> labelsIn, std::vector<std::string> labelNamesIn = std::vector<std::string>()); EisSpectra(const std::vector<DataPoint>& dataIn, const std::string& modelIn, const std::string& headerIn, @@ -157,7 +159,11 @@ public: void setLabel(size_t label, size_t maxLabel); size_t getLabel(); void setSzLabels(std::vector<size_t> label); + void setLabels(const std::vector<double>& labelsIn); + void setLabels(const std::vector<float>& labelsIn); std::vector<size_t> getSzLabels() const; + bool isMulticlass(); + std::vector<fvalue> getFvalueLabels(); }; bool saveToDisk(const EisSpectra& data, const std::filesystem::path& path); diff --git a/eistype.cpp b/eistype.cpp index 2f0558eebf18e52d3f67d6cfa4df280847284d9f..7c6e9443748cb016550521d7dc0fdaa1ba7afca8 100644 --- a/eistype.cpp +++ b/eistype.cpp @@ -3,6 +3,7 @@ #include <sstream> #include <algorithm> #include <string> +#include <vector> #include "strops.h" #include "log.h" @@ -237,6 +238,13 @@ data(dataIn), model(modelIn), header(headerIn), labels(labelsIn), labelNames(lab } +EisSpectra::EisSpectra(const std::vector<DataPoint>& dataIn, const std::string& modelIn, + const std::string& headerIn, std::vector<float> labelsIn, std::vector<std::string> labelNamesIn): +data(dataIn), model(modelIn), header(headerIn), labelNames(labelNamesIn) +{ + setLabels(labelsIn); +} + EisSpectra::EisSpectra(const std::vector<DataPoint>& dataIn, const std::string& modelIn, const std::string& headerIn, std::vector<size_t> labelsIn, std::vector<std::string> labelNamesIn): data(dataIn), model(modelIn), header(headerIn), labelNames(labelNamesIn) @@ -275,7 +283,53 @@ std::vector<size_t> EisSpectra::getSzLabels() const size_t EisSpectra::getLabel() { for(size_t i = 0; i < labels.size(); ++i) - if(labels[i] > 0.5) + if(labels[i] != 0) return i; return 0; } + +bool EisSpectra::isMulticlass() +{ + bool foundFirst = false; + for(size_t i = 0; i < labels.size(); ++i) + { + if(labels[i] != 0) + { + if(foundFirst) + return true; + else + foundFirst = true; + } + } + return false; +} + +void EisSpectra::setLabels(const std::vector<float>& labelsIn) +{ + labels.assign(labelsIn.size(), 0); + for(size_t i = 0; i < labels.size(); ++i) + labels[i] = labelsIn[i]; +} + +void EisSpectra::setLabels(const std::vector<double>& labelsIn) +{ + labels = labelsIn; +} + +std::vector<fvalue> EisSpectra::getFvalueLabels() +{ + if constexpr(std::is_same<fvalue, double>::value) + { + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" + return *reinterpret_cast<std::vector<fvalue>*>(&labels); + #pragma GCC diagnostic pop + } + else + { + std::vector<fvalue> out(labels.size()); + for(size_t i = 0; i < labels.size(); ++i) + out[i] = static_cast<fvalue>(labels[i]); + return out; + } +}