Skip to content
Snippets Groups Projects
Select Git revision
  • 0f49232b0f3e30815bf9952d66ebb735f5568c5b
  • master default protected
  • file_refactoring
  • 1.1.0
4 results

network.cs

Blame
  • user avatar
    Carl Philipp Klemm authored
    These functions check if the lib is able to service requests as currently loaded
    Also use these functions to check if everything is ok at entry points to the various other classes
    add Utils.FloatEq() that uses the same logic as libeisgenerator to compare float values
    0f49232b
    History
    network.cs 6.53 KiB
    using System;
    using System.Runtime.InteropServices;
    using System.Collections.Generic;
    using System.Threading;
    
    namespace Kiss
    {
    
    /// <summary>
    /// This is the main class of kissinference, it allows you to load and use Kiss neural neturoks
    /// </summary>
    public class Network: System.IDisposable
    {
    	public delegate void ResultDlg(float[] result, Network network);
    
    	private struct Flight
    	{
    		public ResultDlg Callback;
    		public AutoResetEvent Signal;
    	}
    
    	private CNetwork net;
    	private bool disposed = false;
    	private Dictionary<IntPtr, Flight> inflight = new Dictionary<IntPtr, Flight>();
    	private IntPtr flightcounter = (IntPtr)0;
    
    	private void CResultCb(IntPtr result, ref CNetwork network, IntPtr data)
    	{
    		var managedResult = new float[(int)net.outputSize];
    		Marshal.Copy(result, managedResult, 0, (int)net.outputSize);
    		Capi.free(result);
    		var flight = inflight[data];
    		inflight.Remove(data);
    		flight.Callback(managedResult, this);
    		flight.Signal.Set();
    	}
    
    	/// <summary>
    	/// Constructs a new Network object by loading a network off disk
    	/// </summary>
    	/// <param name="path">
    	/// the path to the onnx network file to load
    	/// </param>
    	/// <param name="verbose">
    	/// if true is set here some extra debug information will be printed by the runtime
    	/// </param>
    	/// <exception cref="System.IO.FileLoadException">
    	/// Thrown when the network file could not be loaded
    	/// </exception>
    	/// <exception cref="Kiss.EnviromentException">
    	/// Thrown if the enviroment is unacceptable for the operation of this libaray
    	/// </exception>
    	public Network(string path, bool verbose = false)
    	{
    		Utils.CheckEnvThrow();
    		byte ret = Capi.kiss_load_network_prealloc(ref net, path, CResultCb, Convert.ToByte(verbose));
    		Console.WriteLine("ret: {0:D}", ret);
    		if(ret == 0)
    			throw new System.IO.FileLoadException(getError());
    	}
    
    	/// <summary>
    	/// This method runs an inference pass on the given spectra asynchronously
    	/// </summary>
    	/// <param name="spectra">
    	/// the spectra to run inference on, must be of InputSize length
    	/// </param>
    	/// <param name="callback">
    	/// a delegate that will be called when inference completes, this callback will be called from a native thread in libkissiniferences thread pool
    	/// </param>
    	/// <exception cref="ArgumentException">
    	/// Thrown when the input is of an incorrect size for the network
    	/// </exception>
    	/// <exception cref="InferenceException">
    	/// Thrown when the inference engine encounteres an internal error
    	/// </exception>
    	/// <returns>
    	/// The AutoResetEvent returned can be waited on to ensure that the inference is finished and the callback has been executed.
    	/// </returns>
    	public AutoResetEvent Run(Spectra spectra, ResultDlg callback)
    	{
    		if(spectra.Real.Length*2 != (int)net.inputSize)
    			throw new ArgumentException("Spectra is the wrong size for network, use resample() first");
    		var flight = new Flight();
    		flight.Callback = callback;
    		flight.Signal = new AutoResetEvent(false);
    		inflight.Add(flightcounter, flight);
    		byte ret = Capi.kiss_async_run_inference_complex(ref net, spectra.Real, spectra.Imaginary, flightcounter);
    		if(ret == 0)
    			throw new InferenceException(getError());
    
    		flightcounter += 1;
    		return flight.Signal;
    	}
    
    	/// <summary>
    	/// This method runs an inference pass on the given spectra asynchronously
    	/// </summary>
    	/// <param name="data">
    	/// and array of floats containing the data inference on, must be of InputSize length
    	/// </param>
    	/// <param name="callback">
    	/// a delegate that will be called when inference completes, this callback will be called from a native thread in libkissiniferences thread pool
    	/// </param>
    	/// <exception cref="ArgumentException">
    	/// Thrown when the input is of an incorrect size for the network
    	/// </exception>
    	/// <exception cref="InferenceException">
    	/// Thrown when the inference engine encounteres an internal error
    	/// </exception>
    	/// <returns>
    	/// The AutoResetEvent returned can be waited on to ensure that the inference is finished and the callback has been executed.
    	/// </returns>
    	public AutoResetEvent Run(float[] data, ResultDlg callback)
    	{
    		if(data.Length != (int)net.inputSize)
    			throw new ArgumentException("Data is the wrong size for network");
    		var flight = new Flight();
    		flight.Callback = callback;
    		flight.Signal = new AutoResetEvent(false);
    		inflight.Add(flightcounter, flight);
    		byte ret = Capi.kiss_async_run_inference(ref net, data, flightcounter);
    		if(ret == 0)
    			throw new InferenceException(getError());
    		flightcounter += 1;
    		return flight.Signal;
    	}
    
    	/// <summary>
    	/// This method returns an error string for the last error to have occured in the libkissiniference.
    	/// Usually this method is not usefull as all errors reported by this interface are forwared via exceptions,
    	/// It is provided to ease implementations that also call the C api directly.
    	/// </summary>
    	/// <returns>
    	/// A string describeing the last reported error.
    	/// </returns>
    	private string getError()
    	{
    #if (NET6_0_OR_GREATER)
    		return Marshal.PtrToStringUTF8(Capi.kiss_get_strerror(ref net));
    #else
    		return Capi.UTF8StringFromPointer(Capi.kiss_get_strerror(ref net));
    #endif
    	}
    
    	public void Dispose()
    	{
    		Dispose(true);
    		GC.SuppressFinalize(this);
    	}
    
    	protected virtual void Dispose(bool disposing)
    	{
    		if(!disposed)
    		{
    			disposed = true;
    			Capi.kiss_free_network_prealloc(ref net);
    		}
    	}
    
    	~Network()
    	{
    		Dispose(false);
    	}
    
    	/// <summary>
    	/// This Property holds the number of input values the loaded network expects.
    	/// For networks with a convolutional input this may be -1 designateing "Any"
    	/// </summary>
    	public int InputSize
    	{
    		get{return (int)net.inputSize;}
    	}
    
    	/// <summary>
    	/// This Property holds the number of output values the loaded network expects.
    	/// </summary>
    	public int OutputSize
    	{
    		get{return (int)net.outputSize;}
    	}
    
    	/// <summary>
    	/// This Property holds a string describeing the purpose the network was trained on.
    	/// </summary>
    	public string Purpose
    	{
    		get{return Marshal.PtrToStringAnsi(net.purpose);}
    	}
    
    	/// <summary>
    	/// This Property holds a string describeing the type of input the network expects.
    	/// </summary>
    	public string InputLabel
    	{
    		get{return Marshal.PtrToStringAnsi(net.inputLabel);}
    	}
    
    	/// <summary>
    	/// This Property is true if the network expects a complex valued input
    	/// </summary>
    	public bool ComplexInput
    	{
    		get{return Convert.ToBoolean(net.complexInput);}
    	}
    
    	/// <summary>
    	/// This Property is an array of OutputSize with strings nameing what every output corrisponds to
    	/// </summary>
    	public string[] OutputLabels
    	{
    		get{return Capi.IntPtrToUtf8Array(net.outputLabels);}
    	}
    }
    
    }