Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the new System.Numerics.Tensors as an input/output type when using dotnet 8.0 and up. #23261

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="MSBuild.Sdk.Extras/3.0.22">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<!--- packaging properties -->
<OrtPackageId Condition="'$(OrtPackageId)' == ''">Microsoft.ML.OnnxRuntime</OrtPackageId>
Expand Down Expand Up @@ -184,6 +184,10 @@
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
</ItemGroup>

<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<PackageReference Include="System.Numerics.Tensors" Version="9.0.0" />
</ItemGroup>

<!-- debug output - makes finding/fixing any issues with the the conditions easy. -->
<Target Name="DumpValues" BeforeTargets="PreBuildEvent">
<Message Text="SolutionName='$(SolutionName)'" />
Expand Down
156 changes: 156 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;

#if NET8_0_OR_GREATER
using DotnetTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
Expand Down Expand Up @@ -205,6 +213,34 @@ public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : unmanaged
return MemoryMarshal.Cast<byte, T>(byteSpan);
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// Returns a ReadOnlyTensorSpan<typeparamref name="T"/> over tensor native buffer that
/// provides a read-only view.
///
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
/// To get memory descriptor use GetTensorMemoryInfo().
///
/// OrtValue must contain a non-string tensor.
/// The span is valid as long as the OrtValue instance is alive (not disposed).
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
/// <exception cref="OnnxRuntimeException"></exception>
public DotnetTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where T : unmanaged
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new DotnetTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []);
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// Returns a Span<typeparamref name="T"/> over tensor native buffer.
/// This enables you to safely and efficiently modify the underlying
Expand All @@ -225,6 +261,33 @@ public Span<T> GetTensorMutableDataAsSpan<T>() where T : unmanaged
return MemoryMarshal.Cast<byte, T>(byteSpan);
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
/// <summary>
/// Returns a TensorSpan<typeparamref name="T"/> over tensor native buffer.
///
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
/// To get memory descriptor use GetTensorMemoryInfo().
///
/// OrtValue must contain a non-string tensor.
/// The span is valid as long as the OrtValue instance is alive (not disposed).
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
/// <exception cref="OnnxRuntimeException"></exception>
public DotnetTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new DotnetTensors.TensorSpan<T>(typeSpan, nArray, []);
}
#pragma warning restore SYSLIB5001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
#endif

/// <summary>
/// Provides mutable raw native buffer access.
/// </summary>
Expand All @@ -234,6 +297,24 @@ public Span<byte> GetTensorMutableRawData()
return GetTensorBufferRawData(typeof(byte));
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
/// <summary>
/// Provides mutable raw native buffer access.
/// </summary>
/// <returns>TensorSpan over the native buffer bytes</returns>
public DotnetTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new DotnetTensors.TensorSpan<byte>(byteSpan, nArray, []);
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// Fetch string tensor element buffer pointer at the specified index,
/// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance.
Expand Down Expand Up @@ -605,6 +686,81 @@ public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape) wh
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<T>(data), shape);
}

#if NET8_0_OR_GREATER
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
/// <summary>
/// This is a factory method creates a native Onnxruntime OrtValue containing a tensor.
/// The method will attempt to pin managed memory so no copying occurs when data is passed down
/// to native code.
/// </summary>
/// <param name="value">Tensor object</param>
/// <param name="elementType">discovered tensor element type</param>
/// <returns>And instance of OrtValue constructed on top of the object</returns>
public static OrtValue CreateTensorValueFromDotnetTensorObject<T>(DotnetTensors.Tensor<T> tensor) where T : unmanaged
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
{
if (!IsContiguousAndDense(tensor))
{
var newTensor = DotnetTensors.Tensor.Create<T>(tensor.Lengths);
tensor.CopyTo(newTensor);
tensor = newTensor;
}
unsafe
{
var field = tensor.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic).Where(x => x.Name == "_values").FirstOrDefault();
var backingData = (T[])field.GetValue(tensor);
GCHandle handle = GCHandle.Alloc(backingData, GCHandleType.Pinned);
var memHandle = new MemoryHandle(Unsafe.AsPointer(ref tensor.GetPinnableReference()), handle);

try
{
IntPtr dataBufferPointer = IntPtr.Zero;
unsafe
{
dataBufferPointer = (IntPtr)memHandle.Pointer;
}

var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T);
long[] shape = Array.ConvertAll(tensor.Lengths.ToArray(), new Converter<nint, long>(x => (long)x));

var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");

NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue(
OrtMemoryInfo.DefaultInstance.Pointer,
dataBufferPointer,
(UIntPtr)(bufferLengthInBytes),
shape,
(UIntPtr)tensor.Rank,
typeInfo.ElementType,
out IntPtr nativeValue));

return new OrtValue(nativeValue, memHandle);
}
catch (Exception)
{
memHandle.Dispose();
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
throw;
}
}
}

private static bool IsContiguousAndDense<T>(DotnetTensors.Tensor<T> tensor)
{
// Right most dimension must be 1 for a dense tensor.
if (tensor.Strides[^1] != 1)
return false;

// For other dimensions, the stride must be equal to the product of the dimensions to the right.
for (int i = tensor.Rank - 2; i >= 0; i--)
{
if (tensor.Strides[i] != TensorPrimitives.Product(tensor.Lengths.Slice(i + 1, tensor.Lengths.Length - i - 1)))
return false;
}
return true;
}
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
#endif

/// <summary>
/// The factory API creates an OrtValue with memory allocated using the given allocator
/// according to the specified shape and element type. The memory will be released when OrtValue
Expand Down
Loading
Loading