Skip to content
Snippets Groups Projects
Select Git revision
  • c67a8a213f99480fac951458b4b0b35876ea88e5
  • master default
  • main
  • 1.2.2
  • 1.2.1
  • 1.2.0
  • 1.1.0
7 results

main.c

Blame
  • user avatar
    Carl Philipp Klemm authored
    31e51ea2
    History
    main.c 3.18 KiB
    #include <stdio.h>
    #include <threads.h>
    #include <assert.h>
    #include <stdlib.h>
    
    #ifdef USE_RLX
    #include <relaxisloader/relaxisloader.h>
    #endif
    
    #include "kissinference/kissinference.h"
    
    static void result_cb(float* results, struct kiss_network* net, void *data)
    {
    	puts("Result:");
    	for(size_t i = 0; i < net->output_size; ++i)
    	{
    		printf("%s: %f\n", net->output_labels[i], results[i]);
    	}
    	free(results);
    	cnd_t *cond = data;
    	cnd_broadcast(cond);
    }
    
    int main(int argc, char** argv)
    {
    	if(argc < 2)
    	{
    		puts("A model path is required");
    		return 1;
    	}
    
    #ifdef USE_RLX
    	unsigned long int spectra_id = 0;
    	if(argc < 3)
    	{
    		puts("A relaxis file is required");
    		puts("usage: kissinference_test [NETWORK FILE] [RELAXIS FILE] [SPECTRA ID]");
    		return 1;
    	}
    	else if(argc < 4)
    	{
    		puts("No spectra id specified will use spectra 0");
    	}
    	else
    	{
    		char *ptr ;
    		spectra_id = strtoul(argv[3], &ptr, 10);
    		if(ptr == argv[3])
    		{
    			printf("%s is not a valid spectra id\n", argv[3]);
    			return 1;
    		}
    	}
    #endif
    
    	cnd_t cond;
    	int ret = cnd_init(&cond);
    	if(ret != thrd_success)
    	{
    		puts("unable to create iso thread condition");
    		return 3;
    	}
    
    	mtx_t mutex;
    	ret = mtx_init(&mutex, mtx_plain);
    	if(ret != thrd_success)
    	{
    		puts("unable to create iso thread mutex");
    		return 3;
    	}
    
    	struct kiss_network *net = kiss_load_network(argv[1], result_cb, false);
    	if(!net)
    		return 3;
    	if(!net->ready)
    	{
    		puts(kiss_get_strerror(net));
    		return 2;
    	}
    
    	float *real;
    	float *imaginary;
    	size_t length;
    
    #ifdef USE_RLX
    	const char *error;
    	struct rlxfile *file = rlx_open_file(argv[2], &error);
    	if(!file)
    	{
    		printf("Unable to open %s: %s", argv[2], error);
    		return 2;
    	}
    
    	size_t len;
    	struct rlx_project **projects = rlx_get_projects(file, &len);
    	if(!projects)
    	{
    		rlx_close_file(file);
    		printf("%s contains no projects", argv[2]);
    		return 2;
    	}
    
    
    	struct rlx_spectra *spectra = rlx_get_spectra(file, projects[0] , spectra_id);
    	if(!spectra)
    	{
    		rlx_project_free_array(projects);
    		rlx_close_file(file);
    		printf("%s contains dose not contain a spectra with id %lu", argv[2], spectra_id);
    		return 2;
    	}
    
    	float *re_raw;
    	float *im_raw;
    	float *omega_raw;
    	rlx_get_float_arrays(spectra, &re_raw, &im_raw, &omega_raw);
    
    	float *re_filtered;
    	float *im_filtered;
    	kiss_reduce_spectra(re_raw, im_raw, omega_raw, spectra->length, 0.01f, false, &re_filtered, &im_filtered, &length);
    
    	kiss_resample_spectra(re_filtered, im_filtered, length, &real, &imaginary, net->input_size/2);
    
    	rlx_spectra_free(spectra);
    	rlx_project_free_array(projects);
    	rlx_close_file(file);
    	free(re_raw);
    	free(im_raw);
    	free(omega_raw);
    	free(re_filtered);
    	free(im_filtered);
    #else
    	real = malloc(sizeof(*real)*net->input_size/2);
    	imaginary = malloc(sizeof(*real)*net->input_size/2);
    	length = net->input_size/2;
    
    	for(size_t i = 0; i < length; ++i)
    	{
    		real[i] = 0.5;
    		imaginary[i] = 0.5;
    	}
    #endif
    
    	ret = kiss_async_run_inference_complex(net, real, imaginary, &cond);
    	if(!ret)
    	{
    		puts(kiss_get_strerror(net));
    		return 2;
    	}
    	assert(mtx_lock(&mutex) == thrd_success);
    	assert(cnd_wait(&cond, &mutex) == thrd_success);
    	assert(mtx_unlock(&mutex) == thrd_success);
    
    	cnd_destroy(&cond);
    	mtx_destroy(&mutex);
    	free(real);
    	free(imaginary);
    	kiss_free_network(net);
    	return 0;
    }