diff --git a/Kaitai.Struct.Runtime.Async.Tests/CancelableTestsBase.cs b/Kaitai.Struct.Runtime.Async.Tests/CancelableTestsBase.cs new file mode 100644 index 0000000..8904346 --- /dev/null +++ b/Kaitai.Struct.Runtime.Async.Tests/CancelableTestsBase.cs @@ -0,0 +1,54 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Kaitai.Struct.Runtime.Async.Tests +{ + public abstract class CancelableTestsBase + { + protected readonly CancellationToken CancellationToken; + + protected CancelableTestsBase(bool isTestingCancellation) + { + CancellationToken = new CancellationToken(isTestingCancellation); + } + + protected async Task Evaluate(Func assertFunc) + { + if (CancellationToken.IsCancellationRequested) + { + await Assert.ThrowsAsync(assertFunc); + } + else + { + await assertFunc(); + } + } + + protected async Task EvaluateMaybeCancelled(Func assertFunc) + { + try + { + await assertFunc(); + } + catch (TaskCanceledException) + { + } + } + + protected async Task Evaluate(Func assertFunc) where TExpectedException : Exception + { + try + { + await assertFunc(); + } + catch (TaskCanceledException) + { + } + catch (TExpectedException) + { + } + } + } +} \ No newline at end of file diff --git a/Kaitai.Struct.Runtime.Async.Tests/KaitaiAsyncStreamBaseTests.cs b/Kaitai.Struct.Runtime.Async.Tests/KaitaiAsyncStreamBaseTests.cs index 0acf91a..b1b5d77 100644 --- a/Kaitai.Struct.Runtime.Async.Tests/KaitaiAsyncStreamBaseTests.cs +++ b/Kaitai.Struct.Runtime.Async.Tests/KaitaiAsyncStreamBaseTests.cs @@ -1,4 +1,5 @@ -using System.IO; +using System; +using System.IO; using System.IO.Pipelines; using System.Threading.Tasks; using Kaitai.Async; @@ -8,17 +9,48 @@ namespace Kaitai.Struct.Runtime.Async.Tests { public class StreamKaitaiAsyncStreamBaseTests : KaitaiAsyncStreamBaseTests { + public StreamKaitaiAsyncStreamBaseTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); } public class PipeReaderKaitaiAsyncStreamBaseTests : KaitaiAsyncStreamBaseTests { + public PipeReaderKaitaiAsyncStreamBaseTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); } - public abstract class KaitaiAsyncStreamBaseTests + public class StreamKaitaiAsyncStreamBaseCancelledTests : KaitaiAsyncStreamBaseTests + { + public StreamKaitaiAsyncStreamBaseCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); + } + + public class PipeReaderKaitaiAsyncStreamBaseCancelledTests : KaitaiAsyncStreamBaseTests { + public PipeReaderKaitaiAsyncStreamBaseCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); + } + + public abstract class KaitaiAsyncStreamBaseTests : CancelableTestsBase + { + protected KaitaiAsyncStreamBaseTests(bool isTestingCancellation) : base(isTestingCancellation) + { + } + protected abstract KaitaiAsyncStream Create(byte[] data); [Theory] @@ -35,19 +67,22 @@ public abstract class KaitaiAsyncStreamBaseTests public async Task Eof_Test(bool shouldBeEof, int streamSize, int readBitsAmount) { var kaitaiStreamSUT = Create(new byte[streamSize]); - await kaitaiStreamSUT.ReadBitsIntAsync(readBitsAmount); - long positionBeforeIsEof = kaitaiStreamSUT.Pos; - - if (shouldBeEof) + await EvaluateMaybeCancelled(async () => { - Assert.True(kaitaiStreamSUT.IsEof); - } - else - { - Assert.False(kaitaiStreamSUT.IsEof); - } - - Assert.Equal(positionBeforeIsEof, kaitaiStreamSUT.Pos); + await kaitaiStreamSUT.ReadBitsIntAsync(readBitsAmount); + long positionBeforeIsEof = kaitaiStreamSUT.Pos; + + if (shouldBeEof) + { + Assert.True(kaitaiStreamSUT.IsEof); + } + else + { + Assert.False(kaitaiStreamSUT.IsEof); + } + + Assert.Equal(positionBeforeIsEof, kaitaiStreamSUT.Pos); + }); } [Theory] @@ -57,9 +92,12 @@ public async Task Pos_ByRead_Test(int expectedPos, int readBitsAmount) { var kaitaiStreamSUT = Create(new byte[1]); - await kaitaiStreamSUT.ReadBytesAsync(readBitsAmount); + await EvaluateMaybeCancelled(async () => + { + await kaitaiStreamSUT.ReadBytesAsync(readBitsAmount); - Assert.Equal(expectedPos, kaitaiStreamSUT.Pos); + Assert.Equal(expectedPos, kaitaiStreamSUT.Pos); + }); } [Theory] @@ -69,9 +107,12 @@ public async Task Pos_BySeek_Test(int expectedPos, int position) { var kaitaiStreamSUT = Create(new byte[1]); - await kaitaiStreamSUT.SeekAsync(position); + await EvaluateMaybeCancelled(async () => + { + await kaitaiStreamSUT.SeekAsync(position); - Assert.Equal(expectedPos, kaitaiStreamSUT.Pos); + Assert.Equal(expectedPos, kaitaiStreamSUT.Pos); + }); } [Theory] diff --git a/Kaitai.Struct.Runtime.Async.Tests/ReadBytesAsyncTests.cs b/Kaitai.Struct.Runtime.Async.Tests/ReadBytesAsyncTests.cs index 1de942b..6b903c2 100644 --- a/Kaitai.Struct.Runtime.Async.Tests/ReadBytesAsyncTests.cs +++ b/Kaitai.Struct.Runtime.Async.Tests/ReadBytesAsyncTests.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -11,19 +12,48 @@ namespace Kaitai.Struct.Runtime.Async.Tests { public class StreamReadBytesAsyncTests : ReadBytesAsyncTests { + public StreamReadBytesAsyncTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); } public class PipeReaderReadBytesAsyncTests : ReadBytesAsyncTests { + public PipeReaderReadBytesAsyncTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => - new KaitaiAsyncStream(System.IO.Pipelines.PipeReader.Create(new MemoryStream(data))); + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); } - public abstract class ReadBytesAsyncTests + public class StreamReadBytesAsyncCancelledTests : ReadBytesAsyncTests { - protected abstract KaitaiAsyncStream Create(byte[] data); - + public StreamReadBytesAsyncCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); + } + + public class PipeReaderReadBytesAsyncCancelledTests : ReadBytesAsyncTests + { + public PipeReaderReadBytesAsyncCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); + } + + public abstract class ReadBytesAsyncTests : CancelableTestsBase + { + protected ReadBytesAsyncTests(bool isTestingCancellation) : base(isTestingCancellation) + { + } + public static IEnumerable BytesData => new List<(byte[] streamContent, int bytesCount)> { @@ -33,6 +63,39 @@ public abstract class ReadBytesAsyncTests (new byte[] {0b_1101_0101, 0b_1101_0101}, 2) }.Select(t => new object[] {t.streamContent, t.bytesCount}); + public static IEnumerable StringData => + new List + { + "", + "ABC", + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + }.Select(t => new[] {Encoding.ASCII.GetBytes(t)}); + + public static IEnumerable StringWithTerminatorsData => + new List<(string streamContent, string expected, char terminator, bool isPresent, bool shouldInclude)> + { + ("", "", '\0', false, false), + ("", "", '\0', false, true), + + ("ABC", "ABC", '\0', false, false), + ("ABC", "ABC", '\0', false, true), + + ("ABC", "", 'A', true, false), + ("ABC", "A", 'A', true, true), + + ("ABC", "A", 'B', true, false), + ("ABC", "AB", 'B', true, true), + + ("ABC", "AB", 'C', true, false), + ("ABC", "ABC", 'C', true, true) + }.Select(t => new[] + { + Encoding.ASCII.GetBytes(t.streamContent), Encoding.ASCII.GetBytes(t.expected), (object) (byte) t.terminator, + t.isPresent, t.shouldInclude + }); + + protected abstract KaitaiAsyncStream Create(byte[] data); + [Theory] [MemberData(nameof(BytesData))] @@ -40,7 +103,9 @@ public async Task ReadBytesAsync_long_Test(byte[] streamContent, long bytesCount { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(streamContent.Take((int) bytesCount), await kaitaiStreamSUT.ReadBytesAsync(bytesCount)); + await Evaluate(async () => + Assert.Equal(streamContent.Take((int) bytesCount), + await kaitaiStreamSUT.ReadBytesAsync(bytesCount, CancellationToken))); } [Theory] @@ -49,17 +114,11 @@ public async Task ReadBytesAsync_ulong_Test(byte[] streamContent, ulong bytesCou { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(streamContent.Take((int) bytesCount), await kaitaiStreamSUT.ReadBytesAsync(bytesCount)); + await Evaluate(async () => + Assert.Equal(streamContent.Take((int) bytesCount), + await kaitaiStreamSUT.ReadBytesAsync(bytesCount, CancellationToken))); } - public static IEnumerable StringData => - new List - { - "", - "ABC", - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - }.Select(t => new[] {Encoding.ASCII.GetBytes(t)}); - [Theory] [MemberData(nameof(StringData))] @@ -67,7 +126,8 @@ public async Task ReadBytesFullAsync_Test(byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(streamContent, await kaitaiStreamSUT.ReadBytesFullAsync()); + await Evaluate(async () => + Assert.Equal(streamContent, await kaitaiStreamSUT.ReadBytesFullAsync(CancellationToken))); } [Theory] @@ -76,7 +136,8 @@ public async Task EnsureFixedContentsAsync_Test(byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(streamContent, await kaitaiStreamSUT.EnsureFixedContentsAsync(streamContent)); + await Evaluate(async () => + Assert.Equal(streamContent, await kaitaiStreamSUT.EnsureFixedContentsAsync(streamContent, CancellationToken))); } [Theory] @@ -93,32 +154,10 @@ public async Task EnsureFixedContentsAsync_ThrowsIfByteIsChanged(byte[] streamCo var expected = streamContent.ToArray(); expected[0] = (byte) ~expected[0]; - await Assert.ThrowsAsync(async () => await kaitaiStreamSUT.EnsureFixedContentsAsync(expected)); + await Evaluate(async () => + await kaitaiStreamSUT.EnsureFixedContentsAsync(expected, CancellationToken)); } - public static IEnumerable StringWithTerminatorsData => - new List<(string streamContent, string expected, char terminator, bool isPresent, bool shouldInclude)> - { - ("", "", '\0', false, false), - ("", "", '\0', false, true), - - ("ABC", "ABC", '\0', false, false), - ("ABC", "ABC", '\0', false, true), - - ("ABC", "", 'A', true, false), - ("ABC", "A", 'A', true, true), - - ("ABC", "A", 'B', true, false), - ("ABC", "AB", 'B', true, true), - - ("ABC", "AB", 'C', true, false), - ("ABC", "ABC", 'C', true, true) - }.Select(t => new[] - { - Encoding.ASCII.GetBytes(t.streamContent), Encoding.ASCII.GetBytes(t.expected), (object) (byte) t.terminator, - t.isPresent, t.shouldInclude - }); - [Theory] [MemberData(nameof(StringWithTerminatorsData))] public async Task ReadBytesTermAsync(byte[] streamContent, @@ -129,7 +168,8 @@ public async Task ReadBytesTermAsync(byte[] streamContent, { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(expected, await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, false, false)); + await Evaluate(async () => Assert.Equal(expected, + await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, false, false, CancellationToken))); } [Theory] @@ -147,8 +187,8 @@ public async Task ReadBytesTermAsync_ThrowsIsTerminatorNotPresent(byte[] streamC return; } - await Assert.ThrowsAsync(async () => - await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, false, true)); + await Evaluate(async () => + await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, false, true, CancellationToken)); } [Theory] @@ -162,17 +202,20 @@ public async Task ReadBytesTermAsync_ShouldNotConsumeTerminator(byte[] streamCon //Arrange var kaitaiStreamSUT = Create(streamContent); - //Act - await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, false, false); - - //Assert - int amountToConsume = expected.Length; - if (expected.Length > 0 && shouldInclude && terminatorIsPresent) + await Evaluate(async () => { - amountToConsume--; - } + //Act + await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, false, false, CancellationToken); - Assert.Equal(amountToConsume, kaitaiStreamSUT.Pos); + //Assert + int amountToConsume = expected.Length; + if (expected.Length > 0 && shouldInclude && terminatorIsPresent) + { + amountToConsume--; + } + + Assert.Equal(amountToConsume, kaitaiStreamSUT.Pos); + }); } [Theory] @@ -186,17 +229,20 @@ public async Task ReadBytesTermAsync_ShouldConsumeTerminator(byte[] streamConten //Arrange var kaitaiStreamSUT = Create(streamContent); - //Act - await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, true, false); - - //Assert - int amountToConsume = expected.Length; - if (!shouldInclude && terminatorIsPresent) + await Evaluate(async () => { - amountToConsume++; - } + //Act + await kaitaiStreamSUT.ReadBytesTermAsync(terminator, shouldInclude, true, false, CancellationToken); + + //Assert + int amountToConsume = expected.Length; + if (!shouldInclude && terminatorIsPresent) + { + amountToConsume++; + } - Assert.Equal(amountToConsume, kaitaiStreamSUT.Pos); + Assert.Equal(amountToConsume, kaitaiStreamSUT.Pos); + }); } [Fact] @@ -204,8 +250,7 @@ public async Task ReadBytesAsyncLong_LargerThanBufferInvoke_ThrowsArgumentOutOfR { var kaitaiStreamSUT = Create(new byte[0]); - await Assert.ThrowsAsync(async () => - await kaitaiStreamSUT.ReadBytesAsync(1)); + await Evaluate(async () => await kaitaiStreamSUT.ReadBytesAsync(1, CancellationToken)); } [Fact] @@ -213,8 +258,8 @@ public async Task ReadBytesAsyncLong_LargerThanInt32Invoke_ThrowsArgumentOutOfRa { var kaitaiStreamSUT = Create(new byte[0]); - await Assert.ThrowsAsync(async () => - await kaitaiStreamSUT.ReadBytesAsync((long) int.MaxValue + 1)); + await Evaluate(async () => + await kaitaiStreamSUT.ReadBytesAsync((long) int.MaxValue + 1, CancellationToken)); } [Fact] @@ -222,8 +267,8 @@ public async Task ReadBytesAsyncLong_NegativeInvoke_ThrowsArgumentOutOfRangeExce { var kaitaiStreamSUT = Create(new byte[0]); - await Assert.ThrowsAsync(async () => - await kaitaiStreamSUT.ReadBytesAsync(-1)); + await Evaluate(async () => + await kaitaiStreamSUT.ReadBytesAsync(-1, CancellationToken)); } [Fact] @@ -231,8 +276,8 @@ public async Task ReadBytesAsyncULong_LargerThanBufferInvoke_ThrowsArgumentOutOf { var kaitaiStreamSUT = Create(new byte[0]); - await Assert.ThrowsAsync(async () => - await kaitaiStreamSUT.ReadBytesAsync((ulong) 1)); + await Evaluate(async () => + await kaitaiStreamSUT.ReadBytesAsync((ulong) 1, CancellationToken)); } [Fact] @@ -240,8 +285,8 @@ public async Task ReadBytesAsyncULong_LargerThanInt32Invoke_ThrowsArgumentOutOfR { var kaitaiStreamSUT = Create(new byte[0]); - await Assert.ThrowsAsync(async () => - await kaitaiStreamSUT.ReadBytesAsync((ulong) int.MaxValue + 1)); + await Evaluate(async () => + await kaitaiStreamSUT.ReadBytesAsync((ulong) int.MaxValue + 1, CancellationToken)); } } } \ No newline at end of file diff --git a/Kaitai.Struct.Runtime.Async.Tests/ReadDecimalAsyncTests.cs b/Kaitai.Struct.Runtime.Async.Tests/ReadDecimalAsyncTests.cs index dc4f652..3891dea 100644 --- a/Kaitai.Struct.Runtime.Async.Tests/ReadDecimalAsyncTests.cs +++ b/Kaitai.Struct.Runtime.Async.Tests/ReadDecimalAsyncTests.cs @@ -1,4 +1,5 @@ using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Threading.Tasks; using Kaitai.Async; @@ -8,17 +9,48 @@ namespace Kaitai.Struct.Runtime.Async.Tests { public class StreamReadDecimalAsyncTests : ReadDecimalAsyncTests { + public StreamReadDecimalAsyncTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); } public class PipeReaderReadDecimalAsyncTests : ReadDecimalAsyncTests { + public PipeReaderReadDecimalAsyncTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => - new KaitaiAsyncStream(System.IO.Pipelines.PipeReader.Create(new MemoryStream(data))); + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); } - public abstract class ReadDecimalAsyncTests + public class StreamReadDecimalAsyncCancelledTests : ReadDecimalAsyncTests { + public StreamReadDecimalAsyncCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); + } + + public class PipeReaderReadDecimalAsyncCancelledTests : ReadDecimalAsyncTests + { + public PipeReaderReadDecimalAsyncCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); + } + + public abstract class ReadDecimalAsyncTests : CancelableTestsBase + { + protected ReadDecimalAsyncTests(bool isTestingCancellation) : base(isTestingCancellation) + { + } + protected abstract KaitaiAsyncStream Create(byte[] data); [Theory] @@ -27,7 +59,7 @@ public async Task ReadF4beAsync_Test(float expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal(expected, await kaitaiStreamSUT.ReadF4beAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadF4beAsync(CancellationToken))); } [Theory] @@ -36,7 +68,7 @@ public async Task ReadF4leAsync_Test(float expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(expected, await kaitaiStreamSUT.ReadF4leAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadF4leAsync(CancellationToken))); } [Theory] @@ -45,7 +77,7 @@ public async Task ReadF8beAsync_Test(double expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal(expected, await kaitaiStreamSUT.ReadF8beAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadF8beAsync(CancellationToken))); } [Theory] @@ -54,7 +86,7 @@ public async Task ReadF8leAsync_Test(double expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(expected, await kaitaiStreamSUT.ReadF8leAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadF8leAsync(CancellationToken))); } } } \ No newline at end of file diff --git a/Kaitai.Struct.Runtime.Async.Tests/ReadSignedAsyncTests.cs b/Kaitai.Struct.Runtime.Async.Tests/ReadSignedAsyncTests.cs index 64ba08b..b31f4c9 100644 --- a/Kaitai.Struct.Runtime.Async.Tests/ReadSignedAsyncTests.cs +++ b/Kaitai.Struct.Runtime.Async.Tests/ReadSignedAsyncTests.cs @@ -1,4 +1,5 @@ using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Threading.Tasks; using Kaitai.Async; @@ -8,17 +9,48 @@ namespace Kaitai.Struct.Runtime.Async.Tests { public class StreamReadSignedAsyncTests : ReadSignedAsyncTests { + public StreamReadSignedAsyncTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); } public class PipeReaderReadSignedAsyncTests : ReadSignedAsyncTests { + public PipeReaderReadSignedAsyncTests() : base(false) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => - new KaitaiAsyncStream(System.IO.Pipelines.PipeReader.Create(new MemoryStream(data))); + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); } - public abstract class ReadSignedAsyncTests + public class StreamReadSignedAsyncCancelledTests : ReadSignedAsyncTests { + public StreamReadSignedAsyncCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); + } + + public class PipeReaderReadSignedAsyncCancelledTests : ReadSignedAsyncTests + { + public PipeReaderReadSignedAsyncCancelledTests() : base(true) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); + } + + public abstract class ReadSignedAsyncTests : CancelableTestsBase + { + protected ReadSignedAsyncTests(bool isTestingCancellation) : base(isTestingCancellation) + { + } + protected abstract KaitaiAsyncStream Create(byte[] data); [Theory] @@ -27,7 +59,7 @@ public async Task ReadS1Async_Test(sbyte expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(expected, await kaitaiStreamSUT.ReadS1Async()); + await Evaluate(async ()=>Assert.Equal(expected, await kaitaiStreamSUT.ReadS1Async(CancellationToken))); } [Theory] @@ -36,7 +68,7 @@ public async Task ReadS2beAsync_Test(short expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(expected, await kaitaiStreamSUT.ReadS2beAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadS2beAsync(CancellationToken))); } [Theory] @@ -45,7 +77,7 @@ public async Task ReadS4beAsync_Test(int expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(expected, await kaitaiStreamSUT.ReadS4beAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadS4beAsync(CancellationToken))); } [Theory] @@ -54,7 +86,7 @@ public async Task ReadS8beAsync_Test(long expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal(expected, await kaitaiStreamSUT.ReadS8beAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadS8beAsync(CancellationToken))); } [Theory] @@ -63,7 +95,7 @@ public async Task ReadS2leAsync_Test(short expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal(expected, await kaitaiStreamSUT.ReadS2leAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadS2leAsync(CancellationToken))); } [Theory] @@ -72,7 +104,7 @@ public async Task ReadS4leAsync_Test(int expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal(expected, await kaitaiStreamSUT.ReadS4leAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadS4leAsync(CancellationToken))); } [Theory] @@ -81,7 +113,7 @@ public async Task ReadS8leAsync_Test(long expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal(expected, await kaitaiStreamSUT.ReadS8leAsync()); + await Evaluate(async () => Assert.Equal(expected, await kaitaiStreamSUT.ReadS8leAsync(CancellationToken))); } } } \ No newline at end of file diff --git a/Kaitai.Struct.Runtime.Async.Tests/ReadUnSignedAsyncTests.cs b/Kaitai.Struct.Runtime.Async.Tests/ReadUnSignedAsyncTests.cs index e593961..2c0c441 100644 --- a/Kaitai.Struct.Runtime.Async.Tests/ReadUnSignedAsyncTests.cs +++ b/Kaitai.Struct.Runtime.Async.Tests/ReadUnSignedAsyncTests.cs @@ -1,4 +1,5 @@ using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Threading.Tasks; using Kaitai.Async; @@ -8,17 +9,49 @@ namespace Kaitai.Struct.Runtime.Async.Tests { public class StreamReadUnSignedAsyncTests : ReadUnSignedAsyncTests { + public StreamReadUnSignedAsyncTests() : base(false) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); + } + + public class StreamReadUnSignedAsyncCancelledTests : ReadUnSignedAsyncTests + { + public StreamReadUnSignedAsyncCancelledTests() : base(true) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => new KaitaiAsyncStream(data); } public class PipeReaderReadUnSignedAsyncTests : ReadUnSignedAsyncTests { + public PipeReaderReadUnSignedAsyncTests() : base(false) + { + } + + protected override KaitaiAsyncStream Create(byte[] data) => + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); + } + + public class PipeReaderReadUnSignedAsyncCancelledTests : ReadUnSignedAsyncTests + { + public PipeReaderReadUnSignedAsyncCancelledTests() : base(true) + { + } + protected override KaitaiAsyncStream Create(byte[] data) => - new KaitaiAsyncStream(System.IO.Pipelines.PipeReader.Create(new MemoryStream(data))); + new KaitaiAsyncStream(PipeReader.Create(new MemoryStream(data))); } - public abstract class ReadUnSignedAsyncTests + + public abstract class ReadUnSignedAsyncTests : CancelableTestsBase { + protected ReadUnSignedAsyncTests(bool isTestingCancellation) : base(isTestingCancellation) + { + } + protected abstract KaitaiAsyncStream Create(byte[] data); [Theory] @@ -27,7 +60,7 @@ public async Task ReadU1Async_Test( /*u*/ sbyte expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal((byte) expected, await kaitaiStreamSUT.ReadU1Async()); + await Evaluate(async () => Assert.Equal((byte) expected, await kaitaiStreamSUT.ReadU1Async(CancellationToken))); } [Theory] @@ -36,7 +69,8 @@ public async Task ReadU2beAsync_Test( /*u*/ short expected, byte[] streamContent { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal((ushort) expected, await kaitaiStreamSUT.ReadU2beAsync()); + await Evaluate(async () => + Assert.Equal((ushort) expected, await kaitaiStreamSUT.ReadU2beAsync(CancellationToken))); } [Theory] @@ -45,7 +79,8 @@ public async Task ReadU4beAsync_Test( /*u*/ int expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal((uint) expected, await kaitaiStreamSUT.ReadU4beAsync()); + await Evaluate(async () => + Assert.Equal((uint) expected, await kaitaiStreamSUT.ReadU4beAsync(CancellationToken))); } [Theory] @@ -54,7 +89,8 @@ public async Task ReadU8beAsync_Test( /*u*/ long expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent); - Assert.Equal((ulong) expected, await kaitaiStreamSUT.ReadU8beAsync()); + await Evaluate( + async () => Assert.Equal((ulong) expected, await kaitaiStreamSUT.ReadU8beAsync(CancellationToken))); } [Theory] @@ -63,7 +99,8 @@ public async Task ReadU2leAsync_Test( /*u*/ short expected, byte[] streamContent { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal((ushort) expected, await kaitaiStreamSUT.ReadU2leAsync()); + await Evaluate(async () => + Assert.Equal((ushort) expected, await kaitaiStreamSUT.ReadU2leAsync(CancellationToken))); } [Theory] @@ -72,7 +109,8 @@ public async Task ReadU4leAsync_Test( /*u*/ int expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal((uint) expected, await kaitaiStreamSUT.ReadU4leAsync()); + await Evaluate(async () => + Assert.Equal((uint) expected, await kaitaiStreamSUT.ReadU4leAsync(CancellationToken))); } [Theory] @@ -81,7 +119,8 @@ public async Task ReadU8leAsync_Test( /*u*/ long expected, byte[] streamContent) { var kaitaiStreamSUT = Create(streamContent.Reverse().ToArray()); - Assert.Equal((ulong) expected, await kaitaiStreamSUT.ReadU8leAsync()); + await Evaluate( + async () => Assert.Equal((ulong) expected, await kaitaiStreamSUT.ReadU8leAsync(CancellationToken))); } } } \ No newline at end of file diff --git a/Kaitai.Struct.Runtime.Async/Interface/IKaitaiAsyncStream.cs b/Kaitai.Struct.Runtime.Async/Interface/IKaitaiAsyncStream.cs index 4df5abb..5b58aa7 100644 --- a/Kaitai.Struct.Runtime.Async/Interface/IKaitaiAsyncStream.cs +++ b/Kaitai.Struct.Runtime.Async/Interface/IKaitaiAsyncStream.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System.Threading; +using System.Threading.Tasks; namespace Kaitai.Async { @@ -7,150 +8,153 @@ public interface IKaitaiAsyncStream : IKaitaiStreamBase /// /// Check if the stream position is at the end of the stream /// - ValueTask IsEofAsync(); + ValueTask IsEofAsync(CancellationToken cancellationToken = default); /// /// Get the total length of the stream (ie. file size) /// - ValueTask GetSizeAsync(); + ValueTask GetSizeAsync(CancellationToken cancellationToken = default); /// /// Seek to a specific position from the beginning of the stream /// /// The position to seek to - Task SeekAsync(long position); + /// + Task SeekAsync(long position, CancellationToken cancellationToken = default); /// /// Read a signed byte from the stream /// /// - Task ReadS1Async(); + Task ReadS1Async(CancellationToken cancellationToken = default); /// /// Read a signed short from the stream (big endian) /// /// - Task ReadS2beAsync(); + Task ReadS2beAsync(CancellationToken cancellationToken = default); /// /// Read a signed int from the stream (big endian) /// /// - Task ReadS4beAsync(); + Task ReadS4beAsync(CancellationToken cancellationToken = default); /// /// Read a signed long from the stream (big endian) /// /// - Task ReadS8beAsync(); + Task ReadS8beAsync(CancellationToken cancellationToken = default); /// /// Read a signed short from the stream (little endian) /// /// - Task ReadS2leAsync(); + Task ReadS2leAsync(CancellationToken cancellationToken = default); /// /// Read a signed int from the stream (little endian) /// /// - Task ReadS4leAsync(); + Task ReadS4leAsync(CancellationToken cancellationToken = default); /// /// Read a signed long from the stream (little endian) /// /// - Task ReadS8leAsync(); + Task ReadS8leAsync(CancellationToken cancellationToken = default); /// /// Read an unsigned byte from the stream /// /// - Task ReadU1Async(); + Task ReadU1Async(CancellationToken cancellationToken = default); /// /// Read an unsigned short from the stream (big endian) /// /// - Task ReadU2beAsync(); + Task ReadU2beAsync(CancellationToken cancellationToken = default); /// /// Read an unsigned int from the stream (big endian) /// /// - Task ReadU4beAsync(); + Task ReadU4beAsync(CancellationToken cancellationToken = default); /// /// Read an unsigned long from the stream (big endian) /// /// - Task ReadU8beAsync(); + Task ReadU8beAsync(CancellationToken cancellationToken = default); /// /// Read an unsigned short from the stream (little endian) /// /// - Task ReadU2leAsync(); + Task ReadU2leAsync(CancellationToken cancellationToken = default); /// /// Read an unsigned int from the stream (little endian) /// /// - Task ReadU4leAsync(); + Task ReadU4leAsync(CancellationToken cancellationToken = default); /// /// Read an unsigned long from the stream (little endian) /// /// - Task ReadU8leAsync(); + Task ReadU8leAsync(CancellationToken cancellationToken = default); /// /// Read a single-precision floating point value from the stream (big endian) /// /// - Task ReadF4beAsync(); + Task ReadF4beAsync(CancellationToken cancellationToken = default); /// /// Read a double-precision floating point value from the stream (big endian) /// /// - Task ReadF8beAsync(); + Task ReadF8beAsync(CancellationToken cancellationToken = default); /// /// Read a single-precision floating point value from the stream (little endian) /// /// - Task ReadF4leAsync(); + Task ReadF4leAsync(CancellationToken cancellationToken = default); /// /// Read a double-precision floating point value from the stream (little endian) /// /// - Task ReadF8leAsync(); + Task ReadF8leAsync(CancellationToken cancellationToken = default); - Task ReadBitsIntAsync(int n); + Task ReadBitsIntAsync(int n, CancellationToken cancellationToken = default); - Task ReadBitsIntLeAsync(int n); + Task ReadBitsIntLeAsync(int n, CancellationToken cancellationToken = default); /// /// Read a fixed number of bytes from the stream /// /// The number of bytes to read + /// /// - Task ReadBytesAsync(long count); + Task ReadBytesAsync(long count, CancellationToken cancellationToken = default); /// /// Read a fixed number of bytes from the stream /// /// The number of bytes to read + /// /// - Task ReadBytesAsync(ulong count); + Task ReadBytesAsync(ulong count, CancellationToken cancellationToken = default); /// /// Read all the remaining bytes from the stream until the end is reached /// /// - Task ReadBytesFullAsync(); + Task ReadBytesFullAsync(CancellationToken cancellationToken = default); /// /// Read a terminated string from the stream @@ -163,13 +167,15 @@ public interface IKaitaiAsyncStream : IKaitaiStreamBase Task ReadBytesTermAsync(byte terminator, bool includeTerminator, bool consumeTerminator, - bool eosError); + bool eosError, + CancellationToken cancellationToken = default); /// /// Read a specific set of bytes and assert that they are the same as an expected result /// /// The expected result + /// /// - Task EnsureFixedContentsAsync(byte[] expected); + Task EnsureFixedContentsAsync(byte[] expected, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/Kaitai.Struct.Runtime.Async/Interface/IReaderContext.cs b/Kaitai.Struct.Runtime.Async/Interface/IReaderContext.cs index 4eab4fb..d497968 100644 --- a/Kaitai.Struct.Runtime.Async/Interface/IReaderContext.cs +++ b/Kaitai.Struct.Runtime.Async/Interface/IReaderContext.cs @@ -1,15 +1,16 @@ -using System.Threading.Tasks; +using System.Threading; +using System.Threading.Tasks; namespace Kaitai.Async { public interface IReaderContext { long Position { get; } - ValueTask GetSizeAsync(); - ValueTask IsEofAsync(); - ValueTask SeekAsync(long position); - ValueTask ReadByteAsync(); - ValueTask ReadBytesAsync(long count); - ValueTask ReadBytesFullAsync(); + ValueTask GetSizeAsync(CancellationToken cancellationToken = default); + ValueTask IsEofAsync(CancellationToken cancellationToken = default); + ValueTask SeekAsync(long position, CancellationToken cancellationToken = default); + ValueTask ReadByteAsync(CancellationToken cancellationToken = default); + ValueTask ReadBytesAsync(long count, CancellationToken cancellationToken = default); + ValueTask ReadBytesFullAsync(CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/Kaitai.Struct.Runtime.Async/KaitaiAsyncStream.cs b/Kaitai.Struct.Runtime.Async/KaitaiAsyncStream.cs index 560c0ea..193acc5 100644 --- a/Kaitai.Struct.Runtime.Async/KaitaiAsyncStream.cs +++ b/Kaitai.Struct.Runtime.Async/KaitaiAsyncStream.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.IO; using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; namespace Kaitai.Async @@ -53,12 +54,12 @@ public KaitaiAsyncStream(byte[] bytes) : this(new MemoryStream(bytes)) public override bool IsEof => ReaderContext.IsEofAsync().GetAwaiter().GetResult() && _bitsLeft == 0; - public async ValueTask IsEofAsync() => await ReaderContext.IsEofAsync() && _bitsLeft == 0; + public async ValueTask IsEofAsync(CancellationToken cancellationToken = default) => await ReaderContext.IsEofAsync(cancellationToken) && _bitsLeft == 0; - public ValueTask GetSizeAsync() => ReaderContext.GetSizeAsync(); + public ValueTask GetSizeAsync(CancellationToken cancellationToken = default) => ReaderContext.GetSizeAsync(cancellationToken); - public virtual async Task SeekAsync(long position) => await ReaderContext.SeekAsync(position); - public virtual async Task SeekAsync(ulong position) => await SeekAsync((long)position); + public virtual async Task SeekAsync(long position, CancellationToken cancellationToken = default) => await ReaderContext.SeekAsync(position, cancellationToken); + public virtual async Task SeekAsync(ulong position, CancellationToken cancellationToken = default) => await SeekAsync((long)position, cancellationToken); public override long Pos => ReaderContext.Position; @@ -70,25 +71,25 @@ public KaitaiAsyncStream(byte[] bytes) : this(new MemoryStream(bytes)) #region Signed - public async Task ReadS1Async() => (sbyte) await ReadU1Async(); + public async Task ReadS1Async(CancellationToken cancellationToken = default) => (sbyte) await ReadU1Async(cancellationToken); #region Big-endian - public async Task ReadS2beAsync() => BitConverter.ToInt16(await ReadBytesNormalisedBigEndianAsync(2), 0); + public async Task ReadS2beAsync(CancellationToken cancellationToken = default) => BitConverter.ToInt16(await ReadBytesNormalisedBigEndianAsync(2, cancellationToken), 0); - public async Task ReadS4beAsync() => BitConverter.ToInt32(await ReadBytesNormalisedBigEndianAsync(4), 0); + public async Task ReadS4beAsync(CancellationToken cancellationToken = default) => BitConverter.ToInt32(await ReadBytesNormalisedBigEndianAsync(4, cancellationToken), 0); - public async Task ReadS8beAsync() => BitConverter.ToInt64(await ReadBytesNormalisedBigEndianAsync(8), 0); + public async Task ReadS8beAsync(CancellationToken cancellationToken = default) => BitConverter.ToInt64(await ReadBytesNormalisedBigEndianAsync(8, cancellationToken), 0); #endregion #region Little-endian - public async Task ReadS2leAsync() => BitConverter.ToInt16(await ReadBytesNormalisedLittleEndianAsync(2), 0); + public async Task ReadS2leAsync(CancellationToken cancellationToken = default) => BitConverter.ToInt16(await ReadBytesNormalisedLittleEndianAsync(2, cancellationToken), 0); - public async Task ReadS4leAsync() => BitConverter.ToInt32(await ReadBytesNormalisedLittleEndianAsync(4), 0); + public async Task ReadS4leAsync(CancellationToken cancellationToken = default) => BitConverter.ToInt32(await ReadBytesNormalisedLittleEndianAsync(4, cancellationToken), 0); - public async Task ReadS8leAsync() => BitConverter.ToInt64(await ReadBytesNormalisedLittleEndianAsync(8), 0); + public async Task ReadS8leAsync(CancellationToken cancellationToken = default) => BitConverter.ToInt64(await ReadBytesNormalisedLittleEndianAsync(8, cancellationToken), 0); #endregion @@ -96,26 +97,26 @@ public KaitaiAsyncStream(byte[] bytes) : this(new MemoryStream(bytes)) #region Unsigned - public async Task ReadU1Async() => await ReaderContext.ReadByteAsync(); + public async Task ReadU1Async(CancellationToken cancellationToken = default) => await ReaderContext.ReadByteAsync(cancellationToken); #region Big-endian - public async Task ReadU2beAsync() => BitConverter.ToUInt16(await ReadBytesNormalisedBigEndianAsync(2), 0); + public async Task ReadU2beAsync(CancellationToken cancellationToken = default) => BitConverter.ToUInt16(await ReadBytesNormalisedBigEndianAsync(2, cancellationToken), 0); - public async Task ReadU4beAsync() => BitConverter.ToUInt32(await ReadBytesNormalisedBigEndianAsync(4), 0); + public async Task ReadU4beAsync(CancellationToken cancellationToken = default) => BitConverter.ToUInt32(await ReadBytesNormalisedBigEndianAsync(4, cancellationToken), 0); - public async Task ReadU8beAsync() => BitConverter.ToUInt64(await ReadBytesNormalisedBigEndianAsync(8), 0); + public async Task ReadU8beAsync(CancellationToken cancellationToken = default) => BitConverter.ToUInt64(await ReadBytesNormalisedBigEndianAsync(8, cancellationToken), 0); #endregion #region Little-endian - public async Task ReadU2leAsync() => - BitConverter.ToUInt16(await ReadBytesNormalisedLittleEndianAsync(2), 0); + public async Task ReadU2leAsync(CancellationToken cancellationToken = default) => + BitConverter.ToUInt16(await ReadBytesNormalisedLittleEndianAsync(2, cancellationToken), 0); - public async Task ReadU4leAsync() => BitConverter.ToUInt32(await ReadBytesNormalisedLittleEndianAsync(4), 0); + public async Task ReadU4leAsync(CancellationToken cancellationToken = default) => BitConverter.ToUInt32(await ReadBytesNormalisedLittleEndianAsync(4, cancellationToken), 0); - public async Task ReadU8leAsync() => BitConverter.ToUInt64(await ReadBytesNormalisedLittleEndianAsync(8), 0); + public async Task ReadU8leAsync(CancellationToken cancellationToken = default) => BitConverter.ToUInt64(await ReadBytesNormalisedLittleEndianAsync(8, cancellationToken), 0); #endregion @@ -127,18 +128,18 @@ public async Task ReadU2leAsync() => #region Big-endian - public async Task ReadF4beAsync() => BitConverter.ToSingle(await ReadBytesNormalisedBigEndianAsync(4), 0); + public async Task ReadF4beAsync(CancellationToken cancellationToken = default) => BitConverter.ToSingle(await ReadBytesNormalisedBigEndianAsync(4, cancellationToken), 0); - public async Task ReadF8beAsync() => BitConverter.ToDouble(await ReadBytesNormalisedBigEndianAsync(8), 0); + public async Task ReadF8beAsync(CancellationToken cancellationToken = default) => BitConverter.ToDouble(await ReadBytesNormalisedBigEndianAsync(8, cancellationToken), 0); #endregion #region Little-endian - public async Task ReadF4leAsync() => BitConverter.ToSingle(await ReadBytesNormalisedLittleEndianAsync(4), 0); + public async Task ReadF4leAsync(CancellationToken cancellationToken = default) => BitConverter.ToSingle(await ReadBytesNormalisedLittleEndianAsync(4, cancellationToken), 0); - public async Task ReadF8leAsync() => - BitConverter.ToDouble(await ReadBytesNormalisedLittleEndianAsync(8), 0); + public async Task ReadF8leAsync(CancellationToken cancellationToken = default) => + BitConverter.ToDouble(await ReadBytesNormalisedLittleEndianAsync(8, cancellationToken), 0); #endregion @@ -152,7 +153,7 @@ public override void AlignToByte() _bitsLeft = 0; } - public async Task ReadBitsIntAsync(int n) + public async Task ReadBitsIntAsync(int n, CancellationToken cancellationToken = default) { int bitsNeeded = n - _bitsLeft; if (bitsNeeded > 0) @@ -161,7 +162,7 @@ public async Task ReadBitsIntAsync(int n) // 8 bits => 1 byte // 9 bits => 2 bytes int bytesNeeded = (bitsNeeded - 1) / 8 + 1; - var buf = await ReadBytesAsync(bytesNeeded); + var buf = await ReadBytesAsync(bytesNeeded, cancellationToken); for (var i = 0; i < buf.Length; i++) { _bits <<= 8; @@ -186,7 +187,7 @@ public async Task ReadBitsIntAsync(int n) } //Method ported from algorithm specified @ issue#155 - public async Task ReadBitsIntLeAsync(int n) + public async Task ReadBitsIntLeAsync(int n, CancellationToken cancellationToken = default) { int bitsNeeded = n - _bitsLeft; @@ -196,7 +197,7 @@ public async Task ReadBitsIntLeAsync(int n) // 8 bits => 1 byte // 9 bits => 2 bytes int bytesNeeded = (bitsNeeded - 1) / 8 + 1; - var buf = await ReadBytesAsync(bytesNeeded); + var buf = await ReadBytesAsync(bytesNeeded, cancellationToken); for (var i = 0; i < buf.Length; i++) { ulong v = (ulong) buf[i] << _bitsLeft; @@ -222,9 +223,9 @@ public async Task ReadBitsIntLeAsync(int n) #region Byte arrays - public async Task ReadBytesAsync(long count) => await ReaderContext.ReadBytesAsync(count); + public async Task ReadBytesAsync(long count, CancellationToken cancellationToken = default) => await ReaderContext.ReadBytesAsync(count, cancellationToken); - public async Task ReadBytesAsync(ulong count) + public async Task ReadBytesAsync(ulong count, CancellationToken cancellationToken = default) { if (count > int.MaxValue) { @@ -232,7 +233,7 @@ public async Task ReadBytesAsync(ulong count) $"requested {count} bytes, while only non-negative int32 amount of bytes possible"); } - return await ReadBytesAsync((long) count); + return await ReadBytesAsync((long) count, cancellationToken); } /// @@ -240,9 +241,9 @@ public async Task ReadBytesAsync(ulong count) /// /// The number of bytes to read /// An array of bytes that matches the endianness of the current platform - protected async Task ReadBytesNormalisedLittleEndianAsync(int count) + protected async Task ReadBytesNormalisedLittleEndianAsync(int count, CancellationToken cancellationToken = default) { - var bytes = await ReadBytesAsync(count); + var bytes = await ReadBytesAsync(count, cancellationToken); if (!IsLittleEndian) { Array.Reverse(bytes); @@ -256,9 +257,9 @@ protected async Task ReadBytesNormalisedLittleEndianAsync(int count) /// /// The number of bytes to read /// An array of bytes that matches the endianness of the current platform - protected async Task ReadBytesNormalisedBigEndianAsync(int count) + protected async Task ReadBytesNormalisedBigEndianAsync(int count, CancellationToken cancellationToken = default) { - var bytes = await ReadBytesAsync(count); + var bytes = await ReadBytesAsync(count, cancellationToken); if (IsLittleEndian) { Array.Reverse(bytes); @@ -271,7 +272,7 @@ protected async Task ReadBytesNormalisedBigEndianAsync(int count) /// Read all the remaining bytes from the stream until the end is reached /// /// - public virtual async Task ReadBytesFullAsync() => await ReaderContext.ReadBytesFullAsync(); + public virtual async Task ReadBytesFullAsync(CancellationToken cancellationToken = default) => await ReaderContext.ReadBytesFullAsync(cancellationToken); /// /// Read a terminated string from the stream @@ -280,16 +281,18 @@ protected async Task ReadBytesNormalisedBigEndianAsync(int count) /// True to include the terminator in the returned string /// True to consume the terminator byte before returning /// True to throw an error when the EOS was reached before the terminator + /// /// public async Task ReadBytesTermAsync(byte terminator, bool includeTerminator, bool consumeTerminator, - bool eosError) + bool eosError, + CancellationToken cancellationToken = default) { var bytes = new List(); while (true) { - if (IsEof) + if (await IsEofAsync(cancellationToken)) { if (eosError) { @@ -300,7 +303,7 @@ public async Task ReadBytesTermAsync(byte terminator, break; } - byte b = await ReadU1Async(); + byte b = await ReadU1Async(cancellationToken); if (b == terminator) { if (includeTerminator) @@ -310,7 +313,7 @@ public async Task ReadBytesTermAsync(byte terminator, if (!consumeTerminator) { - await SeekAsync(Pos - 1); + await SeekAsync(Pos - 1, cancellationToken); } break; @@ -326,10 +329,11 @@ public async Task ReadBytesTermAsync(byte terminator, /// Read a specific set of bytes and assert that they are the same as an expected result /// /// The expected result + /// /// - public async Task EnsureFixedContentsAsync(byte[] expected) + public async Task EnsureFixedContentsAsync(byte[] expected, CancellationToken cancellationToken = default) { - var bytes = await ReadBytesAsync(expected.Length); + var bytes = await ReadBytesAsync(expected.Length, cancellationToken); if (bytes.Length != expected.Length) //TODO Is this necessary? { diff --git a/Kaitai.Struct.Runtime.Async/ReaderContext/PipeReaderContext.cs b/Kaitai.Struct.Runtime.Async/ReaderContext/PipeReaderContext.cs index e4731b4..97f832f 100644 --- a/Kaitai.Struct.Runtime.Async/ReaderContext/PipeReaderContext.cs +++ b/Kaitai.Struct.Runtime.Async/ReaderContext/PipeReaderContext.cs @@ -2,6 +2,7 @@ using System.Buffers; using System.IO; using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; namespace Kaitai.Async @@ -20,28 +21,28 @@ public PipeReaderContext(PipeReader pipeReader) public long Position { get; protected set; } - public virtual async ValueTask GetSizeAsync() + public virtual async ValueTask GetSizeAsync(CancellationToken cancellationToken = default) { - await FillReadResultBufferToTheEnd(); + await FillReadResultBufferToTheEndAsync(cancellationToken); return ReadResult.Buffer.Length; } - public virtual async ValueTask IsEofAsync() + public virtual async ValueTask IsEofAsync(CancellationToken cancellationToken = default) { - await EnsureReadResultIsNotDefault(); + await EnsureReadResultIsNotDefaultAsync(cancellationToken); if (Position >= ReadResult.Buffer.Length && !ReadResult.IsCompleted) { PipeReader.AdvanceTo(ReadResult.Buffer.Start, ReadResult.Buffer.GetPosition(Position)); - ReadResult = await PipeReader.ReadAsync(); + ReadResult = await PipeReader.ReadAsync(cancellationToken); } return Position >= ReadResult.Buffer.Length && ReadResult.IsCompleted; } - public virtual async ValueTask SeekAsync(long position) + public virtual async ValueTask SeekAsync(long position, CancellationToken cancellationToken = default) { if (position <= Position) { @@ -49,12 +50,12 @@ public virtual async ValueTask SeekAsync(long position) } else { - await EnsureReadResultIsNotDefault(); + await EnsureReadResultIsNotDefaultAsync(cancellationToken); while (ReadResult.Buffer.Length < position && !ReadResult.IsCompleted) { PipeReader.AdvanceTo(ReadResult.Buffer.Start, ReadResult.Buffer.End); - ReadResult = await PipeReader.ReadAsync(); + ReadResult = await PipeReader.ReadAsync(cancellationToken); } if (ReadResult.Buffer.Length <= position) @@ -71,15 +72,15 @@ public virtual async ValueTask SeekAsync(long position) } } - public virtual async ValueTask ReadByteAsync() + public virtual async ValueTask ReadByteAsync(CancellationToken cancellationToken = default) { - await EnsureReadResultIsNotDefault(); + await EnsureReadResultIsNotDefaultAsync(cancellationToken); var value = byte.MinValue; while (!TryReadByte(out value) && !ReadResult.IsCompleted) { PipeReader.AdvanceTo(ReadResult.Buffer.Start, ReadResult.Buffer.GetPosition(Position)); - ReadResult = await PipeReader.ReadAsync(); + ReadResult = await PipeReader.ReadAsync(cancellationToken); } Position += 1; @@ -93,7 +94,7 @@ bool TryReadByte(out byte readValue) } } - public virtual async ValueTask ReadBytesAsync(long count) + public virtual async ValueTask ReadBytesAsync(long count, CancellationToken cancellationToken = default) { if (count < 0 || count > int.MaxValue) { @@ -101,7 +102,7 @@ public virtual async ValueTask ReadBytesAsync(long count) $"requested {count} bytes, while only non-negative int32 amount of bytes possible"); } - await EnsureReadResultIsNotDefault(); + await EnsureReadResultIsNotDefaultAsync(cancellationToken); byte[] value = null; @@ -114,7 +115,7 @@ public virtual async ValueTask ReadBytesAsync(long count) } PipeReader.AdvanceTo(ReadResult.Buffer.Start, ReadResult.Buffer.GetPosition(Position)); - ReadResult = await PipeReader.ReadAsync(); + ReadResult = await PipeReader.ReadAsync(cancellationToken); } Position += count; @@ -133,9 +134,9 @@ bool TryRead(out byte[] readBytes, long readBytesCount) } } - public virtual async ValueTask ReadBytesFullAsync() + public virtual async ValueTask ReadBytesFullAsync(CancellationToken cancellationToken = default) { - await FillReadResultBufferToTheEnd(); + await FillReadResultBufferToTheEndAsync(cancellationToken); PipeReader.AdvanceTo(ReadResult.Buffer.Start, ReadResult.Buffer.End); var value = ReadResult.Buffer.Slice(Position, ReadResult.Buffer.End).ToArray(); @@ -143,22 +144,22 @@ public virtual async ValueTask ReadBytesFullAsync() return value; } - private async ValueTask FillReadResultBufferToTheEnd() + private async ValueTask FillReadResultBufferToTheEndAsync(CancellationToken cancellationToken = default) { - await EnsureReadResultIsNotDefault(); + await EnsureReadResultIsNotDefaultAsync(cancellationToken); while (!ReadResult.IsCompleted) { PipeReader.AdvanceTo(ReadResult.Buffer.Start, ReadResult.Buffer.End); - ReadResult = await PipeReader.ReadAsync(); + ReadResult = await PipeReader.ReadAsync(cancellationToken); } } - private async ValueTask EnsureReadResultIsNotDefault() + private async ValueTask EnsureReadResultIsNotDefaultAsync(CancellationToken cancellationToken = default) { if (ReadResult.Equals(default(ReadResult))) { - ReadResult = await PipeReader.ReadAsync(); + ReadResult = await PipeReader.ReadAsync(cancellationToken); } } } diff --git a/Kaitai.Struct.Runtime.Async/ReaderContext/StreamReaderContext.cs b/Kaitai.Struct.Runtime.Async/ReaderContext/StreamReaderContext.cs index 541680a..6eeddf5 100644 --- a/Kaitai.Struct.Runtime.Async/ReaderContext/StreamReaderContext.cs +++ b/Kaitai.Struct.Runtime.Async/ReaderContext/StreamReaderContext.cs @@ -1,5 +1,6 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; using Overby.Extensions.AsyncBinaryReaderWriter; @@ -17,20 +18,32 @@ public StreamReaderContext(Stream stream) } public long Position => _baseStream.Position; - public virtual ValueTask GetSizeAsync() => new ValueTask(_baseStream.Length); - public virtual ValueTask IsEofAsync() => - new ValueTask(_baseStream.Position >= _baseStream.Length); + public virtual async ValueTask GetSizeAsync(CancellationToken cancellationToken = default) + { + await CheckIsCancellationRequested(cancellationToken); + + return _baseStream.Length; + } + + public virtual async ValueTask IsEofAsync(CancellationToken cancellationToken = default) + { + await CheckIsCancellationRequested(cancellationToken); + + return _baseStream.Position >= _baseStream.Length; + } - public virtual ValueTask SeekAsync(long position) + public virtual async ValueTask SeekAsync(long position, CancellationToken cancellationToken = default) { + await CheckIsCancellationRequested(cancellationToken); + _baseStream.Seek(position, SeekOrigin.Begin); - return new ValueTask(); } - public virtual async ValueTask ReadByteAsync() => (byte) await AsyncBinaryReader.ReadSByteAsync(); + public virtual async ValueTask ReadByteAsync(CancellationToken cancellationToken = default) => + (byte) await AsyncBinaryReader.ReadSByteAsync(cancellationToken); - public virtual async ValueTask ReadBytesAsync(long count) + public virtual async ValueTask ReadBytesAsync(long count, CancellationToken cancellationToken = default) { if (count < 0 || count > int.MaxValue) { @@ -38,7 +51,9 @@ public virtual async ValueTask ReadBytesAsync(long count) $"requested {count} bytes, while only non-negative int32 amount of bytes possible"); } - var bytes = await AsyncBinaryReader.ReadBytesAsync((int) count); + await CheckIsCancellationRequested(cancellationToken); + + var bytes = await AsyncBinaryReader.ReadBytesAsync((int) count, cancellationToken); if (bytes.Length < count) { throw new EndOfStreamException($"requested {count} bytes, but got only {bytes.Length} bytes"); @@ -47,7 +62,15 @@ public virtual async ValueTask ReadBytesAsync(long count) return bytes; } - public virtual async ValueTask ReadBytesFullAsync() => - await ReadBytesAsync(_baseStream.Length - _baseStream.Position); + public virtual async ValueTask ReadBytesFullAsync(CancellationToken cancellationToken = default) => + await ReadBytesAsync(_baseStream.Length - _baseStream.Position, cancellationToken); + + private static async Task CheckIsCancellationRequested(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + await Task.FromCanceled(cancellationToken); + } + } } } \ No newline at end of file