diff --git a/src/capi.cs b/src/capi.cs index dd875d8b74e050d4808b4b71fbf4f3e2a46809d3..98341e2fded02d5e3fbfa70f085cb794af0a47ca 100644 --- a/src/capi.cs +++ b/src/capi.cs @@ -14,6 +14,7 @@ public struct CNetwork private IntPtr _private; public UIntPtr inputSize; public UIntPtr outputSize; + public IntPtr outputMask; public IntPtr purpose; public IntPtr inputLabel; public IntPtr outputLabels; @@ -52,12 +53,12 @@ internal class Capi public static extern byte kiss_reduce_spectra(float[] in_re, float[] in_im, float[] omegas, UIntPtr input_length, float thresh_factor, byte use_second_deriv, ref IntPtr out_re, ref IntPtr out_im, ref UIntPtr output_length); - [DllImport("kissinference", CharSet = CharSet.Unicode, SetLastError = true, CallingConvention=CallingConvention.StdCall)] + [DllImport("kissinference", SetLastError = true, CallingConvention=CallingConvention.StdCall)] public static extern byte kiss_filter_spectra(float[] in_re, float[] in_im, float[] omegas, UIntPtr input_length, ref IntPtr out_re, ref IntPtr out_im, UIntPtr output_length); - [DllImport("kissinference", CharSet = CharSet.Unicode, SetLastError = true, CallingConvention=CallingConvention.StdCall)] - public static extern byte kiss_load_network_prealloc(ref CNetwork net, string path, resultDlg callback, byte verbose); + [DllImport("kissinference", SetLastError = true, CallingConvention=CallingConvention.StdCall)] + public static extern byte kiss_load_network_prealloc(ref CNetwork net, byte[] path, resultDlg callback, byte verbose); [DllImport("kissinference", SetLastError = true, CallingConvention=CallingConvention.StdCall)] public static extern void kiss_free_network_prealloc(ref CNetwork net); @@ -68,6 +69,9 @@ internal class Capi [DllImport("kissinference", SetLastError = true, CallingConvention=CallingConvention.StdCall)] public static extern byte kiss_async_run_inference_complex(ref CNetwork net, [Out] float[] real, [Out] float[] imag, IntPtr data); + [DllImport("kissinference", CharSet = CharSet.Ansi, SetLastError = true, CallingConvention=CallingConvention.StdCall)] + public static extern byte kiss_set_output_mask(ref CNetwork net, [Out] byte[] output_mask); + [DllImport("kissinference", SetLastError = true, CallingConvention=CallingConvention.StdCall)] public static extern IntPtr kiss_get_strerror(ref CNetwork net); @@ -117,6 +121,22 @@ internal class Capi return System.Text.Encoding.UTF8.GetString(bytes.ToArray()); } + + public static bool[] IntPtrToBoolArray(IntPtr ptr, int size) + { + var output = new bool[size]; + for(int i = 0; i < size; ++i) + output[i] = Convert.ToBoolean(Marshal.ReadByte(ptr, i)); + return output; + } + + public static byte[] StringToPlatformBytes(String input) + { + if(RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return System.Text.Encoding.Unicode.GetBytes(input); + else + return System.Text.Encoding.UTF8.GetBytes(input); + } } } diff --git a/src/network.cs b/src/network.cs index d06d866d372b60b3b249339a9215473db63711ed..4e64a9b604134cd25a1f8252b8900c29e5a7bb9c 100644 --- a/src/network.cs +++ b/src/network.cs @@ -53,7 +53,7 @@ public class Network: System.IDisposable public Network(string path, bool verbose = false) { Utils.CheckEnvThrow(); - byte ret = Capi.kiss_load_network_prealloc(ref net, path, CResultCb, Convert.ToByte(verbose)); + byte ret = Capi.kiss_load_network_prealloc(ref net, Capi.StringToPlatformBytes(path), CResultCb, Convert.ToByte(verbose)); Console.WriteLine("ret: {0:D}", ret); if(ret == 0) throw new System.IO.FileLoadException(getError()); @@ -211,6 +211,34 @@ public class Network: System.IDisposable { get{return Capi.IntPtrToUtf8Array(net.outputLabels);} } + + /// <summary> + /// This Property is an array of bools enableing or disableing a given output. + /// + /// This Property must only be set while no inference requests are pending. + /// Disabled outputs will set to zero, exact effect on non-disabled outputs is network depenant. + /// For classifier networks the likelyhoods will reflect the reduction in problem space. + /// </summary> + public bool[] OutputMask + { + get + { + return Capi.IntPtrToBoolArray(net.outputMask, OutputSize); + } + set + { + if(inflight.Count > 0) + throw new InferenceException("OutputMask can not be set while there are inference requests in flight"); + if(value.Length != OutputSize) + throw new ArgumentException("The given array needs to be the same length as the networks OutputSize"); + var bytes = new byte[OutputSize]; + for(int i = 0; i < OutputSize; ++i) + bytes[i] = Convert.ToByte(value[i]); + byte ret = Capi.kiss_set_output_mask(ref net, bytes); + if(ret == 0) + throw new ArgumentException(getError()); + } + } } } diff --git a/src/test.cs b/src/test.cs index 83811b09a224df428cd0f9f3ac3088c02360078b..90e84339da503f7fce66003ec7bfb84cf53878e3 100644 --- a/src/test.cs +++ b/src/test.cs @@ -1,5 +1,6 @@ using System; using System.Threading; +using System.Linq; using Kiss; public class Program @@ -10,6 +11,14 @@ public class Program Console.WriteLine("{0:F}+{1:F}i", spectra.Real[i], spectra.Imaginary[i]); } + private static void PrintOutputMask(Kiss.Network network) + { + var outputMask = network.OutputMask; + var labels = network.OutputLabels; + for(int i = 0; i < outputMask.Length; ++i) + Console.WriteLine("{0:S}: {1:B}", labels[i], outputMask[i]); + } + public static void Result(float[] result, Kiss.Network network) { Console.WriteLine("Thread id: {0:D}", System.Environment.CurrentManagedThreadId); @@ -57,7 +66,17 @@ public class Program PrintSpectra(spectra); var signal = net.Run(spectra, Result); signal.WaitOne(); - Console.WriteLine("Awaited"); + Console.WriteLine("Awaited\n"); + + PrintOutputMask(net); + var newMask = Enumerable.Repeat(true, net.OutputSize).ToArray(); + newMask[7] = false; + net.OutputMask = newMask; + Console.WriteLine(""); + + signal = net.Run(spectra, Result); + signal.WaitOne(); + Console.WriteLine("Awaited\n"); } catch(System.IO.FileLoadException e) { diff --git a/src/utils.cs b/src/utils.cs index 03682989118d115aedbb9ec6a7c6b5a650615685..28bc84293bb918a676803a02a401275b7a83a9c8 100644 --- a/src/utils.cs +++ b/src/utils.cs @@ -114,8 +114,8 @@ public class Utils return "libkissinferencesharp only works on X86 or X64 cpu architecture"; var version = VersionFixed.GetVersionNoCheck(); - if(version.Major != 1 || version.Minor != 2) - return "This version of libkissinferencesharp only supports native code backends of version 1.2.x, but the loaded libary is " + version; + if(version.Major != 1 || version.Minor != 3) + return "This version of libkissinferencesharp only supports native code backends of version 1.3.x, but the loaded libary is " + version; EnviromentChecked = true; return string.Empty;