From b1be644413fc604246556a6ea6eb29d444f31891 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 15:12:14 +0100 Subject: [PATCH 01/75] Added tensors library --- Bonsai.ML.sln | 7 +++++++ .../Bonsai.ML.Tensors.csproj | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index c5a91b13..22b8a35a 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -30,6 +30,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.LinearDynamicalSy EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModels.Design", "src\Bonsai.ML.HiddenMarkovModels.Design\Bonsai.ML.HiddenMarkovModels.Design.csproj", "{FC395DDC-62A4-4E14-A198-272AB05B33C7}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tensors", "src\Bonsai.ML.Tensors\Bonsai.ML.Tensors.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -72,6 +74,10 @@ Global {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU {FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.Build.0 = Debug|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.ActiveCfg = Release|Any CPU + {06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -86,6 +92,7 @@ Global {39A4414F-52B1-42D7-82FA-E65DAD885264} = {12312384-8828-4786-AE19-EFCEDF968290} {A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290} {17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290} + {06FCC9AF-CE38-44BB-92B3-0D451BE88537} = {12312384-8828-4786-AE19-EFCEDF968290} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {B6468F13-97CD-45E0-9E1E-C122D7F1E09F} diff --git a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj new file mode 100644 index 00000000..2a0a76e2 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj @@ -0,0 +1,19 @@ + + + Bonsai.ML.Tensors + A Bonsai package for TorchSharp tensor manipulations. + Bonsai Rx ML Tensors TorchSharp + net472;netstandard2.0 + 12.0 + + + + + + + + + + + + \ No newline at end of file From 158f682126a98e142f48a33ab35f06e5eaba3237 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 15:12:24 +0100 Subject: [PATCH 02/75] Added arange function --- src/Bonsai.ML.Tensors/Arange.cs | 39 +++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Arange.cs diff --git a/src/Bonsai.ML.Tensors/Arange.cs b/src/Bonsai.ML.Tensors/Arange.cs new file mode 100644 index 00000000..e3c355d0 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Arange.cs @@ -0,0 +1,39 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a 1-D tensor of values within a given range given the start, end, and step. + /// + [Combinator] + [Description("Creates a 1-D tensor of values within a given range given the start, end, and step.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Arange + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 10; + + /// + /// The step of the range. + /// + public int Step { get; set; } = 1; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(arange(Start, End, Step))); + } + } +} From ceabd853974f3f7776e5b7399aa8448e35147adc Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 15:12:34 +0100 Subject: [PATCH 03/75] Added concat class --- src/Bonsai.ML.Tensors/Concat.cs | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Concat.cs diff --git a/src/Bonsai.ML.Tensors/Concat.cs b/src/Bonsai.ML.Tensors/Concat.cs new file mode 100644 index 00000000..1a11eb0e --- /dev/null +++ b/src/Bonsai.ML.Tensors/Concat.cs @@ -0,0 +1,44 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Concatenates tensors along a given dimension. + /// + [Combinator] + [Description("Concatenates tensors along a given dimension.")] + [WorkflowElementCategory(ElementCategory.Combinator)] + public class Concat + { + /// + /// The dimension along which to concatenate the tensors. + /// + public long Dimension { get; set; } = 0; + + /// + /// Takes any number of observable sequences of tensors and concatenates the input tensors along the specified dimension by zipping each tensor together. + /// + public IObservable Process(params IObservable[] sources) + { + return sources.Aggregate((current, next) => + current.Zip(next, (tensor1, tensor2) => + cat(new Tensor[] { tensor1, tensor2 }, Dimension))); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + var tensor1 = value.Item1; + var tensor2 = value.Item2; + return cat(new Tensor[] { tensor1, tensor2 }, Dimension); + }); + } + } +} \ No newline at end of file From 0c33017bde41387c0111fa10feffd22e85382c02 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:55:34 +0100 Subject: [PATCH 04/75] Added arange function --- src/Bonsai.ML.Tensors/Arange.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Tensors/Arange.cs b/src/Bonsai.ML.Tensors/Arange.cs index e3c355d0..2a1eda40 100644 --- a/src/Bonsai.ML.Tensors/Arange.cs +++ b/src/Bonsai.ML.Tensors/Arange.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Reactive.Linq; using static TorchSharp.torch; +using TorchSharp; namespace Bonsai.ML.Tensors { @@ -29,7 +30,7 @@ public class Arange public int Step { get; set; } = 1; /// - /// Generates an observable sequence of 1-D tensors created with the function. + /// Generates an observable sequence of 1-D tensors created with the function. /// public IObservable Process() { From 518a9892a077221bf92aebceb6616762f5515f84 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:55:48 +0100 Subject: [PATCH 05/75] Added linspace --- src/Bonsai.ML.Tensors/Linspace.cs | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Linspace.cs diff --git a/src/Bonsai.ML.Tensors/Linspace.cs b/src/Bonsai.ML.Tensors/Linspace.cs new file mode 100644 index 00000000..aa263500 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Linspace.cs @@ -0,0 +1,40 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. + /// + [Combinator] + [Description("Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Linspace + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 1; + + /// + /// The number of points to generate. + /// + public int Count { get; set; } = 10; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(linspace(Start, End, Count))); + } + } +} \ No newline at end of file From 004db9b7a1223c5274d8ce048ce2b88cb844d5a9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:15 +0100 Subject: [PATCH 06/75] Added meshgrid --- src/Bonsai.ML.Tensors/MeshGrid.cs | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/MeshGrid.cs diff --git a/src/Bonsai.ML.Tensors/MeshGrid.cs b/src/Bonsai.ML.Tensors/MeshGrid.cs new file mode 100644 index 00000000..6b0a2c73 --- /dev/null +++ b/src/Bonsai.ML.Tensors/MeshGrid.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using static TorchSharp.torch; +using System.Linq; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class MeshGrid + { + /// + /// The indexing mode to use for the mesh grid. + /// + public string Indexing { get; set; } = "ij"; + + /// + /// Creates a mesh grid from the input tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(tensors => meshgrid(tensors, indexing: Indexing)); + } + } +} \ No newline at end of file From a8e4a833fa54ea4d9fa138827e5917b7a0ccfd4d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:22 +0100 Subject: [PATCH 07/75] Added ones --- src/Bonsai.ML.Tensors/Ones.cs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Ones.cs diff --git a/src/Bonsai.ML.Tensors/Ones.cs b/src/Bonsai.ML.Tensors/Ones.cs new file mode 100644 index 00000000..499012bd --- /dev/null +++ b/src/Bonsai.ML.Tensors/Ones.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a tensor filled with ones. + /// + [Combinator] + [Description("Creates a tensor filled with ones.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Ones + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with ones. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file From d5bf19619616b5b076c680bcee5e870501c5c226 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:30 +0100 Subject: [PATCH 08/75] Added zeros --- src/Bonsai.ML.Tensors/Zeros.cs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Zeros.cs diff --git a/src/Bonsai.ML.Tensors/Zeros.cs b/src/Bonsai.ML.Tensors/Zeros.cs new file mode 100644 index 00000000..af220641 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Zeros.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a tensor filled with zeros. + /// + [Combinator] + [Description("Creates a tensor filled with zeros.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Zeros + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with zeros. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file From 9f2293242ee0291887b04707571780e9c774bbfa Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:40 +0100 Subject: [PATCH 09/75] Added device initialization --- .../InitializeTorchDevice.cs | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/InitializeTorchDevice.cs diff --git a/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs new file mode 100644 index 00000000..dc9123f0 --- /dev/null +++ b/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Tensors +{ + /// + /// Initializes the Torch device with the specified device type. + /// + [Combinator] + [Description("Initializes the Torch device with the specified device type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class InitializeTorchDevice + { + /// + /// The device type to initialize. + /// + public DeviceType DeviceType { get; set; } + + /// + /// Initializes the Torch device with the specified device type. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + { + InitializeDeviceType(DeviceType); + return Observable.Return(new Device(DeviceType)); + }); + } + } +} \ No newline at end of file From 7da5f092b14ecfcb1217066259570293dc9a87de Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:56:56 +0100 Subject: [PATCH 10/75] Added ability to move tensor to device --- src/Bonsai.ML.Tensors/ToDevice.cs | 34 +++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/ToDevice.cs diff --git a/src/Bonsai.ML.Tensors/ToDevice.cs b/src/Bonsai.ML.Tensors/ToDevice.cs new file mode 100644 index 00000000..574be5f3 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToDevice.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Moves the input tensor to the specified device. + /// + [Combinator] + [Description("Moves the input tensor to the specified device.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToDevice + { + /// + /// The device to which the input tensor should be moved. + /// + public Device Device { get; set; } + + /// + /// Returns the input tensor moved to the specified device. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.to(Device); + }); + } + } +} \ No newline at end of file From a6dc975bd4aa307eff24ecaaccf5680b8e4ffb2c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:57:11 +0100 Subject: [PATCH 11/75] Added permute --- src/Bonsai.ML.Tensors/Permute.cs | 33 ++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Permute.cs diff --git a/src/Bonsai.ML.Tensors/Permute.cs b/src/Bonsai.ML.Tensors/Permute.cs new file mode 100644 index 00000000..7f037d79 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Permute.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Permutes the dimensions of the input tensor according to the specified permutation. + /// + [Combinator] + [Description("Permutes the dimensions of the input tensor according to the specified permutation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Permute + { + /// + /// The permutation of the dimensions. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Returns an observable sequence that permutes the dimensions of the input tensor according to the specified permutation. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.permute(Dimensions); + }); + } + } +} \ No newline at end of file From 111e97e3fef0f925a740c7b4365036518f400c09 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:57:18 +0100 Subject: [PATCH 12/75] Added reshape --- src/Bonsai.ML.Tensors/Reshape.cs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Reshape.cs diff --git a/src/Bonsai.ML.Tensors/Reshape.cs b/src/Bonsai.ML.Tensors/Reshape.cs new file mode 100644 index 00000000..4fef3d83 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Reshape.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + [Combinator] + [Description("Reshapes the input tensor according to the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Reshape + { + /// + /// The dimensions of the reshaped tensor. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.reshape(Dimensions)); + } + } +} \ No newline at end of file From 2776c09fd88c9e06d92d4b14918b536d32b4671b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 17:57:26 +0100 Subject: [PATCH 13/75] Added set --- src/Bonsai.ML.Tensors/Set.cs | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Set.cs diff --git a/src/Bonsai.ML.Tensors/Set.cs b/src/Bonsai.ML.Tensors/Set.cs new file mode 100644 index 00000000..3f2a6f50 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Set.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Sets the value of the input tensor at the specified index. + /// + [Combinator] + [Description("Sets the value of the input tensor at the specified index.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Set + { + /// + /// The index at which to set the value. + /// + public string Index + { + get => Helpers.IndexParser.SerializeIndexes(indexes); + set => indexes = Helpers.IndexParser.ParseString(value); + } + + private TensorIndex[] indexes; + + /// + /// The value to set at the specified index. + /// + [XmlIgnore] + public Tensor Value { get; set; } = null; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.index_put_(Value, indexes); + }); + } + } +} \ No newline at end of file From 1bb53119d00de662356ed959cd8b57ee32e403fa Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:01:42 +0100 Subject: [PATCH 14/75] Updated csproj with opencv.net package --- src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj index 2a0a76e2..8d87ac9b 100644 --- a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj +++ b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj @@ -4,15 +4,13 @@ A Bonsai package for TorchSharp tensor manipulations. Bonsai Rx ML Tensors TorchSharp net472;netstandard2.0 - 12.0 + true + - - - From 747f8d1f11d51edde3ffba801f6ff5a7728c5ee2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:02:07 +0100 Subject: [PATCH 15/75] Added concatenate class --- src/Bonsai.ML.Tensors/Concat.cs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Bonsai.ML.Tensors/Concat.cs b/src/Bonsai.ML.Tensors/Concat.cs index 1a11eb0e..1dd99b7b 100644 --- a/src/Bonsai.ML.Tensors/Concat.cs +++ b/src/Bonsai.ML.Tensors/Concat.cs @@ -1,5 +1,6 @@ -using System; +using System; using System.ComponentModel; +using System.Linq; using System.Reactive.Linq; using static TorchSharp.torch; @@ -23,9 +24,9 @@ public class Concat /// public IObservable Process(params IObservable[] sources) { - return sources.Aggregate((current, next) => - current.Zip(next, (tensor1, tensor2) => - cat(new Tensor[] { tensor1, tensor2 }, Dimension))); + return sources.Aggregate((current, next) => + current.Zip(next, (tensor1, tensor2) => + cat([tensor1, tensor2], Dimension))); } /// @@ -33,12 +34,12 @@ public IObservable Process(params IObservable[] sources) /// public IObservable Process(IObservable> source) { - return source.Select(value => + return source.Select(value => { var tensor1 = value.Item1; var tensor2 = value.Item2; - return cat(new Tensor[] { tensor1, tensor2 }, Dimension); + return cat([tensor1, tensor2], Dimension); }); } } -} \ No newline at end of file +} From 78aa6be186415b5057b03d43b7c9eedd023d9d93 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:03:04 +0100 Subject: [PATCH 16/75] Added convert data type --- src/Bonsai.ML.Tensors/ConvertDataType.cs | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/ConvertDataType.cs diff --git a/src/Bonsai.ML.Tensors/ConvertDataType.cs b/src/Bonsai.ML.Tensors/ConvertDataType.cs new file mode 100644 index 00000000..14b0db84 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ConvertDataType.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor to the specified scalar type. + /// + [Combinator] + [Description("Converts the input tensor to the specified scalar type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ConvertDataType + { + /// + /// The scalar type to which to convert the input tensor. + /// + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Returns an observable sequence that converts the input tensor to the specified scalar type. + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => + { + return tensor.to_type(Type); + }); + } + } +} From 829545e3afdb9f635f7b89c4b693884f6aa63aad Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:03:22 +0100 Subject: [PATCH 17/75] Added create tensor method --- src/Bonsai.ML.Tensors/CreateTensor.cs | 245 ++++++++++++++++++++ src/Bonsai.ML.Tensors/Helpers/DataHelper.cs | 190 +++++++++++++++ 2 files changed, 435 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/CreateTensor.cs create mode 100644 src/Bonsai.ML.Tensors/Helpers/DataHelper.cs diff --git a/src/Bonsai.ML.Tensors/CreateTensor.cs b/src/Bonsai.ML.Tensors/CreateTensor.cs new file mode 100644 index 00000000..712c7243 --- /dev/null +++ b/src/Bonsai.ML.Tensors/CreateTensor.cs @@ -0,0 +1,245 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reactive.Linq; +using System.Reflection; +using System.Xml.Serialization; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// + [Combinator] + [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] + [WorkflowElementCategory(ElementCategory.Source)] + public class CreateTensor : ExpressionBuilder + { + Range argumentRange = new Range(0, 1); + + /// + public override Range ArgumentRange => argumentRange; + + /// + /// The data type of the tensor elements. + /// + public TensorDataType Type + { + get => scalarType; + set => scalarType = value; + } + + private TensorDataType scalarType = TensorDataType.Float32; + + /// + /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. + /// + public string Values + { + get => values; + set + { + values = value.Replace("False", "false").Replace("True", "true"); + } + } + + private string values = "[0]"; + + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + public Device Device { get => device; set => device = value; } + + private Device device = null; + + private Expression BuildTensorFromArray(Array arrayValues, Type returnType) + { + var rank = arrayValues.Rank; + var lengths = new int[rank]; + for (int i = 0; i < rank; i++) + { + lengths[i] = arrayValues.GetLength(i); + } + + var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); + var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); + var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); + + var assignments = new List(); + for (int i = 0; i < values.Length; i++) + { + var indices = new Expression[rank]; + int temp = i; + for (int j = rank - 1; j >= 0; j--) + { + indices[j] = Expression.Constant(temp % lengths[j]); + temp /= lengths[j]; + } + var value = Expression.Constant(arrayValues.GetValue(indices.Select(e => ((ConstantExpression)e).Value).Cast().ToArray())); + var arrayAccess = Expression.ArrayAccess(arrayVariable, indices); + var assignArrayValue = Expression.Assign(arrayAccess, value); + assignments.Add(assignArrayValue); + } + + var tensorDataInitializationBlock = Expression.Block( + arrayVariable, + assignArray, + Expression.Block(assignments), + arrayVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + arrayVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool), + typeof(string).MakeArrayType() + ] + ); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorDataInitializationBlock, + Expression.Constant(scalarType, typeof(ScalarType?)), + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)), + Expression.Constant(null, typeof(string).MakeArrayType()) + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + private Expression BuildTensorFromScalarValue(object scalarValue, Type returnType) + { + var valueVariable = Expression.Variable(returnType, "value"); + var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); + + var tensorDataInitializationBlock = Expression.Block( + valueVariable, + assignValue, + valueVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(Device), + typeof(bool) + ] + ); + + var tensorCreationMethodArguments = new Expression[] { + Expression.Constant(device, typeof(Device) ), + Expression.Constant(false, typeof(bool) ) + }; + + if (tensorCreationMethodInfo == null) + { + tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool) + ] + ); + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + Expression.Constant(scalarType, typeof(ScalarType?)) + ).ToArray(); + } + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + tensorDataInitializationBlock + ).ToArray(); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorCreationMethodArguments + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + /// + public override Expression Build(IEnumerable arguments) + { + var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); + var argTypes = arguments.Select(arg => arg.Type).ToArray(); + + var methodInfoArgumentTypes = new Type[] { + typeof(Tensor) + }; + + var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name == "Process") + .ToArray(); + + var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) + .MakeGenericMethod( + arguments + .First() + .Type + .GetGenericArguments()[0] + ) : methods.FirstOrDefault(m => !m.IsGenericMethod); + + var tensorValues = Helpers.DataHelper.ParseString(values, returnType); + var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); + var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); + + try + { + return Expression.Call( + Expression.Constant(this), + methodInfo, + methodArguments + ); + } + finally + { + values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); + scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); + } + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values. + /// + public IObservable Process(Tensor tensor) + { + return Observable.Return(tensor); + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. + /// + public IObservable Process(IObservable source, Tensor tensor) + { + return Observable.Select(source, (_) => tensor); + } + } +} diff --git a/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs b/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs new file mode 100644 index 00000000..1bbf3228 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs @@ -0,0 +1,190 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Provides helper methods for parsing tensor data types. + /// + public static class DataHelper + { + + /// + /// Serializes the input data into a string representation. + /// + public static string SerializeData(object data) + { + if (data is Array array) + { + return SerializeArray(array); + } + else + { + return JsonConvert.SerializeObject(data); + } + } + + /// + /// Serializes the input array into a string representation. + /// + public static string SerializeArray(Array array) + { + StringBuilder sb = new StringBuilder(); + SerializeArrayRecursive(array, sb, [0]); + return sb.ToString(); + } + + private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) + { + if (indices.Length < array.Rank) + { + sb.Append("["); + int length = array.GetLength(indices.Length); + for (int i = 0; i < length; i++) + { + int[] newIndices = new int[indices.Length + 1]; + indices.CopyTo(newIndices, 0); + newIndices[indices.Length] = i; + SerializeArrayRecursive(array, sb, newIndices); + if (i < length - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else + { + object value = array.GetValue(indices); + sb.Append(value.ToString()); + } + } + + private static bool IsValidJson(string input) + { + int squareBrackets = 0; + foreach (char c in input) + { + if (c == '[') squareBrackets++; + else if (c == ']') squareBrackets--; + } + return squareBrackets == 0; + } + + /// + /// Parses the input string into an object of the specified type. + /// + public static object ParseString(string input, Type dtype) + { + if (!IsValidJson(input)) + { + throw new ArgumentException("JSON is invalid."); + } + var obj = JsonConvert.DeserializeObject(input); + int depth = ParseDepth(obj); + if (depth == 0) + { + return Convert.ChangeType(input, dtype); + } + int[] dimensions = ParseDimensions(obj, depth); + var resultArray = Array.CreateInstance(dtype, dimensions); + PopulateArray(obj, resultArray, [0], dtype); + return resultArray; + } + + private static int ParseDepth(JToken token, int currentDepth = 0) + { + if (token is JArray arr && arr.Count > 0) + { + return ParseDepth(arr[0], currentDepth + 1); + } + return currentDepth; + } + + private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) + { + if (depth == 0 || !(token is JArray)) + { + return [0]; + } + + List dimensions = new List(); + JToken current = token; + + while (current != null && current is JArray) + { + JArray currentArray = current as JArray; + dimensions.Add(currentArray.Count); + if (currentArray.Count > 0) + { + if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) + { + throw new Exception("Error parsing input. Dimensions are inconsistent."); + } + + if (!(currentArray.First() is JArray)) + { + if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) + { + throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); + } + } + } + + current = currentArray.Count > 0 ? currentArray[0] : null; + } + + if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) + { + var subArrayDimensions = new HashSet(); + foreach (JArray subArr in arr) + { + int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); + subArrayDimensions.Add(string.Join(",", subDims)); + } + + if (subArrayDimensions.Count > 1) + { + throw new ArgumentException("Inconsistent array dimensions."); + } + } + + return dimensions.ToArray(); + } + + private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) + { + if (token is JArray arr) + { + for (int i = 0; i < arr.Count; i++) + { + int[] newIndices = new int[indices.Length + 1]; + Array.Copy(indices, newIndices, indices.Length); + newIndices[newIndices.Length - 1] = i; + PopulateArray(arr[i], array, newIndices, dtype); + } + } + else + { + var values = ConvertType(token, dtype); + array.SetValue(values, indices); + } + } + + private static object ConvertType(object value, Type targetType) + { + try + { + return Convert.ChangeType(value, targetType); + } + catch (Exception ex) + { + throw new Exception("Error parsing type: ", ex); + } + } + } +} \ No newline at end of file From 65347bd3b5e31296063ee925d8343f1fc387ee29 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:03:45 +0100 Subject: [PATCH 18/75] Added index method and updated set method --- src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs | 91 ++++++++++++++++++++ src/Bonsai.ML.Tensors/Index.cs | 35 ++++++++ src/Bonsai.ML.Tensors/Set.cs | 4 +- 3 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs create mode 100644 src/Bonsai.ML.Tensors/Index.cs diff --git a/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs b/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs new file mode 100644 index 00000000..785eccea --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs @@ -0,0 +1,91 @@ +using System; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Provides helper methods to parse tensor indexes. + /// + public static class IndexHelper + { + + /// + /// Parses the input string into an array of tensor indexes. + /// + /// + public static TensorIndex[] ParseString(string input) + { + if (string.IsNullOrEmpty(input)) + { + return [0]; + } + + var indexStrings = input.Split(','); + var indices = new TensorIndex[indexStrings.Length]; + + for (int i = 0; i < indexStrings.Length; i++) + { + var indexString = indexStrings[i].Trim(); + if (int.TryParse(indexString, out int intIndex)) + { + indices[i] = TensorIndex.Single(intIndex); + } + else if (indexString == ":") + { + indices[i] = TensorIndex.Colon; + } + else if (indexString == "None") + { + indices[i] = TensorIndex.None; + } + else if (indexString == "...") + { + indices[i] = TensorIndex.Ellipsis; + } + else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") + { + indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); + } + else if (indexString.Contains(":")) + { + var rangeParts = indexString.Split(':'); + if (rangeParts.Length == 0) + { + indices[i] = TensorIndex.Slice(); + } + else if (rangeParts.Length == 1) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); + } + else if (rangeParts.Length == 2) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); + } + else if (rangeParts.Length == 3) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + return indices; + } + + /// + /// Serializes the input array of tensor indexes into a string representation. + /// + /// + /// + public static string SerializeIndexes(TensorIndex[] indexes) + { + return string.Join(", ", indexes); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Index.cs b/src/Bonsai.ML.Tensors/Index.cs new file mode 100644 index 00000000..3c1948f9 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Index.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. + /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. + /// + [Combinator] + [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Index + { + /// + /// The indices to use for indexing the tensor. + /// + public string Indexes { get; set; } = string.Empty; + + /// + /// Indexes the input tensor with the specified indices. + /// + /// + /// + public IObservable Process(IObservable source) + { + var index = Helpers.IndexHelper.ParseString(Indexes); + return source.Select(tensor => { + return tensor.index(index); + }); + } + } +} diff --git a/src/Bonsai.ML.Tensors/Set.cs b/src/Bonsai.ML.Tensors/Set.cs index 3f2a6f50..7f6f8b92 100644 --- a/src/Bonsai.ML.Tensors/Set.cs +++ b/src/Bonsai.ML.Tensors/Set.cs @@ -21,8 +21,8 @@ public class Set /// public string Index { - get => Helpers.IndexParser.SerializeIndexes(indexes); - set => indexes = Helpers.IndexParser.ParseString(value); + get => Helpers.IndexHelper.SerializeIndexes(indexes); + set => indexes = Helpers.IndexHelper.ParseString(value); } private TensorIndex[] indexes; From 5fea7bfabf5ebee5d1bd9f088b0d613dc1aafdb3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:04:08 +0100 Subject: [PATCH 19/75] Defined tensor data types as subset of ScalarType --- src/Bonsai.ML.Tensors/TensorDataType.cs | 56 +++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/TensorDataType.cs diff --git a/src/Bonsai.ML.Tensors/TensorDataType.cs b/src/Bonsai.ML.Tensors/TensorDataType.cs new file mode 100644 index 00000000..a710a9ed --- /dev/null +++ b/src/Bonsai.ML.Tensors/TensorDataType.cs @@ -0,0 +1,56 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. + /// + public enum TensorDataType + { + /// + /// 8-bit unsigned integer. + /// + Byte = ScalarType.Byte, + + /// + /// 8-bit signed integer. + /// + Int8 = ScalarType.Int8, + + /// + /// 16-bit signed integer. + /// + Int16 = ScalarType.Int16, + + /// + /// 32-bit signed integer. + /// + Int32 = ScalarType.Int32, + + /// + /// 64-bit signed integer. + /// + Int64 = ScalarType.Int64, + + /// + /// 32-bit floating point. + /// + Float32 = ScalarType.Float32, + + /// + /// 64-bit floating point. + /// + Float64 = ScalarType.Float64, + + /// + /// Boolean. + /// + Bool = ScalarType.Bool + } +} \ No newline at end of file From 7d1abd4653b7eb6b22bf08ca95b40a71f181a216 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:04:33 +0100 Subject: [PATCH 20/75] Added to array method --- src/Bonsai.ML.Tensors/ToArray.cs | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/ToArray.cs diff --git a/src/Bonsai.ML.Tensors/ToArray.cs b/src/Bonsai.ML.Tensors/ToArray.cs new file mode 100644 index 00000000..af35ab4f --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToArray.cs @@ -0,0 +1,73 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor into an array of the specified element type. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + public TypeMapping Type { get; set; } + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + methodInfo = methodInfo.MakeGenericMethod(returnType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return tensor.data().ToArray(); + }); + } + } +} \ No newline at end of file From 34f48578d8d57dd8c659e954e9afef9561ec6c00 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:04:47 +0100 Subject: [PATCH 21/75] Added tensor data type helper --- .../Helpers/TensorDataTypeHelper.cs | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs diff --git a/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs new file mode 100644 index 00000000..7ea03f65 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Provides helper methods for working with tensor data types. + /// + public class TensorDataTypeHelper + { + private static readonly Dictionary _lookup = new Dictionary + { + { TensorDataType.Byte, (typeof(byte), "byte") }, + { TensorDataType.Int16, (typeof(short), "short") }, + { TensorDataType.Int32, (typeof(int), "int") }, + { TensorDataType.Int64, (typeof(long), "long") }, + { TensorDataType.Float32, (typeof(float), "float") }, + { TensorDataType.Float64, (typeof(double), "double") }, + { TensorDataType.Bool, (typeof(bool), "bool") }, + { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, + }; + + /// + /// Returns the type corresponding to the specified tensor data type. + /// + /// + /// + public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; + + /// + /// Returns the string representation corresponding to the specified tensor data type. + /// + /// + /// + public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; + + /// + /// Returns the tensor data type corresponding to the specified string representation. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Returns the tensor data type corresponding to the specified type. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} \ No newline at end of file From 97e2fe747d9e7920e48a52a1600329240e5d6665 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 23 Aug 2024 18:05:16 +0100 Subject: [PATCH 22/75] Added methods to convert OpenCV types --- src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs | 169 ++++++++++++++++++ src/Bonsai.ML.Tensors/ToImage.cs | 28 +++ src/Bonsai.ML.Tensors/ToMat.cs | 28 +++ src/Bonsai.ML.Tensors/ToTensor.cs | 134 ++++++++++++++ 4 files changed, 359 insertions(+) create mode 100644 src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs create mode 100644 src/Bonsai.ML.Tensors/ToImage.cs create mode 100644 src/Bonsai.ML.Tensors/ToMat.cs create mode 100644 src/Bonsai.ML.Tensors/ToTensor.cs diff --git a/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs new file mode 100644 index 00000000..265f2119 --- /dev/null +++ b/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs @@ -0,0 +1,169 @@ +using System; +using System.Runtime.InteropServices; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors.Helpers +{ + /// + /// Helper class to convert between OpenCV mats, images and Torch tensors. + /// + public static class OpenCVHelper + { + private static Dictionary bitDepthLookup = new Dictionary { + { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, + { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, + { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, + { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, + { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, + { ScalarType.Int8, (IplDepth.S8, Depth.S8) } + }; + + private static ConcurrentDictionary deleters = new ConcurrentDictionary(); + + internal delegate void GCHandleDeleter(IntPtr memory); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_data(IntPtr handle); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); + + /// + /// Creates a tensor from a pointer to the data and the dimensions of the tensor. + /// + /// + /// + /// + /// + public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) + { + var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); + var gchp = GCHandle.ToIntPtr(dataHandle); + GCHandleDeleter deleter = null; + + deleter = new GCHandleDeleter((IntPtr ptrHandler) => + { + GCHandle.FromIntPtr(gchp).Free(); + deleters.TryRemove(deleter, out deleter); + }); + deleters.TryAdd(deleter, deleter); + + fixed (long* dimensionsPtr = dimensions) + { + IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + if (tensorHandle == IntPtr.Zero) { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + } + if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } + var output = Tensor.UnsafeCreateTensor(tensorHandle); + return output; + } + } + + /// + /// Converts an OpenCV image to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(IplImage image) + { + if (image == null) + { + return empty([ 0, 0, 0 ]); + } + + int width = image.Width; + int height = image.Height; + int channels = image.Channels; + + var iplDepth = image.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; + + IntPtr tensorDataPtr = image.ImageData; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts an OpenCV mat to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(Mat mat) + { + if (mat == null) + { + return empty([0, 0, 0 ]); + } + + int width = mat.Size.Width; + int height = mat.Size.Height; + int channels = mat.Channels; + + var depth = mat.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; + + IntPtr tensorDataPtr = mat.Data; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts a Torch tensor to an OpenCV image. + /// + /// + /// + public unsafe static IplImage ToImage(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var iplDepth = bitDepthLookup[tensorType].IplDepth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); + + return image; + } + + /// + /// Converts a Torch tensor to an OpenCV mat. + /// + /// + /// + public unsafe static Mat ToMat(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var depth = bitDepthLookup[tensorType].Depth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); + + return mat; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToImage.cs b/src/Bonsai.ML.Tensors/ToImage.cs new file mode 100644 index 00000000..e29a3825 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToImage.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor into an OpenCV image. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToImage + { + /// + /// Converts the input tensor into an OpenCV image. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToImage); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToMat.cs b/src/Bonsai.ML.Tensors/ToMat.cs new file mode 100644 index 00000000..8a22f408 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToMat.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input tensor into an OpenCV mat. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV mat.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToMat + { + /// + /// Converts the input tensor into an OpenCV mat. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToMat); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToTensor.cs b/src/Bonsai.ML.Tensors/ToTensor.cs new file mode 100644 index 00000000..083e2797 --- /dev/null +++ b/src/Bonsai.ML.Tensors/ToTensor.cs @@ -0,0 +1,134 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Tensors +{ + /// + /// Converts the input value into a tensor. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToTensor + { + /// + /// Converts an int into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a double into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a byte into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a bool into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a float into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a long into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a short into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an array into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an IplImage into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + + /// + /// Converts a Mat into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + } +} \ No newline at end of file From 9d932e5c9114cf75ca8f56b5afbf6a975c3117f2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 12:51:18 +0100 Subject: [PATCH 23/75] Refactored to torch namespace instead of tensors --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 17 ++ src/Bonsai.ML.Torch/Helpers/DataHelper.cs | 190 ++++++++++++++ src/Bonsai.ML.Torch/Helpers/IndexHelper.cs | 91 +++++++ src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs | 169 ++++++++++++ .../Helpers/TensorDataTypeHelper.cs | 53 ++++ .../Configuration/ModelConfiguration.cs | 5 + .../Configuration/ModuleConfiguration.cs | 5 + .../NeuralNets/LoadPretrainedModel.cs | 47 ++++ .../NeuralNets/ModelManager.cs | 5 + .../NeuralNets/Models/AlexNet.cs | 70 +++++ .../NeuralNets/Models/MNIST.cs | 61 +++++ .../NeuralNets/Models/MobileNet.cs | 72 +++++ .../NeuralNets/Models/PretrainedModels.cs | 9 + src/Bonsai.ML.Torch/Tensors/Arange.cs | 40 +++ src/Bonsai.ML.Torch/Tensors/Concat.cs | 45 ++++ .../Tensors/ConvertDataType.cs | 32 +++ src/Bonsai.ML.Torch/Tensors/CreateTensor.cs | 245 ++++++++++++++++++ src/Bonsai.ML.Torch/Tensors/Index.cs | 35 +++ .../Tensors/InitializeTorchDevice.cs | 35 +++ src/Bonsai.ML.Torch/Tensors/Linspace.cs | 40 +++ src/Bonsai.ML.Torch/Tensors/MeshGrid.cs | 33 +++ src/Bonsai.ML.Torch/Tensors/Ones.cs | 30 +++ src/Bonsai.ML.Torch/Tensors/Permute.cs | 33 +++ src/Bonsai.ML.Torch/Tensors/Reshape.cs | 32 +++ src/Bonsai.ML.Torch/Tensors/Set.cs | 48 ++++ src/Bonsai.ML.Torch/Tensors/TensorDataType.cs | 56 ++++ src/Bonsai.ML.Torch/Tensors/ToArray.cs | 73 ++++++ src/Bonsai.ML.Torch/Tensors/ToDevice.cs | 34 +++ src/Bonsai.ML.Torch/Tensors/ToImage.cs | 28 ++ src/Bonsai.ML.Torch/Tensors/ToMat.cs | 28 ++ src/Bonsai.ML.Torch/Tensors/ToTensor.cs | 134 ++++++++++ src/Bonsai.ML.Torch/Tensors/Zeros.cs | 30 +++ 32 files changed, 1825 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj create mode 100644 src/Bonsai.ML.Torch/Helpers/DataHelper.cs create mode 100644 src/Bonsai.ML.Torch/Helpers/IndexHelper.cs create mode 100644 src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs create mode 100644 src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Arange.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Concat.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/CreateTensor.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Index.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Linspace.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/MeshGrid.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Ones.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Permute.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Reshape.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Set.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/TensorDataType.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToArray.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToDevice.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToImage.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToMat.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/ToTensor.cs create mode 100644 src/Bonsai.ML.Torch/Tensors/Zeros.cs diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj new file mode 100644 index 00000000..9ed3c5d8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -0,0 +1,17 @@ + + + Bonsai.ML.Torch + A Bonsai package for TorchSharp tensor manipulations. + Bonsai Rx ML Tensors TorchSharp + net472;netstandard2.0 + true + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/DataHelper.cs b/src/Bonsai.ML.Torch/Helpers/DataHelper.cs new file mode 100644 index 00000000..ffed053a --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/DataHelper.cs @@ -0,0 +1,190 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Provides helper methods for parsing tensor data types. + /// + public static class DataHelper + { + + /// + /// Serializes the input data into a string representation. + /// + public static string SerializeData(object data) + { + if (data is Array array) + { + return SerializeArray(array); + } + else + { + return JsonConvert.SerializeObject(data); + } + } + + /// + /// Serializes the input array into a string representation. + /// + public static string SerializeArray(Array array) + { + StringBuilder sb = new StringBuilder(); + SerializeArrayRecursive(array, sb, [0]); + return sb.ToString(); + } + + private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) + { + if (indices.Length < array.Rank) + { + sb.Append("["); + int length = array.GetLength(indices.Length); + for (int i = 0; i < length; i++) + { + int[] newIndices = new int[indices.Length + 1]; + indices.CopyTo(newIndices, 0); + newIndices[indices.Length] = i; + SerializeArrayRecursive(array, sb, newIndices); + if (i < length - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else + { + object value = array.GetValue(indices); + sb.Append(value.ToString()); + } + } + + private static bool IsValidJson(string input) + { + int squareBrackets = 0; + foreach (char c in input) + { + if (c == '[') squareBrackets++; + else if (c == ']') squareBrackets--; + } + return squareBrackets == 0; + } + + /// + /// Parses the input string into an object of the specified type. + /// + public static object ParseString(string input, Type dtype) + { + if (!IsValidJson(input)) + { + throw new ArgumentException("JSON is invalid."); + } + var obj = JsonConvert.DeserializeObject(input); + int depth = ParseDepth(obj); + if (depth == 0) + { + return Convert.ChangeType(input, dtype); + } + int[] dimensions = ParseDimensions(obj, depth); + var resultArray = Array.CreateInstance(dtype, dimensions); + PopulateArray(obj, resultArray, [0], dtype); + return resultArray; + } + + private static int ParseDepth(JToken token, int currentDepth = 0) + { + if (token is JArray arr && arr.Count > 0) + { + return ParseDepth(arr[0], currentDepth + 1); + } + return currentDepth; + } + + private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) + { + if (depth == 0 || !(token is JArray)) + { + return [0]; + } + + List dimensions = new List(); + JToken current = token; + + while (current != null && current is JArray) + { + JArray currentArray = current as JArray; + dimensions.Add(currentArray.Count); + if (currentArray.Count > 0) + { + if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) + { + throw new Exception("Error parsing input. Dimensions are inconsistent."); + } + + if (!(currentArray.First() is JArray)) + { + if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) + { + throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); + } + } + } + + current = currentArray.Count > 0 ? currentArray[0] : null; + } + + if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) + { + var subArrayDimensions = new HashSet(); + foreach (JArray subArr in arr) + { + int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); + subArrayDimensions.Add(string.Join(",", subDims)); + } + + if (subArrayDimensions.Count > 1) + { + throw new ArgumentException("Inconsistent array dimensions."); + } + } + + return dimensions.ToArray(); + } + + private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) + { + if (token is JArray arr) + { + for (int i = 0; i < arr.Count; i++) + { + int[] newIndices = new int[indices.Length + 1]; + Array.Copy(indices, newIndices, indices.Length); + newIndices[newIndices.Length - 1] = i; + PopulateArray(arr[i], array, newIndices, dtype); + } + } + else + { + var values = ConvertType(token, dtype); + array.SetValue(values, indices); + } + } + + private static object ConvertType(object value, Type targetType) + { + try + { + return Convert.ChangeType(value, targetType); + } + catch (Exception ex) + { + throw new Exception("Error parsing type: ", ex); + } + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs b/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs new file mode 100644 index 00000000..541ae443 --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs @@ -0,0 +1,91 @@ +using System; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Provides helper methods to parse tensor indexes. + /// + public static class IndexHelper + { + + /// + /// Parses the input string into an array of tensor indexes. + /// + /// + public static TensorIndex[] ParseString(string input) + { + if (string.IsNullOrEmpty(input)) + { + return [0]; + } + + var indexStrings = input.Split(','); + var indices = new TensorIndex[indexStrings.Length]; + + for (int i = 0; i < indexStrings.Length; i++) + { + var indexString = indexStrings[i].Trim(); + if (int.TryParse(indexString, out int intIndex)) + { + indices[i] = TensorIndex.Single(intIndex); + } + else if (indexString == ":") + { + indices[i] = TensorIndex.Colon; + } + else if (indexString == "None") + { + indices[i] = TensorIndex.None; + } + else if (indexString == "...") + { + indices[i] = TensorIndex.Ellipsis; + } + else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") + { + indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); + } + else if (indexString.Contains(":")) + { + var rangeParts = indexString.Split(':'); + if (rangeParts.Length == 0) + { + indices[i] = TensorIndex.Slice(); + } + else if (rangeParts.Length == 1) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); + } + else if (rangeParts.Length == 2) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); + } + else if (rangeParts.Length == 3) + { + indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + return indices; + } + + /// + /// Serializes the input array of tensor indexes into a string representation. + /// + /// + /// + public static string SerializeIndexes(TensorIndex[] indexes) + { + return string.Join(", ", indexes); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs new file mode 100644 index 00000000..4e90fa35 --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs @@ -0,0 +1,169 @@ +using System; +using System.Runtime.InteropServices; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Helper class to convert between OpenCV mats, images and Torch tensors. + /// + public static class OpenCVHelper + { + private static Dictionary bitDepthLookup = new Dictionary { + { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, + { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, + { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, + { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, + { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, + { ScalarType.Int8, (IplDepth.S8, Depth.S8) } + }; + + private static ConcurrentDictionary deleters = new ConcurrentDictionary(); + + internal delegate void GCHandleDeleter(IntPtr memory); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_data(IntPtr handle); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); + + /// + /// Creates a tensor from a pointer to the data and the dimensions of the tensor. + /// + /// + /// + /// + /// + public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) + { + var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); + var gchp = GCHandle.ToIntPtr(dataHandle); + GCHandleDeleter deleter = null; + + deleter = new GCHandleDeleter((IntPtr ptrHandler) => + { + GCHandle.FromIntPtr(gchp).Free(); + deleters.TryRemove(deleter, out deleter); + }); + deleters.TryAdd(deleter, deleter); + + fixed (long* dimensionsPtr = dimensions) + { + IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + if (tensorHandle == IntPtr.Zero) { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + } + if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } + var output = Tensor.UnsafeCreateTensor(tensorHandle); + return output; + } + } + + /// + /// Converts an OpenCV image to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(IplImage image) + { + if (image == null) + { + return empty([ 0, 0, 0 ]); + } + + int width = image.Width; + int height = image.Height; + int channels = image.Channels; + + var iplDepth = image.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; + + IntPtr tensorDataPtr = image.ImageData; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts an OpenCV mat to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(Mat mat) + { + if (mat == null) + { + return empty([0, 0, 0 ]); + } + + int width = mat.Size.Width; + int height = mat.Size.Height; + int channels = mat.Channels; + + var depth = mat.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; + + IntPtr tensorDataPtr = mat.Data; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts a Torch tensor to an OpenCV image. + /// + /// + /// + public unsafe static IplImage ToImage(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var iplDepth = bitDepthLookup[tensorType].IplDepth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); + + return image; + } + + /// + /// Converts a Torch tensor to an OpenCV mat. + /// + /// + /// + public unsafe static Mat ToMat(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var depth = bitDepthLookup[tensorType].Depth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); + + return mat; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs new file mode 100644 index 00000000..91faf20b --- /dev/null +++ b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Bonsai.ML.Torch.Tensors; + +namespace Bonsai.ML.Torch.Helpers +{ + /// + /// Provides helper methods for working with tensor data types. + /// + public class TensorDataTypeHelper + { + private static readonly Dictionary _lookup = new Dictionary + { + { TensorDataType.Byte, (typeof(byte), "byte") }, + { TensorDataType.Int16, (typeof(short), "short") }, + { TensorDataType.Int32, (typeof(int), "int") }, + { TensorDataType.Int64, (typeof(long), "long") }, + { TensorDataType.Float32, (typeof(float), "float") }, + { TensorDataType.Float64, (typeof(double), "double") }, + { TensorDataType.Bool, (typeof(bool), "bool") }, + { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, + }; + + /// + /// Returns the type corresponding to the specified tensor data type. + /// + /// + /// + public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; + + /// + /// Returns the string representation corresponding to the specified tensor data type. + /// + /// + /// + public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; + + /// + /// Returns the tensor data type corresponding to the specified string representation. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Returns the tensor data type corresponding to the specified type. + /// + /// + /// + public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs new file mode 100644 index 00000000..7628f72d --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs @@ -0,0 +1,5 @@ +namespace Bonsai.ML.Torch.NeuralNets.Configuration; + +public class ModelConfiguration +{ +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs new file mode 100644 index 00000000..dfe56272 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs @@ -0,0 +1,5 @@ +namespace Bonsai.ML.Torch.NeuralNets.Configuration; + +public class ModuleConfiguration +{ +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs new file mode 100644 index 00000000..fb7722f2 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -0,0 +1,47 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; +using static TorchSharp.torch.nn; +using Bonsai.Expressions; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadPretrainedModel + { + public Models.PretrainedModels ModelName { get; set; } + public Device Device { get; set; } + + private int numClasses = 10; + + public IObservable Process() + { + Module model = null; + var modelName = ModelName.ToString().ToLower(); + var device = Device; + + switch (modelName) + { + case "alexnet": + model = new Models.AlexNet(modelName, numClasses, device); + break; + case "mobilenet": + model = new Models.MobileNet(modelName, numClasses, device); + break; + case "mnist": + model = new Models.MNIST(modelName, device); + break; + default: + throw new ArgumentException($"Model {modelName} not supported."); + } + + return Observable.Defer(() => { + return Observable.Return(model); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs b/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs new file mode 100644 index 00000000..035b3ca1 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs @@ -0,0 +1,5 @@ +namespace Bonsai.ML.Torch.NeuralNets; + +public class ModelManager +{ +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs new file mode 100644 index 00000000..4ca9f79c --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -0,0 +1,70 @@ +using System; +using System.IO; +using System.Linq; +using System.Collections.Generic; +using System.Diagnostics; + +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.torch.nn.functional; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// Modified version of original AlexNet to fix CIFAR10 32x32 images. + /// + public class AlexNet : Module + { + private readonly Module features; + private readonly Module avgPool; + private readonly Module classifier; + + public AlexNet(string name, int numClasses, Device device = null) : base(name) + { + features = Sequential( + ("c1", Conv2d(3, 64, kernelSize: 3, stride: 2, padding: 1)), + ("r1", ReLU(inplace: true)), + ("mp1", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("c2", Conv2d(64, 192, kernelSize: 3, padding: 1)), + ("r2", ReLU(inplace: true)), + ("mp2", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("c3", Conv2d(192, 384, kernelSize: 3, padding: 1)), + ("r3", ReLU(inplace: true)), + ("c4", Conv2d(384, 256, kernelSize: 3, padding: 1)), + ("r4", ReLU(inplace: true)), + ("c5", Conv2d(256, 256, kernelSize: 3, padding: 1)), + ("r5", ReLU(inplace: true)), + ("mp3", MaxPool2d(kernelSize: new long[] { 2, 2 }))); + + avgPool = AdaptiveAvgPool2d(new long[] { 2, 2 }); + + classifier = Sequential( + ("d1", Dropout()), + ("l1", Linear(256 * 2 * 2, 4096)), + ("r1", ReLU(inplace: true)), + ("d2", Dropout()), + ("l2", Linear(4096, 4096)), + ("r3", ReLU(inplace: true)), + ("d3", Dropout()), + ("l3", Linear(4096, numClasses)) + ); + + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + public override Tensor forward(Tensor input) + { + var f = features.forward(input); + var avg = avgPool.forward(f); + + var x = avg.view(new long[] { avg.shape[0], 256 * 2 * 2 }); + + return classifier.forward(x); + } + } + +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs new file mode 100644 index 00000000..b707e2d5 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -0,0 +1,61 @@ +using System; +using System.IO; +using System.Collections.Generic; +using System.Diagnostics; +using TorchSharp; +using static TorchSharp.torch; + +using static TorchSharp.torch.nn; +using static TorchSharp.torch.nn.functional; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + public class MNIST : Module + { + private Module conv1 = Conv2d(1, 32, 3); + private Module conv2 = Conv2d(32, 64, 3); + private Module fc1 = Linear(9216, 128); + private Module fc2 = Linear(128, 10); + + private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 }); + + private Module relu1 = ReLU(); + private Module relu2 = ReLU(); + private Module relu3 = ReLU(); + + private Module dropout1 = Dropout(0.25); + private Module dropout2 = Dropout(0.5); + + private Module flatten = Flatten(); + private Module logsm = LogSoftmax(1); + + public MNIST(string name, Device device = null) : base(name) + { + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + public override Tensor forward(Tensor input) + { + var l11 = conv1.forward(input); + var l12 = relu1.forward(l11); + + var l21 = conv2.forward(l12); + var l22 = relu2.forward(l21); + var l23 = pool1.forward(l22); + var l24 = dropout1.forward(l23); + + var x = flatten.forward(l24); + + var l31 = fc1.forward(x); + var l32 = relu3.forward(l31); + var l33 = dropout2.forward(l32); + + var l41 = fc2.forward(l33); + + return logsm.forward(l41); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs new file mode 100644 index 00000000..e9d66038 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using Bonsai.ML.Torch.Tensors; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; + +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + /// + /// Modified version of MobileNet to classify CIFAR10 32x32 images. + /// + /// + /// With an unaugmented CIFAR-10 data set, the author of this saw training converge + /// at roughly 75% accuracy on the test set, over the course of 1500 epochs. + /// + public class MobileNet : Module + { + // The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenet.py + // Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE + + private readonly long[] planes = new long[] { 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 }; + private readonly long[] strides = new long[] { 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 }; + + private readonly Module layers; + + public MobileNet(string name, int numClasses, Device device = null) : base(name) + { + if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length."); + + var modules = new List<(string, Module)>(); + + modules.Add(($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false))); + modules.Add(($"bnrm2d-first", BatchNorm2d(32))); + modules.Add(($"relu-first", ReLU())); + MakeLayers(modules, 32); + modules.Add(("avgpool", AvgPool2d(new long[] { 2, 2 }))); + modules.Add(("flatten", Flatten())); + modules.Add(($"linear", Linear(planes[planes.Length-1], numClasses))); + + layers = Sequential(modules); + + RegisterComponents(); + + if (device != null && device.type != DeviceType.CPU) + this.to(device); + } + + private void MakeLayers(List<(string, Module)> modules, long in_planes) + { + + for (var i = 0; i < strides.Length; i++) { + var out_planes = planes[i]; + var stride = strides[i]; + + modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernelSize: 3, stride: stride, padding: 1, groups: in_planes, bias: false))); + modules.Add(($"bnrm2d-{i}a", BatchNorm2d(in_planes))); + modules.Add(($"relu-{i}a", ReLU())); + modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernelSize: 1L, stride: 1L, padding: 0L, bias: false))); + modules.Add(($"bnrm2d-{i}b", BatchNorm2d(out_planes))); + modules.Add(($"relu-{i}b", ReLU())); + + in_planes = out_planes; + } + } + + public override Tensor forward(Tensor input) + { + return layers.forward(input); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs new file mode 100644 index 00000000..a3c65bdc --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs @@ -0,0 +1,9 @@ +namespace Bonsai.ML.Torch.NeuralNets.Models +{ + public enum PretrainedModels + { + AlexNet, + MobileNet, + MNIST + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Arange.cs b/src/Bonsai.ML.Torch/Tensors/Arange.cs new file mode 100644 index 00000000..011d0708 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Arange.cs @@ -0,0 +1,40 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a 1-D tensor of values within a given range given the start, end, and step. + /// + [Combinator] + [Description("Creates a 1-D tensor of values within a given range given the start, end, and step.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Arange + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 10; + + /// + /// The step of the range. + /// + public int Step { get; set; } = 1; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(arange(Start, End, Step))); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/Concat.cs b/src/Bonsai.ML.Torch/Tensors/Concat.cs new file mode 100644 index 00000000..52275bb7 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Concat.cs @@ -0,0 +1,45 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Concatenates tensors along a given dimension. + /// + [Combinator] + [Description("Concatenates tensors along a given dimension.")] + [WorkflowElementCategory(ElementCategory.Combinator)] + public class Concat + { + /// + /// The dimension along which to concatenate the tensors. + /// + public long Dimension { get; set; } = 0; + + /// + /// Takes any number of observable sequences of tensors and concatenates the input tensors along the specified dimension by zipping each tensor together. + /// + public IObservable Process(params IObservable[] sources) + { + return sources.Aggregate((current, next) => + current.Zip(next, (tensor1, tensor2) => + cat([tensor1, tensor2], Dimension))); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + var tensor1 = value.Item1; + var tensor2 = value.Item2; + return cat([tensor1, tensor2], Dimension); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs b/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs new file mode 100644 index 00000000..3683d2a6 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor to the specified scalar type. + /// + [Combinator] + [Description("Converts the input tensor to the specified scalar type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ConvertDataType + { + /// + /// The scalar type to which to convert the input tensor. + /// + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Returns an observable sequence that converts the input tensor to the specified scalar type. + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => + { + return tensor.to_type(Type); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs b/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs new file mode 100644 index 00000000..4585b70b --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs @@ -0,0 +1,245 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reactive.Linq; +using System.Reflection; +using System.Xml.Serialization; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// + [Combinator] + [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] + [WorkflowElementCategory(ElementCategory.Source)] + public class CreateTensor : ExpressionBuilder + { + Range argumentRange = new Range(0, 1); + + /// + public override Range ArgumentRange => argumentRange; + + /// + /// The data type of the tensor elements. + /// + public TensorDataType Type + { + get => scalarType; + set => scalarType = value; + } + + private TensorDataType scalarType = TensorDataType.Float32; + + /// + /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. + /// + public string Values + { + get => values; + set + { + values = value.Replace("False", "false").Replace("True", "true"); + } + } + + private string values = "[0]"; + + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + public Device Device { get => device; set => device = value; } + + private Device device = null; + + private Expression BuildTensorFromArray(Array arrayValues, Type returnType) + { + var rank = arrayValues.Rank; + var lengths = new int[rank]; + for (int i = 0; i < rank; i++) + { + lengths[i] = arrayValues.GetLength(i); + } + + var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); + var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); + var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); + + var assignments = new List(); + for (int i = 0; i < values.Length; i++) + { + var indices = new Expression[rank]; + int temp = i; + for (int j = rank - 1; j >= 0; j--) + { + indices[j] = Expression.Constant(temp % lengths[j]); + temp /= lengths[j]; + } + var value = Expression.Constant(arrayValues.GetValue(indices.Select(e => ((ConstantExpression)e).Value).Cast().ToArray())); + var arrayAccess = Expression.ArrayAccess(arrayVariable, indices); + var assignArrayValue = Expression.Assign(arrayAccess, value); + assignments.Add(assignArrayValue); + } + + var tensorDataInitializationBlock = Expression.Block( + arrayVariable, + assignArray, + Expression.Block(assignments), + arrayVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + arrayVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool), + typeof(string).MakeArrayType() + ] + ); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorDataInitializationBlock, + Expression.Constant(scalarType, typeof(ScalarType?)), + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)), + Expression.Constant(null, typeof(string).MakeArrayType()) + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + private Expression BuildTensorFromScalarValue(object scalarValue, Type returnType) + { + var valueVariable = Expression.Variable(returnType, "value"); + var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); + + var tensorDataInitializationBlock = Expression.Block( + valueVariable, + assignValue, + valueVariable + ); + + var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(Device), + typeof(bool) + ] + ); + + var tensorCreationMethodArguments = new Expression[] { + Expression.Constant(device, typeof(Device) ), + Expression.Constant(false, typeof(bool) ) + }; + + if (tensorCreationMethodInfo == null) + { + tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + "tensor", [ + valueVariable.Type, + typeof(ScalarType?), + typeof(Device), + typeof(bool) + ] + ); + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + Expression.Constant(scalarType, typeof(ScalarType?)) + ).ToArray(); + } + + tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + tensorDataInitializationBlock + ).ToArray(); + + var tensorAssignment = Expression.Call( + tensorCreationMethodInfo, + tensorCreationMethodArguments + ); + + var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); + + var buildTensor = Expression.Block( + tensorVariable, + assignTensor, + tensorVariable + ); + + return buildTensor; + } + + /// + public override Expression Build(IEnumerable arguments) + { + var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); + var argTypes = arguments.Select(arg => arg.Type).ToArray(); + + var methodInfoArgumentTypes = new Type[] { + typeof(Tensor) + }; + + var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name == "Process") + .ToArray(); + + var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) + .MakeGenericMethod( + arguments + .First() + .Type + .GetGenericArguments()[0] + ) : methods.FirstOrDefault(m => !m.IsGenericMethod); + + var tensorValues = Helpers.DataHelper.ParseString(values, returnType); + var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); + var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); + + try + { + return Expression.Call( + Expression.Constant(this), + methodInfo, + methodArguments + ); + } + finally + { + values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); + scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); + } + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values. + /// + public IObservable Process(Tensor tensor) + { + return Observable.Return(tensor); + } + + /// + /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. + /// + public IObservable Process(IObservable source, Tensor tensor) + { + return Observable.Select(source, (_) => tensor); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/Index.cs b/src/Bonsai.ML.Torch/Tensors/Index.cs new file mode 100644 index 00000000..78024237 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Index.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. + /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. + /// + [Combinator] + [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Index + { + /// + /// The indices to use for indexing the tensor. + /// + public string Indexes { get; set; } = string.Empty; + + /// + /// Indexes the input tensor with the specified indices. + /// + /// + /// + public IObservable Process(IObservable source) + { + var index = Helpers.IndexHelper.ParseString(Indexes); + return source.Select(tensor => { + return tensor.index(index); + }); + } + } +} diff --git a/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs new file mode 100644 index 00000000..2258467f --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using TorchSharp; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Initializes the Torch device with the specified device type. + /// + [Combinator] + [Description("Initializes the Torch device with the specified device type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class InitializeTorchDevice + { + /// + /// The device type to initialize. + /// + public DeviceType DeviceType { get; set; } + + /// + /// Initializes the Torch device with the specified device type. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + { + InitializeDeviceType(DeviceType); + return Observable.Return(new Device(DeviceType)); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Linspace.cs b/src/Bonsai.ML.Torch/Tensors/Linspace.cs new file mode 100644 index 00000000..6e7495f8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Linspace.cs @@ -0,0 +1,40 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. + /// + [Combinator] + [Description("Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Linspace + { + /// + /// The start of the range. + /// + public int Start { get; set; } = 0; + + /// + /// The end of the range. + /// + public int End { get; set; } = 1; + + /// + /// The number of points to generate. + /// + public int Count { get; set; } = 10; + + /// + /// Generates an observable sequence of 1-D tensors created with the function. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(linspace(Start, End, Count))); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs b/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs new file mode 100644 index 00000000..77f4cecb --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using static TorchSharp.torch; +using System.Linq; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class MeshGrid + { + /// + /// The indexing mode to use for the mesh grid. + /// + public string Indexing { get; set; } = "ij"; + + /// + /// Creates a mesh grid from the input tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(tensors => meshgrid(tensors, indexing: Indexing)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Ones.cs b/src/Bonsai.ML.Torch/Tensors/Ones.cs new file mode 100644 index 00000000..77768dd1 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Ones.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a tensor filled with ones. + /// + [Combinator] + [Description("Creates a tensor filled with ones.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Ones + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with ones. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Permute.cs b/src/Bonsai.ML.Torch/Tensors/Permute.cs new file mode 100644 index 00000000..317e34f8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Permute.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Permutes the dimensions of the input tensor according to the specified permutation. + /// + [Combinator] + [Description("Permutes the dimensions of the input tensor according to the specified permutation.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Permute + { + /// + /// The permutation of the dimensions. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Returns an observable sequence that permutes the dimensions of the input tensor according to the specified permutation. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.permute(Dimensions); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Reshape.cs b/src/Bonsai.ML.Torch/Tensors/Reshape.cs new file mode 100644 index 00000000..5d3e9412 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Reshape.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + [Combinator] + [Description("Reshapes the input tensor according to the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Reshape + { + /// + /// The dimensions of the reshaped tensor. + /// + public long[] Dimensions { get; set; } = [0]; + + /// + /// Reshapes the input tensor according to the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.reshape(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Set.cs b/src/Bonsai.ML.Torch/Tensors/Set.cs new file mode 100644 index 00000000..a4d8b2d2 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Set.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Sets the value of the input tensor at the specified index. + /// + [Combinator] + [Description("Sets the value of the input tensor at the specified index.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Set + { + /// + /// The index at which to set the value. + /// + public string Index + { + get => Helpers.IndexHelper.SerializeIndexes(indexes); + set => indexes = Helpers.IndexHelper.ParseString(value); + } + + private TensorIndex[] indexes; + + /// + /// The value to set at the specified index. + /// + [XmlIgnore] + public Tensor Value { get; set; } = null; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.index_put_(Value, indexes); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs b/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs new file mode 100644 index 00000000..de1ba8d2 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs @@ -0,0 +1,56 @@ +using System; +using System.Text; +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. + /// + public enum TensorDataType + { + /// + /// 8-bit unsigned integer. + /// + Byte = ScalarType.Byte, + + /// + /// 8-bit signed integer. + /// + Int8 = ScalarType.Int8, + + /// + /// 16-bit signed integer. + /// + Int16 = ScalarType.Int16, + + /// + /// 32-bit signed integer. + /// + Int32 = ScalarType.Int32, + + /// + /// 64-bit signed integer. + /// + Int64 = ScalarType.Int64, + + /// + /// 32-bit floating point. + /// + Float32 = ScalarType.Float32, + + /// + /// 64-bit floating point. + /// + Float64 = ScalarType.Float64, + + /// + /// Boolean. + /// + Bool = ScalarType.Bool + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToArray.cs b/src/Bonsai.ML.Torch/Tensors/ToArray.cs new file mode 100644 index 00000000..70083ad2 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToArray.cs @@ -0,0 +1,73 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor into an array of the specified element type. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + public TypeMapping Type { get; set; } + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + methodInfo = methodInfo.MakeGenericMethod(returnType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return tensor.data().ToArray(); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToDevice.cs b/src/Bonsai.ML.Torch/Tensors/ToDevice.cs new file mode 100644 index 00000000..4aa1b92a --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToDevice.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Moves the input tensor to the specified device. + /// + [Combinator] + [Description("Moves the input tensor to the specified device.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToDevice + { + /// + /// The device to which the input tensor should be moved. + /// + public Device Device { get; set; } + + /// + /// Returns the input tensor moved to the specified device. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tensor.to(Device); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToImage.cs b/src/Bonsai.ML.Torch/Tensors/ToImage.cs new file mode 100644 index 00000000..eebf8399 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToImage.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor into an OpenCV image. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToImage + { + /// + /// Converts the input tensor into an OpenCV image. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToImage); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToMat.cs b/src/Bonsai.ML.Torch/Tensors/ToMat.cs new file mode 100644 index 00000000..756ac636 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToMat.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input tensor into an OpenCV mat. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV mat.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToMat + { + /// + /// Converts the input tensor into an OpenCV mat. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToMat); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/ToTensor.cs b/src/Bonsai.ML.Torch/Tensors/ToTensor.cs new file mode 100644 index 00000000..753d4422 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/ToTensor.cs @@ -0,0 +1,134 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Converts the input value into a tensor. + /// + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class ToTensor + { + /// + /// Converts an int into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a double into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a byte into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a bool into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a float into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a long into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts a short into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an array into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return as_tensor(value); + }); + } + + /// + /// Converts an IplImage into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + + /// + /// Converts a Mat into a tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Helpers.OpenCVHelper.ToTensor); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tensors/Zeros.cs b/src/Bonsai.ML.Torch/Tensors/Zeros.cs new file mode 100644 index 00000000..256a43ed --- /dev/null +++ b/src/Bonsai.ML.Torch/Tensors/Zeros.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Tensors +{ + /// + /// Creates a tensor filled with zeros. + /// + [Combinator] + [Description("Creates a tensor filled with zeros.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class Zeros + { + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// Generates an observable sequence of tensors filled with zeros. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => Observable.Return(ones(Size))); + } + } +} \ No newline at end of file From ba7f53b57a368bd172dd06cedc159f51c576bc03 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 12:53:06 +0100 Subject: [PATCH 24/75] Removed previous Bonsai.ML.Tensors directory and contents --- src/Bonsai.ML.Tensors/Arange.cs | 40 --- .../Bonsai.ML.Tensors.csproj | 17 -- src/Bonsai.ML.Tensors/Concat.cs | 45 ---- src/Bonsai.ML.Tensors/ConvertDataType.cs | 32 --- src/Bonsai.ML.Tensors/CreateTensor.cs | 245 ------------------ src/Bonsai.ML.Tensors/Helpers/DataHelper.cs | 190 -------------- src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs | 91 ------- src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs | 169 ------------ .../Helpers/TensorDataTypeHelper.cs | 52 ---- src/Bonsai.ML.Tensors/Index.cs | 35 --- .../InitializeTorchDevice.cs | 35 --- src/Bonsai.ML.Tensors/Linspace.cs | 40 --- src/Bonsai.ML.Tensors/MeshGrid.cs | 33 --- src/Bonsai.ML.Tensors/Ones.cs | 30 --- src/Bonsai.ML.Tensors/Permute.cs | 33 --- src/Bonsai.ML.Tensors/Reshape.cs | 32 --- src/Bonsai.ML.Tensors/Set.cs | 48 ---- src/Bonsai.ML.Tensors/TensorDataType.cs | 56 ---- src/Bonsai.ML.Tensors/ToArray.cs | 73 ------ src/Bonsai.ML.Tensors/ToDevice.cs | 34 --- src/Bonsai.ML.Tensors/ToImage.cs | 28 -- src/Bonsai.ML.Tensors/ToMat.cs | 28 -- src/Bonsai.ML.Tensors/ToTensor.cs | 134 ---------- src/Bonsai.ML.Tensors/Zeros.cs | 30 --- 24 files changed, 1550 deletions(-) delete mode 100644 src/Bonsai.ML.Tensors/Arange.cs delete mode 100644 src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj delete mode 100644 src/Bonsai.ML.Tensors/Concat.cs delete mode 100644 src/Bonsai.ML.Tensors/ConvertDataType.cs delete mode 100644 src/Bonsai.ML.Tensors/CreateTensor.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/DataHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs delete mode 100644 src/Bonsai.ML.Tensors/Index.cs delete mode 100644 src/Bonsai.ML.Tensors/InitializeTorchDevice.cs delete mode 100644 src/Bonsai.ML.Tensors/Linspace.cs delete mode 100644 src/Bonsai.ML.Tensors/MeshGrid.cs delete mode 100644 src/Bonsai.ML.Tensors/Ones.cs delete mode 100644 src/Bonsai.ML.Tensors/Permute.cs delete mode 100644 src/Bonsai.ML.Tensors/Reshape.cs delete mode 100644 src/Bonsai.ML.Tensors/Set.cs delete mode 100644 src/Bonsai.ML.Tensors/TensorDataType.cs delete mode 100644 src/Bonsai.ML.Tensors/ToArray.cs delete mode 100644 src/Bonsai.ML.Tensors/ToDevice.cs delete mode 100644 src/Bonsai.ML.Tensors/ToImage.cs delete mode 100644 src/Bonsai.ML.Tensors/ToMat.cs delete mode 100644 src/Bonsai.ML.Tensors/ToTensor.cs delete mode 100644 src/Bonsai.ML.Tensors/Zeros.cs diff --git a/src/Bonsai.ML.Tensors/Arange.cs b/src/Bonsai.ML.Tensors/Arange.cs deleted file mode 100644 index 2a1eda40..00000000 --- a/src/Bonsai.ML.Tensors/Arange.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using TorchSharp; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a 1-D tensor of values within a given range given the start, end, and step. - /// - [Combinator] - [Description("Creates a 1-D tensor of values within a given range given the start, end, and step.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Arange - { - /// - /// The start of the range. - /// - public int Start { get; set; } = 0; - - /// - /// The end of the range. - /// - public int End { get; set; } = 10; - - /// - /// The step of the range. - /// - public int Step { get; set; } = 1; - - /// - /// Generates an observable sequence of 1-D tensors created with the function. - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(arange(Start, End, Step))); - } - } -} diff --git a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj b/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj deleted file mode 100644 index 8d87ac9b..00000000 --- a/src/Bonsai.ML.Tensors/Bonsai.ML.Tensors.csproj +++ /dev/null @@ -1,17 +0,0 @@ - - - Bonsai.ML.Tensors - A Bonsai package for TorchSharp tensor manipulations. - Bonsai Rx ML Tensors TorchSharp - net472;netstandard2.0 - true - - - - - - - - - - \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Concat.cs b/src/Bonsai.ML.Tensors/Concat.cs deleted file mode 100644 index 1dd99b7b..00000000 --- a/src/Bonsai.ML.Tensors/Concat.cs +++ /dev/null @@ -1,45 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Concatenates tensors along a given dimension. - /// - [Combinator] - [Description("Concatenates tensors along a given dimension.")] - [WorkflowElementCategory(ElementCategory.Combinator)] - public class Concat - { - /// - /// The dimension along which to concatenate the tensors. - /// - public long Dimension { get; set; } = 0; - - /// - /// Takes any number of observable sequences of tensors and concatenates the input tensors along the specified dimension by zipping each tensor together. - /// - public IObservable Process(params IObservable[] sources) - { - return sources.Aggregate((current, next) => - current.Zip(next, (tensor1, tensor2) => - cat([tensor1, tensor2], Dimension))); - } - - /// - /// Concatenates the input tensors along the specified dimension. - /// - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - var tensor1 = value.Item1; - var tensor2 = value.Item2; - return cat([tensor1, tensor2], Dimension); - }); - } - } -} diff --git a/src/Bonsai.ML.Tensors/ConvertDataType.cs b/src/Bonsai.ML.Tensors/ConvertDataType.cs deleted file mode 100644 index 14b0db84..00000000 --- a/src/Bonsai.ML.Tensors/ConvertDataType.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor to the specified scalar type. - /// - [Combinator] - [Description("Converts the input tensor to the specified scalar type.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ConvertDataType - { - /// - /// The scalar type to which to convert the input tensor. - /// - public ScalarType Type { get; set; } = ScalarType.Float32; - - /// - /// Returns an observable sequence that converts the input tensor to the specified scalar type. - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => - { - return tensor.to_type(Type); - }); - } - } -} diff --git a/src/Bonsai.ML.Tensors/CreateTensor.cs b/src/Bonsai.ML.Tensors/CreateTensor.cs deleted file mode 100644 index 712c7243..00000000 --- a/src/Bonsai.ML.Tensors/CreateTensor.cs +++ /dev/null @@ -1,245 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Linq.Expressions; -using System.Reactive.Linq; -using System.Reflection; -using System.Xml.Serialization; -using Bonsai.Expressions; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". - /// - [Combinator] - [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] - [WorkflowElementCategory(ElementCategory.Source)] - public class CreateTensor : ExpressionBuilder - { - Range argumentRange = new Range(0, 1); - - /// - public override Range ArgumentRange => argumentRange; - - /// - /// The data type of the tensor elements. - /// - public TensorDataType Type - { - get => scalarType; - set => scalarType = value; - } - - private TensorDataType scalarType = TensorDataType.Float32; - - /// - /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. - /// - public string Values - { - get => values; - set - { - values = value.Replace("False", "false").Replace("True", "true"); - } - } - - private string values = "[0]"; - - /// - /// The device on which to create the tensor. - /// - [XmlIgnore] - public Device Device { get => device; set => device = value; } - - private Device device = null; - - private Expression BuildTensorFromArray(Array arrayValues, Type returnType) - { - var rank = arrayValues.Rank; - var lengths = new int[rank]; - for (int i = 0; i < rank; i++) - { - lengths[i] = arrayValues.GetLength(i); - } - - var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); - var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); - var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); - - var assignments = new List(); - for (int i = 0; i < values.Length; i++) - { - var indices = new Expression[rank]; - int temp = i; - for (int j = rank - 1; j >= 0; j--) - { - indices[j] = Expression.Constant(temp % lengths[j]); - temp /= lengths[j]; - } - var value = Expression.Constant(arrayValues.GetValue(indices.Select(e => ((ConstantExpression)e).Value).Cast().ToArray())); - var arrayAccess = Expression.ArrayAccess(arrayVariable, indices); - var assignArrayValue = Expression.Assign(arrayAccess, value); - assignments.Add(assignArrayValue); - } - - var tensorDataInitializationBlock = Expression.Block( - arrayVariable, - assignArray, - Expression.Block(assignments), - arrayVariable - ); - - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( - "tensor", [ - arrayVariable.Type, - typeof(ScalarType?), - typeof(Device), - typeof(bool), - typeof(string).MakeArrayType() - ] - ); - - var tensorAssignment = Expression.Call( - tensorCreationMethodInfo, - tensorDataInitializationBlock, - Expression.Constant(scalarType, typeof(ScalarType?)), - Expression.Constant(device, typeof(Device)), - Expression.Constant(false, typeof(bool)), - Expression.Constant(null, typeof(string).MakeArrayType()) - ); - - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); - var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); - - var buildTensor = Expression.Block( - tensorVariable, - assignTensor, - tensorVariable - ); - - return buildTensor; - } - - private Expression BuildTensorFromScalarValue(object scalarValue, Type returnType) - { - var valueVariable = Expression.Variable(returnType, "value"); - var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); - - var tensorDataInitializationBlock = Expression.Block( - valueVariable, - assignValue, - valueVariable - ); - - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( - "tensor", [ - valueVariable.Type, - typeof(Device), - typeof(bool) - ] - ); - - var tensorCreationMethodArguments = new Expression[] { - Expression.Constant(device, typeof(Device) ), - Expression.Constant(false, typeof(bool) ) - }; - - if (tensorCreationMethodInfo == null) - { - tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( - "tensor", [ - valueVariable.Type, - typeof(ScalarType?), - typeof(Device), - typeof(bool) - ] - ); - - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( - Expression.Constant(scalarType, typeof(ScalarType?)) - ).ToArray(); - } - - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( - tensorDataInitializationBlock - ).ToArray(); - - var tensorAssignment = Expression.Call( - tensorCreationMethodInfo, - tensorCreationMethodArguments - ); - - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); - var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); - - var buildTensor = Expression.Block( - tensorVariable, - assignTensor, - tensorVariable - ); - - return buildTensor; - } - - /// - public override Expression Build(IEnumerable arguments) - { - var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); - var argTypes = arguments.Select(arg => arg.Type).ToArray(); - - var methodInfoArgumentTypes = new Type[] { - typeof(Tensor) - }; - - var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) - .Where(m => m.Name == "Process") - .ToArray(); - - var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) - .MakeGenericMethod( - arguments - .First() - .Type - .GetGenericArguments()[0] - ) : methods.FirstOrDefault(m => !m.IsGenericMethod); - - var tensorValues = Helpers.DataHelper.ParseString(values, returnType); - var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); - var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); - - try - { - return Expression.Call( - Expression.Constant(this), - methodInfo, - methodArguments - ); - } - finally - { - values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); - scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); - } - } - - /// - /// Returns an observable sequence that creates a tensor from the specified values. - /// - public IObservable Process(Tensor tensor) - { - return Observable.Return(tensor); - } - - /// - /// Returns an observable sequence that creates a tensor from the specified values for each element in the input sequence. - /// - public IObservable Process(IObservable source, Tensor tensor) - { - return Observable.Select(source, (_) => tensor); - } - } -} diff --git a/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs b/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs deleted file mode 100644 index 1bbf3228..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/DataHelper.cs +++ /dev/null @@ -1,190 +0,0 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Provides helper methods for parsing tensor data types. - /// - public static class DataHelper - { - - /// - /// Serializes the input data into a string representation. - /// - public static string SerializeData(object data) - { - if (data is Array array) - { - return SerializeArray(array); - } - else - { - return JsonConvert.SerializeObject(data); - } - } - - /// - /// Serializes the input array into a string representation. - /// - public static string SerializeArray(Array array) - { - StringBuilder sb = new StringBuilder(); - SerializeArrayRecursive(array, sb, [0]); - return sb.ToString(); - } - - private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) - { - if (indices.Length < array.Rank) - { - sb.Append("["); - int length = array.GetLength(indices.Length); - for (int i = 0; i < length; i++) - { - int[] newIndices = new int[indices.Length + 1]; - indices.CopyTo(newIndices, 0); - newIndices[indices.Length] = i; - SerializeArrayRecursive(array, sb, newIndices); - if (i < length - 1) - { - sb.Append(", "); - } - } - sb.Append("]"); - } - else - { - object value = array.GetValue(indices); - sb.Append(value.ToString()); - } - } - - private static bool IsValidJson(string input) - { - int squareBrackets = 0; - foreach (char c in input) - { - if (c == '[') squareBrackets++; - else if (c == ']') squareBrackets--; - } - return squareBrackets == 0; - } - - /// - /// Parses the input string into an object of the specified type. - /// - public static object ParseString(string input, Type dtype) - { - if (!IsValidJson(input)) - { - throw new ArgumentException("JSON is invalid."); - } - var obj = JsonConvert.DeserializeObject(input); - int depth = ParseDepth(obj); - if (depth == 0) - { - return Convert.ChangeType(input, dtype); - } - int[] dimensions = ParseDimensions(obj, depth); - var resultArray = Array.CreateInstance(dtype, dimensions); - PopulateArray(obj, resultArray, [0], dtype); - return resultArray; - } - - private static int ParseDepth(JToken token, int currentDepth = 0) - { - if (token is JArray arr && arr.Count > 0) - { - return ParseDepth(arr[0], currentDepth + 1); - } - return currentDepth; - } - - private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) - { - if (depth == 0 || !(token is JArray)) - { - return [0]; - } - - List dimensions = new List(); - JToken current = token; - - while (current != null && current is JArray) - { - JArray currentArray = current as JArray; - dimensions.Add(currentArray.Count); - if (currentArray.Count > 0) - { - if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) - { - throw new Exception("Error parsing input. Dimensions are inconsistent."); - } - - if (!(currentArray.First() is JArray)) - { - if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) - { - throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); - } - } - } - - current = currentArray.Count > 0 ? currentArray[0] : null; - } - - if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) - { - var subArrayDimensions = new HashSet(); - foreach (JArray subArr in arr) - { - int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); - subArrayDimensions.Add(string.Join(",", subDims)); - } - - if (subArrayDimensions.Count > 1) - { - throw new ArgumentException("Inconsistent array dimensions."); - } - } - - return dimensions.ToArray(); - } - - private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) - { - if (token is JArray arr) - { - for (int i = 0; i < arr.Count; i++) - { - int[] newIndices = new int[indices.Length + 1]; - Array.Copy(indices, newIndices, indices.Length); - newIndices[newIndices.Length - 1] = i; - PopulateArray(arr[i], array, newIndices, dtype); - } - } - else - { - var values = ConvertType(token, dtype); - array.SetValue(values, indices); - } - } - - private static object ConvertType(object value, Type targetType) - { - try - { - return Convert.ChangeType(value, targetType); - } - catch (Exception ex) - { - throw new Exception("Error parsing type: ", ex); - } - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs b/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs deleted file mode 100644 index 785eccea..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/IndexHelper.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Provides helper methods to parse tensor indexes. - /// - public static class IndexHelper - { - - /// - /// Parses the input string into an array of tensor indexes. - /// - /// - public static TensorIndex[] ParseString(string input) - { - if (string.IsNullOrEmpty(input)) - { - return [0]; - } - - var indexStrings = input.Split(','); - var indices = new TensorIndex[indexStrings.Length]; - - for (int i = 0; i < indexStrings.Length; i++) - { - var indexString = indexStrings[i].Trim(); - if (int.TryParse(indexString, out int intIndex)) - { - indices[i] = TensorIndex.Single(intIndex); - } - else if (indexString == ":") - { - indices[i] = TensorIndex.Colon; - } - else if (indexString == "None") - { - indices[i] = TensorIndex.None; - } - else if (indexString == "...") - { - indices[i] = TensorIndex.Ellipsis; - } - else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") - { - indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); - } - else if (indexString.Contains(":")) - { - var rangeParts = indexString.Split(':'); - if (rangeParts.Length == 0) - { - indices[i] = TensorIndex.Slice(); - } - else if (rangeParts.Length == 1) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); - } - else if (rangeParts.Length == 2) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); - } - else if (rangeParts.Length == 3) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - return indices; - } - - /// - /// Serializes the input array of tensor indexes into a string representation. - /// - /// - /// - public static string SerializeIndexes(TensorIndex[] indexes) - { - return string.Join(", ", indexes); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs deleted file mode 100644 index 265f2119..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/OpenCVHelper.cs +++ /dev/null @@ -1,169 +0,0 @@ -using System; -using System.Runtime.InteropServices; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Helper class to convert between OpenCV mats, images and Torch tensors. - /// - public static class OpenCVHelper - { - private static Dictionary bitDepthLookup = new Dictionary { - { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, - { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, - { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, - { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, - { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, - { ScalarType.Int8, (IplDepth.S8, Depth.S8) } - }; - - private static ConcurrentDictionary deleters = new ConcurrentDictionary(); - - internal delegate void GCHandleDeleter(IntPtr memory); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_data(IntPtr handle); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); - - /// - /// Creates a tensor from a pointer to the data and the dimensions of the tensor. - /// - /// - /// - /// - /// - public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) - { - var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); - var gchp = GCHandle.ToIntPtr(dataHandle); - GCHandleDeleter deleter = null; - - deleter = new GCHandleDeleter((IntPtr ptrHandler) => - { - GCHandle.FromIntPtr(gchp).Free(); - deleters.TryRemove(deleter, out deleter); - }); - deleters.TryAdd(deleter, deleter); - - fixed (long* dimensionsPtr = dimensions) - { - IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - if (tensorHandle == IntPtr.Zero) { - GC.Collect(); - GC.WaitForPendingFinalizers(); - tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - } - if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } - var output = Tensor.UnsafeCreateTensor(tensorHandle); - return output; - } - } - - /// - /// Converts an OpenCV image to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(IplImage image) - { - if (image == null) - { - return empty([ 0, 0, 0 ]); - } - - int width = image.Width; - int height = image.Height; - int channels = image.Channels; - - var iplDepth = image.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; - - IntPtr tensorDataPtr = image.ImageData; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts an OpenCV mat to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(Mat mat) - { - if (mat == null) - { - return empty([0, 0, 0 ]); - } - - int width = mat.Size.Width; - int height = mat.Size.Height; - int channels = mat.Channels; - - var depth = mat.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; - - IntPtr tensorDataPtr = mat.Data; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts a Torch tensor to an OpenCV image. - /// - /// - /// - public unsafe static IplImage ToImage(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var iplDepth = bitDepthLookup[tensorType].IplDepth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); - - return image; - } - - /// - /// Converts a Torch tensor to an OpenCV mat. - /// - /// - /// - public unsafe static Mat ToMat(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var depth = bitDepthLookup[tensorType].Depth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); - - return mat; - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs deleted file mode 100644 index 7ea03f65..00000000 --- a/src/Bonsai.ML.Tensors/Helpers/TensorDataTypeHelper.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Bonsai.ML.Tensors.Helpers -{ - /// - /// Provides helper methods for working with tensor data types. - /// - public class TensorDataTypeHelper - { - private static readonly Dictionary _lookup = new Dictionary - { - { TensorDataType.Byte, (typeof(byte), "byte") }, - { TensorDataType.Int16, (typeof(short), "short") }, - { TensorDataType.Int32, (typeof(int), "int") }, - { TensorDataType.Int64, (typeof(long), "long") }, - { TensorDataType.Float32, (typeof(float), "float") }, - { TensorDataType.Float64, (typeof(double), "double") }, - { TensorDataType.Bool, (typeof(bool), "bool") }, - { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, - }; - - /// - /// Returns the type corresponding to the specified tensor data type. - /// - /// - /// - public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; - - /// - /// Returns the string representation corresponding to the specified tensor data type. - /// - /// - /// - public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; - - /// - /// Returns the tensor data type corresponding to the specified string representation. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; - - /// - /// Returns the tensor data type corresponding to the specified type. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Index.cs b/src/Bonsai.ML.Tensors/Index.cs deleted file mode 100644 index 3c1948f9..00000000 --- a/src/Bonsai.ML.Tensors/Index.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. - /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. - /// - [Combinator] - [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Index - { - /// - /// The indices to use for indexing the tensor. - /// - public string Indexes { get; set; } = string.Empty; - - /// - /// Indexes the input tensor with the specified indices. - /// - /// - /// - public IObservable Process(IObservable source) - { - var index = Helpers.IndexHelper.ParseString(Indexes); - return source.Select(tensor => { - return tensor.index(index); - }); - } - } -} diff --git a/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs deleted file mode 100644 index dc9123f0..00000000 --- a/src/Bonsai.ML.Tensors/InitializeTorchDevice.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using TorchSharp; - -namespace Bonsai.ML.Tensors -{ - /// - /// Initializes the Torch device with the specified device type. - /// - [Combinator] - [Description("Initializes the Torch device with the specified device type.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class InitializeTorchDevice - { - /// - /// The device type to initialize. - /// - public DeviceType DeviceType { get; set; } - - /// - /// Initializes the Torch device with the specified device type. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => - { - InitializeDeviceType(DeviceType); - return Observable.Return(new Device(DeviceType)); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Linspace.cs b/src/Bonsai.ML.Tensors/Linspace.cs deleted file mode 100644 index aa263500..00000000 --- a/src/Bonsai.ML.Tensors/Linspace.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. - /// - [Combinator] - [Description("Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Linspace - { - /// - /// The start of the range. - /// - public int Start { get; set; } = 0; - - /// - /// The end of the range. - /// - public int End { get; set; } = 1; - - /// - /// The number of points to generate. - /// - public int Count { get; set; } = 10; - - /// - /// Generates an observable sequence of 1-D tensors created with the function. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(linspace(Start, End, Count))); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/MeshGrid.cs b/src/Bonsai.ML.Tensors/MeshGrid.cs deleted file mode 100644 index 6b0a2c73..00000000 --- a/src/Bonsai.ML.Tensors/MeshGrid.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using System.Collections.Generic; -using static TorchSharp.torch; -using System.Linq; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. - /// - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Source)] - public class MeshGrid - { - /// - /// The indexing mode to use for the mesh grid. - /// - public string Indexing { get; set; } = "ij"; - - /// - /// Creates a mesh grid from the input tensors. - /// - /// - /// - public IObservable Process(IObservable> source) - { - return source.Select(tensors => meshgrid(tensors, indexing: Indexing)); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Ones.cs b/src/Bonsai.ML.Tensors/Ones.cs deleted file mode 100644 index 499012bd..00000000 --- a/src/Bonsai.ML.Tensors/Ones.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a tensor filled with ones. - /// - [Combinator] - [Description("Creates a tensor filled with ones.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Ones - { - /// - /// The size of the tensor. - /// - public long[] Size { get; set; } = [0]; - - /// - /// Generates an observable sequence of tensors filled with ones. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(ones(Size))); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Permute.cs b/src/Bonsai.ML.Tensors/Permute.cs deleted file mode 100644 index 7f037d79..00000000 --- a/src/Bonsai.ML.Tensors/Permute.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Permutes the dimensions of the input tensor according to the specified permutation. - /// - [Combinator] - [Description("Permutes the dimensions of the input tensor according to the specified permutation.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Permute - { - /// - /// The permutation of the dimensions. - /// - public long[] Dimensions { get; set; } = [0]; - - /// - /// Returns an observable sequence that permutes the dimensions of the input tensor according to the specified permutation. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return tensor.permute(Dimensions); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Reshape.cs b/src/Bonsai.ML.Tensors/Reshape.cs deleted file mode 100644 index 4fef3d83..00000000 --- a/src/Bonsai.ML.Tensors/Reshape.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Reshapes the input tensor according to the specified dimensions. - /// - [Combinator] - [Description("Reshapes the input tensor according to the specified dimensions.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Reshape - { - /// - /// The dimensions of the reshaped tensor. - /// - public long[] Dimensions { get; set; } = [0]; - - /// - /// Reshapes the input tensor according to the specified dimensions. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(input => input.reshape(Dimensions)); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Set.cs b/src/Bonsai.ML.Tensors/Set.cs deleted file mode 100644 index 7f6f8b92..00000000 --- a/src/Bonsai.ML.Tensors/Set.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using System.Xml.Serialization; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Sets the value of the input tensor at the specified index. - /// - [Combinator] - [Description("Sets the value of the input tensor at the specified index.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Set - { - /// - /// The index at which to set the value. - /// - public string Index - { - get => Helpers.IndexHelper.SerializeIndexes(indexes); - set => indexes = Helpers.IndexHelper.ParseString(value); - } - - private TensorIndex[] indexes; - - /// - /// The value to set at the specified index. - /// - [XmlIgnore] - public Tensor Value { get; set; } = null; - - /// - /// Returns an observable sequence that sets the value of the input tensor at the specified index. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return tensor.index_put_(Value, indexes); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/TensorDataType.cs b/src/Bonsai.ML.Tensors/TensorDataType.cs deleted file mode 100644 index a710a9ed..00000000 --- a/src/Bonsai.ML.Tensors/TensorDataType.cs +++ /dev/null @@ -1,56 +0,0 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. - /// - public enum TensorDataType - { - /// - /// 8-bit unsigned integer. - /// - Byte = ScalarType.Byte, - - /// - /// 8-bit signed integer. - /// - Int8 = ScalarType.Int8, - - /// - /// 16-bit signed integer. - /// - Int16 = ScalarType.Int16, - - /// - /// 32-bit signed integer. - /// - Int32 = ScalarType.Int32, - - /// - /// 64-bit signed integer. - /// - Int64 = ScalarType.Int64, - - /// - /// 32-bit floating point. - /// - Float32 = ScalarType.Float32, - - /// - /// 64-bit floating point. - /// - Float64 = ScalarType.Float64, - - /// - /// Boolean. - /// - Bool = ScalarType.Bool - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToArray.cs b/src/Bonsai.ML.Tensors/ToArray.cs deleted file mode 100644 index af35ab4f..00000000 --- a/src/Bonsai.ML.Tensors/ToArray.cs +++ /dev/null @@ -1,73 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using System.Xml.Serialization; -using System.Linq.Expressions; -using System.Reflection; -using Bonsai.Expressions; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor into an array of the specified element type. - /// - [Combinator] - [Description("Converts the input tensor into an array of the specified element type.")] - [WorkflowElementCategory(ElementCategory.Transform)] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - [XmlInclude(typeof(TypeMapping))] - public class ToArray : SingleArgumentExpressionBuilder - { - /// - /// Initializes a new instance of the class. - /// - public ToArray() - { - Type = new TypeMapping(); - } - - /// - /// Gets or sets the type mapping used to convert the input tensor into an array. - /// - public TypeMapping Type { get; set; } - - /// - public override Expression Build(IEnumerable arguments) - { - TypeMapping typeMapping = Type; - var returnType = typeMapping.GetType().GetGenericArguments()[0]; - MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); - methodInfo = methodInfo.MakeGenericMethod(returnType); - Expression sourceExpression = arguments.First(); - - return Expression.Call( - Expression.Constant(this), - methodInfo, - sourceExpression - ); - } - - /// - /// Converts the input tensor into an array of the specified element type. - /// - /// - /// - /// - public IObservable Process(IObservable source) where T : unmanaged - { - return source.Select(tensor => - { - return tensor.data().ToArray(); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToDevice.cs b/src/Bonsai.ML.Tensors/ToDevice.cs deleted file mode 100644 index 574be5f3..00000000 --- a/src/Bonsai.ML.Tensors/ToDevice.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Moves the input tensor to the specified device. - /// - [Combinator] - [Description("Moves the input tensor to the specified device.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToDevice - { - /// - /// The device to which the input tensor should be moved. - /// - public Device Device { get; set; } - - /// - /// Returns the input tensor moved to the specified device. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return tensor.to(Device); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToImage.cs b/src/Bonsai.ML.Tensors/ToImage.cs deleted file mode 100644 index e29a3825..00000000 --- a/src/Bonsai.ML.Tensors/ToImage.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor into an OpenCV image. - /// - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToImage - { - /// - /// Converts the input tensor into an OpenCV image. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToImage); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToMat.cs b/src/Bonsai.ML.Tensors/ToMat.cs deleted file mode 100644 index 8a22f408..00000000 --- a/src/Bonsai.ML.Tensors/ToMat.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input tensor into an OpenCV mat. - /// - [Combinator] - [Description("Converts the input tensor into an OpenCV mat.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToMat - { - /// - /// Converts the input tensor into an OpenCV mat. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToMat); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/ToTensor.cs b/src/Bonsai.ML.Tensors/ToTensor.cs deleted file mode 100644 index 083e2797..00000000 --- a/src/Bonsai.ML.Tensors/ToTensor.cs +++ /dev/null @@ -1,134 +0,0 @@ -using System; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Converts the input value into a tensor. - /// - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class ToTensor - { - /// - /// Converts an int into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a double into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a byte into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a bool into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a float into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a long into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts a short into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts an array into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(value => { - return as_tensor(value); - }); - } - - /// - /// Converts an IplImage into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToTensor); - } - - /// - /// Converts a Mat into a tensor. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(Helpers.OpenCVHelper.ToTensor); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Tensors/Zeros.cs b/src/Bonsai.ML.Tensors/Zeros.cs deleted file mode 100644 index af220641..00000000 --- a/src/Bonsai.ML.Tensors/Zeros.cs +++ /dev/null @@ -1,30 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Tensors -{ - /// - /// Creates a tensor filled with zeros. - /// - [Combinator] - [Description("Creates a tensor filled with zeros.")] - [WorkflowElementCategory(ElementCategory.Source)] - public class Zeros - { - /// - /// The size of the tensor. - /// - public long[] Size { get; set; } = [0]; - - /// - /// Generates an observable sequence of tensors filled with zeros. - /// - /// - public IObservable Process() - { - return Observable.Defer(() => Observable.Return(ones(Size))); - } - } -} \ No newline at end of file From 10aa92b38023391f13d4d7f492e2d76c3726be71 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 12:54:13 +0100 Subject: [PATCH 25/75] Updated solution to reflect change --- Bonsai.ML.sln | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index 22b8a35a..30c6b6f1 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -30,7 +30,7 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.LinearDynamicalSy EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModels.Design", "src\Bonsai.ML.HiddenMarkovModels.Design\Bonsai.ML.HiddenMarkovModels.Design.csproj", "{FC395DDC-62A4-4E14-A198-272AB05B33C7}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tensors", "src\Bonsai.ML.Tensors\Bonsai.ML.Tensors.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch", "src\Bonsai.ML.Torch\Bonsai.ML.Torch.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution From f7e372cc33166d56bd667f93fbd7ea7188d244d5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 13:30:34 +0100 Subject: [PATCH 26/75] Moved Tensors namespace to main Torch namespace --- src/Bonsai.ML.Torch/{Tensors => }/Arange.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Concat.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ConvertDataType.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/CreateTensor.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Index.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/InitializeTorchDevice.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Linspace.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/MeshGrid.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Ones.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Permute.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Reshape.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Set.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/TensorDataType.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToArray.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToDevice.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToImage.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToMat.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/ToTensor.cs | 2 +- src/Bonsai.ML.Torch/{Tensors => }/Zeros.cs | 2 +- 19 files changed, 19 insertions(+), 19 deletions(-) rename src/Bonsai.ML.Torch/{Tensors => }/Arange.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/Concat.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/ConvertDataType.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/CreateTensor.cs (99%) rename src/Bonsai.ML.Torch/{Tensors => }/Index.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/InitializeTorchDevice.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Linspace.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/MeshGrid.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Ones.cs (95%) rename src/Bonsai.ML.Torch/{Tensors => }/Permute.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Reshape.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/Set.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/TensorDataType.cs (97%) rename src/Bonsai.ML.Torch/{Tensors => }/ToArray.cs (98%) rename src/Bonsai.ML.Torch/{Tensors => }/ToDevice.cs (96%) rename src/Bonsai.ML.Torch/{Tensors => }/ToImage.cs (95%) rename src/Bonsai.ML.Torch/{Tensors => }/ToMat.cs (95%) rename src/Bonsai.ML.Torch/{Tensors => }/ToTensor.cs (99%) rename src/Bonsai.ML.Torch/{Tensors => }/Zeros.cs (95%) diff --git a/src/Bonsai.ML.Torch/Tensors/Arange.cs b/src/Bonsai.ML.Torch/Arange.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Arange.cs rename to src/Bonsai.ML.Torch/Arange.cs index 011d0708..14e3259b 100644 --- a/src/Bonsai.ML.Torch/Tensors/Arange.cs +++ b/src/Bonsai.ML.Torch/Arange.cs @@ -4,7 +4,7 @@ using static TorchSharp.torch; using TorchSharp; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a 1-D tensor of values within a given range given the start, end, and step. diff --git a/src/Bonsai.ML.Torch/Tensors/Concat.cs b/src/Bonsai.ML.Torch/Concat.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Concat.cs rename to src/Bonsai.ML.Torch/Concat.cs index 52275bb7..b07b211d 100644 --- a/src/Bonsai.ML.Torch/Tensors/Concat.cs +++ b/src/Bonsai.ML.Torch/Concat.cs @@ -4,7 +4,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Concatenates tensors along a given dimension. diff --git a/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs b/src/Bonsai.ML.Torch/ConvertDataType.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs rename to src/Bonsai.ML.Torch/ConvertDataType.cs index 3683d2a6..59981adc 100644 --- a/src/Bonsai.ML.Torch/Tensors/ConvertDataType.cs +++ b/src/Bonsai.ML.Torch/ConvertDataType.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor to the specified scalar type. diff --git a/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs similarity index 99% rename from src/Bonsai.ML.Torch/Tensors/CreateTensor.cs rename to src/Bonsai.ML.Torch/CreateTensor.cs index 4585b70b..0100f920 100644 --- a/src/Bonsai.ML.Torch/Tensors/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -9,7 +9,7 @@ using Bonsai.Expressions; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". diff --git a/src/Bonsai.ML.Torch/Tensors/Index.cs b/src/Bonsai.ML.Torch/Index.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Index.cs rename to src/Bonsai.ML.Torch/Index.cs index 78024237..5b7b9192 100644 --- a/src/Bonsai.ML.Torch/Tensors/Index.cs +++ b/src/Bonsai.ML.Torch/Index.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. diff --git a/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs rename to src/Bonsai.ML.Torch/InitializeTorchDevice.cs index 2258467f..e82daa36 100644 --- a/src/Bonsai.ML.Torch/Tensors/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -4,7 +4,7 @@ using static TorchSharp.torch; using TorchSharp; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Initializes the Torch device with the specified device type. diff --git a/src/Bonsai.ML.Torch/Tensors/Linspace.cs b/src/Bonsai.ML.Torch/Linspace.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Linspace.cs rename to src/Bonsai.ML.Torch/Linspace.cs index 6e7495f8..ee6516cf 100644 --- a/src/Bonsai.ML.Torch/Tensors/Linspace.cs +++ b/src/Bonsai.ML.Torch/Linspace.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a 1-D tensor of linearly interpolated values within a given range given the start, end, and count. diff --git a/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs b/src/Bonsai.ML.Torch/MeshGrid.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/MeshGrid.cs rename to src/Bonsai.ML.Torch/MeshGrid.cs index 77f4cecb..725b12a9 100644 --- a/src/Bonsai.ML.Torch/Tensors/MeshGrid.cs +++ b/src/Bonsai.ML.Torch/MeshGrid.cs @@ -5,7 +5,7 @@ using static TorchSharp.torch; using System.Linq; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. diff --git a/src/Bonsai.ML.Torch/Tensors/Ones.cs b/src/Bonsai.ML.Torch/Ones.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/Ones.cs rename to src/Bonsai.ML.Torch/Ones.cs index 77768dd1..52bf8732 100644 --- a/src/Bonsai.ML.Torch/Tensors/Ones.cs +++ b/src/Bonsai.ML.Torch/Ones.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a tensor filled with ones. diff --git a/src/Bonsai.ML.Torch/Tensors/Permute.cs b/src/Bonsai.ML.Torch/Permute.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/Permute.cs rename to src/Bonsai.ML.Torch/Permute.cs index 317e34f8..a82107ba 100644 --- a/src/Bonsai.ML.Torch/Tensors/Permute.cs +++ b/src/Bonsai.ML.Torch/Permute.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Permutes the dimensions of the input tensor according to the specified permutation. diff --git a/src/Bonsai.ML.Torch/Tensors/Reshape.cs b/src/Bonsai.ML.Torch/Reshape.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/Reshape.cs rename to src/Bonsai.ML.Torch/Reshape.cs index 5d3e9412..ebdc8e41 100644 --- a/src/Bonsai.ML.Torch/Tensors/Reshape.cs +++ b/src/Bonsai.ML.Torch/Reshape.cs @@ -4,7 +4,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Reshapes the input tensor according to the specified dimensions. diff --git a/src/Bonsai.ML.Torch/Tensors/Set.cs b/src/Bonsai.ML.Torch/Set.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/Set.cs rename to src/Bonsai.ML.Torch/Set.cs index a4d8b2d2..14bf3dad 100644 --- a/src/Bonsai.ML.Torch/Tensors/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -6,7 +6,7 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Sets the value of the input tensor at the specified index. diff --git a/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs b/src/Bonsai.ML.Torch/TensorDataType.cs similarity index 97% rename from src/Bonsai.ML.Torch/Tensors/TensorDataType.cs rename to src/Bonsai.ML.Torch/TensorDataType.cs index de1ba8d2..fe8861f3 100644 --- a/src/Bonsai.ML.Torch/Tensors/TensorDataType.cs +++ b/src/Bonsai.ML.Torch/TensorDataType.cs @@ -6,7 +6,7 @@ using Newtonsoft.Json.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. diff --git a/src/Bonsai.ML.Torch/Tensors/ToArray.cs b/src/Bonsai.ML.Torch/ToArray.cs similarity index 98% rename from src/Bonsai.ML.Torch/Tensors/ToArray.cs rename to src/Bonsai.ML.Torch/ToArray.cs index 70083ad2..1c2c721a 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToArray.cs +++ b/src/Bonsai.ML.Torch/ToArray.cs @@ -9,7 +9,7 @@ using Bonsai.Expressions; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor into an array of the specified element type. diff --git a/src/Bonsai.ML.Torch/Tensors/ToDevice.cs b/src/Bonsai.ML.Torch/ToDevice.cs similarity index 96% rename from src/Bonsai.ML.Torch/Tensors/ToDevice.cs rename to src/Bonsai.ML.Torch/ToDevice.cs index 4aa1b92a..cb73f733 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToDevice.cs +++ b/src/Bonsai.ML.Torch/ToDevice.cs @@ -4,7 +4,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Moves the input tensor to the specified device. diff --git a/src/Bonsai.ML.Torch/Tensors/ToImage.cs b/src/Bonsai.ML.Torch/ToImage.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/ToImage.cs rename to src/Bonsai.ML.Torch/ToImage.cs index eebf8399..894a9602 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToImage.cs +++ b/src/Bonsai.ML.Torch/ToImage.cs @@ -5,7 +5,7 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor into an OpenCV image. diff --git a/src/Bonsai.ML.Torch/Tensors/ToMat.cs b/src/Bonsai.ML.Torch/ToMat.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/ToMat.cs rename to src/Bonsai.ML.Torch/ToMat.cs index 756ac636..fa50020c 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToMat.cs +++ b/src/Bonsai.ML.Torch/ToMat.cs @@ -5,7 +5,7 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input tensor into an OpenCV mat. diff --git a/src/Bonsai.ML.Torch/Tensors/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs similarity index 99% rename from src/Bonsai.ML.Torch/Tensors/ToTensor.cs rename to src/Bonsai.ML.Torch/ToTensor.cs index 753d4422..5bb460de 100644 --- a/src/Bonsai.ML.Torch/Tensors/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -5,7 +5,7 @@ using OpenCV.Net; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Converts the input value into a tensor. diff --git a/src/Bonsai.ML.Torch/Tensors/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs similarity index 95% rename from src/Bonsai.ML.Torch/Tensors/Zeros.cs rename to src/Bonsai.ML.Torch/Zeros.cs index 256a43ed..5af526d6 100644 --- a/src/Bonsai.ML.Torch/Tensors/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Tensors +namespace Bonsai.ML.Torch { /// /// Creates a tensor filled with zeros. From 3692cb39a6f2d6574189262f0b11a4f326a90d2d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 13:35:03 +0100 Subject: [PATCH 27/75] Updated to reflect new namespace --- src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs | 1 - src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs index 91faf20b..66a5396b 100644 --- a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs +++ b/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Linq; -using Bonsai.ML.Torch.Tensors; namespace Bonsai.ML.Torch.Helpers { diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index e9d66038..f82a33f9 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using Bonsai.ML.Torch.Tensors; using TorchSharp; using static TorchSharp.torch; using static TorchSharp.torch.nn; From b51fb535aa41add455f766ef884f872709e5e7f8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Aug 2024 13:42:25 +0100 Subject: [PATCH 28/75] Removed unfinished classes --- .../NeuralNets/Configuration/ModelConfiguration.cs | 5 ----- .../NeuralNets/Configuration/ModuleConfiguration.cs | 5 ----- src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs | 5 ----- 3 files changed, 15 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs deleted file mode 100644 index 7628f72d..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModelConfiguration.cs +++ /dev/null @@ -1,5 +0,0 @@ -namespace Bonsai.ML.Torch.NeuralNets.Configuration; - -public class ModelConfiguration -{ -} diff --git a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs b/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs deleted file mode 100644 index dfe56272..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/Configuration/ModuleConfiguration.cs +++ /dev/null @@ -1,5 +0,0 @@ -namespace Bonsai.ML.Torch.NeuralNets.Configuration; - -public class ModuleConfiguration -{ -} diff --git a/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs b/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs deleted file mode 100644 index 035b3ca1..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/ModelManager.cs +++ /dev/null @@ -1,5 +0,0 @@ -namespace Bonsai.ML.Torch.NeuralNets; - -public class ModelManager -{ -} From 3c1838caec615a8f4888767ae8c16b11e975d58f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:14:31 +0100 Subject: [PATCH 29/75] Updated to use common Bonsai.ML.Data project --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 1 + src/Bonsai.ML.Torch/Helpers/DataHelper.cs | 190 --------------------- 2 files changed, 1 insertion(+), 190 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/Helpers/DataHelper.cs diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index 9ed3c5d8..bb401adc 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -13,5 +13,6 @@ + \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Helpers/DataHelper.cs b/src/Bonsai.ML.Torch/Helpers/DataHelper.cs deleted file mode 100644 index ffed053a..00000000 --- a/src/Bonsai.ML.Torch/Helpers/DataHelper.cs +++ /dev/null @@ -1,190 +0,0 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; - -namespace Bonsai.ML.Torch.Helpers -{ - /// - /// Provides helper methods for parsing tensor data types. - /// - public static class DataHelper - { - - /// - /// Serializes the input data into a string representation. - /// - public static string SerializeData(object data) - { - if (data is Array array) - { - return SerializeArray(array); - } - else - { - return JsonConvert.SerializeObject(data); - } - } - - /// - /// Serializes the input array into a string representation. - /// - public static string SerializeArray(Array array) - { - StringBuilder sb = new StringBuilder(); - SerializeArrayRecursive(array, sb, [0]); - return sb.ToString(); - } - - private static void SerializeArrayRecursive(Array array, StringBuilder sb, int[] indices) - { - if (indices.Length < array.Rank) - { - sb.Append("["); - int length = array.GetLength(indices.Length); - for (int i = 0; i < length; i++) - { - int[] newIndices = new int[indices.Length + 1]; - indices.CopyTo(newIndices, 0); - newIndices[indices.Length] = i; - SerializeArrayRecursive(array, sb, newIndices); - if (i < length - 1) - { - sb.Append(", "); - } - } - sb.Append("]"); - } - else - { - object value = array.GetValue(indices); - sb.Append(value.ToString()); - } - } - - private static bool IsValidJson(string input) - { - int squareBrackets = 0; - foreach (char c in input) - { - if (c == '[') squareBrackets++; - else if (c == ']') squareBrackets--; - } - return squareBrackets == 0; - } - - /// - /// Parses the input string into an object of the specified type. - /// - public static object ParseString(string input, Type dtype) - { - if (!IsValidJson(input)) - { - throw new ArgumentException("JSON is invalid."); - } - var obj = JsonConvert.DeserializeObject(input); - int depth = ParseDepth(obj); - if (depth == 0) - { - return Convert.ChangeType(input, dtype); - } - int[] dimensions = ParseDimensions(obj, depth); - var resultArray = Array.CreateInstance(dtype, dimensions); - PopulateArray(obj, resultArray, [0], dtype); - return resultArray; - } - - private static int ParseDepth(JToken token, int currentDepth = 0) - { - if (token is JArray arr && arr.Count > 0) - { - return ParseDepth(arr[0], currentDepth + 1); - } - return currentDepth; - } - - private static int[] ParseDimensions(JToken token, int depth, int currentLevel = 0) - { - if (depth == 0 || !(token is JArray)) - { - return [0]; - } - - List dimensions = new List(); - JToken current = token; - - while (current != null && current is JArray) - { - JArray currentArray = current as JArray; - dimensions.Add(currentArray.Count); - if (currentArray.Count > 0) - { - if (currentArray.Any(item => !(item is JArray)) && currentArray.Any(item => item is JArray) || currentArray.All(item => item is JArray) && currentArray.Any(item => ((JArray)item).Count != ((JArray)currentArray.First()).Count)) - { - throw new Exception("Error parsing input. Dimensions are inconsistent."); - } - - if (!(currentArray.First() is JArray)) - { - if (!currentArray.All(item => double.TryParse(item.ToString(), out _)) && !currentArray.All(item => bool.TryParse(item.ToString(), out _))) - { - throw new Exception("Error parsing types. All values must be of the same type and only numeric or boolean types are supported."); - } - } - } - - current = currentArray.Count > 0 ? currentArray[0] : null; - } - - if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray)) - { - var subArrayDimensions = new HashSet(); - foreach (JArray subArr in arr) - { - int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1); - subArrayDimensions.Add(string.Join(",", subDims)); - } - - if (subArrayDimensions.Count > 1) - { - throw new ArgumentException("Inconsistent array dimensions."); - } - } - - return dimensions.ToArray(); - } - - private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype) - { - if (token is JArray arr) - { - for (int i = 0; i < arr.Count; i++) - { - int[] newIndices = new int[indices.Length + 1]; - Array.Copy(indices, newIndices, indices.Length); - newIndices[newIndices.Length - 1] = i; - PopulateArray(arr[i], array, newIndices, dtype); - } - } - else - { - var values = ConvertType(token, dtype); - array.SetValue(values, indices); - } - } - - private static object ConvertType(object value, Type targetType) - { - try - { - return Convert.ChangeType(value, targetType); - } - catch (Exception ex) - { - throw new Exception("Error parsing type: ", ex); - } - } - } -} \ No newline at end of file From c548469d7bd287bf299cff8cf69320894dce7771 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:15:11 +0100 Subject: [PATCH 30/75] Added additional overloads to process method --- src/Bonsai.ML.Torch/Concat.cs | 60 +++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/Concat.cs b/src/Bonsai.ML.Torch/Concat.cs index b07b211d..34adf731 100644 --- a/src/Bonsai.ML.Torch/Concat.cs +++ b/src/Bonsai.ML.Torch/Concat.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Linq; using System.Reactive.Linq; +using System.Collections.Generic; using static TorchSharp.torch; namespace Bonsai.ML.Torch @@ -36,9 +37,62 @@ public IObservable Process(IObservable> source) { return source.Select(value => { - var tensor1 = value.Item1; - var tensor2 = value.Item2; - return cat([tensor1, tensor2], Dimension); + return cat([value.Item1, value.Item2], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4, value.Item5], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat([value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6], Dimension); + }); + } + + /// + /// Concatenates the input tensors along the specified dimension. + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cat(value.ToList(), Dimension); }); } } From a594f0457844a86f96d3b552aa482588b78dcec4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:15:54 +0100 Subject: [PATCH 31/75] Updated to use common data tools --- src/Bonsai.ML.Torch/CreateTensor.cs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 0100f920..1fc8ef5e 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -8,6 +8,8 @@ using System.Xml.Serialization; using Bonsai.Expressions; using static TorchSharp.torch; +using Bonsai.ML.Data; +using Bonsai.ML.Torch.Helpers; namespace Bonsai.ML.Torch { @@ -188,7 +190,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp /// public override Expression Build(IEnumerable arguments) { - var returnType = Helpers.TensorDataTypeHelper.GetTypeFromTensorDataType(scalarType); + var returnType = TensorDataTypeLookup.GetTypeFromTensorDataType(scalarType); var argTypes = arguments.Select(arg => arg.Type).ToArray(); var methodInfoArgumentTypes = new Type[] { @@ -207,7 +209,7 @@ public override Expression Build(IEnumerable arguments) .GetGenericArguments()[0] ) : methods.FirstOrDefault(m => !m.IsGenericMethod); - var tensorValues = Helpers.DataHelper.ParseString(values, returnType); + var tensorValues = ArrayHelper.ParseString(values, returnType); var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); @@ -221,8 +223,8 @@ public override Expression Build(IEnumerable arguments) } finally { - values = Helpers.DataHelper.SerializeData(tensorValues).Replace("False", "false").Replace("True", "true"); - scalarType = Helpers.TensorDataTypeHelper.GetTensorDataTypeFromType(returnType); + values = ArrayHelper.SerializeToJson(tensorValues).ToLower(); + scalarType = TensorDataTypeLookup.GetTensorDataTypeFromType(returnType); } } From 3a15591899e4019f3dbcc0e6d342be2041a20786 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:16:11 +0100 Subject: [PATCH 32/75] Updated tensor data type and helper --- src/Bonsai.ML.Torch/TensorDataType.cs | 6 ------ .../TensorDataTypeHelper.cs => TensorDataTypeLookup.cs} | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) rename src/Bonsai.ML.Torch/{Helpers/TensorDataTypeHelper.cs => TensorDataTypeLookup.cs} (98%) diff --git a/src/Bonsai.ML.Torch/TensorDataType.cs b/src/Bonsai.ML.Torch/TensorDataType.cs index fe8861f3..f76a04c1 100644 --- a/src/Bonsai.ML.Torch/TensorDataType.cs +++ b/src/Bonsai.ML.Torch/TensorDataType.cs @@ -1,9 +1,3 @@ -using System; -using System.Text; -using System.Collections.Generic; -using System.Linq; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; using static TorchSharp.torch; namespace Bonsai.ML.Torch diff --git a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs b/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs similarity index 98% rename from src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs rename to src/Bonsai.ML.Torch/TensorDataTypeLookup.cs index 66a5396b..6e2b1be0 100644 --- a/src/Bonsai.ML.Torch/Helpers/TensorDataTypeHelper.cs +++ b/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch.Helpers /// /// Provides helper methods for working with tensor data types. /// - public class TensorDataTypeHelper + public class TensorDataTypeLookup { private static readonly Dictionary _lookup = new Dictionary { From 279d6b713f2143b8165866f8a89f0b7298d11e6c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 4 Sep 2024 18:16:46 +0100 Subject: [PATCH 33/75] Updated formatting --- src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs index 4e90fa35..b0938f42 100644 --- a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs @@ -13,7 +13,8 @@ namespace Bonsai.ML.Torch.Helpers /// public static class OpenCVHelper { - private static Dictionary bitDepthLookup = new Dictionary { + private static Dictionary bitDepthLookup = new Dictionary + { { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, @@ -55,12 +56,16 @@ public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dim fixed (long* dimensionsPtr = dimensions) { IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - if (tensorHandle == IntPtr.Zero) { + if (tensorHandle == IntPtr.Zero) + { GC.Collect(); GC.WaitForPendingFinalizers(); tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); } - if (tensorHandle == IntPtr.Zero) { CheckForErrors(); } + if (tensorHandle == IntPtr.Zero) + { + CheckForErrors(); + } var output = Tensor.UnsafeCreateTensor(tensorHandle); return output; } From 391086756a7bc20a1da358177150daee627ec9d2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:01:06 +0100 Subject: [PATCH 34/75] Removed bonsai core dependency in favor of bonsai.ml dependency --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index bb401adc..2a0c1d53 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -7,12 +7,12 @@ true - + \ No newline at end of file From f002c6f203fb0871c2ae5ad5b1f7934955b7ecbe Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:03:36 +0100 Subject: [PATCH 35/75] Fixed bugs with create tensor method and updated to use string formatter --- src/Bonsai.ML.Torch/CreateTensor.cs | 41 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 1fc8ef5e..66509bbc 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -9,6 +9,7 @@ using Bonsai.Expressions; using static TorchSharp.torch; using Bonsai.ML.Data; +using Bonsai.ML.Python; using Bonsai.ML.Torch.Helpers; namespace Bonsai.ML.Torch @@ -45,7 +46,7 @@ public string Values get => values; set { - values = value.Replace("False", "false").Replace("True", "true"); + values = value.ToLower(); } } @@ -55,18 +56,20 @@ public string Values /// The device on which to create the tensor. /// [XmlIgnore] - public Device Device { get => device; set => device = value; } + public Device Device + { + get => device; + set => device = value; + } private Device device = null; private Expression BuildTensorFromArray(Array arrayValues, Type returnType) { var rank = arrayValues.Rank; - var lengths = new int[rank]; - for (int i = 0; i < rank; i++) - { - lengths[i] = arrayValues.GetLength(i); - } + var lengths = Enumerable.Range(0, rank) + .Select(arrayValues.GetLength) + .ToArray(); var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); @@ -89,7 +92,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) } var tensorDataInitializationBlock = Expression.Block( - arrayVariable, + [arrayVariable], assignArray, Expression.Block(assignments), arrayVariable @@ -108,7 +111,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) var tensorAssignment = Expression.Call( tensorCreationMethodInfo, tensorDataInitializationBlock, - Expression.Constant(scalarType, typeof(ScalarType?)), + Expression.Constant((ScalarType)scalarType, typeof(ScalarType?)), Expression.Constant(device, typeof(Device)), Expression.Constant(false, typeof(bool)), Expression.Constant(null, typeof(string).MakeArrayType()) @@ -118,7 +121,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); var buildTensor = Expression.Block( - tensorVariable, + [tensorVariable], assignTensor, tensorVariable ); @@ -132,7 +135,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp var assignValue = Expression.Assign(valueVariable, Expression.Constant(scalarValue, returnType)); var tensorDataInitializationBlock = Expression.Block( - valueVariable, + [valueVariable], assignValue, valueVariable ); @@ -145,10 +148,10 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp ] ); - var tensorCreationMethodArguments = new Expression[] { - Expression.Constant(device, typeof(Device) ), - Expression.Constant(false, typeof(bool) ) - }; + Expression[] tensorCreationMethodArguments = [ + Expression.Constant(device, typeof(Device)), + Expression.Constant(false, typeof(bool)) + ]; if (tensorCreationMethodInfo == null) { @@ -193,9 +196,7 @@ public override Expression Build(IEnumerable arguments) var returnType = TensorDataTypeLookup.GetTypeFromTensorDataType(scalarType); var argTypes = arguments.Select(arg => arg.Type).ToArray(); - var methodInfoArgumentTypes = new Type[] { - typeof(Tensor) - }; + Type[] methodInfoArgumentTypes = [typeof(Tensor)]; var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) .Where(m => m.Name == "Process") @@ -211,7 +212,7 @@ public override Expression Build(IEnumerable arguments) var tensorValues = ArrayHelper.ParseString(values, returnType); var buildTensor = tensorValues is Array arrayValues ? BuildTensorFromArray(arrayValues, returnType) : BuildTensorFromScalarValue(tensorValues, returnType); - var methodArguments = arguments.Count() == 0 ? [ buildTensor ] : arguments.Concat([ buildTensor ]); + var methodArguments = arguments.Count() == 0 ? [buildTensor] : arguments.Concat([buildTensor]); try { @@ -223,7 +224,7 @@ public override Expression Build(IEnumerable arguments) } finally { - values = ArrayHelper.SerializeToJson(tensorValues).ToLower(); + values = StringFormatter.FormatToPython(tensorValues).ToLower(); scalarType = TensorDataTypeLookup.GetTensorDataTypeFromType(returnType); } } From 9ffb9f0e3a92dddb7d3adacaacabc36837d6a158 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:04:58 +0100 Subject: [PATCH 36/75] Added empty tensor creator --- src/Bonsai.ML.Torch/Empty.cs | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Empty.cs diff --git a/src/Bonsai.ML.Torch/Empty.cs b/src/Bonsai.ML.Torch/Empty.cs new file mode 100644 index 00000000..1c4f6af5 --- /dev/null +++ b/src/Bonsai.ML.Torch/Empty.cs @@ -0,0 +1,38 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Returns an empty tensor with the given data type and size. + /// + [Combinator] + [Description("Converts the input tensor into an OpenCV mat.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Empty + { + + /// + /// The size of the tensor. + /// + public long[] Size { get; set; } = [0]; + + /// + /// The data type of the tensor elements. + /// + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Returns an empty tensor with the given data type and size. + /// + public IObservable Process() + { + return Observable.Defer(() => + { + return Observable.Return(empty(Size, Type)); + }); + } + } +} From 200057a5afdedfec850152ed59e5f940361d04bb Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:05:25 +0100 Subject: [PATCH 37/75] Moved index helper to main library --- src/Bonsai.ML.Torch/Index.cs | 2 +- src/Bonsai.ML.Torch/{Helpers => }/IndexHelper.cs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename src/Bonsai.ML.Torch/{Helpers => }/IndexHelper.cs (94%) diff --git a/src/Bonsai.ML.Torch/Index.cs b/src/Bonsai.ML.Torch/Index.cs index 5b7b9192..818bb401 100644 --- a/src/Bonsai.ML.Torch/Index.cs +++ b/src/Bonsai.ML.Torch/Index.cs @@ -26,7 +26,7 @@ public class Index /// public IObservable Process(IObservable source) { - var index = Helpers.IndexHelper.ParseString(Indexes); + var index = IndexHelper.Parse(Indexes); return source.Select(tensor => { return tensor.index(index); }); diff --git a/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs b/src/Bonsai.ML.Torch/IndexHelper.cs similarity index 94% rename from src/Bonsai.ML.Torch/Helpers/IndexHelper.cs rename to src/Bonsai.ML.Torch/IndexHelper.cs index 541ae443..2af466a0 100644 --- a/src/Bonsai.ML.Torch/Helpers/IndexHelper.cs +++ b/src/Bonsai.ML.Torch/IndexHelper.cs @@ -1,7 +1,7 @@ using System; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.Helpers +namespace Bonsai.ML.Torch { /// /// Provides helper methods to parse tensor indexes. @@ -13,7 +13,7 @@ public static class IndexHelper /// Parses the input string into an array of tensor indexes. /// /// - public static TensorIndex[] ParseString(string input) + public static TensorIndex[] Parse(string input) { if (string.IsNullOrEmpty(input)) { @@ -83,7 +83,7 @@ public static TensorIndex[] ParseString(string input) /// /// /// - public static string SerializeIndexes(TensorIndex[] indexes) + public static string Serialize(TensorIndex[] indexes) { return string.Join(", ", indexes); } From 0438eb286bf58caf9ae343342e92671db2ed479e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:05:38 +0100 Subject: [PATCH 38/75] Moved opencv helper to main library --- src/Bonsai.ML.Torch/OpenCVHelper.cs | 174 ++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 src/Bonsai.ML.Torch/OpenCVHelper.cs diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs new file mode 100644 index 00000000..1ca049c9 --- /dev/null +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -0,0 +1,174 @@ +using System; +using System.Runtime.InteropServices; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using OpenCV.Net; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Helper class to convert between OpenCV mats, images and Torch tensors. + /// + public static class OpenCVHelper + { + private static Dictionary bitDepthLookup = new Dictionary + { + { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, + { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, + { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, + { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, + { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, + { ScalarType.Int8, (IplDepth.S8, Depth.S8) } + }; + + private static ConcurrentDictionary deleters = new ConcurrentDictionary(); + + internal delegate void GCHandleDeleter(IntPtr memory); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_data(IntPtr handle); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); + + /// + /// Creates a tensor from a pointer to the data and the dimensions of the tensor. + /// + /// + /// + /// + /// + public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) + { + var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); + var gchp = GCHandle.ToIntPtr(dataHandle); + GCHandleDeleter deleter = null; + + deleter = new GCHandleDeleter((IntPtr ptrHandler) => + { + GCHandle.FromIntPtr(gchp).Free(); + deleters.TryRemove(deleter, out deleter); + }); + deleters.TryAdd(deleter, deleter); + + fixed (long* dimensionsPtr = dimensions) + { + IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + if (tensorHandle == IntPtr.Zero) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); + } + if (tensorHandle == IntPtr.Zero) + { + CheckForErrors(); + } + var output = Tensor.UnsafeCreateTensor(tensorHandle); + return output; + } + } + + /// + /// Converts an OpenCV image to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(IplImage image) + { + if (image == null) + { + return empty([ 0, 0, 0 ]); + } + + int width = image.Width; + int height = image.Height; + int channels = image.Channels; + + var iplDepth = image.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; + + IntPtr tensorDataPtr = image.ImageData; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts an OpenCV mat to a Torch tensor. + /// + /// + /// + public static Tensor ToTensor(Mat mat) + { + if (mat == null) + { + return empty([0, 0, 0 ]); + } + + int width = mat.Size.Width; + int height = mat.Size.Height; + int channels = mat.Channels; + + var depth = mat.Depth; + var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; + + IntPtr tensorDataPtr = mat.Data; + long[] dimensions = [ height, width, channels ]; + if (tensorDataPtr == IntPtr.Zero) + { + return empty(dimensions); + } + return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); + } + + /// + /// Converts a Torch tensor to an OpenCV image. + /// + /// + /// + public unsafe static IplImage ToImage(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var iplDepth = bitDepthLookup[tensorType].IplDepth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); + + return image; + } + + /// + /// Converts a Torch tensor to an OpenCV mat. + /// + /// + /// + public unsafe static Mat ToMat(Tensor tensor) + { + var height = (int)tensor.shape[0]; + var width = (int)tensor.shape[1]; + var channels = (int)tensor.shape[2]; + + var tensorType = tensor.dtype; + var depth = bitDepthLookup[tensorType].Depth; + + var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + + var res = THSTensor_data(new_tensor.Handle); + var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); + + return mat; + } + } +} \ No newline at end of file From 37acfc4023cb0d9a191cfb47494a473df8d8b154 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:06:20 +0100 Subject: [PATCH 39/75] Removed opencv helper from helpers subfolder --- src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs | 174 -------------------- 1 file changed, 174 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs diff --git a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs b/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs deleted file mode 100644 index b0938f42..00000000 --- a/src/Bonsai.ML.Torch/Helpers/OpenCVHelper.cs +++ /dev/null @@ -1,174 +0,0 @@ -using System; -using System.Runtime.InteropServices; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using OpenCV.Net; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch.Helpers -{ - /// - /// Helper class to convert between OpenCV mats, images and Torch tensors. - /// - public static class OpenCVHelper - { - private static Dictionary bitDepthLookup = new Dictionary - { - { ScalarType.Byte, (IplDepth.U8, Depth.U8) }, - { ScalarType.Int16, (IplDepth.S16, Depth.S16) }, - { ScalarType.Int32, (IplDepth.S32, Depth.S32) }, - { ScalarType.Float32, (IplDepth.F32, Depth.F32) }, - { ScalarType.Float64, (IplDepth.F64, Depth.F64) }, - { ScalarType.Int8, (IplDepth.S8, Depth.S8) } - }; - - private static ConcurrentDictionary deleters = new ConcurrentDictionary(); - - internal delegate void GCHandleDeleter(IntPtr memory); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_data(IntPtr handle); - - [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_new(IntPtr rawArray, GCHandleDeleter deleter, IntPtr dimensions, int numDimensions, sbyte type, sbyte dtype, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad); - - /// - /// Creates a tensor from a pointer to the data and the dimensions of the tensor. - /// - /// - /// - /// - /// - public static unsafe Tensor CreateTensorFromPtr(IntPtr tensorDataPtr, long[] dimensions, ScalarType dtype = ScalarType.Byte) - { - var dataHandle = GCHandle.Alloc(tensorDataPtr, GCHandleType.Pinned); - var gchp = GCHandle.ToIntPtr(dataHandle); - GCHandleDeleter deleter = null; - - deleter = new GCHandleDeleter((IntPtr ptrHandler) => - { - GCHandle.FromIntPtr(gchp).Free(); - deleters.TryRemove(deleter, out deleter); - }); - deleters.TryAdd(deleter, deleter); - - fixed (long* dimensionsPtr = dimensions) - { - IntPtr tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - if (tensorHandle == IntPtr.Zero) - { - GC.Collect(); - GC.WaitForPendingFinalizers(); - tensorHandle = THSTensor_new(tensorDataPtr, deleter, (IntPtr)dimensionsPtr, dimensions.Length, (sbyte)dtype, (sbyte)dtype, 0, 0, false); - } - if (tensorHandle == IntPtr.Zero) - { - CheckForErrors(); - } - var output = Tensor.UnsafeCreateTensor(tensorHandle); - return output; - } - } - - /// - /// Converts an OpenCV image to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(IplImage image) - { - if (image == null) - { - return empty([ 0, 0, 0 ]); - } - - int width = image.Width; - int height = image.Height; - int channels = image.Channels; - - var iplDepth = image.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; - - IntPtr tensorDataPtr = image.ImageData; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts an OpenCV mat to a Torch tensor. - /// - /// - /// - public static Tensor ToTensor(Mat mat) - { - if (mat == null) - { - return empty([0, 0, 0 ]); - } - - int width = mat.Size.Width; - int height = mat.Size.Height; - int channels = mat.Channels; - - var depth = mat.Depth; - var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.Depth == depth).Key; - - IntPtr tensorDataPtr = mat.Data; - long[] dimensions = [ height, width, channels ]; - if (tensorDataPtr == IntPtr.Zero) - { - return empty(dimensions); - } - return CreateTensorFromPtr(tensorDataPtr, dimensions, tensorType); - } - - /// - /// Converts a Torch tensor to an OpenCV image. - /// - /// - /// - public unsafe static IplImage ToImage(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var iplDepth = bitDepthLookup[tensorType].IplDepth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); - - return image; - } - - /// - /// Converts a Torch tensor to an OpenCV mat. - /// - /// - /// - public unsafe static Mat ToMat(Tensor tensor) - { - var height = (int)tensor.shape[0]; - var width = (int)tensor.shape[1]; - var channels = (int)tensor.shape[2]; - - var tensorType = tensor.dtype; - var depth = bitDepthLookup[tensorType].Depth; - - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); - - var res = THSTensor_data(new_tensor.Handle); - var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); - - return mat; - } - } -} \ No newline at end of file From d52af6c62f350ad846e41b312432341320e8719e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:06:39 +0100 Subject: [PATCH 40/75] Updated with new index helper --- src/Bonsai.ML.Torch/Set.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/Set.cs b/src/Bonsai.ML.Torch/Set.cs index 14bf3dad..6b0fd86b 100644 --- a/src/Bonsai.ML.Torch/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -21,8 +21,8 @@ public class Set /// public string Index { - get => Helpers.IndexHelper.SerializeIndexes(indexes); - set => indexes = Helpers.IndexHelper.ParseString(value); + get => IndexHelper.Serialize(indexes); + set => indexes = IndexHelper.Parse(value); } private TensorIndex[] indexes; From 15a586d946ea01618d5c5f1a507378c324c481d0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:07:31 +0100 Subject: [PATCH 41/75] Updated with opencv helper --- src/Bonsai.ML.Torch/ToImage.cs | 2 +- src/Bonsai.ML.Torch/ToMat.cs | 2 +- src/Bonsai.ML.Torch/ToTensor.cs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/ToImage.cs b/src/Bonsai.ML.Torch/ToImage.cs index 894a9602..0b9d8ccd 100644 --- a/src/Bonsai.ML.Torch/ToImage.cs +++ b/src/Bonsai.ML.Torch/ToImage.cs @@ -22,7 +22,7 @@ public class ToImage /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToImage); + return source.Select(OpenCVHelper.ToImage); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToMat.cs b/src/Bonsai.ML.Torch/ToMat.cs index fa50020c..1b1746ed 100644 --- a/src/Bonsai.ML.Torch/ToMat.cs +++ b/src/Bonsai.ML.Torch/ToMat.cs @@ -22,7 +22,7 @@ public class ToMat /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToMat); + return source.Select(OpenCVHelper.ToMat); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs index 5bb460de..061a13cb 100644 --- a/src/Bonsai.ML.Torch/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -118,7 +118,7 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToTensor); + return source.Select(OpenCVHelper.ToTensor); } /// @@ -128,7 +128,7 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(Helpers.OpenCVHelper.ToTensor); + return source.Select(OpenCVHelper.ToTensor); } } } \ No newline at end of file From 30d5c67236d8341835fa39676f2e1df8d9dbb434 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 26 Sep 2024 15:07:49 +0100 Subject: [PATCH 42/75] Added process overload for generating on input --- src/Bonsai.ML.Torch/Zeros.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/Bonsai.ML.Torch/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs index 5af526d6..69673d4a 100644 --- a/src/Bonsai.ML.Torch/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -26,5 +26,17 @@ public IObservable Process() { return Observable.Defer(() => Observable.Return(ones(Size))); } + + /// + /// Generates an observable sequence of tensors filled with zeros for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return ones(Size); + }); + } } } \ No newline at end of file From 3e6797b43f70d464fe9ddc584d0ea79693978ab4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 27 Sep 2024 12:39:06 +0100 Subject: [PATCH 43/75] Added xml ignore tag on device --- src/Bonsai.ML.Torch/ToDevice.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Bonsai.ML.Torch/ToDevice.cs b/src/Bonsai.ML.Torch/ToDevice.cs index cb73f733..531ff585 100644 --- a/src/Bonsai.ML.Torch/ToDevice.cs +++ b/src/Bonsai.ML.Torch/ToDevice.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Linq; using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; namespace Bonsai.ML.Torch @@ -17,6 +18,7 @@ public class ToDevice /// /// The device to which the input tensor should be moved. /// + [XmlIgnore] public Device Device { get; set; } /// From 9bfe4e3f9221ae37aef4b3a724ae9c78b6562732 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 16 Oct 2024 14:47:40 +0100 Subject: [PATCH 44/75] Updated to use shared module interface --- src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index fb7722f2..28a3d57b 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -3,7 +3,6 @@ using System.Reactive.Linq; using static TorchSharp.torch; using System.Xml.Serialization; -using static TorchSharp.torch.nn; using Bonsai.Expressions; namespace Bonsai.ML.Torch.NeuralNets @@ -18,9 +17,9 @@ public class LoadPretrainedModel private int numClasses = 10; - public IObservable Process() + public IObservable> Process() { - Module model = null; + nn.Module model = null; var modelName = ModelName.ToString().ToLower(); var device = Device; From 0b22a0713acdbfa2fca43cd0d8cf99d765f751e0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 30 Oct 2024 13:12:54 +0000 Subject: [PATCH 45/75] Modified to use Module interface --- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Forward.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs new file mode 100644 index 00000000..03b2e10e --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -0,0 +1,23 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; +using TorchSharp.Modules; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Forward + { + [XmlIgnore] + public nn.Module Model { get; set; } + + public IObservable Process(IObservable source) + { + return source.Select(Model.forward); + } + } +} \ No newline at end of file From 804aaf92bce93874ea0c8e851cf5c320f95f9d9c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 30 Oct 2024 13:13:31 +0000 Subject: [PATCH 46/75] Removed unnecessary null string --- .../NeuralNets/LoadPretrainedModel.cs | 5 +++- .../NeuralNets/LoadScriptModule.cs | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index 28a3d57b..a87bd744 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -15,9 +15,12 @@ public class LoadPretrainedModel public Models.PretrainedModels ModelName { get; set; } public Device Device { get; set; } + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelWeightsPath { get; set; } + private int numClasses = 10; - public IObservable> Process() + public IObservable> Process() { nn.Module model = null; var modelName = ModelName.ToString().ToLower(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs new file mode 100644 index 00000000..9c73031f --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadScriptModule + { + + [XmlIgnore] + public Device Device { get; set; } = CPU; + + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelPath { get; set; } + + public IObservable> Process() + { + return Observable.Return((nn.IModule)jit.load(ModelPath, Device)); + } + } +} \ No newline at end of file From 77caa53282e91c2a7686fd26702594323904a586 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 1 Nov 2024 13:27:03 +0000 Subject: [PATCH 47/75] Added a common interface --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 3 ++- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 2 +- src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs | 9 +++++++++ src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs | 4 ++-- src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs | 4 ++-- 5 files changed, 16 insertions(+), 6 deletions(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index 2a0c1d53..97bfe18c 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -8,7 +8,8 @@ - + + diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs index 03b2e10e..e1a8a283 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -13,7 +13,7 @@ namespace Bonsai.ML.Torch.NeuralNets public class Forward { [XmlIgnore] - public nn.Module Model { get; set; } + public ITorchModule Model { get; set; } public IObservable Process(IObservable source) { diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs new file mode 100644 index 00000000..1bfcdab3 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs @@ -0,0 +1,9 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public interface ITorchModule + { + public Tensor forward(Tensor tensor); + } +} diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index a87bd744..43443c24 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -20,7 +20,7 @@ public class LoadPretrainedModel private int numClasses = 10; - public IObservable> Process() + public IObservable Process() { nn.Module model = null; var modelName = ModelName.ToString().ToLower(); @@ -42,7 +42,7 @@ public class LoadPretrainedModel } return Observable.Defer(() => { - return Observable.Return(model); + return Observable.Return((ITorchModule)model); }); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs index 9c73031f..7e5d53b0 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -18,9 +18,9 @@ public class LoadScriptModule [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelPath { get; set; } - public IObservable> Process() + public IObservable Process() { - return Observable.Return((nn.IModule)jit.load(ModelPath, Device)); + return Observable.Return((ITorchModule)jit.load(ModelPath, Device)); } } } \ No newline at end of file From 40136134704218085a573f36d3594ef5c4d59b34 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:26:56 +0000 Subject: [PATCH 48/75] Added swap axes function --- src/Bonsai.ML.Torch/Swapaxes.cs | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Swapaxes.cs diff --git a/src/Bonsai.ML.Torch/Swapaxes.cs b/src/Bonsai.ML.Torch/Swapaxes.cs new file mode 100644 index 00000000..4777e882 --- /dev/null +++ b/src/Bonsai.ML.Torch/Swapaxes.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + [Combinator] + [Description("Swaps the axes of the input tensor.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Swapaxes + { + + /// + /// The value of axis 1. + /// + public long Axis1 { get; set; } = 0; + + /// + /// The value of axis 2. + /// + public long Axis2 { get; set; } = 1; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return swapaxes(tensor, Axis1, Axis2); + }); + } + } +} \ No newline at end of file From ee95aea0f3d0bc7be6248bfa4b5c2b602ee0bc0a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:27:05 +0000 Subject: [PATCH 49/75] Added tile function --- src/Bonsai.ML.Torch/Tile.cs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Tile.cs diff --git a/src/Bonsai.ML.Torch/Tile.cs b/src/Bonsai.ML.Torch/Tile.cs new file mode 100644 index 00000000..1df78122 --- /dev/null +++ b/src/Bonsai.ML.Torch/Tile.cs @@ -0,0 +1,22 @@ +using static TorchSharp.torch; +using System; +using System.ComponentModel; +using System.Reactive.Linq; + +namespace Bonsai.ML.Torch +{ + [Combinator] + [Description("Constructs a tensor by repeating the elements of input. The Dimensions argument specifies the number of repetitions in each dimension.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Tile + { + public long[] Dimensions { get; set; } + + public IObservable Process(IObservable source) + { + return source.Select(tensor => { + return tile(tensor, Dimensions); + }); + } + } +} From 6b3d178048b3fcbb942ab52069e9fc9e85093aba Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Nov 2024 12:27:39 +0000 Subject: [PATCH 50/75] Updated model loading and model forward procedure --- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 2 ++ .../NeuralNets/ITorchModule.cs | 1 + .../NeuralNets/LoadPretrainedModel.cs | 17 +++++---- .../NeuralNets/LoadScriptModule.cs | 4 ++- .../NeuralNets/TorchModuleAdapter.cs | 36 +++++++++++++++++++ src/Bonsai.ML.Torch/Vision/Normalize.cs | 28 +++++++++++++++ 6 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs create mode 100644 src/Bonsai.ML.Torch/Vision/Normalize.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs index e1a8a283..3aae4012 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -4,6 +4,7 @@ using static TorchSharp.torch; using System.Xml.Serialization; using TorchSharp.Modules; +using TorchSharp; namespace Bonsai.ML.Torch.NeuralNets { @@ -17,6 +18,7 @@ public class Forward public IObservable Process(IObservable source) { + Model.Module.eval(); return source.Select(Model.forward); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs index 1bfcdab3..e7ebf994 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs @@ -4,6 +4,7 @@ namespace Bonsai.ML.Torch.NeuralNets { public interface ITorchModule { + public nn.Module Module { get; } public Tensor forward(Tensor tensor); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs index 43443c24..e4dddba1 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs @@ -3,7 +3,6 @@ using System.Reactive.Linq; using static TorchSharp.torch; using System.Xml.Serialization; -using Bonsai.Expressions; namespace Bonsai.ML.Torch.NeuralNets { @@ -13,6 +12,8 @@ namespace Bonsai.ML.Torch.NeuralNets public class LoadPretrainedModel { public Models.PretrainedModels ModelName { get; set; } + + [XmlIgnore] public Device Device { get; set; } [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] @@ -22,27 +23,31 @@ public class LoadPretrainedModel public IObservable Process() { - nn.Module model = null; + nn.Module module = null; var modelName = ModelName.ToString().ToLower(); var device = Device; switch (modelName) { case "alexnet": - model = new Models.AlexNet(modelName, numClasses, device); + module = new Models.AlexNet(modelName, numClasses, device); + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); break; case "mobilenet": - model = new Models.MobileNet(modelName, numClasses, device); + module = new Models.MobileNet(modelName, numClasses, device); + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); break; case "mnist": - model = new Models.MNIST(modelName, device); + module = new Models.MNIST(modelName, device); + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); break; default: throw new ArgumentException($"Model {modelName} not supported."); } + var torchModule = new TorchModuleAdapter(module); return Observable.Defer(() => { - return Observable.Return((ITorchModule)model); + return Observable.Return((ITorchModule)torchModule); }); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs index 7e5d53b0..7e6c73fe 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -20,7 +20,9 @@ public class LoadScriptModule public IObservable Process() { - return Observable.Return((ITorchModule)jit.load(ModelPath, Device)); + var scriptModule = jit.load(ModelPath, Device); + var torchModule = new TorchModuleAdapter(scriptModule); + return Observable.Return((ITorchModule)torchModule); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs new file mode 100644 index 00000000..a1c44d96 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs @@ -0,0 +1,36 @@ +using System; +using System.Reflection; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public class TorchModuleAdapter : ITorchModule + { + private readonly nn.Module _module = null; + + private readonly jit.ScriptModule _scriptModule = null; + + private Func forwardFunc; + + public nn.Module Module { get; } + + public TorchModuleAdapter(nn.Module module) + { + _module = module; + forwardFunc = _module.forward; + Module = _module; + } + + public TorchModuleAdapter(jit.ScriptModule scriptModule) + { + _scriptModule = scriptModule; + forwardFunc = _scriptModule.forward; + Module = _scriptModule; + } + + public Tensor forward(Tensor input) + { + return forwardFunc(input); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Vision/Normalize.cs b/src/Bonsai.ML.Torch/Vision/Normalize.cs new file mode 100644 index 00000000..fee8a3b9 --- /dev/null +++ b/src/Bonsai.ML.Torch/Vision/Normalize.cs @@ -0,0 +1,28 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torchvision; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Vision +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Normalize + { + private ITransform inputTransform; + + public IObservable Process(IObservable source) + { + inputTransform = transforms.Normalize(new double[] { 0.1307 }, new double[] { 0.3081 }); + + return source.Select(tensor => { + return inputTransform.call(tensor); + }); + } + } +} \ No newline at end of file From 1e5f99fc9ee9f802d2e87e5955dc598f2b864af1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 14 Nov 2024 13:12:55 +0000 Subject: [PATCH 51/75] Added backward function for running online gradient descent with specified loss and optimization --- src/Bonsai.ML.Torch/NeuralNets/Backward.cs | 61 +++++++++++++++++++++ src/Bonsai.ML.Torch/NeuralNets/Loss.cs | 13 +++++ src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs | 13 +++++ 3 files changed, 87 insertions(+) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Backward.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Loss.cs create mode 100644 src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs new file mode 100644 index 00000000..5fa4902e --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs @@ -0,0 +1,61 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using static TorchSharp.torch.nn; +using static TorchSharp.torch.optim; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Backward + { + public Optimizer Optimizer { get; set; } + + [XmlIgnore] + public ITorchModule Model { get; set; } + + public Loss Loss { get; set; } + + public IObservable Process(IObservable> source) + { + optim.Optimizer optimizer = null; + switch (Optimizer) + { + case Optimizer.Adam: + optimizer = Adam(Model.Module.parameters()); + break; + } + + Module loss = null; + switch (Loss) + { + case Loss.NLLLoss: + loss = NLLLoss(); + break; + } + + var scheduler = lr_scheduler.StepLR(optimizer, 1, 0.7); + Model.Module.train(); + + return source.Select((input) => { + var (data, target) = input; + using (_ = NewDisposeScope()) + { + optimizer.zero_grad(); + + var prediction = Model.forward(data); + var output = loss.forward(prediction, target); + + output.backward(); + + optimizer.step(); + return output.MoveToOuterDisposeScope(); + } + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Loss.cs b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs new file mode 100644 index 00000000..ff003019 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs @@ -0,0 +1,13 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.optim; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public enum Loss + { + NLLLoss, + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs new file mode 100644 index 00000000..3c0d4bb7 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs @@ -0,0 +1,13 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.optim; + +namespace Bonsai.ML.Torch.NeuralNets +{ + public enum Optimizer + { + Adam, + } +} \ No newline at end of file From fe6eb7dc67702b5ca71cf66c3c7320c8542f865c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 14 Nov 2024 13:13:05 +0000 Subject: [PATCH 52/75] Added function to save model --- src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs | 28 +++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs diff --git a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs new file mode 100644 index 00000000..314e9f3f --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; +using TorchSharp.Modules; +using TorchSharp; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Sink)] + public class SaveModel + { + [XmlIgnore] + public ITorchModule Model { get; set; } + + public string ModelPath { get; set; } + + public IObservable Process(IObservable source) + { + return source.Do(input => { + Model.Module.save(ModelPath); + }); + } + } +} \ No newline at end of file From 811d0314634da7a10c75f7e88ef2d1aa94a3799e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 14:54:37 +0000 Subject: [PATCH 53/75] Changed name to indicate loading from an existing architecture --- .../NeuralNets/LoadModuleFromArchitecture.cs | 60 +++++++++++++++++++ .../NeuralNets/LoadPretrainedModel.cs | 54 ----------------- ...etrainedModels.cs => ModelArchitecture.cs} | 2 +- 3 files changed, 61 insertions(+), 55 deletions(-) create mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs delete mode 100644 src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs rename src/Bonsai.ML.Torch/NeuralNets/Models/{PretrainedModels.cs => ModelArchitecture.cs} (76%) diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs new file mode 100644 index 00000000..43d74544 --- /dev/null +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs @@ -0,0 +1,60 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.NeuralNets +{ + [Combinator] + [Description("")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadModuleFromArchitecture + { + public Models.ModelArchitecture ModelArchitecture { get; set; } + + [XmlIgnore] + public Device Device { get; set; } + + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ModelWeightsPath { get; set; } + + private int numClasses = 10; + public int NumClasses + { + get => numClasses; + set + { + if (value <= 0) + { + numClasses = 10; + } + else + { + numClasses = value; + } + } + } + + public IObservable Process() + { + var modelArchitecture = ModelArchitecture.ToString().ToLower(); + var device = Device; + + nn.Module module = modelArchitecture switch + { + "alexnet" => new Models.AlexNet(modelArchitecture, numClasses, device), + "mobilenet" => new Models.MobileNet(modelArchitecture, numClasses, device), + "mnist" => new Models.MNIST(modelArchitecture, device), + _ => throw new ArgumentException($"Model {modelArchitecture} not supported.") + }; + + if (ModelWeightsPath is not null) module.load(ModelWeightsPath); + + var torchModule = new TorchModuleAdapter(module); + return Observable.Defer(() => { + return Observable.Return((ITorchModule)torchModule); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs deleted file mode 100644 index e4dddba1..00000000 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadPretrainedModel.cs +++ /dev/null @@ -1,54 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using System.Xml.Serialization; - -namespace Bonsai.ML.Torch.NeuralNets -{ - [Combinator] - [Description("")] - [WorkflowElementCategory(ElementCategory.Source)] - public class LoadPretrainedModel - { - public Models.PretrainedModels ModelName { get; set; } - - [XmlIgnore] - public Device Device { get; set; } - - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string ModelWeightsPath { get; set; } - - private int numClasses = 10; - - public IObservable Process() - { - nn.Module module = null; - var modelName = ModelName.ToString().ToLower(); - var device = Device; - - switch (modelName) - { - case "alexnet": - module = new Models.AlexNet(modelName, numClasses, device); - if (ModelWeightsPath is not null) module.load(ModelWeightsPath); - break; - case "mobilenet": - module = new Models.MobileNet(modelName, numClasses, device); - if (ModelWeightsPath is not null) module.load(ModelWeightsPath); - break; - case "mnist": - module = new Models.MNIST(modelName, device); - if (ModelWeightsPath is not null) module.load(ModelWeightsPath); - break; - default: - throw new ArgumentException($"Model {modelName} not supported."); - } - - var torchModule = new TorchModuleAdapter(module); - return Observable.Defer(() => { - return Observable.Return((ITorchModule)torchModule); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs similarity index 76% rename from src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs rename to src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs index a3c65bdc..0a221b5c 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/PretrainedModels.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs @@ -1,6 +1,6 @@ namespace Bonsai.ML.Torch.NeuralNets.Models { - public enum PretrainedModels + public enum ModelArchitecture { AlexNet, MobileNet, From 978194b2f1d756b33948f4f7163ec667f938634e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 14:58:11 +0000 Subject: [PATCH 54/75] Added descriptions and documentation --- src/Bonsai.ML.Torch/Arange.cs | 5 +- src/Bonsai.ML.Torch/Concat.cs | 1 + src/Bonsai.ML.Torch/ConvertDataType.cs | 1 + src/Bonsai.ML.Torch/CreateTensor.cs | 33 +++++--- src/Bonsai.ML.Torch/Empty.cs | 23 ++++-- src/Bonsai.ML.Torch/InitializeTorchDevice.cs | 3 +- src/Bonsai.ML.Torch/Linspace.cs | 3 + src/Bonsai.ML.Torch/LoadTensor.cs | 33 ++++++++ src/Bonsai.ML.Torch/Mean.cs | 32 ++++++++ src/Bonsai.ML.Torch/MeshGrid.cs | 3 +- src/Bonsai.ML.Torch/Ones.cs | 13 +++ src/Bonsai.ML.Torch/Permute.cs | 1 + src/Bonsai.ML.Torch/Reshape.cs | 1 + src/Bonsai.ML.Torch/SaveTensor.cs | 34 ++++++++ src/Bonsai.ML.Torch/ScalarTypeLookup.cs | 53 +++++++++++++ src/Bonsai.ML.Torch/Set.cs | 11 +-- src/Bonsai.ML.Torch/Sum.cs | 28 +++++++ src/Bonsai.ML.Torch/Swapaxes.cs | 40 ---------- src/Bonsai.ML.Torch/TensorDataType.cs | 50 ------------ src/Bonsai.ML.Torch/TensorDataTypeLookup.cs | 52 ------------ src/Bonsai.ML.Torch/Tile.cs | 13 ++- src/Bonsai.ML.Torch/ToArray.cs | 1 + src/Bonsai.ML.Torch/ToDevice.cs | 1 + src/Bonsai.ML.Torch/ToImage.cs | 2 +- src/Bonsai.ML.Torch/ToNDArray.cs | 83 ++++++++++++++++++++ src/Bonsai.ML.Torch/ToTensor.cs | 2 +- src/Bonsai.ML.Torch/View.cs | 33 ++++++++ src/Bonsai.ML.Torch/Zeros.cs | 1 + 28 files changed, 383 insertions(+), 173 deletions(-) create mode 100644 src/Bonsai.ML.Torch/LoadTensor.cs create mode 100644 src/Bonsai.ML.Torch/Mean.cs create mode 100644 src/Bonsai.ML.Torch/SaveTensor.cs create mode 100644 src/Bonsai.ML.Torch/ScalarTypeLookup.cs create mode 100644 src/Bonsai.ML.Torch/Sum.cs delete mode 100644 src/Bonsai.ML.Torch/Swapaxes.cs delete mode 100644 src/Bonsai.ML.Torch/TensorDataType.cs delete mode 100644 src/Bonsai.ML.Torch/TensorDataTypeLookup.cs create mode 100644 src/Bonsai.ML.Torch/ToNDArray.cs create mode 100644 src/Bonsai.ML.Torch/View.cs diff --git a/src/Bonsai.ML.Torch/Arange.cs b/src/Bonsai.ML.Torch/Arange.cs index 14e3259b..fa80c08e 100644 --- a/src/Bonsai.ML.Torch/Arange.cs +++ b/src/Bonsai.ML.Torch/Arange.cs @@ -17,16 +17,19 @@ public class Arange /// /// The start of the range. /// + [Description("The start of the range.")] public int Start { get; set; } = 0; /// /// The end of the range. /// + [Description("The end of the range.")] public int End { get; set; } = 10; /// - /// The step of the range. + /// The step size between values. /// + [Description("The step size between values.")] public int Step { get; set; } = 1; /// diff --git a/src/Bonsai.ML.Torch/Concat.cs b/src/Bonsai.ML.Torch/Concat.cs index 34adf731..45402621 100644 --- a/src/Bonsai.ML.Torch/Concat.cs +++ b/src/Bonsai.ML.Torch/Concat.cs @@ -18,6 +18,7 @@ public class Concat /// /// The dimension along which to concatenate the tensors. /// + [Description("The dimension along which to concatenate the tensors.")] public long Dimension { get; set; } = 0; /// diff --git a/src/Bonsai.ML.Torch/ConvertDataType.cs b/src/Bonsai.ML.Torch/ConvertDataType.cs index 59981adc..efe3496b 100644 --- a/src/Bonsai.ML.Torch/ConvertDataType.cs +++ b/src/Bonsai.ML.Torch/ConvertDataType.cs @@ -16,6 +16,7 @@ public class ConvertDataType /// /// The scalar type to which to convert the input tensor. /// + [Description("The scalar type to which to convert the input tensor.")] public ScalarType Type { get; set; } = ScalarType.Float32; /// diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 66509bbc..2bddaa46 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -10,19 +10,21 @@ using static TorchSharp.torch; using Bonsai.ML.Data; using Bonsai.ML.Python; -using Bonsai.ML.Torch.Helpers; +using TorchSharp; namespace Bonsai.ML.Torch { /// - /// Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". + /// Creates a tensor from the specified values. + /// Uses Python-like syntax to specify the tensor values. + /// For example, a 2x2 tensor can be created with the following values: "[[1, 2], [3, 4]]". /// [Combinator] [Description("Creates a tensor from the specified values. Uses Python-like syntax to specify the tensor values. For example, a 2x2 tensor can be created with the following values: \"[[1, 2], [3, 4]]\".")] [WorkflowElementCategory(ElementCategory.Source)] public class CreateTensor : ExpressionBuilder { - Range argumentRange = new Range(0, 1); + readonly Range argumentRange = new Range(0, 1); /// public override Range ArgumentRange => argumentRange; @@ -30,17 +32,21 @@ public class CreateTensor : ExpressionBuilder /// /// The data type of the tensor elements. /// - public TensorDataType Type + [Description("The data type of the tensor elements.")] + public ScalarType Type { get => scalarType; set => scalarType = value; } - private TensorDataType scalarType = TensorDataType.Float32; + private ScalarType scalarType = ScalarType.Float32; /// - /// The values of the tensor elements. Uses Python-like syntax to specify the tensor values. + /// The values of the tensor elements. + /// Uses Python-like syntax to specify the tensor values. + /// For example: "[[1, 2], [3, 4]]". /// + [Description("The values of the tensor elements. Uses Python-like syntax to specify the tensor values. For example: \"[[1, 2], [3, 4]]\".")] public string Values { get => values; @@ -56,6 +62,7 @@ public string Values /// The device on which to create the tensor. /// [XmlIgnore] + [Description("The device on which to create the tensor.")] public Device Device { get => device; @@ -98,7 +105,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) arrayVariable ); - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + var tensorCreationMethodInfo = typeof(torch).GetMethod( "tensor", [ arrayVariable.Type, typeof(ScalarType?), @@ -111,7 +118,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) var tensorAssignment = Expression.Call( tensorCreationMethodInfo, tensorDataInitializationBlock, - Expression.Constant((ScalarType)scalarType, typeof(ScalarType?)), + Expression.Constant(scalarType, typeof(ScalarType?)), Expression.Constant(device, typeof(Device)), Expression.Constant(false, typeof(bool)), Expression.Constant(null, typeof(string).MakeArrayType()) @@ -140,7 +147,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp valueVariable ); - var tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + var tensorCreationMethodInfo = typeof(torch).GetMethod( "tensor", [ valueVariable.Type, typeof(Device), @@ -155,7 +162,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp if (tensorCreationMethodInfo == null) { - tensorCreationMethodInfo = typeof(TorchSharp.torch).GetMethod( + tensorCreationMethodInfo = typeof(torch).GetMethod( "tensor", [ valueVariable.Type, typeof(ScalarType?), @@ -193,7 +200,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp /// public override Expression Build(IEnumerable arguments) { - var returnType = TensorDataTypeLookup.GetTypeFromTensorDataType(scalarType); + var returnType = ScalarTypeLookup.GetTypeFromScalarType(scalarType); var argTypes = arguments.Select(arg => arg.Type).ToArray(); Type[] methodInfoArgumentTypes = [typeof(Tensor)]; @@ -225,7 +232,7 @@ public override Expression Build(IEnumerable arguments) finally { values = StringFormatter.FormatToPython(tensorValues).ToLower(); - scalarType = TensorDataTypeLookup.GetTensorDataTypeFromType(returnType); + scalarType = ScalarTypeLookup.GetScalarTypeFromType(returnType); } } @@ -242,7 +249,7 @@ public IObservable Process(Tensor tensor) /// public IObservable Process(IObservable source, Tensor tensor) { - return Observable.Select(source, (_) => tensor); + return source.Select(_ => tensor); } } } diff --git a/src/Bonsai.ML.Torch/Empty.cs b/src/Bonsai.ML.Torch/Empty.cs index 1c4f6af5..dafcee05 100644 --- a/src/Bonsai.ML.Torch/Empty.cs +++ b/src/Bonsai.ML.Torch/Empty.cs @@ -6,26 +6,27 @@ namespace Bonsai.ML.Torch { /// - /// Returns an empty tensor with the given data type and size. + /// Creates an empty tensor with the given data type and size. /// [Combinator] - [Description("Converts the input tensor into an OpenCV mat.")] - [WorkflowElementCategory(ElementCategory.Transform)] + [Description("Creates an empty tensor with the given data type and size.")] + [WorkflowElementCategory(ElementCategory.Source)] public class Empty { - /// /// The size of the tensor. /// + [Description("The size of the tensor.")] public long[] Size { get; set; } = [0]; /// /// The data type of the tensor elements. /// + [Description("The data type of the tensor elements.")] public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Returns an empty tensor with the given data type and size. + /// Creates an empty tensor with the given data type and size. /// public IObservable Process() { @@ -34,5 +35,17 @@ public IObservable Process() return Observable.Return(empty(Size, Type)); }); } + + /// + /// Generates an observable sequence of empty tensors for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return empty(Size, Type); + }); + } } } diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index e82daa36..b8ce574c 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -11,12 +11,13 @@ namespace Bonsai.ML.Torch /// [Combinator] [Description("Initializes the Torch device with the specified device type.")] - [WorkflowElementCategory(ElementCategory.Transform)] + [WorkflowElementCategory(ElementCategory.Source)] public class InitializeTorchDevice { /// /// The device type to initialize. /// + [Description("The device type to initialize.")] public DeviceType DeviceType { get; set; } /// diff --git a/src/Bonsai.ML.Torch/Linspace.cs b/src/Bonsai.ML.Torch/Linspace.cs index ee6516cf..f7e27887 100644 --- a/src/Bonsai.ML.Torch/Linspace.cs +++ b/src/Bonsai.ML.Torch/Linspace.cs @@ -16,16 +16,19 @@ public class Linspace /// /// The start of the range. /// + [Description("The start of the range.")] public int Start { get; set; } = 0; /// /// The end of the range. /// + [Description("The end of the range.")] public int End { get; set; } = 1; /// /// The number of points to generate. /// + [Description("The number of points to generate.")] public int Count { get; set; } = 10; /// diff --git a/src/Bonsai.ML.Torch/LoadTensor.cs b/src/Bonsai.ML.Torch/LoadTensor.cs new file mode 100644 index 00000000..af1e7f05 --- /dev/null +++ b/src/Bonsai.ML.Torch/LoadTensor.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Loads a tensor from the specified file. + /// + [Combinator] + [Description("Loads a tensor from the specified file.")] + [WorkflowElementCategory(ElementCategory.Source)] + public class LoadTensor + { + /// + /// The path to the file containing the tensor. + /// + [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the file containing the tensor.")] + public string Path { get; set; } + + /// + /// Loads a tensor from the specified file. + /// + /// + public IObservable Process() + { + return Observable.Return(Tensor.Load(Path)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Mean.cs b/src/Bonsai.ML.Torch/Mean.cs new file mode 100644 index 00000000..294edf31 --- /dev/null +++ b/src/Bonsai.ML.Torch/Mean.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Takes the mean of the tensor along the specified dimensions. + /// + [Combinator] + [Description("Takes the mean of the tensor along the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Mean + { + /// + /// The dimensions along which to compute the mean. + /// + [Description("The dimensions along which to compute the mean.")] + public long[] Dimensions { get; set; } + + /// + /// Takes the mean of the tensor along the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.mean(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/MeshGrid.cs b/src/Bonsai.ML.Torch/MeshGrid.cs index 725b12a9..a32f9eca 100644 --- a/src/Bonsai.ML.Torch/MeshGrid.cs +++ b/src/Bonsai.ML.Torch/MeshGrid.cs @@ -11,13 +11,14 @@ namespace Bonsai.ML.Torch /// Creates a mesh grid from an observable sequence of enumerable of 1-D tensors. /// [Combinator] - [Description("")] + [Description("Creates a mesh grid from an observable sequence of enumerable of 1-D tensors.")] [WorkflowElementCategory(ElementCategory.Source)] public class MeshGrid { /// /// The indexing mode to use for the mesh grid. /// + [Description("The indexing mode to use for the mesh grid.")] public string Indexing { get; set; } = "ij"; /// diff --git a/src/Bonsai.ML.Torch/Ones.cs b/src/Bonsai.ML.Torch/Ones.cs index 52bf8732..77d26577 100644 --- a/src/Bonsai.ML.Torch/Ones.cs +++ b/src/Bonsai.ML.Torch/Ones.cs @@ -16,6 +16,7 @@ public class Ones /// /// The size of the tensor. /// + [Description("The size of the tensor.")] public long[] Size { get; set; } = [0]; /// @@ -26,5 +27,17 @@ public IObservable Process() { return Observable.Defer(() => Observable.Return(ones(Size))); } + + /// + /// Generates an observable sequence of tensors filled with ones for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => { + return ones(Size); + }); + } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Permute.cs b/src/Bonsai.ML.Torch/Permute.cs index a82107ba..507d31d2 100644 --- a/src/Bonsai.ML.Torch/Permute.cs +++ b/src/Bonsai.ML.Torch/Permute.cs @@ -16,6 +16,7 @@ public class Permute /// /// The permutation of the dimensions. /// + [Description("The permutation of the dimensions.")] public long[] Dimensions { get; set; } = [0]; /// diff --git a/src/Bonsai.ML.Torch/Reshape.cs b/src/Bonsai.ML.Torch/Reshape.cs index ebdc8e41..fdd07fa5 100644 --- a/src/Bonsai.ML.Torch/Reshape.cs +++ b/src/Bonsai.ML.Torch/Reshape.cs @@ -17,6 +17,7 @@ public class Reshape /// /// The dimensions of the reshaped tensor. /// + [Description("The dimensions of the reshaped tensor.")] public long[] Dimensions { get; set; } = [0]; /// diff --git a/src/Bonsai.ML.Torch/SaveTensor.cs b/src/Bonsai.ML.Torch/SaveTensor.cs new file mode 100644 index 00000000..1a3c4772 --- /dev/null +++ b/src/Bonsai.ML.Torch/SaveTensor.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Saves the input tensor to the specified file. + /// + [Combinator] + [Description("Saves the input tensor to the specified file.")] + [WorkflowElementCategory(ElementCategory.Sink)] + public class SaveTensor + { + /// + /// The path to the file where the tensor will be saved. + /// + [FileNameFilter("Binary files(*.bin)|*.bin|All files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the file where the tensor will be saved.")] + public string Path { get; set; } = string.Empty; + + /// + /// Saves the input tensor to the specified file. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Do(tensor => tensor.save(Path)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ScalarTypeLookup.cs b/src/Bonsai.ML.Torch/ScalarTypeLookup.cs new file mode 100644 index 00000000..1e4c6c57 --- /dev/null +++ b/src/Bonsai.ML.Torch/ScalarTypeLookup.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Provides methods to look up tensor data types. + /// + public static class ScalarTypeLookup + { + private static readonly Dictionary _lookup = new() + { + { ScalarType.Byte, (typeof(byte), "byte") }, + { ScalarType.Int16, (typeof(short), "short") }, + { ScalarType.Int32, (typeof(int), "int") }, + { ScalarType.Int64, (typeof(long), "long") }, + { ScalarType.Float32, (typeof(float), "float") }, + { ScalarType.Float64, (typeof(double), "double") }, + { ScalarType.Bool, (typeof(bool), "bool") }, + { ScalarType.Int8, (typeof(sbyte), "sbyte") }, + }; + + /// + /// Returns the type corresponding to the specified tensor data type. + /// + /// + /// + public static Type GetTypeFromScalarType(ScalarType type) => _lookup[type].Type; + + /// + /// Returns the string representation corresponding to the specified tensor data type. + /// + /// + /// + public static string GetStringFromScalarType(ScalarType type) => _lookup[type].StringValue; + + /// + /// Returns the tensor data type corresponding to the specified string representation. + /// + /// + /// + public static ScalarType GetScalarTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; + + /// + /// Returns the tensor data type corresponding to the specified type. + /// + /// + /// + public static ScalarType GetScalarTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Set.cs b/src/Bonsai.ML.Torch/Set.cs index 6b0fd86b..0e2965c6 100644 --- a/src/Bonsai.ML.Torch/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -19,18 +19,14 @@ public class Set /// /// The index at which to set the value. /// - public string Index - { - get => IndexHelper.Serialize(indexes); - set => indexes = IndexHelper.Parse(value); - } - - private TensorIndex[] indexes; + [Description("The index at which to set the value.")] + public string Index { get; set; } = string.Empty; /// /// The value to set at the specified index. /// [XmlIgnore] + [Description("The value to set at the specified index.")] public Tensor Value { get; set; } = null; /// @@ -41,6 +37,7 @@ public string Index public IObservable Process(IObservable source) { return source.Select(tensor => { + var indexes = IndexHelper.Parse(Index); return tensor.index_put_(Value, indexes); }); } diff --git a/src/Bonsai.ML.Torch/Sum.cs b/src/Bonsai.ML.Torch/Sum.cs new file mode 100644 index 00000000..d01efb95 --- /dev/null +++ b/src/Bonsai.ML.Torch/Sum.cs @@ -0,0 +1,28 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + [Combinator] + [Description("Computes the sum of the input tensor elements along the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Sum + { + /// + /// The dimensions along which to compute the sum. + /// + public long[] Dimensions { get; set; } + + /// + /// Computes the sum of the input tensor elements along the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.sum(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Swapaxes.cs b/src/Bonsai.ML.Torch/Swapaxes.cs deleted file mode 100644 index 4777e882..00000000 --- a/src/Bonsai.ML.Torch/Swapaxes.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Linq; -using System.Reactive.Linq; -using System.Xml.Serialization; -using TorchSharp; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - [Combinator] - [Description("Swaps the axes of the input tensor.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Swapaxes - { - - /// - /// The value of axis 1. - /// - public long Axis1 { get; set; } = 0; - - /// - /// The value of axis 2. - /// - public long Axis2 { get; set; } = 1; - - /// - /// Returns an observable sequence that sets the value of the input tensor at the specified index. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - return swapaxes(tensor, Axis1, Axis2); - }); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/TensorDataType.cs b/src/Bonsai.ML.Torch/TensorDataType.cs deleted file mode 100644 index f76a04c1..00000000 --- a/src/Bonsai.ML.Torch/TensorDataType.cs +++ /dev/null @@ -1,50 +0,0 @@ -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - /// - /// Represents the data type of the tensor elements. Contains currently supported data types. A subset of the available ScalarType data types in TorchSharp. - /// - public enum TensorDataType - { - /// - /// 8-bit unsigned integer. - /// - Byte = ScalarType.Byte, - - /// - /// 8-bit signed integer. - /// - Int8 = ScalarType.Int8, - - /// - /// 16-bit signed integer. - /// - Int16 = ScalarType.Int16, - - /// - /// 32-bit signed integer. - /// - Int32 = ScalarType.Int32, - - /// - /// 64-bit signed integer. - /// - Int64 = ScalarType.Int64, - - /// - /// 32-bit floating point. - /// - Float32 = ScalarType.Float32, - - /// - /// 64-bit floating point. - /// - Float64 = ScalarType.Float64, - - /// - /// Boolean. - /// - Bool = ScalarType.Bool - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs b/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs deleted file mode 100644 index 6e2b1be0..00000000 --- a/src/Bonsai.ML.Torch/TensorDataTypeLookup.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Bonsai.ML.Torch.Helpers -{ - /// - /// Provides helper methods for working with tensor data types. - /// - public class TensorDataTypeLookup - { - private static readonly Dictionary _lookup = new Dictionary - { - { TensorDataType.Byte, (typeof(byte), "byte") }, - { TensorDataType.Int16, (typeof(short), "short") }, - { TensorDataType.Int32, (typeof(int), "int") }, - { TensorDataType.Int64, (typeof(long), "long") }, - { TensorDataType.Float32, (typeof(float), "float") }, - { TensorDataType.Float64, (typeof(double), "double") }, - { TensorDataType.Bool, (typeof(bool), "bool") }, - { TensorDataType.Int8, (typeof(sbyte), "sbyte") }, - }; - - /// - /// Returns the type corresponding to the specified tensor data type. - /// - /// - /// - public static Type GetTypeFromTensorDataType(TensorDataType type) => _lookup[type].Type; - - /// - /// Returns the string representation corresponding to the specified tensor data type. - /// - /// - /// - public static string GetStringFromTensorDataType(TensorDataType type) => _lookup[type].StringValue; - - /// - /// Returns the tensor data type corresponding to the specified string representation. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromString(string value) => _lookup.First(x => x.Value.StringValue == value).Key; - - /// - /// Returns the tensor data type corresponding to the specified type. - /// - /// - /// - public static TensorDataType GetTensorDataTypeFromType(Type type) => _lookup.First(x => x.Value.Type == type).Key; - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Tile.cs b/src/Bonsai.ML.Torch/Tile.cs index 1df78122..df25b8ac 100644 --- a/src/Bonsai.ML.Torch/Tile.cs +++ b/src/Bonsai.ML.Torch/Tile.cs @@ -5,13 +5,24 @@ namespace Bonsai.ML.Torch { + /// + /// Constructs a tensor by repeating the elements of input. + /// [Combinator] - [Description("Constructs a tensor by repeating the elements of input. The Dimensions argument specifies the number of repetitions in each dimension.")] + [Description("Constructs a tensor by repeating the elements of input.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Tile { + /// + /// The number of repetitions in each dimension. + /// public long[] Dimensions { get; set; } + /// + /// Constructs a tensor by repeating the elements of input along the specified dimensions. + /// + /// + /// public IObservable Process(IObservable source) { return source.Select(tensor => { diff --git a/src/Bonsai.ML.Torch/ToArray.cs b/src/Bonsai.ML.Torch/ToArray.cs index 1c2c721a..e9ca21f1 100644 --- a/src/Bonsai.ML.Torch/ToArray.cs +++ b/src/Bonsai.ML.Torch/ToArray.cs @@ -38,6 +38,7 @@ public ToArray() /// /// Gets or sets the type mapping used to convert the input tensor into an array. /// + [Description("Gets or sets the type mapping used to convert the input tensor into an array.")] public TypeMapping Type { get; set; } /// diff --git a/src/Bonsai.ML.Torch/ToDevice.cs b/src/Bonsai.ML.Torch/ToDevice.cs index 531ff585..0377df46 100644 --- a/src/Bonsai.ML.Torch/ToDevice.cs +++ b/src/Bonsai.ML.Torch/ToDevice.cs @@ -19,6 +19,7 @@ public class ToDevice /// The device to which the input tensor should be moved. /// [XmlIgnore] + [Description("The device to which the input tensor should be moved.")] public Device Device { get; set; } /// diff --git a/src/Bonsai.ML.Torch/ToImage.cs b/src/Bonsai.ML.Torch/ToImage.cs index 0b9d8ccd..70c8227e 100644 --- a/src/Bonsai.ML.Torch/ToImage.cs +++ b/src/Bonsai.ML.Torch/ToImage.cs @@ -11,7 +11,7 @@ namespace Bonsai.ML.Torch /// Converts the input tensor into an OpenCV image. /// [Combinator] - [Description("")] + [Description("Converts the input tensor into an OpenCV image.")] [WorkflowElementCategory(ElementCategory.Transform)] public class ToImage { diff --git a/src/Bonsai.ML.Torch/ToNDArray.cs b/src/Bonsai.ML.Torch/ToNDArray.cs new file mode 100644 index 00000000..89b7b1e1 --- /dev/null +++ b/src/Bonsai.ML.Torch/ToNDArray.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.Linq.Expressions; +using System.Reflection; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Converts the input tensor into an array of the specified element type and rank. + /// + [Combinator] + [Description("Converts the input tensor into an array of the specified element type.")] + [WorkflowElementCategory(ElementCategory.Transform)] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + [XmlInclude(typeof(TypeMapping))] + public class ToNDArray : SingleArgumentExpressionBuilder + { + /// + /// Initializes a new instance of the class. + /// + public ToNDArray() + { + Type = new TypeMapping(); + } + + /// + /// Gets or sets the type mapping used to convert the input tensor into an array. + /// + [Description("Gets or sets the type mapping used to convert the input tensor into an array.")] + public TypeMapping Type { get; set; } + + /// + /// Gets or sets the rank of the output array. Must be greater than or equal to 1. + /// + [Description("Gets or sets the rank of the output array. Must be greater than or equal to 1.")] + public int Rank { get; set; } = 1; + + /// + public override Expression Build(IEnumerable arguments) + { + TypeMapping typeMapping = Type; + var returnType = typeMapping.GetType().GetGenericArguments()[0]; + MethodInfo methodInfo = GetType().GetMethod("Process", BindingFlags.Public | BindingFlags.Instance); + var lengths = new int[Rank]; + Type arrayType = Array.CreateInstance(returnType, lengths).GetType(); + methodInfo = methodInfo.MakeGenericMethod(returnType, arrayType); + Expression sourceExpression = arguments.First(); + + return Expression.Call( + Expression.Constant(this), + methodInfo, + sourceExpression + ); + } + + /// + /// Converts the input tensor into an array of the specified element type. + /// + /// + /// + /// + /// + public IObservable Process(IObservable source) where T : unmanaged + { + return source.Select(tensor => + { + return (TResult)(object)tensor.data().ToNDArray(); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/ToTensor.cs b/src/Bonsai.ML.Torch/ToTensor.cs index 061a13cb..7af26dc9 100644 --- a/src/Bonsai.ML.Torch/ToTensor.cs +++ b/src/Bonsai.ML.Torch/ToTensor.cs @@ -11,7 +11,7 @@ namespace Bonsai.ML.Torch /// Converts the input value into a tensor. /// [Combinator] - [Description("")] + [Description("Converts the input value into a tensor.")] [WorkflowElementCategory(ElementCategory.Transform)] public class ToTensor { diff --git a/src/Bonsai.ML.Torch/View.cs b/src/Bonsai.ML.Torch/View.cs new file mode 100644 index 00000000..65a409be --- /dev/null +++ b/src/Bonsai.ML.Torch/View.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch +{ + /// + /// Creates a new view of the input tensor with the specified dimensions. + /// + [Combinator] + [Description("Creates a new view of the input tensor with the specified dimensions.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class View + { + /// + /// The dimensions of the reshaped tensor. + /// + [Description("The dimensions of the reshaped tensor.")] + public long[] Dimensions { get; set; } = [0]; + + /// + /// Creates a new view of the input tensor with the specified dimensions. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => input.view(Dimensions)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs index 69673d4a..e4fb3c7a 100644 --- a/src/Bonsai.ML.Torch/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -16,6 +16,7 @@ public class Zeros /// /// The size of the tensor. /// + [Description("The size of the tensor.")] public long[] Size { get; set; } = [0]; /// From 6286f615d11bcb6ccd684d7a5aa51e024dc1bc4c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 17:21:29 +0000 Subject: [PATCH 55/75] Added some useful classes for linear algebra --- src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs | 26 +++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Det.cs | 26 +++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs | 29 +++++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs | 27 ++++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs | 37 +++++++++++++++++++ src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs | 34 +++++++++++++++++ src/Bonsai.ML.Torch/Vision/Normalize.cs | 14 +++---- 7 files changed, 186 insertions(+), 7 deletions(-) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Det.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs new file mode 100644 index 00000000..6843779a --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Cholesky.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + /// + [Combinator] + [Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Cholesky + { + /// + /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(linalg.cholesky); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs new file mode 100644 index 00000000..90c5a45d --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Det.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the determinant of a square matrix. + /// + [Combinator] + [Description("Computes the determinant of a square matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Det + { + /// + /// Computes the determinant of a square matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(linalg.det); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs new file mode 100644 index 00000000..a94c8eb8 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Eig.cs @@ -0,0 +1,29 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the eigenvalue decomposition of a square matrix if it exists. + /// + [Combinator] + [Description("Computes the eigenvalue decomposition of a square matrix if it exists.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Eig + { + /// + /// Computes the eigenvalue decomposition of a square matrix if it exists. + /// + /// + /// + public IObservable> Process(IObservable source) + { + return source.Select(tensor => { + var (eigvals, eigvecs) = linalg.eig(tensor); + return Tuple.Create(eigvals, eigvecs); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs new file mode 100644 index 00000000..58bb4ce3 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Inv.cs @@ -0,0 +1,27 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the inverse of the input matrix. + /// + [Combinator] + [Description("Computes the inverse of the input matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Inv + { + /// + /// Computes the inverse of the input matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(inv); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs new file mode 100644 index 00000000..82914d39 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes a matrix norm. + /// + [Combinator] + [Description("Computes a matrix norm.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class MatrixNorm + { + + /// + /// The dimensions along which to compute the matrix norm. + /// + public long[] Dimensions { get; set; } = null; + + /// + /// If true, the reduced dimensions are retained in the result as dimensions with size one. + /// + public bool Keepdim { get; set; } = false; + + /// + /// Computes a matrix norm. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => linalg.norm(tensor, dims: Dimensions, keepdim: Keepdim)); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs new file mode 100644 index 00000000..c722f53b --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/SVD.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LinearAlgebra +{ + /// + /// Computes the singular value decomposition (SVD) of a matrix. + /// + [Combinator] + [Description("Computes the singular value decomposition (SVD) of a matrix.")] + [WorkflowElementCategory(ElementCategory.Transform)] + public class SVD + { + /// + /// Whether to compute the full or reduced SVD. + /// + public bool FullMatrices { get; set; } = false; + + /// + /// Computes the singular value decomposition (SVD) of a matrix. + /// + /// + /// + public IObservable> Process(IObservable source) + { + return source.Select(tensor => { + var (u, s, v) = linalg.svd(tensor, fullMatrices: FullMatrices); + return Tuple.Create(u, s, v); + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Vision/Normalize.cs b/src/Bonsai.ML.Torch/Vision/Normalize.cs index fee8a3b9..f4fc80a3 100644 --- a/src/Bonsai.ML.Torch/Vision/Normalize.cs +++ b/src/Bonsai.ML.Torch/Vision/Normalize.cs @@ -5,23 +5,23 @@ using System.Reactive.Linq; using static TorchSharp.torch; using static TorchSharp.torchvision; -using System.Xml.Serialization; namespace Bonsai.ML.Torch.Vision { [Combinator] - [Description("")] + [Description("Normalizes the input tensor with the mean and standard deviation.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Normalize - { - private ITransform inputTransform; + { + public double[] Means { get; set; } = [ 0.1307 ]; + public double[] StdDevs { get; set; } = [ 0.3081 ]; + private ITransform transform = null; public IObservable Process(IObservable source) { - inputTransform = transforms.Normalize(new double[] { 0.1307 }, new double[] { 0.3081 }); - return source.Select(tensor => { - return inputTransform.call(tensor); + transform ??= transforms.Normalize(Means, StdDevs, tensor.dtype, tensor.device); + return transform.call(tensor); }); } } From c363b5ba722d05e74f7f02f8f1ff5fc2504fa716 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 17:48:20 +0000 Subject: [PATCH 56/75] Adding classes for creating tensor indexes --- src/Bonsai.ML.Torch/Index.cs | 35 -------- src/Bonsai.ML.Torch/Index/BooleanIndex.cs | 42 ++++++++++ src/Bonsai.ML.Torch/Index/ColonIndex.cs | 36 ++++++++ src/Bonsai.ML.Torch/Index/EllipsesIndex.cs | 37 +++++++++ src/Bonsai.ML.Torch/Index/Index.cs | 37 +++++++++ src/Bonsai.ML.Torch/Index/IndexHelper.cs | 97 ++++++++++++++++++++++ src/Bonsai.ML.Torch/Index/NoneIndex.cs | 36 ++++++++ src/Bonsai.ML.Torch/Index/SingleIndex.cs | 42 ++++++++++ src/Bonsai.ML.Torch/Index/SliceIndex.cs | 54 ++++++++++++ src/Bonsai.ML.Torch/Index/TensorIndex.cs | 26 ++++++ src/Bonsai.ML.Torch/IndexHelper.cs | 91 -------------------- 11 files changed, 407 insertions(+), 126 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/Index.cs create mode 100644 src/Bonsai.ML.Torch/Index/BooleanIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/ColonIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/EllipsesIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/Index.cs create mode 100644 src/Bonsai.ML.Torch/Index/IndexHelper.cs create mode 100644 src/Bonsai.ML.Torch/Index/NoneIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/SingleIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/SliceIndex.cs create mode 100644 src/Bonsai.ML.Torch/Index/TensorIndex.cs delete mode 100644 src/Bonsai.ML.Torch/IndexHelper.cs diff --git a/src/Bonsai.ML.Torch/Index.cs b/src/Bonsai.ML.Torch/Index.cs deleted file mode 100644 index 818bb401..00000000 --- a/src/Bonsai.ML.Torch/Index.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - /// - /// Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. - /// Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis. - /// - [Combinator] - [Description("Indexes a tensor with the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Index - { - /// - /// The indices to use for indexing the tensor. - /// - public string Indexes { get; set; } = string.Empty; - - /// - /// Indexes the input tensor with the specified indices. - /// - /// - /// - public IObservable Process(IObservable source) - { - var index = IndexHelper.Parse(Indexes); - return source.Select(tensor => { - return tensor.index(index); - }); - } - } -} diff --git a/src/Bonsai.ML.Torch/Index/BooleanIndex.cs b/src/Bonsai.ML.Torch/Index/BooleanIndex.cs new file mode 100644 index 00000000..f854aa56 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/BooleanIndex.cs @@ -0,0 +1,42 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents a boolean index that can be used to select elements from a tensor. +/// +[Combinator] +[Description("Represents a boolean index that can be used to select elements from a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class BooleanIndex +{ + /// + /// Gets or sets the boolean value used to select elements from a tensor. + /// + [Description("The boolean value used to select elements from a tensor.")] + public bool Value { get; set; } = false; + + /// + /// Generates the boolean index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Bool(Value)); + } + + /// + /// Processes the input sequence and generates the boolean index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Bool(Value)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/ColonIndex.cs b/src/Bonsai.ML.Torch/Index/ColonIndex.cs new file mode 100644 index 00000000..bfd9ca7b --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/ColonIndex.cs @@ -0,0 +1,36 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents the colon index used to select all elements along a given dimension. +/// +[Combinator] +[Description("Represents the colon index used to select all elements along a given dimension.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class ColonIndex +{ + /// + /// Generates the colon index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Colon); + } + + /// + /// Processes the input sequence and generates the colon index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Colon); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs b/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs new file mode 100644 index 00000000..06207a8e --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/EllipsesIndex.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects all dimensions of a tensor. +/// +[Combinator] +[Description("Represents an index that selects all dimensions of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class EllipsisIndex +{ + + /// + /// Generates the ellipsis index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Ellipsis); + } + + /// + /// Processes the input sequence and generates the ellipsis index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Ellipsis); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/Index.cs b/src/Bonsai.ML.Torch/Index/Index.cs new file mode 100644 index 00000000..6846b8c7 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/Index.cs @@ -0,0 +1,37 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Indexes a tensor by parsing the specified indices. +/// Indices are specified as a comma-separated values. +/// Currently supports Python-style slicing syntax. +/// This includes numeric indices, None, slices, and ellipsis. +/// +[Combinator] +[Description("Indexes a tensor by parsing the specified indices. Indices are specified as a comma-separated values. Currently supports Python-style slicing syntax. This includes numeric indices, None, slices, and ellipsis.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Index +{ + /// + /// The indices to use for indexing the tensor. + /// + [Description("The indices to use for indexing the tensor. For example, '...,3:5,:'")] + public string Indexes { get; set; } = string.Empty; + + /// + /// Indexes the input tensor with the specified indices. + /// + /// + /// + public IObservable Process(IObservable source) + { + var index = IndexHelper.Parse(Indexes); + return source.Select(tensor => { + return tensor.index(index); + }); + } +} diff --git a/src/Bonsai.ML.Torch/Index/IndexHelper.cs b/src/Bonsai.ML.Torch/Index/IndexHelper.cs new file mode 100644 index 00000000..a6f60fb3 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/IndexHelper.cs @@ -0,0 +1,97 @@ +using System; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Provides helper methods to parse tensor indexes. +/// +public static class IndexHelper +{ + + /// + /// Parses the input string into an array of tensor indexes. + /// + /// + public static torch.TensorIndex[] Parse(string input) + { + if (string.IsNullOrEmpty(input)) + { + return [0]; + } + + var indexStrings = input.Split(','); + var indices = new torch.TensorIndex[indexStrings.Length]; + + for (int i = 0; i < indexStrings.Length; i++) + { + var indexString = indexStrings[i].Trim(); + if (int.TryParse(indexString, out int intIndex)) + { + indices[i] = torch.TensorIndex.Single(intIndex); + } + else if (indexString == ":") + { + indices[i] = torch.TensorIndex.Colon; + } + else if (indexString == "None") + { + indices[i] = torch.TensorIndex.None; + } + else if (indexString == "...") + { + indices[i] = torch.TensorIndex.Ellipsis; + } + else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") + { + indices[i] = torch.TensorIndex.Bool(indexString.ToLower() == "true"); + } + else if (indexString.Contains(":")) + { + var rangeParts = indexString.Split(':'); + rangeParts = [.. rangeParts.Where(p => { + p = p.Trim(); + return !string.IsNullOrEmpty(p); + })]; + + if (rangeParts.Length == 0) + { + indices[i] = torch.TensorIndex.Slice(); + } + else if (rangeParts.Length == 1) + { + indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0])); + } + else if (rangeParts.Length == 2) + { + indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); + } + else if (rangeParts.Length == 3) + { + indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + else + { + throw new Exception($"Invalid index format: {indexString}"); + } + } + return indices; + } + + /// + /// Serializes the input array of tensor indexes into a string representation. + /// + /// + /// + public static string Serialize(torch.TensorIndex[] indexes) + { + return string.Join(", ", indexes); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/NoneIndex.cs b/src/Bonsai.ML.Torch/Index/NoneIndex.cs new file mode 100644 index 00000000..b10c9d86 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/NoneIndex.cs @@ -0,0 +1,36 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects no elements of a tensor. +/// +[Combinator] +[Description("Represents an index that selects no elements of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class NoneIndex +{ + /// + /// Generates the none index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.None); + } + + /// + /// Processes the input sequence and generates the none index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.None); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/SingleIndex.cs b/src/Bonsai.ML.Torch/Index/SingleIndex.cs new file mode 100644 index 00000000..9b3ec641 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/SingleIndex.cs @@ -0,0 +1,42 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects a single value of a tensor. +/// +[Combinator] +[Description("Represents an index that selects a single value of a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class CreateTensorIndexSingle +{ + /// + /// Gets or sets the index value used to select a single element from a tensor. + /// + [Description("The index value used to select a single element from a tensor.")] + public long Index { get; set; } = 0; + + /// + /// Generates the single index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Single(Index)); + } + + /// + /// Processes the input sequence and generates the single index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Single(Index)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/SliceIndex.cs b/src/Bonsai.ML.Torch/Index/SliceIndex.cs new file mode 100644 index 00000000..b31802a4 --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/SliceIndex.cs @@ -0,0 +1,54 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that selects a range of elements from a tensor. +/// +[Combinator] +[Description("Represents an index that selects a range of elements from a tensor.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class SliceIndex +{ + /// + /// Gets or sets the start index of the slice. + /// + [Description("The start index of the slice.")] + public long? Start { get; set; } = null; + + /// + /// Gets or sets the end index of the slice. + /// + [Description("The end index of the slice.")] + public long? End { get; set; } = null; + + /// + /// Gets or sets the step size of the slice. + /// + [Description("The step size of the slice.")] + public long? Step { get; set; } = null; + + /// + /// Generates the slice index. + /// + /// + public IObservable Process() + { + return Observable.Return(torch.TensorIndex.Slice(Start, End, Step)); + } + + /// + /// Processes the input sequence and generates the slice index. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => torch.TensorIndex.Slice(Start, End, Step)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Index/TensorIndex.cs b/src/Bonsai.ML.Torch/Index/TensorIndex.cs new file mode 100644 index 00000000..e3f6612d --- /dev/null +++ b/src/Bonsai.ML.Torch/Index/TensorIndex.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Index; + +/// +/// Represents an index that is created from a tensor. +/// +[Combinator] +[Description("Represents an index that is created from a tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class TensorIndex +{ + /// + /// Converts the input tensor into a tensor index. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(TorchSharp.torch.TensorIndex.Tensor); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/IndexHelper.cs b/src/Bonsai.ML.Torch/IndexHelper.cs deleted file mode 100644 index 2af466a0..00000000 --- a/src/Bonsai.ML.Torch/IndexHelper.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System; -using static TorchSharp.torch; - -namespace Bonsai.ML.Torch -{ - /// - /// Provides helper methods to parse tensor indexes. - /// - public static class IndexHelper - { - - /// - /// Parses the input string into an array of tensor indexes. - /// - /// - public static TensorIndex[] Parse(string input) - { - if (string.IsNullOrEmpty(input)) - { - return [0]; - } - - var indexStrings = input.Split(','); - var indices = new TensorIndex[indexStrings.Length]; - - for (int i = 0; i < indexStrings.Length; i++) - { - var indexString = indexStrings[i].Trim(); - if (int.TryParse(indexString, out int intIndex)) - { - indices[i] = TensorIndex.Single(intIndex); - } - else if (indexString == ":") - { - indices[i] = TensorIndex.Colon; - } - else if (indexString == "None") - { - indices[i] = TensorIndex.None; - } - else if (indexString == "...") - { - indices[i] = TensorIndex.Ellipsis; - } - else if (indexString.ToLower() == "false" || indexString.ToLower() == "true") - { - indices[i] = TensorIndex.Bool(indexString.ToLower() == "true"); - } - else if (indexString.Contains(":")) - { - var rangeParts = indexString.Split(':'); - if (rangeParts.Length == 0) - { - indices[i] = TensorIndex.Slice(); - } - else if (rangeParts.Length == 1) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0])); - } - else if (rangeParts.Length == 2) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); - } - else if (rangeParts.Length == 3) - { - indices[i] = TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - else - { - throw new Exception($"Invalid index format: {indexString}"); - } - } - return indices; - } - - /// - /// Serializes the input array of tensor indexes into a string representation. - /// - /// - /// - public static string Serialize(TensorIndex[] indexes) - { - return string.Join(", ", indexes); - } - } -} \ No newline at end of file From b2bebb80a0b515ad138bd1532cb5a3e883add9cd Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 17:49:02 +0000 Subject: [PATCH 57/75] Updated MNIST model architecture with correct fully connected layer size --- src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index b707e2d5..994aca73 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -15,9 +15,9 @@ public class MNIST : Module private Module conv1 = Conv2d(1, 32, 3); private Module conv2 = Conv2d(32, 64, 3); private Module fc1 = Linear(9216, 128); - private Module fc2 = Linear(128, 10); - - private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 }); + private Module fc2 = Linear(128, 128); + + private Module pool1 = MaxPool2d(kernelSize: [2, 2]); private Module relu1 = ReLU(); private Module relu2 = ReLU(); From e2614259d2c76922f429100ddc6782b024c39185 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:09:43 +0000 Subject: [PATCH 58/75] Added class to explicitly create a clone of a tensor --- src/Bonsai.ML.Torch/Clone.cs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Clone.cs diff --git a/src/Bonsai.ML.Torch/Clone.cs b/src/Bonsai.ML.Torch/Clone.cs new file mode 100644 index 00000000..b8dc15fd --- /dev/null +++ b/src/Bonsai.ML.Torch/Clone.cs @@ -0,0 +1,25 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch; + +/// +/// Clones the input tensor. +/// +[Combinator] +[Description("Clones the input tensor.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Clone +{ + /// + /// Clones the input tensor. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => tensor.clone()); + } +} \ No newline at end of file From 98a65fda8bb1ddbe45cf38c041df012457245f79 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:10:39 +0000 Subject: [PATCH 59/75] Update to use correct width for an IplImage based on widthstep rather than width --- src/Bonsai.ML.Torch/OpenCVHelper.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs index 1ca049c9..a45e2228 100644 --- a/src/Bonsai.ML.Torch/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -82,10 +82,10 @@ public static Tensor ToTensor(IplImage image) { return empty([ 0, 0, 0 ]); } - - int width = image.Width; + // int width = image.Width; int height = image.Height; int channels = image.Channels; + var width = image.WidthStep / channels; var iplDepth = image.Depth; var tensorType = bitDepthLookup.FirstOrDefault(x => x.Value.IplDepth == iplDepth).Key; From 4e2ae5378893b017a1cb863df0566f16c9089caa Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:16:14 +0000 Subject: [PATCH 60/75] Explicitly use static torch.Tensor type for defining expressions --- src/Bonsai.ML.Torch/CreateTensor.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 2bddaa46..4436db0b 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -124,7 +124,7 @@ private Expression BuildTensorFromArray(Array arrayValues, Type returnType) Expression.Constant(null, typeof(string).MakeArrayType()) ); - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var tensorVariable = Expression.Variable(typeof(torch.Tensor), "tensor"); var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); var buildTensor = Expression.Block( @@ -185,7 +185,7 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp tensorCreationMethodArguments ); - var tensorVariable = Expression.Variable(typeof(Tensor), "tensor"); + var tensorVariable = Expression.Variable(typeof(torch.Tensor), "tensor"); var assignTensor = Expression.Assign(tensorVariable, tensorAssignment); var buildTensor = Expression.Block( From 12ec59e5a3239e25df19d98cfba1dabc0d63c4ab Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:18:21 +0000 Subject: [PATCH 61/75] Update set with process overloads to handle passing in tensor index --- src/Bonsai.ML.Torch/Set.cs | 85 +++++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 29 deletions(-) diff --git a/src/Bonsai.ML.Torch/Set.cs b/src/Bonsai.ML.Torch/Set.cs index 0e2965c6..18dcc02a 100644 --- a/src/Bonsai.ML.Torch/Set.cs +++ b/src/Bonsai.ML.Torch/Set.cs @@ -6,40 +6,67 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch +namespace Bonsai.ML.Torch; + +/// +/// Sets the value of the input tensor at the specified index. +/// +[Combinator] +[Description("Sets the value of the input tensor at the specified index.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Set { /// - /// Sets the value of the input tensor at the specified index. + /// The index at which to set the value. + /// + [Description("The index at which to set the value.")] + public string Index { get; set; } = string.Empty; + + /// + /// The value to set at the specified index. /// - [Combinator] - [Description("Sets the value of the input tensor at the specified index.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Set + [XmlIgnore] + [Description("The value to set at the specified index.")] + public Tensor Value { get; set; } = null; + + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable source) { - /// - /// The index at which to set the value. - /// - [Description("The index at which to set the value.")] - public string Index { get; set; } = string.Empty; + return source.Select(tensor => { + var indexes = Torch.Index.IndexHelper.Parse(Index); + return tensor.index_put_(Value, indexes); + }); + } - /// - /// The value to set at the specified index. - /// - [XmlIgnore] - [Description("The value to set at the specified index.")] - public Tensor Value { get; set; } = null; + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => { + var tensor = input.Item1; + var index = input.Item2; + return tensor.index_put_(Value, index); + }); + } - /// - /// Returns an observable sequence that sets the value of the input tensor at the specified index. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => { - var indexes = IndexHelper.Parse(Index); - return tensor.index_put_(Value, indexes); - }); - } + /// + /// Returns an observable sequence that sets the value of the input tensor at the specified index. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => { + var tensor = input.Item1; + var indexes = input.Item2; + return tensor.index_put_(Value, indexes); + }); } } \ No newline at end of file From 5630f01ce5ebc942eb5cf42ccc17aff65616cc6c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:18:57 +0000 Subject: [PATCH 62/75] Fixed incorrectly generating ones instead of zeros --- src/Bonsai.ML.Torch/Zeros.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/Zeros.cs b/src/Bonsai.ML.Torch/Zeros.cs index e4fb3c7a..e99bdce6 100644 --- a/src/Bonsai.ML.Torch/Zeros.cs +++ b/src/Bonsai.ML.Torch/Zeros.cs @@ -25,7 +25,7 @@ public class Zeros /// public IObservable Process() { - return Observable.Defer(() => Observable.Return(ones(Size))); + return Observable.Defer(() => Observable.Return(zeros(Size))); } /// @@ -36,7 +36,7 @@ public IObservable Process() public IObservable Process(IObservable source) { return source.Select(value => { - return ones(Size); + return zeros(Size); }); } } From 69efcbf1f01d5bc0c230b82a49ed6fe4c5a34bb7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:21:37 +0000 Subject: [PATCH 63/75] Updated to correctly parse colons in string --- src/Bonsai.ML.Torch/Index/IndexHelper.cs | 38 +++++++++--------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/Bonsai.ML.Torch/Index/IndexHelper.cs b/src/Bonsai.ML.Torch/Index/IndexHelper.cs index a6f60fb3..b62c1c2c 100644 --- a/src/Bonsai.ML.Torch/Index/IndexHelper.cs +++ b/src/Bonsai.ML.Torch/Index/IndexHelper.cs @@ -1,6 +1,5 @@ using System; -using System.Linq; -using System.Reactive.Linq; +using System.Collections.Generic; using TorchSharp; namespace Bonsai.ML.Torch.Index; @@ -50,32 +49,23 @@ public static torch.TensorIndex[] Parse(string input) } else if (indexString.Contains(":")) { - var rangeParts = indexString.Split(':'); - rangeParts = [.. rangeParts.Where(p => { - p = p.Trim(); - return !string.IsNullOrEmpty(p); - })]; - - if (rangeParts.Length == 0) - { - indices[i] = torch.TensorIndex.Slice(); - } - else if (rangeParts.Length == 1) - { - indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0])); - } - else if (rangeParts.Length == 2) - { - indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1])); - } - else if (rangeParts.Length == 3) + string[] rangeParts = [.. indexString.Split(':')]; + var argsList = new List([null, null, null]); + try { - indices[i] = torch.TensorIndex.Slice(int.Parse(rangeParts[0]), int.Parse(rangeParts[1]), int.Parse(rangeParts[2])); + for (int j = 0; j < rangeParts.Length; j++) + { + if (!string.IsNullOrEmpty(rangeParts[j])) + { + argsList[j] = long.Parse(rangeParts[j]); + } + } } - else + catch (Exception) { throw new Exception($"Invalid index format: {indexString}"); } + indices[i] = torch.TensorIndex.Slice(argsList[0], argsList[1], argsList[2]); } else { @@ -84,7 +74,7 @@ public static torch.TensorIndex[] Parse(string input) } return indices; } - + /// /// Serializes the input array of tensor indexes into a string representation. /// From 556e1adce57d90aba3382f616dcae5463181f960 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:30:36 +0000 Subject: [PATCH 64/75] Updated to use collection expressions --- src/Bonsai.ML.Torch/CreateTensor.cs | 12 ++++++------ src/Bonsai.ML.Torch/OpenCVHelper.cs | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 4436db0b..46e1f611 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -24,7 +24,7 @@ namespace Bonsai.ML.Torch [WorkflowElementCategory(ElementCategory.Source)] public class CreateTensor : ExpressionBuilder { - readonly Range argumentRange = new Range(0, 1); + readonly Range argumentRange = new(0, 1); /// public override Range ArgumentRange => argumentRange; @@ -171,14 +171,14 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp ] ); - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( + tensorCreationMethodArguments = [.. tensorCreationMethodArguments.Prepend( Expression.Constant(scalarType, typeof(ScalarType?)) - ).ToArray(); + )]; } - tensorCreationMethodArguments = tensorCreationMethodArguments.Prepend( - tensorDataInitializationBlock - ).ToArray(); + tensorCreationMethodArguments = [.. tensorCreationMethodArguments.Prepend( + tensorDataInitializationBlock + )]; var tensorAssignment = Expression.Call( tensorCreationMethodInfo, diff --git a/src/Bonsai.ML.Torch/OpenCVHelper.cs b/src/Bonsai.ML.Torch/OpenCVHelper.cs index a45e2228..9c7fc0e8 100644 --- a/src/Bonsai.ML.Torch/OpenCVHelper.cs +++ b/src/Bonsai.ML.Torch/OpenCVHelper.cs @@ -141,7 +141,7 @@ public unsafe static IplImage ToImage(Tensor tensor) var tensorType = tensor.dtype; var iplDepth = bitDepthLookup[tensorType].IplDepth; - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + var new_tensor = zeros([height, width, channels], tensorType).copy_(tensor); var res = THSTensor_data(new_tensor.Handle); var image = new IplImage(new OpenCV.Net.Size(width, height), iplDepth, channels, res); @@ -163,7 +163,7 @@ public unsafe static Mat ToMat(Tensor tensor) var tensorType = tensor.dtype; var depth = bitDepthLookup[tensorType].Depth; - var new_tensor = zeros(new long[] { height, width, channels }, tensorType).copy_(tensor); + var new_tensor = zeros([height, width, channels], tensorType).copy_(tensor); var res = THSTensor_data(new_tensor.Handle); var mat = new Mat(new OpenCV.Net.Size(width, height), depth, channels, res); From f2ceb9d8fd9d38f9a51665860e84e5101b0cb9dc Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 20 Jan 2025 18:31:04 +0000 Subject: [PATCH 65/75] Added documentation --- src/Bonsai.ML.Torch/Sum.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Bonsai.ML.Torch/Sum.cs b/src/Bonsai.ML.Torch/Sum.cs index d01efb95..1e4c1a2c 100644 --- a/src/Bonsai.ML.Torch/Sum.cs +++ b/src/Bonsai.ML.Torch/Sum.cs @@ -5,6 +5,9 @@ namespace Bonsai.ML.Torch { + /// + /// Computes the sum of the input tensor elements along the specified dimensions. + /// [Combinator] [Description("Computes the sum of the input tensor elements along the specified dimensions.")] [WorkflowElementCategory(ElementCategory.Transform)] From ff8e902e4e83a6116004cb9fdb4bafca49c6a268 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 11:50:59 +0000 Subject: [PATCH 66/75] Updated to use collection expressions --- src/Bonsai.ML.Torch/CreateTensor.cs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/Bonsai.ML.Torch/CreateTensor.cs b/src/Bonsai.ML.Torch/CreateTensor.cs index 46e1f611..52d8de1a 100644 --- a/src/Bonsai.ML.Torch/CreateTensor.cs +++ b/src/Bonsai.ML.Torch/CreateTensor.cs @@ -74,15 +74,13 @@ public Device Device private Expression BuildTensorFromArray(Array arrayValues, Type returnType) { var rank = arrayValues.Rank; - var lengths = Enumerable.Range(0, rank) - .Select(arrayValues.GetLength) - .ToArray(); + int[] lengths = [.. Enumerable.Range(0, rank).Select(arrayValues.GetLength)]; - var arrayCreationExpression = Expression.NewArrayBounds(returnType, lengths.Select(len => Expression.Constant(len)).ToArray()); + var arrayCreationExpression = Expression.NewArrayBounds(returnType, [.. lengths.Select(len => Expression.Constant(len))]); var arrayVariable = Expression.Variable(arrayCreationExpression.Type, "array"); var assignArray = Expression.Assign(arrayVariable, arrayCreationExpression); - var assignments = new List(); + List assignments = []; for (int i = 0; i < values.Length; i++) { var indices = new Expression[rank]; @@ -201,13 +199,11 @@ private Expression BuildTensorFromScalarValue(object scalarValue, Type returnTyp public override Expression Build(IEnumerable arguments) { var returnType = ScalarTypeLookup.GetTypeFromScalarType(scalarType); - var argTypes = arguments.Select(arg => arg.Type).ToArray(); + Type[] argTypes = [.. arguments.Select(arg => arg.Type)]; Type[] methodInfoArgumentTypes = [typeof(Tensor)]; - var methods = typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance) - .Where(m => m.Name == "Process") - .ToArray(); + MethodInfo[] methods = [.. typeof(CreateTensor).GetMethods(BindingFlags.Public | BindingFlags.Instance).Where(m => m.Name == "Process")]; var methodInfo = arguments.Count() > 0 ? methods.FirstOrDefault(m => m.IsGenericMethod) .MakeGenericMethod( From 189aff12a4641821800658351eb103ed37cdb126 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 11:51:24 +0000 Subject: [PATCH 67/75] Add process overload to initialize device on input --- src/Bonsai.ML.Torch/InitializeTorchDevice.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs index b8ce574c..a598b794 100644 --- a/src/Bonsai.ML.Torch/InitializeTorchDevice.cs +++ b/src/Bonsai.ML.Torch/InitializeTorchDevice.cs @@ -32,5 +32,17 @@ public IObservable Process() return Observable.Return(new Device(DeviceType)); }); } + + /// + /// Initializes the Torch device when the input sequence produces an element. + /// + /// + public IObservable Process(IObservable source) + { + return source.Select((_) => { + InitializeDeviceType(DeviceType); + return new Device(DeviceType); + }); + } } } \ No newline at end of file From 9bb600fa8ff31a222d2fd91119ceba89c7373bc2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 11:52:22 +0000 Subject: [PATCH 68/75] Added documentation --- src/Bonsai.ML.Torch/NeuralNets/Backward.cs | 21 ++++++++- src/Bonsai.ML.Torch/NeuralNets/Forward.cs | 17 ++++++-- .../NeuralNets/ITorchModule.cs | 14 +++++- .../NeuralNets/LoadModuleFromArchitecture.cs | 26 ++++++++++- .../NeuralNets/LoadScriptModule.cs | 17 +++++++- src/Bonsai.ML.Torch/NeuralNets/Loss.cs | 12 +++--- .../NeuralNets/Models/AlexNet.cs | 28 ++++++------ .../NeuralNets/Models/MNIST.cs | 43 +++++++++++-------- .../NeuralNets/Models/MobileNet.cs | 38 +++++++++------- .../NeuralNets/Models/ModelArchitecture.cs | 14 ++++++ src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs | 12 +++--- src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs | 22 ++++++++-- .../NeuralNets/TorchModuleAdapter.cs | 25 ++++++++--- src/Bonsai.ML.Torch/Vision/Normalize.cs | 20 ++++++++- 14 files changed, 232 insertions(+), 77 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs index 5fa4902e..328c35ba 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs @@ -8,18 +8,35 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Trains a model using backpropagation. + /// [Combinator] - [Description("")] + [Description("Trains a model using backpropagation.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Backward { + /// + /// The optimizer to use for training. + /// public Optimizer Optimizer { get; set; } + /// + /// The model to train. + /// [XmlIgnore] public ITorchModule Model { get; set; } + /// + /// The loss function to use for training. + /// public Loss Loss { get; set; } + /// + /// Trains the model using backpropagation. + /// + /// + /// public IObservable Process(IObservable> source) { optim.Optimizer optimizer = null; @@ -47,7 +64,7 @@ public IObservable Process(IObservable> source) { optimizer.zero_grad(); - var prediction = Model.forward(data); + var prediction = Model.Forward(data); var output = loss.forward(prediction, target); output.backward(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs index 3aae4012..175ed3c0 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs @@ -3,23 +3,32 @@ using System.Reactive.Linq; using static TorchSharp.torch; using System.Xml.Serialization; -using TorchSharp.Modules; -using TorchSharp; namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Runs forward inference on the input tensor using the specified model. + /// [Combinator] - [Description("")] + [Description("Runs forward inference on the input tensor using the specified model.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Forward { + /// + /// The model to use for inference. + /// [XmlIgnore] public ITorchModule Model { get; set; } + /// + /// Runs forward inference on the input tensor using the specified model. + /// + /// + /// public IObservable Process(IObservable source) { Model.Module.eval(); - return source.Select(Model.forward); + return source.Select(Model.Forward); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs index e7ebf994..5cde6f73 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs @@ -2,9 +2,21 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents an interface for a Torch module. + /// public interface ITorchModule { + /// + /// The module. + /// public nn.Module Module { get; } - public Tensor forward(Tensor tensor); + + /// + /// Runs forward inference on the input tensor using the specified model. + /// + /// + /// + public Tensor Forward(Tensor tensor); } } diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs index 43d74544..8276156f 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs @@ -6,20 +6,39 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Loads a neural network module from a specified architecture. + /// [Combinator] - [Description("")] + [Description("Loads a neural network module from a specified architecture.")] [WorkflowElementCategory(ElementCategory.Source)] public class LoadModuleFromArchitecture { + /// + /// The model architecture to load. + /// + [Description("The model architecture to load.")] public Models.ModelArchitecture ModelArchitecture { get; set; } + /// + /// The device on which to load the model. + /// + [Description("The device on which to load the model.")] [XmlIgnore] public Device Device { get; set; } + /// + /// The optional path to the model weights file. + /// + [Description("The optional path to the model weights file.")] [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelWeightsPath { get; set; } private int numClasses = 10; + /// + /// The number of classes in the dataset. + /// + [Description("The number of classes in the dataset.")] public int NumClasses { get => numClasses; @@ -36,6 +55,11 @@ public int NumClasses } } + /// + /// Loads the neural network module from the specified architecture. + /// + /// + /// public IObservable Process() { var modelArchitecture = ModelArchitecture.ToString().ToLower(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs index 7e6c73fe..fb3b2b78 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadScriptModule.cs @@ -6,18 +6,33 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Loads a TorchScript module from the specified file path. + /// [Combinator] - [Description("")] + [Description("Loads a TorchScript module from the specified file path.")] [WorkflowElementCategory(ElementCategory.Source)] public class LoadScriptModule { + /// + /// The device on which to load the model. + /// + [Description("The device on which to load the model.")] [XmlIgnore] public Device Device { get; set; } = CPU; + /// + /// The path to the TorchScript model file. + /// + [Description("The path to the TorchScript model file.")] [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelPath { get; set; } + /// + /// Loads the TorchScript module from the specified file path. + /// + /// public IObservable Process() { var scriptModule = jit.load(ModelPath, Device); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Loss.cs b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs index ff003019..376139c1 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Loss.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Loss.cs @@ -1,13 +1,13 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using static TorchSharp.torch.optim; - namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents a loss function that computes the loss value for a given input and target tensor. + /// public enum Loss { + /// + /// Computes the negative log likelihood loss. + /// NLLLoss, } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs index 4ca9f79c..2ded685d 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -1,13 +1,6 @@ -using System; -using System.IO; -using System.Linq; -using System.Collections.Generic; -using System.Diagnostics; - using TorchSharp; using static TorchSharp.torch; using static TorchSharp.torch.nn; -using static TorchSharp.torch.nn.functional; namespace Bonsai.ML.Torch.NeuralNets.Models { @@ -20,24 +13,30 @@ public class AlexNet : Module private readonly Module avgPool; private readonly Module classifier; + /// + /// Constructs a new AlexNet model. + /// + /// + /// + /// public AlexNet(string name, int numClasses, Device device = null) : base(name) { features = Sequential( ("c1", Conv2d(3, 64, kernelSize: 3, stride: 2, padding: 1)), ("r1", ReLU(inplace: true)), - ("mp1", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("mp1", MaxPool2d(kernelSize: [ 2, 2 ])), ("c2", Conv2d(64, 192, kernelSize: 3, padding: 1)), ("r2", ReLU(inplace: true)), - ("mp2", MaxPool2d(kernelSize: new long[] { 2, 2 })), + ("mp2", MaxPool2d(kernelSize: [ 2, 2 ])), ("c3", Conv2d(192, 384, kernelSize: 3, padding: 1)), ("r3", ReLU(inplace: true)), ("c4", Conv2d(384, 256, kernelSize: 3, padding: 1)), ("r4", ReLU(inplace: true)), ("c5", Conv2d(256, 256, kernelSize: 3, padding: 1)), ("r5", ReLU(inplace: true)), - ("mp3", MaxPool2d(kernelSize: new long[] { 2, 2 }))); + ("mp3", MaxPool2d(kernelSize: [ 2, 2 ]))); - avgPool = AdaptiveAvgPool2d(new long[] { 2, 2 }); + avgPool = AdaptiveAvgPool2d([ 2, 2 ]); classifier = Sequential( ("d1", Dropout()), @@ -56,12 +55,17 @@ public AlexNet(string name, int numClasses, Device device = null) : base(name) this.to(device); } + /// + /// Forward pass of the AlexNet model. + /// + /// + /// public override Tensor forward(Tensor input) { var f = features.forward(input); var avg = avgPool.forward(f); - var x = avg.view(new long[] { avg.shape[0], 256 * 2 * 2 }); + var x = avg.view([ avg.shape[0], 256 * 2 * 2 ]); return classifier.forward(x); } diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index 994aca73..32d4bf8a 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -1,34 +1,36 @@ -using System; -using System.IO; -using System.Collections.Generic; -using System.Diagnostics; using TorchSharp; using static TorchSharp.torch; - using static TorchSharp.torch.nn; -using static TorchSharp.torch.nn.functional; namespace Bonsai.ML.Torch.NeuralNets.Models { + /// + /// Represents a simple convolutional neural network for the MNIST dataset. + /// public class MNIST : Module { - private Module conv1 = Conv2d(1, 32, 3); - private Module conv2 = Conv2d(32, 64, 3); - private Module fc1 = Linear(9216, 128); - private Module fc2 = Linear(128, 128); + private readonly Module conv1 = Conv2d(1, 32, 3); + private readonly Module conv2 = Conv2d(32, 64, 3); + private readonly Module fc1 = Linear(9216, 128); + private readonly Module fc2 = Linear(128, 128); - private Module pool1 = MaxPool2d(kernelSize: [2, 2]); + private readonly Module pool1 = MaxPool2d(kernelSize: [2, 2]); - private Module relu1 = ReLU(); - private Module relu2 = ReLU(); - private Module relu3 = ReLU(); + private readonly Module relu1 = ReLU(); + private readonly Module relu2 = ReLU(); + private readonly Module relu3 = ReLU(); - private Module dropout1 = Dropout(0.25); - private Module dropout2 = Dropout(0.5); + private readonly Module dropout1 = Dropout(0.25); + private readonly Module dropout2 = Dropout(0.5); - private Module flatten = Flatten(); - private Module logsm = LogSoftmax(1); + private readonly Module flatten = Flatten(); + private readonly Module logsm = LogSoftmax(1); + /// + /// Constructs a new MNIST model. + /// + /// + /// public MNIST(string name, Device device = null) : base(name) { RegisterComponents(); @@ -37,6 +39,11 @@ public MNIST(string name, Device device = null) : base(name) this.to(device); } + /// + /// Forward pass of the MNIST model. + /// + /// + /// public override Tensor forward(Tensor input) { var l11 = conv1.forward(input); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index f82a33f9..6ede9818 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -7,33 +7,34 @@ namespace Bonsai.ML.Torch.NeuralNets.Models { /// - /// Modified version of MobileNet to classify CIFAR10 32x32 images. + /// MobileNet model. /// - /// - /// With an unaugmented CIFAR-10 data set, the author of this saw training converge - /// at roughly 75% accuracy on the test set, over the course of 1500 epochs. - /// public class MobileNet : Module { - // The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenet.py - // Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE - - private readonly long[] planes = new long[] { 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 }; - private readonly long[] strides = new long[] { 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 }; + private readonly long[] planes = [ 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 ]; + private readonly long[] strides = [ 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 ]; private readonly Module layers; + /// + /// Constructs a new MobileNet model. + /// + /// + /// + /// + /// public MobileNet(string name, int numClasses, Device device = null) : base(name) { if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length."); - var modules = new List<(string, Module)>(); - - modules.Add(($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false))); - modules.Add(($"bnrm2d-first", BatchNorm2d(32))); - modules.Add(($"relu-first", ReLU())); + var modules = new List<(string, Module)> + { + ($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false)), + ($"bnrm2d-first", BatchNorm2d(32)), + ($"relu-first", ReLU()) + }; MakeLayers(modules, 32); - modules.Add(("avgpool", AvgPool2d(new long[] { 2, 2 }))); + modules.Add(("avgpool", AvgPool2d([2, 2]))); modules.Add(("flatten", Flatten())); modules.Add(($"linear", Linear(planes[planes.Length-1], numClasses))); @@ -63,6 +64,11 @@ private void MakeLayers(List<(string, Module)> modules, long in_ } } + /// + /// Forward pass of the MobileNet model. + /// + /// + /// public override Tensor forward(Tensor input) { return layers.forward(input); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs index 0a221b5c..98a30216 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/ModelArchitecture.cs @@ -1,9 +1,23 @@ namespace Bonsai.ML.Torch.NeuralNets.Models { + /// + /// Represents the architecture of a neural network model. + /// public enum ModelArchitecture { + /// + /// The AlexNet model architecture. + /// AlexNet, + + /// + /// The MobileNet model architecture. + /// MobileNet, + + /// + /// The MNIST model architecture. + /// MNIST } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs index 3c0d4bb7..4ab09dbd 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Optimizer.cs @@ -1,13 +1,13 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using static TorchSharp.torch; -using static TorchSharp.torch.optim; - namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents an optimizer that updates the parameters of a neural network. + /// public enum Optimizer { + /// + /// Adam optimizer. + /// Adam, } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs index 314e9f3f..3d5ffd97 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs @@ -1,23 +1,37 @@ using System; using System.ComponentModel; using System.Reactive.Linq; -using static TorchSharp.torch; using System.Xml.Serialization; -using TorchSharp.Modules; -using TorchSharp; namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Saves the model to a file. + /// [Combinator] - [Description("")] + [Description("Saves the model to a file.")] [WorkflowElementCategory(ElementCategory.Sink)] public class SaveModel { + /// + /// The model to save. + /// + [Description("The model to save.")] [XmlIgnore] public ITorchModule Model { get; set; } + /// + /// The path to save the model. + /// + [Description("The path to save the model.")] public string ModelPath { get; set; } + /// + /// Saves the model to the specified file path. + /// + /// + /// + /// public IObservable Process(IObservable source) { return source.Do(input => { diff --git a/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs index a1c44d96..3ec35071 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/TorchModuleAdapter.cs @@ -4,33 +4,48 @@ namespace Bonsai.ML.Torch.NeuralNets { + /// + /// Represents a torch module adapter that wraps a torch module or script module. + /// public class TorchModuleAdapter : ITorchModule { private readonly nn.Module _module = null; private readonly jit.ScriptModule _scriptModule = null; - private Func forwardFunc; + private readonly Func _forwardFunc; + /// + /// The module. + /// public nn.Module Module { get; } + /// + /// Initializes a new instance of the class. + /// + /// public TorchModuleAdapter(nn.Module module) { _module = module; - forwardFunc = _module.forward; + _forwardFunc = _module.forward; Module = _module; } + /// + /// Initializes a new instance of the class. + /// + /// public TorchModuleAdapter(jit.ScriptModule scriptModule) { _scriptModule = scriptModule; - forwardFunc = _scriptModule.forward; + _forwardFunc = _scriptModule.forward; Module = _scriptModule; } - public Tensor forward(Tensor input) + /// + public Tensor Forward(Tensor input) { - return forwardFunc(input); + return _forwardFunc(input); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Vision/Normalize.cs b/src/Bonsai.ML.Torch/Vision/Normalize.cs index f4fc80a3..60b87c44 100644 --- a/src/Bonsai.ML.Torch/Vision/Normalize.cs +++ b/src/Bonsai.ML.Torch/Vision/Normalize.cs @@ -8,21 +8,39 @@ namespace Bonsai.ML.Torch.Vision { + /// + /// Normalizes the input tensor with the mean and standard deviation. + /// [Combinator] [Description("Normalizes the input tensor with the mean and standard deviation.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Normalize { + /// + /// The mean values for each channel. + /// + [Description("The mean values for each channel.")] public double[] Means { get; set; } = [ 0.1307 ]; + + /// + /// The standard deviation values for each channel. + /// + [Description("The standard deviation values for each channel.")] public double[] StdDevs { get; set; } = [ 0.3081 ]; + private ITransform transform = null; + /// + /// Normalizes the input tensor with the mean and standard deviation. + /// + /// + /// public IObservable Process(IObservable source) { return source.Select(tensor => { transform ??= transforms.Normalize(Means, StdDevs, tensor.dtype, tensor.device); return transform.call(tensor); - }); + }).Finally(() => transform = null); } } } \ No newline at end of file From 4d3297d45ac1f35355c67c594c786b8bc9ed6f32 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 12:57:40 +0000 Subject: [PATCH 69/75] Added file name editor attribute to model path --- src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs index 3d5ffd97..c426aedf 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/SaveModel.cs @@ -24,6 +24,7 @@ public class SaveModel /// The path to save the model. /// [Description("The path to save the model.")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string ModelPath { get; set; } /// From 6969a7caa4c887461def43a09b6794a1352d67d2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 13:52:45 +0000 Subject: [PATCH 70/75] Updated to latest torchsharp version --- src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj | 4 ++-- src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs | 16 ++++++++-------- src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs | 2 +- .../NeuralNets/Models/MobileNet.cs | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj index 97bfe18c..3a2f0298 100644 --- a/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj +++ b/src/Bonsai.ML.Torch/Bonsai.ML.Torch.csproj @@ -8,8 +8,8 @@ - - + + diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs index 2ded685d..c80a3d50 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -22,19 +22,19 @@ public class AlexNet : Module public AlexNet(string name, int numClasses, Device device = null) : base(name) { features = Sequential( - ("c1", Conv2d(3, 64, kernelSize: 3, stride: 2, padding: 1)), + ("c1", Conv2d(3, 64, kernel_size: 3, stride: 2, padding: 1)), ("r1", ReLU(inplace: true)), - ("mp1", MaxPool2d(kernelSize: [ 2, 2 ])), - ("c2", Conv2d(64, 192, kernelSize: 3, padding: 1)), + ("mp1", MaxPool2d(kernel_size: [ 2, 2 ])), + ("c2", Conv2d(64, 192, kernel_size: 3, padding: 1)), ("r2", ReLU(inplace: true)), - ("mp2", MaxPool2d(kernelSize: [ 2, 2 ])), - ("c3", Conv2d(192, 384, kernelSize: 3, padding: 1)), + ("mp2", MaxPool2d(kernel_size: [ 2, 2 ])), + ("c3", Conv2d(192, 384, kernel_size: 3, padding: 1)), ("r3", ReLU(inplace: true)), - ("c4", Conv2d(384, 256, kernelSize: 3, padding: 1)), + ("c4", Conv2d(384, 256, kernel_size: 3, padding: 1)), ("r4", ReLU(inplace: true)), - ("c5", Conv2d(256, 256, kernelSize: 3, padding: 1)), + ("c5", Conv2d(256, 256, kernel_size: 3, padding: 1)), ("r5", ReLU(inplace: true)), - ("mp3", MaxPool2d(kernelSize: [ 2, 2 ]))); + ("mp3", MaxPool2d(kernel_size: [ 2, 2 ]))); avgPool = AdaptiveAvgPool2d([ 2, 2 ]); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index 32d4bf8a..e5895d41 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -14,7 +14,7 @@ public class MNIST : Module private readonly Module fc1 = Linear(9216, 128); private readonly Module fc2 = Linear(128, 128); - private readonly Module pool1 = MaxPool2d(kernelSize: [2, 2]); + private readonly Module pool1 = MaxPool2d(kernel_size: [2, 2]); private readonly Module relu1 = ReLU(); private readonly Module relu2 = ReLU(); diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index 6ede9818..0faa1062 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -29,7 +29,7 @@ public MobileNet(string name, int numClasses, Device device = null) : base(name) var modules = new List<(string, Module)> { - ($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false)), + ($"conv2d-first", Conv2d(3, 32, kernel_size: 3, stride: 1, padding: 1, bias: false)), ($"bnrm2d-first", BatchNorm2d(32)), ($"relu-first", ReLU()) }; @@ -53,10 +53,10 @@ private void MakeLayers(List<(string, Module)> modules, long in_ var out_planes = planes[i]; var stride = strides[i]; - modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernelSize: 3, stride: stride, padding: 1, groups: in_planes, bias: false))); + modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernel_size: 3, stride: stride, padding: 1, groups: in_planes, bias: false))); modules.Add(($"bnrm2d-{i}a", BatchNorm2d(in_planes))); modules.Add(($"relu-{i}a", ReLU())); - modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernelSize: 1L, stride: 1L, padding: 0L, bias: false))); + modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernel_size: 1L, stride: 1L, padding: 0L, bias: false))); modules.Add(($"bnrm2d-{i}b", BatchNorm2d(out_planes))); modules.Add(($"relu-{i}b", ReLU())); From e35980f896b03e0bea8185adaca8219946b2a829 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Jan 2025 14:09:12 +0000 Subject: [PATCH 71/75] Update class name to reflect file name --- src/Bonsai.ML.Torch/Index/SingleIndex.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch/Index/SingleIndex.cs b/src/Bonsai.ML.Torch/Index/SingleIndex.cs index 9b3ec641..e2f5decd 100644 --- a/src/Bonsai.ML.Torch/Index/SingleIndex.cs +++ b/src/Bonsai.ML.Torch/Index/SingleIndex.cs @@ -12,7 +12,7 @@ namespace Bonsai.ML.Torch.Index; [Combinator] [Description("Represents an index that selects a single value of a tensor.")] [WorkflowElementCategory(ElementCategory.Source)] -public class CreateTensorIndexSingle +public class SingleIndex { /// /// Gets or sets the index value used to select a single element from a tensor. From 4465a3cf73c493448ace81b016040735801bebb9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 23 Jan 2025 16:49:39 +0000 Subject: [PATCH 72/75] Added documentation to package --- README.md | 3 +++ docs/articles/Torch/torch-getting-started.md | 13 ++++++++++ docs/articles/Torch/torch-overview.md | 27 ++++++++++++++++++++ docs/articles/toc.yml | 7 ++++- 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 docs/articles/Torch/torch-getting-started.md create mode 100644 docs/articles/Torch/torch-overview.md diff --git a/README.md b/README.md index cd875504..18ff09bf 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,9 @@ Facilitates inference using Hidden Markov Models (HMMs). It interfaces with the ### Bonsai.ML.HiddenMarkovModels.Design Visualizers and editor features for the HiddenMarkovModels package. +### Bonsai.ML.Torch +Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the torch library. Provides tooling for manipulating tensors, performing linear algebra, training and inference with deep neural networks, and more. + > [!NOTE] > Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation. diff --git a/docs/articles/Torch/torch-getting-started.md b/docs/articles/Torch/torch-getting-started.md new file mode 100644 index 00000000..6a673b9b --- /dev/null +++ b/docs/articles/Torch/torch-getting-started.md @@ -0,0 +1,13 @@ +# Getting Started + +The aim of the `Bonsai.ML.Torch` package is to integrate the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the powerful libtorch library, into Bonsai. In the current version, the package primarily provides tooling and functionality for users to interact with and manipulate `Tensor`s, the core data type of libtorch which underlies many of the advanced torch operations. Additionally, the package provides some capabilities for defining neural network architectures, running forward inference, and learning via back propogation. + +## Tensor Operations +The package provides several ways to work with tensors. Users can initialize tensors, (`Ones`, `Zeros`, etc.), create tensors from .NET data types, (`ToTensor`), and define custom tensors using Python-like syntax (`CreateTensor`). Tensors can be converted back to .NET types using the `ToArray` node (for flattening tensors into a single array) or the `ToNDArray` node (for preserving multidimensional array shapes). Furthermore, the `Tensor` data types contains many extension methods which can be used via scripting, such as using `ExpressionTransform` (for example, it.sum() to sum a tensor, or it.T to transpose), and works with overloaded operators, for example, `Zip` -> `Multiply`. Thus, `ExpressionTransform` can also be used to access individual elements of a tensor, using the syntax `it.data.ReadCpuT(0)` where `T` is a primitive .NET data type. + + +## Running on the GPU +Users must be explicit about running tensors on the GPU. First, the `InitializeDeviceType` node must run with a CUDA-compatible GPU. Afterwards, tensors are moved to the GPU using the `ToDevice` node. Converting tensors back to .NET data types requires moving the tensor back to the CPU before converting. + +## Neural Networks +The package provides initial support for working with torch `Module`s, the conventional object for deep neural networks. The `LoadModuleFromArchitecture` node allows users to select from a list of common architectures, and can optionally load in pretrained weights from disk. Additionally, the package supports loading `TorchScript` modules with the `LoadScriptModule` node, which enables users to use torch modules saved in the `.pt` file format. Users can then use the `Forward` node to run inference and the `Backward` node to run back propogation. \ No newline at end of file diff --git a/docs/articles/Torch/torch-overview.md b/docs/articles/Torch/torch-overview.md new file mode 100644 index 00000000..4ca6cd08 --- /dev/null +++ b/docs/articles/Torch/torch-overview.md @@ -0,0 +1,27 @@ +# Bonsai.ML.Torch Overview + +The Torch package provides a Bonsai interface to interact with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# implementation of the torch library. + +## General Guide + +The Bonsai.ML.Torch package can be installed through the Bonsai Package Manager and depends on the TorchSharp library. Additionally, running the package requires installing the specific torch DLLs needed for your desired application. The steps for installing are outlined below. + +### Running on the CPU +For running the package using the CPU, the `TorchSharp-cpu` library can be installed though the `nuget` package source. + +### Running on the GPU +To run torch on the GPU, you first need to ensure that you have a CUDA compatible device installed on your system. + +Next, you must follow the [CUDA installation guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) or the [guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html). Make sure to install the correct `CUDA v12.1` version [found here](https://developer.nvidia.com/cuda-12-1-0-download-archive). Ensure that you have the correct CUDA version (v12.1) installed, as `TorchSharp` currently only supports this version. + +Next, you need to install the `cuDNN v9` library following the [guide for Windows](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/windows.html) or the [guide for Linux](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/linux.html). Again, you need to ensure you have the correct version installed (v9). You should consult [nvidia's support matrix](https://docs.nvidia.com/deeplearning/cudnn/latest/reference/support-matrix.html) to ensure the versions of CUDA and cuDNN you installed are compatible with your specific OS, graphics driver, and hardware. + +Once complete, you need to install the cuda-compatible torch libraries and place them into the correct location. You can download the libraries from [the pytorch website](https://pytorch.org/get-started/locally/) with the following options selected: + +- PyTorch Build: Stable (2.5.1) +- OS: [Your OS] +- Package: LibTorch +- Language: C++/Java +- Compute Platform: CUDA 12.1 + +Finally, extract the zip folder and copy all of the DLLs into the `Extensions` folder of your bonsai installation directory. \ No newline at end of file diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index e22b0b80..625cfcc3 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -13,4 +13,9 @@ - name: Overview href: HiddenMarkovModels/hmm-overview.md - name: Getting Started - href: HiddenMarkovModels/hmm-getting-started.md \ No newline at end of file + href: HiddenMarkovModels/hmm-getting-started.md +- name: Torch +- name: Overview + href: Torch/torch-overview.md +- name: Getting Started + href: Torch/torch-getting-started.md \ No newline at end of file From da81e8bcfb86341c83837ad17e6048ead9601dd4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 23 Jan 2025 18:38:50 +0000 Subject: [PATCH 73/75] Fixed issue with MNIST model not accepting num classes --- .../NeuralNets/LoadModuleFromArchitecture.cs | 2 +- .../NeuralNets/Models/MNIST.cs | 44 +++++++++++++------ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs index 8276156f..ac791d8d 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/LoadModuleFromArchitecture.cs @@ -69,7 +69,7 @@ public IObservable Process() { "alexnet" => new Models.AlexNet(modelArchitecture, numClasses, device), "mobilenet" => new Models.MobileNet(modelArchitecture, numClasses, device), - "mnist" => new Models.MNIST(modelArchitecture, device), + "mnist" => new Models.MNIST(modelArchitecture, numClasses, device), _ => throw new ArgumentException($"Model {modelArchitecture} not supported.") }; diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index e5895d41..8a5f84db 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -9,30 +9,48 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// public class MNIST : Module { - private readonly Module conv1 = Conv2d(1, 32, 3); - private readonly Module conv2 = Conv2d(32, 64, 3); - private readonly Module fc1 = Linear(9216, 128); - private readonly Module fc2 = Linear(128, 128); + private readonly Module conv1; + private readonly Module conv2; + private readonly Module fc1; + private readonly Module fc2; - private readonly Module pool1 = MaxPool2d(kernel_size: [2, 2]); + private readonly Module pool1; - private readonly Module relu1 = ReLU(); - private readonly Module relu2 = ReLU(); - private readonly Module relu3 = ReLU(); + private readonly Module relu1; + private readonly Module relu2; + private readonly Module relu3; - private readonly Module dropout1 = Dropout(0.25); - private readonly Module dropout2 = Dropout(0.5); + private readonly Module dropout1; + private readonly Module dropout2; - private readonly Module flatten = Flatten(); - private readonly Module logsm = LogSoftmax(1); + private readonly Module flatten; + private readonly Module logsm; /// /// Constructs a new MNIST model. /// /// + /// /// - public MNIST(string name, Device device = null) : base(name) + public MNIST(string name, int numClasses, Device device = null) : base(name) { + conv1 = Conv2d(1, 32, 3); + conv2 = Conv2d(32, 64, 3); + fc1 = Linear(9216, 128); + fc2 = Linear(128, numClasses); + + pool1 = MaxPool2d(kernel_size: [2, 2]); + + relu1 = ReLU(); + relu2 = ReLU(); + relu3 = ReLU(); + + dropout1 = Dropout(0.25); + dropout2 = Dropout(0.5); + + flatten = Flatten(); + logsm = LogSoftmax(1); + RegisterComponents(); if (device != null && device.type != DeviceType.CPU) From d06a4c28765d0742f6da016158c590dabb9f6cef Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 24 Jan 2025 09:02:13 +0000 Subject: [PATCH 74/75] Made slight correction to GPU documentation --- docs/articles/Torch/torch-overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/articles/Torch/torch-overview.md b/docs/articles/Torch/torch-overview.md index 4ca6cd08..0884ea01 100644 --- a/docs/articles/Torch/torch-overview.md +++ b/docs/articles/Torch/torch-overview.md @@ -24,4 +24,4 @@ Once complete, you need to install the cuda-compatible torch libraries and place - Language: C++/Java - Compute Platform: CUDA 12.1 -Finally, extract the zip folder and copy all of the DLLs into the `Extensions` folder of your bonsai installation directory. \ No newline at end of file +Finally, extract the zip folder and copy the contents of the `lib` folder into the `Extensions` folder of your bonsai installation directory. \ No newline at end of file From 139488e0afd28f1460f8295219175949cdd32603 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Jan 2025 19:07:00 +0000 Subject: [PATCH 75/75] Modified torch module classes to be internel --- src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs | 2 +- src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs | 2 +- src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs index c80a3d50..c3d19d55 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/AlexNet.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// /// Modified version of original AlexNet to fix CIFAR10 32x32 images. /// - public class AlexNet : Module + internal class AlexNet : Module { private readonly Module features; private readonly Module avgPool; diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs index 8a5f84db..8bd3e0a4 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MNIST.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// /// Represents a simple convolutional neural network for the MNIST dataset. /// - public class MNIST : Module + internal class MNIST : Module { private readonly Module conv1; private readonly Module conv2; diff --git a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs index 0faa1062..a5f7701a 100644 --- a/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs +++ b/src/Bonsai.ML.Torch/NeuralNets/Models/MobileNet.cs @@ -9,7 +9,7 @@ namespace Bonsai.ML.Torch.NeuralNets.Models /// /// MobileNet model. /// - public class MobileNet : Module + internal class MobileNet : Module { private readonly long[] planes = [ 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 ]; private readonly long[] strides = [ 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 ];