Select Git revision
network.cs 6.53 KiB
using System;
using System.Runtime.InteropServices;
using System.Collections.Generic;
using System.Threading;
namespace Kiss
{
/// <summary>
/// This is the main class of kissinference, it allows you to load and use Kiss neural neturoks
/// </summary>
public class Network: System.IDisposable
{
public delegate void ResultDlg(float[] result, Network network);
private struct Flight
{
public ResultDlg Callback;
public AutoResetEvent Signal;
}
private CNetwork net;
private bool disposed = false;
private Dictionary<IntPtr, Flight> inflight = new Dictionary<IntPtr, Flight>();
private IntPtr flightcounter = (IntPtr)0;
private void CResultCb(IntPtr result, ref CNetwork network, IntPtr data)
{
var managedResult = new float[(int)net.outputSize];
Marshal.Copy(result, managedResult, 0, (int)net.outputSize);
Capi.free(result);
var flight = inflight[data];
inflight.Remove(data);
flight.Callback(managedResult, this);
flight.Signal.Set();
}
/// <summary>
/// Constructs a new Network object by loading a network off disk
/// </summary>
/// <param name="path">
/// the path to the onnx network file to load
/// </param>
/// <param name="verbose">
/// if true is set here some extra debug information will be printed by the runtime
/// </param>
/// <exception cref="System.IO.FileLoadException">
/// Thrown when the network file could not be loaded
/// </exception>
/// <exception cref="Kiss.EnviromentException">
/// Thrown if the enviroment is unacceptable for the operation of this libaray
/// </exception>
public Network(string path, bool verbose = false)
{
Utils.CheckEnvThrow();
byte ret = Capi.kiss_load_network_prealloc(ref net, path, CResultCb, Convert.ToByte(verbose));
Console.WriteLine("ret: {0:D}", ret);
if(ret == 0)
throw new System.IO.FileLoadException(getError());
}
/// <summary>
/// This method runs an inference pass on the given spectra asynchronously
/// </summary>
/// <param name="spectra">
/// the spectra to run inference on, must be of InputSize length
/// </param>
/// <param name="callback">
/// a delegate that will be called when inference completes, this callback will be called from a native thread in libkissiniferences thread pool
/// </param>
/// <exception cref="ArgumentException">
/// Thrown when the input is of an incorrect size for the network
/// </exception>
/// <exception cref="InferenceException">
/// Thrown when the inference engine encounteres an internal error
/// </exception>
/// <returns>
/// The AutoResetEvent returned can be waited on to ensure that the inference is finished and the callback has been executed.
/// </returns>
public AutoResetEvent Run(Spectra spectra, ResultDlg callback)
{
if(spectra.Real.Length*2 != (int)net.inputSize)
throw new ArgumentException("Spectra is the wrong size for network, use resample() first");
var flight = new Flight();
flight.Callback = callback;
flight.Signal = new AutoResetEvent(false);
inflight.Add(flightcounter, flight);
byte ret = Capi.kiss_async_run_inference_complex(ref net, spectra.Real, spectra.Imaginary, flightcounter);
if(ret == 0)
throw new InferenceException(getError());
flightcounter += 1;
return flight.Signal;
}
/// <summary>
/// This method runs an inference pass on the given spectra asynchronously
/// </summary>
/// <param name="data">
/// and array of floats containing the data inference on, must be of InputSize length
/// </param>
/// <param name="callback">
/// a delegate that will be called when inference completes, this callback will be called from a native thread in libkissiniferences thread pool
/// </param>
/// <exception cref="ArgumentException">
/// Thrown when the input is of an incorrect size for the network
/// </exception>
/// <exception cref="InferenceException">
/// Thrown when the inference engine encounteres an internal error
/// </exception>
/// <returns>
/// The AutoResetEvent returned can be waited on to ensure that the inference is finished and the callback has been executed.
/// </returns>
public AutoResetEvent Run(float[] data, ResultDlg callback)
{
if(data.Length != (int)net.inputSize)
throw new ArgumentException("Data is the wrong size for network");
var flight = new Flight();
flight.Callback = callback;
flight.Signal = new AutoResetEvent(false);
inflight.Add(flightcounter, flight);
byte ret = Capi.kiss_async_run_inference(ref net, data, flightcounter);
if(ret == 0)
throw new InferenceException(getError());
flightcounter += 1;
return flight.Signal;
}
/// <summary>
/// This method returns an error string for the last error to have occured in the libkissiniference.
/// Usually this method is not usefull as all errors reported by this interface are forwared via exceptions,
/// It is provided to ease implementations that also call the C api directly.
/// </summary>
/// <returns>
/// A string describeing the last reported error.
/// </returns>
private string getError()
{
#if (NET6_0_OR_GREATER)
return Marshal.PtrToStringUTF8(Capi.kiss_get_strerror(ref net));
#else
return Capi.UTF8StringFromPointer(Capi.kiss_get_strerror(ref net));
#endif
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
if(!disposed)
{
disposed = true;
Capi.kiss_free_network_prealloc(ref net);
}
}
~Network()
{
Dispose(false);
}
/// <summary>
/// This Property holds the number of input values the loaded network expects.
/// For networks with a convolutional input this may be -1 designateing "Any"
/// </summary>
public int InputSize
{
get{return (int)net.inputSize;}
}
/// <summary>
/// This Property holds the number of output values the loaded network expects.
/// </summary>
public int OutputSize
{
get{return (int)net.outputSize;}
}
/// <summary>
/// This Property holds a string describeing the purpose the network was trained on.
/// </summary>
public string Purpose
{
get{return Marshal.PtrToStringAnsi(net.purpose);}
}
/// <summary>
/// This Property holds a string describeing the type of input the network expects.
/// </summary>
public string InputLabel
{
get{return Marshal.PtrToStringAnsi(net.inputLabel);}
}
/// <summary>
/// This Property is true if the network expects a complex valued input
/// </summary>
public bool ComplexInput
{
get{return Convert.ToBoolean(net.complexInput);}
}
/// <summary>
/// This Property is an array of OutputSize with strings nameing what every output corrisponds to
/// </summary>
public string[] OutputLabels
{
get{return Capi.IntPtrToUtf8Array(net.outputLabels);}
}
}
}