Skip to content
Snippets Groups Projects
Select Git revision
  • 6d8d88cb9b221eee88442b17a3d2da953e902d21
  • master default
  • main
  • 1.2.2
  • 1.2.1
  • 1.2.0
  • 1.1.0
7 results

kissinference.c

Blame
  • kissinference.c 18.48 KiB
    #include "kissinference/kissinference.h"
    #include <stddef.h>
    #include <onnxruntime_c_api.h>
    #include <stdint.h>
    #include <stdio.h>
    #include <assert.h>
    #include <stdlib.h>
    #include <string.h>
    #include <float.h>
    
    #define _USE_MATH_DEFINES
    #include <math.h>
    
    #define KISS_STRERROR_LEN 256
    
    struct kiss_priv
    {
    	const struct OrtApiBase *base_api;
    	const struct OrtApi *api;
    	OrtEnv *env;
    	OrtSession *session;
    
    	char err[KISS_STRERROR_LEN];
    };
    
    struct kiss_inference_req
    {
    	struct kiss_network *net;
    	const char **input_names;
    	const OrtValue **input_tensors;
    	float *input_array;
    	OrtValue *input;
    	const char **output_names;
    	OrtValue **output_tensors;
    	void *user_data;
    };
    
    void kiss_inference_req_free(struct kiss_inference_req *req)
    {
    	free(req->output_names);
    	if(req->input_array)
    		free(req->input_array);
    	free(req->input_tensors);
    	free(req->input_names);
    	free(req);
    }
    
    const struct kiss_version_fixed kiss_get_version(void)
    {
    	static const struct kiss_version_fixed version = {1, 1, 0};
    	return version;
    }
    
    static void kiss_free_str_array(char **array)
    {
    	char **ptr = array;
    	for(; *ptr; ptr++)
    		free(*ptr);
    	free(array);
    }
    
    void kiss_free_network(struct kiss_network *net)
    {
    	if(net) {
    		kiss_free_network_prealloc(net);
    		free(net);
    	}
    }
    
    void kiss_free_network_prealloc(struct kiss_network *net)
    {
    	if(net) {
    		if(net->priv->session)
    			net->priv->api->ReleaseSession(net->priv->session);
    		if(net->priv->env)
    			net->priv->api->ReleaseEnv(net->priv->env);
    		if(net->priv)
    			free(net->priv);
    		if(net->input_label)
    			free(net->input_label);
    		if(net->purpose)
    			free(net->purpose);
    		if(net->output_labels)
    			kiss_free_str_array(net->output_labels);
    	}
    }
    
    static char **kiss_parse_output_lables(char *output_labels, size_t *token_count)
    {
    	*token_count = 1;
    	for(size_t i = 0; output_labels[i]; ++i) {
    		if(output_labels[i] == ',')
    			++(*token_count);
    	}
    
    	char **result = calloc(*token_count+1, sizeof(*result));
    
    	char *token = strtok(output_labels, ",");
    	size_t i = 0;
    	do {
    		result[i] = strdup(token);
    		++i;
    	} while((token = strtok(NULL, ",")));
    	assert(i == *token_count);
    	return result;
    }
    
    static int64_t kiss_get_tensor_size(const OrtTensorTypeAndShapeInfo *info, const struct OrtApi *api)
    {
    	size_t  sizes_size;
    
    	enum ONNXTensorElementDataType type;
    	assert(api->GetTensorElementType(info, &type) == NULL);
    	assert(type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
    
    	assert(api->GetDimensionsCount(info, &sizes_size) == NULL);
    	int64_t *sizes = malloc(sizeof(*sizes)*sizes_size);
    	assert(api->GetDimensions(info, sizes, sizes_size) == NULL);
    	assert(sizes_size == 2);
    	int64_t size = sizes[1];
    	free(sizes);
    	return size;
    }
    
    static bool kiss_check_complex(char *input_name)
    {
    	if(strcmp(input_name, "EIS") == 0)
    		return true;
    	return false;
    }
    
    bool kiss_load_network_prealloc(struct kiss_network* net, const char *path, void (*result_cb)(float* result, struct kiss_network* net, void* user_data), bool verbose)
    {
    	memset(net, 0, sizeof(*net));
    
    	OrtSessionOptions *opts = NULL;
    	OrtModelMetadata *model_meta = NULL;
    	OrtTypeInfo *type_info = NULL;
    	const OrtTensorTypeAndShapeInfo *tensor_info = NULL;
    	char *input_name;
    	char *output_name;
    	char *output_labels;
    	char *is_softmax;
    	OrtAllocator *allocator;
    	size_t count;
    
    	net->priv = calloc(1, sizeof(*net->priv));
    	net->priv->base_api = OrtGetApiBase();
    	net->priv->api = net->priv->base_api->GetApi(ORT_API_VERSION);
    	net->result_cb = result_cb;
    
    	const struct OrtApi *api = net->priv->api;
    
    	if(verbose)
    		printf("Inialized onnxruntime %s\n", net->priv->base_api->GetVersionString());
    
    
    	OrtStatus *status = api->CreateEnv(verbose ? ORT_LOGGING_LEVEL_VERBOSE : ORT_LOGGING_LEVEL_WARNING, "KissInference", &net->priv->env);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to load onnxruntime: %s\n", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    
    	status = api->CreateSessionOptions(&opts);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "%s\n", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    
    	status = api->CreateSession(net->priv->env, path, opts, &net->priv->session);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable create session: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    
    	status = api->SessionGetModelMetadata(net->priv->session, &model_meta);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get model metadata: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    
    	assert(api->SessionGetInputCount(net->priv->session, &count) == NULL);
    	if(count != 1) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Expected model with input count 1 but got %zu", count);
    		goto exit_failue;
    	}
    
    	status = api->GetAllocatorWithDefaultOptions(&allocator);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to create allocator: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    
    	status = api->SessionGetInputTypeInfo(net->priv->session, 0, &type_info);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get input type info: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    	assert(api->CastTypeInfoToTensorInfo(type_info, &tensor_info) == NULL && tensor_info);
    	net->input_size = kiss_get_tensor_size(tensor_info, api);
    	api->ReleaseTypeInfo(type_info);
    
    	status = api->SessionGetOutputTypeInfo(net->priv->session, 0, &type_info);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get output type info: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    	assert(api->CastTypeInfoToTensorInfo(type_info, &tensor_info) == NULL && tensor_info);
    	net->output_size = kiss_get_tensor_size(tensor_info, api);
    	api->ReleaseTypeInfo(type_info);
    
    	status = api->SessionGetInputName(net->priv->session, 0, allocator, &input_name);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get network input name: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    	net->input_label = strdup(input_name);
    	allocator->Free(allocator, input_name);
    	if(verbose)
    		printf("Got network with input name: %s\n", net->input_label);
    	net->complex_input = kiss_check_complex(net->input_label);
    
    	assert(api->SessionGetOutputCount(net->priv->session, &count) == NULL);
    	if(count != 1) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Expected model with output count 1 but got %zu", count);
    		goto exit_failue;
    	}
    
    	status = api->SessionGetOutputName(net->priv->session, 0, allocator, &output_name);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get network output name: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    	net->purpose = strdup(output_name);
    	allocator->Free(allocator, output_name);
    
    	status = api->ModelMetadataLookupCustomMetadataMap(model_meta, allocator, "outputLabels", &output_labels);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get output labels: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    	else if(!output_labels) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "The network contains no output lables");
    		goto exit_failue;
    	}
    	net->output_labels = kiss_parse_output_lables(output_labels, &count);
    	allocator->Free(allocator, output_labels);
    	if(!net->output_labels || count != net->output_size) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to parse output labels, %zu != %zu", count, net->output_size);
    		goto exit_failue;
    	}
    
    	status = api->ModelMetadataLookupCustomMetadataMap(model_meta, allocator, "softmax", &is_softmax);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get softmax state: %s", api->GetErrorMessage(status));
    		goto exit_failue;
    	}
    	else if(!is_softmax) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get softmax state");
    		goto exit_failue;
    	}
    
    	if(strcmp(is_softmax, "False") == 0) {
    		net->softmax = false;
    	}
    	else if(strcmp(is_softmax, "True") == 0) {
    		net->softmax = true;
    	}
    	else {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to parse softmax state got: %s expected \"True\" or \"False\"", is_softmax);
    		allocator->Free(allocator, is_softmax);
    		goto exit_failue;
    	}
    	allocator->Free(allocator, is_softmax);
    
    	api->ReleaseSessionOptions(opts);
    	api->ReleaseModelMetadata(model_meta);
    	net->ready = true;
    	return net->ready;
    
    	exit_failue:
    	if(status)
    		api->ReleaseStatus(status);
    	if(net->priv->env)
    		api->ReleaseEnv(net->priv->env);
    	net->priv->env = NULL;
    	if(net->priv->session)
    		api->ReleaseSession(net->priv->session);
    	net->priv->session = NULL;
    	if(opts)
    		api->ReleaseSessionOptions(opts);
    	if(model_meta)
    		api->ReleaseModelMetadata(model_meta);
    	return net->ready;
    }
    
    struct kiss_network *kiss_load_network(const char *path, void (*result_cb)(float*, struct kiss_network*, void*), bool verbose)
    {
    	struct kiss_network *net = calloc(1, sizeof(*net));
    	kiss_load_network_prealloc(net, path, result_cb, verbose);
    	return net;
    }
    
    static void kiss_run_cb(void *user_data, OrtValue **outputs, size_t num_outputs, OrtStatusPtr status)
    {
    	struct kiss_inference_req *req = user_data;
    	struct kiss_network *net = req->net;
    	const struct OrtApi *api = net->priv->api;
    	void *kiss_user_data = req->user_data;
    
    	assert(num_outputs == 1);
    
    	if(outputs) {
    		OrtTensorTypeAndShapeInfo *info;
    		assert(api->GetTensorTypeAndShape(outputs[0], &info) == NULL);
    		assert(net->output_size == kiss_get_tensor_size(info, api));
    		float *data;
    
    		assert(api->GetTensorMutableData(outputs[0], (void**)&data) == NULL);
    		float *output_floats = malloc(sizeof(*output_floats)*net->output_size);
    		memcpy(output_floats, data, net->output_size*sizeof(*data));
    
    		if(net->softmax)
    			kiss_softmax(output_floats, net->output_size);
    
    		api->ReleaseTensorTypeAndShapeInfo(info);
    
    		for(size_t i = 0; i < num_outputs; ++i)
    			api->ReleaseValue(outputs[i]);
    		free(outputs);
    
    		net->result_cb(output_floats, net, kiss_user_data);
    	}
    	else {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Could not perform inference: %s\n", api->GetErrorMessage(status));
    		api->ReleaseStatus(status);
    		net->result_cb(NULL, net, kiss_user_data);
    	}
    
    	api->ReleaseValue(req->input);
    	kiss_inference_req_free(req);
    }
    
    bool kiss_async_run_inference(struct kiss_network *net, const float *inputs, void* user_data)
    {
    	const struct OrtApi *api = net->priv->api;
    	if(!api) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "kissinferene not ready\n");
    		return false;
    	}
    
    	OrtStatus *status;
    
    	OrtMemoryInfo *mem_info;
    	status = api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &mem_info);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to create allocator: %s\n", api->GetErrorMessage(status));
    		api->ReleaseStatus(status);
    		return false;
    	}
    
    	OrtValue *input_tensor;
    	const int64_t shape[] = {1, net->input_size};
    	float *input_array = malloc(sizeof(*input_array)*net->input_size);
    	memcpy(input_array, inputs, sizeof(*input_array)*net->input_size);
    
    	status = api->CreateTensorWithDataAsOrtValue(mem_info, input_array, net->input_size*sizeof(float), shape, 2, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to create tensor: %s\n", api->GetErrorMessage(status));
    		api->ReleaseStatus(status);
    		api->ReleaseMemoryInfo(mem_info);
    		free(input_array);
    		return false;
    	}
    
    	struct kiss_inference_req *req = malloc(sizeof(*req));
    	req->input_names = malloc(sizeof(char*));
    	req->input_names[0] = net->input_label;
    	req->input_tensors = malloc(sizeof(OrtValue*));
    	req->input_tensors[0] = input_tensor;
    	req->output_names = malloc(sizeof(char*));
    	req->output_names[0] = net->purpose;
    	req->output_tensors = calloc(sizeof(OrtValue*), 1);
    	req->input = input_tensor;
    	req->net = net;
    	req->user_data = user_data;
    	req->input_array = input_array;
    
    	status = api->RunAsync(net->priv->session, NULL,
    		req->input_names, req->input_tensors, 1,
    		req->output_names, 1, req->output_tensors,
    		kiss_run_cb, req);
    	if(status) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to create tensor: %s\n", api->GetErrorMessage(status));
    		api->ReleaseStatus(status);
    		api->ReleaseMemoryInfo(mem_info);
    		kiss_inference_req_free(req);
    		return false;
    	}
    
    	api->ReleaseMemoryInfo(mem_info);
    
    	return true;
    }
    
    bool kiss_async_run_inference_complex(struct kiss_network *net, const float *real, const float *imaginary, void* user_data)
    {
    	if(!net->complex_input) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "This network dosent support complex inputs");
    		return false;
    	}
    	if(net->input_size % 2 != 0) {
    		snprintf(net->priv->err, KISS_STRERROR_LEN, "This network claims to support complex inputs but its input size vector is not divisible by 2");
    		return false;
    	}
    
    	float *input = malloc(net->input_size*sizeof(*input));
    	memcpy(input, real, net->input_size*sizeof(*input)/2);
    	memcpy(input+net->input_size/2, imaginary, net->input_size*sizeof(*input)/2);
    
    	bool ret = kiss_async_run_inference(net, input, user_data);
    
    	free(input);
    
    	return ret;
    }
    
    void kiss_resample_spectra(float *in_re, float *in_im, size_t input_length, float **out_re, float **out_im, size_t output_length)
    {
    	*out_re = malloc(output_length*sizeof(float));
    	*out_im = malloc(output_length*sizeof(float));
    
    	for(size_t i = 0; i < output_length; ++i) {
    		double position = ((double)i)/(output_length-1);
    		double source_position_double = ((double)(input_length-1))*position;
    		size_t source_position = source_position_double;
    		double frac = source_position_double - source_position;
    		if(source_position >= input_length-1) {
    			(*out_re)[i] = in_re[input_length-1];
    			(*out_im)[i] = in_im[input_length-1];
    		}
    		else {
    			(*out_re)[i] = in_re[source_position]*(1-frac) + in_re[source_position+1]*frac;
    			(*out_im)[i] = in_im[source_position]*(1-frac) + in_im[source_position+1]*frac;
    		}
    	}
    }
    
    void kiss_normalize_spectra(float *in_re, float *in_im, size_t input_length)
    {
    	float maxRe = FLT_MIN;
    	float maxIm = FLT_MIN;
    	float minRe = FLT_MAX;
    	for(size_t i = 0; i <input_length; ++i) {
    		maxRe = fabsf(in_re[i]) > maxRe ? fabsf(in_re[i]) : maxRe;
    		maxIm = fabsf(in_im[i]) > maxIm ? fabsf(in_im[i]) : maxIm;
    
    		if(minRe > in_re[i])
    			minRe = in_re[i];
    	}
    
    	maxRe = maxRe == minRe ? 1 : maxRe-minRe;
    	maxIm = maxIm == 0 ? 1 : maxIm;
    
    	for(size_t i = 0; i <input_length; ++i) {
    		in_re[i] = (in_re[i]-minRe) / maxRe;
    		in_im[i] = in_im[i] / maxIm;
    	}
    }
    
    static size_t kiss_grad_index(size_t input_length, size_t index)
    {
    	if(index == 0)
    		index = 1;
    	else if(index > input_length-2)
    		index = input_length-2;
    	return index;
    }
    
    float *kiss_absgrad(float *in_re, float *in_im, float *omega, size_t input_length, size_t index)
    {
    	float *out = malloc(2*sizeof(*out));
    	out[0] = 1;
    	out[1] = 1;
    	if(input_length < 3)
    		return out;
    
    	index = kiss_grad_index(input_length, index);
    
    	out[0] = fabsf((in_re[index+1]-in_re[index-1])/(omega[index+1]-omega[index-1]));
    	out[1] = fabsf((in_im[index+1]-in_im[index-1])/(omega[index+1]-omega[index-1]));
    
    	return out;
    }
    
    float kiss_grad(float *data, float *omega, size_t input_length, size_t index)
    {
    	if(input_length < 3)
    		return 0;
    
    	index = kiss_grad_index(input_length, index);
    
    	return (data[index+1]-data[index-1])/(omega[index+1]-omega[index-1]);
    }
    
    static int kiss_cmp_float(const void *x, const void *y)
    {
    	float *a = (float*)x;
    	float *b = (float*)y;
    
    	if(a < b)
    		return -1;
    	if(b < a)
    		return 1;
    	return 0;
    }
    
    float kiss_median(float *data, size_t input_length)
    {
    	float *data_cpy = malloc(input_length*sizeof(*data));
    	memcpy(data_cpy, data, input_length*sizeof(*data));
    	qsort(data_cpy, input_length, sizeof(*data_cpy), kiss_cmp_float);
    
    	float retval;
    
    	if(input_length % 2 == 0)
    		retval = (data_cpy[input_length/2] + data_cpy[input_length/2-1])/2;
    	else
    		retval = data_cpy[input_length/2];
    	free(data_cpy);
    	return retval;
    }
    
    float kiss_mean(float *data, size_t input_length)
    {
    	double accum = 0;
    
    	for(size_t i = 0; i < input_length; ++i)
    		accum += data[i];
    	return accum/input_length;
    }
    
    float *kiss_create_range(float start, float end, size_t length, bool log)
    {
    	float *out = malloc(sizeof(*out)*length);
    	float startL = log ? log10f(start) : start;
    	float endL = log ? log10f(end) : end;
    	float step = (endL-startL)/(length-1);
    
    	if(log) {
    		for(size_t i = 0; i < length; ++i)
    			out[i] = pow(10, step*i+log10f(start));
    	}
    	else {
    		for(size_t i = 0; i < length; ++i)
    			out[i] = start+step*i;
    	}
    	return out;
    }
    
    bool kiss_reduce_spectra(float *in_re, float *in_im, float *omegas, size_t input_length,
                             float thresh_factor, bool use_second_deriv,
                             float **out_re, float **out_im, size_t *output_length)
    {
    	if(input_length < 3)
    		return false;
    
    	kiss_normalize_spectra(in_re, in_im, input_length);
    
    	float *grads = malloc(sizeof(*grads)*input_length);
    	for(size_t i = 0; i < input_length; ++i) {
    		float *grad = kiss_absgrad(in_re, in_im, omegas, input_length, i);
    		grads[i] = sqrtf(pow(grad[0], 2)+pow(grad[1], 2));
    		free(grad);
    	}
    
    	if(use_second_deriv) {
    		for(size_t i = 0; i < input_length; ++i)
    			grads[i] = fabsf(kiss_grad(grads, omegas, input_length, i));
    	}
    
    	float grad_thresh;
    	if(!use_second_deriv)
    		grad_thresh = kiss_median(grads, input_length)*thresh_factor;
    	else
    		grad_thresh = 1e-12f*thresh_factor;
    
    	size_t start = 0;
    	for(size_t i = 1; i < input_length-1; ++i) {
    		if(grads[i] < grad_thresh)
    			start = i;
    		else
    			break;
    	}
    
    	size_t end = input_length-1;
    	for(size_t i = input_length-1; i > 1; --i) {
    		if(grads[i] < grad_thresh)
    			end = i;
    		else
    			break;
    	}
    
    	free(grads);
    
    	*out_re = malloc(sizeof(float)*((end-start)+1));
    	*out_im = malloc(sizeof(float)*((end-start)+1));
    	*output_length = end-start+1;
    
    	for(size_t i = 0; i < *output_length; ++i) {
    		(*out_re)[i] = in_re[i+start];
    		(*out_im)[i] = in_im[i+start];
    	}
    
    	return true;
    }
    
    bool kiss_filter_spectra(float *in_re, float *in_im, float *omegas, size_t input_length, float **out_re, float **out_im, size_t output_length)
    {
    	size_t reduce_length = 0;
    	float* out_re_red = NULL;
    	float* out_im_red = NULL;
    	bool ret = kiss_reduce_spectra(in_re, in_im, omegas, input_length, 0.01, false, &out_re_red, &out_im_red, &reduce_length);
    	if(!ret) {
    		if(out_re_red)
    			free(out_re_red);
    		if(out_im_red)
    			free(out_im_red);
    		return false;
    	}
    
    	kiss_resample_spectra(out_re_red, out_im_red, reduce_length, out_re, out_im, output_length);
    
    	free(out_re_red);
    	free(out_im_red);
    
    	return true;
    }
    
    void kiss_softmax(float *data, size_t input_length)
    {
    	float accum = 0;
    	for(size_t i = 0; i < input_length; ++i) {
    		data[i] = powf(M_E, data[i]);
    		accum += data[i];
    	}
    
    	for(size_t i = 0; i < input_length; ++i)
    		data[i] /= accum;
    }
    
    const char *kiss_get_strerror(struct kiss_network *net)
    {
    	return net->priv->err;
    }
    
    bool kiss_float_eq(float a, float b, unsigned int ulp)
    {
    	float epsilon = (nextafterf(1.0f, INFINITY) - 1.0f)*fabs(a+b)*ulp;
    	return a - epsilon <= b && a + epsilon >= b;
    }
    
    void kiss_free(void *data)
    {
    	free(data);
    }