Skip to content
Snippets Groups Projects
Commit fca970ad authored by Carl Philipp Klemm's avatar Carl Philipp Klemm
Browse files

improve error handling throughout

parent c4a21f6f
No related branches found
No related tags found
No related merge requests found
......@@ -108,15 +108,34 @@ static char **kiss_parse_output_lables(char *output_labels, size_t *token_count)
static int64_t kiss_get_tensor_size(const OrtTensorTypeAndShapeInfo *info, const struct OrtApi *api)
{
size_t sizes_size;
OrtStatus *status;
enum ONNXTensorElementDataType type;
assert(api->GetTensorElementType(info, &type) == NULL);
assert(type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
status = api->GetTensorElementType(info, &type);
if(status) {
api->ReleaseStatus(status);
return -1;
}
else if(type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
return -1;
}
status = api->GetDimensionsCount(info, &sizes_size);
if(status) {
api->ReleaseStatus(status);
return -1;
}
assert(api->GetDimensionsCount(info, &sizes_size) == NULL);
int64_t *sizes = malloc(sizeof(*sizes)*sizes_size);
assert(api->GetDimensions(info, sizes, sizes_size) == NULL);
assert(sizes_size == 2);
status = api->GetDimensions(info, sizes, sizes_size);
if(status) {
api->ReleaseStatus(status);
return -1;
}
if(sizes_size != 2)
return -1;
int64_t size = sizes[1];
free(sizes);
return size;
......@@ -143,6 +162,8 @@ bool kiss_load_network_prealloc(struct kiss_network* net, const char *path, void
char *is_softmax;
OrtAllocator *allocator;
size_t count;
int64_t input_size;
int64_t output_size;
net->priv = calloc(1, sizeof(*net->priv));
net->priv->base_api = OrtGetApiBase();
......@@ -179,8 +200,12 @@ bool kiss_load_network_prealloc(struct kiss_network* net, const char *path, void
goto exit_failue;
}
assert(api->SessionGetInputCount(net->priv->session, &count) == NULL);
if(count != 1) {
status = api->SessionGetInputCount(net->priv->session, &count);
if(status) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get input count: %s", api->GetErrorMessage(status));
goto exit_failue;
}
else if(count != 1) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Expected model with input count 1 but got %zu", count);
goto exit_failue;
}
......@@ -196,18 +221,36 @@ bool kiss_load_network_prealloc(struct kiss_network* net, const char *path, void
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get input type info: %s", api->GetErrorMessage(status));
goto exit_failue;
}
assert(api->CastTypeInfoToTensorInfo(type_info, &tensor_info) == NULL && tensor_info);
net->input_size = kiss_get_tensor_size(tensor_info, api);
status = api->CastTypeInfoToTensorInfo(type_info, &tensor_info);
if(status || !tensor_info) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to cast type_info to tensor_info: %s", api->GetErrorMessage(status));
goto exit_failue;
}
input_size = kiss_get_tensor_size(tensor_info, api);
api->ReleaseTypeInfo(type_info);
if(input_size < 1) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get input size");
goto exit_failue;
}
net->input_size = input_size;
status = api->SessionGetOutputTypeInfo(net->priv->session, 0, &type_info);
if(status) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get output type info: %s", api->GetErrorMessage(status));
goto exit_failue;
}
assert(api->CastTypeInfoToTensorInfo(type_info, &tensor_info) == NULL && tensor_info);
net->output_size = kiss_get_tensor_size(tensor_info, api);
status = api->CastTypeInfoToTensorInfo(type_info, &tensor_info);
if(status || !tensor_info) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to cast type_info to tensor_info: %s", status ? api->GetErrorMessage(status) : "Unkown");
goto exit_failue;
}
output_size = kiss_get_tensor_size(tensor_info, api);
api->ReleaseTypeInfo(type_info);
if(output_size < 1) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get output size");
goto exit_failue;
}
net->output_size = output_size;
status = api->SessionGetInputName(net->priv->session, 0, allocator, &input_name);
if(status) {
......@@ -220,8 +263,12 @@ bool kiss_load_network_prealloc(struct kiss_network* net, const char *path, void
printf("Got network with input name: %s\n", net->input_label);
net->complex_input = kiss_check_complex(net->input_label);
assert(api->SessionGetOutputCount(net->priv->session, &count) == NULL);
if(count != 1) {
status = api->SessionGetOutputCount(net->priv->session, &count);
if(status) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get output count: %s", api->GetErrorMessage(status));
goto exit_failue;
}
else if(count != 1) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Expected model with output count 1 but got %zu", count);
goto exit_failue;
}
......@@ -308,15 +355,31 @@ static void kiss_run_cb(void *user_data, OrtValue **outputs, size_t num_outputs,
const struct OrtApi *api = net->priv->api;
void *kiss_user_data = req->user_data;
assert(num_outputs == 1);
if(num_outputs != 1) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Expected one output but got %zu", num_outputs);
net->result_cb(NULL, net, kiss_user_data);
}
if(outputs) {
OrtTensorTypeAndShapeInfo *info;
assert(api->GetTensorTypeAndShape(outputs[0], &info) == NULL);
assert(net->output_size == kiss_get_tensor_size(info, api));
OrtStatus *status;
status = api->GetTensorTypeAndShape(outputs[0], &info);
if(status) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Unable to get tnsor type and shape: %s\n", api->GetErrorMessage(status));
api->ReleaseStatus(status);
net->result_cb(NULL, net, kiss_user_data);
}
if (net->output_size != kiss_get_tensor_size(info, api)) {
snprintf(net->priv->err, KISS_STRERROR_LEN, "Output size and inference tensor result size are not the same");
net->result_cb(NULL, net, kiss_user_data);
}
float *data;
assert(api->GetTensorMutableData(outputs[0], (void**)&data) == NULL);
status = api->GetTensorMutableData(outputs[0], (void**)&data);
assert(!status);
float *output_floats = malloc(sizeof(*output_floats)*net->output_size);
memcpy(output_floats, data, net->output_size*sizeof(*data));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment