From fe8e84810d15f8b9f8e0f3a5aa577906fabeb351 Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf <115584722+sfc-gh-ext-simba-lf@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:58:47 -0700 Subject: [PATCH] SNOW-1657238: Fix incorrect row count for rows loaded (#1044) --- .../IntegrationTests/SFDbCommandIT.cs | 90 +++++++++++++++++++ Snowflake.Data/Core/ResultSetUtil.cs | 8 +- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs index 5aa01ee46..5950e4f9b 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs @@ -8,6 +8,8 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Core; +using System.Linq; +using System.IO; namespace Snowflake.Data.Tests.IntegrationTests { @@ -1674,5 +1676,93 @@ public async Task TestCommandWithCommentEmbeddedAsync() Assert.AreEqual("--", reader.GetString(0)); } } + + [Test] + public void TestExecuteNonQueryReturnsCorrectRowCountForUploadWithMultipleFiles() + { + const int NumberOfFiles = 5; + const int NumberOfRows = 3; + const int ExpectedRowCount = NumberOfFiles * NumberOfRows; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + var tempFolder = $"{Path.GetTempPath()}Temp_{Guid.NewGuid()}"; + + try + { + // Arrange + Directory.CreateDirectory(tempFolder); + var data = string.Concat(Enumerable.Repeat(string.Join(",", "TestData") + "\n", NumberOfRows)); + for (int i = 0; i < NumberOfFiles; i++) + { + File.WriteAllText(Path.Combine(tempFolder, $"{TestContext.CurrentContext.Test.Name}_{i}.csv"), data); + } + CreateOrReplaceTable(conn, TableName, new[] { "COL1 STRING" }); + cmd.CommandText = $"PUT file://{Path.Combine(tempFolder, "*.csv")} @%{TableName} AUTO_COMPRESS=FALSE"; + var reader = cmd.ExecuteReader(); + + // Act + cmd.CommandText = $"COPY INTO {TableName} FROM @%{TableName} PATTERN='.*.csv' FILE_FORMAT=(TYPE=CSV)"; + int actualRowCount = cmd.ExecuteNonQuery(); + + // Assert + Assert.AreEqual(ExpectedRowCount, actualRowCount); + } + finally + { + Directory.Delete(tempFolder, true); + } + } + } + } + + [Test] + public async Task TestExecuteNonQueryAsyncReturnsCorrectRowCountForUploadWithMultipleFiles() + { + const int NumberOfFiles = 5; + const int NumberOfRows = 3; + const int ExpectedRowCount = NumberOfFiles * NumberOfRows; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + var tempFolder = $"{Path.GetTempPath()}Temp_{Guid.NewGuid()}"; + + try + { + // Arrange + Directory.CreateDirectory(tempFolder); + var data = string.Concat(Enumerable.Repeat(string.Join(",", "TestData") + "\n", NumberOfRows)); + for (int i = 0; i < NumberOfFiles; i++) + { + File.WriteAllText(Path.Combine(tempFolder, $"{TestContext.CurrentContext.Test.Name}_{i}.csv"), data); + } + CreateOrReplaceTable(conn, TableName, new[] { "COL1 STRING" }); + cmd.CommandText = $"PUT file://{Path.Combine(tempFolder, "*.csv")} @%{TableName} AUTO_COMPRESS=FALSE"; + var reader = cmd.ExecuteReader(); + + // Act + cmd.CommandText = $"COPY INTO {TableName} FROM @%{TableName} PATTERN='.*.csv' FILE_FORMAT=(TYPE=CSV)"; + int actualRowCount = await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); + + // Assert + Assert.AreEqual(ExpectedRowCount, actualRowCount); + } + finally + { + Directory.Delete(tempFolder, true); + } + } + } + } } } diff --git a/Snowflake.Data/Core/ResultSetUtil.cs b/Snowflake.Data/Core/ResultSetUtil.cs index 9d62a17d7..236efab9c 100755 --- a/Snowflake.Data/Core/ResultSetUtil.cs +++ b/Snowflake.Data/Core/ResultSetUtil.cs @@ -36,9 +36,11 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet) var index = resultSet.sfResultSetMetaData.GetColumnIndexByName("rows_loaded"); if (index >= 0) { - resultSet.Next(); - updateCount = resultSet.GetInt64(index); - resultSet.Rewind(); + while (resultSet.Next()) + { + updateCount += resultSet.GetInt64(index); + } + while (resultSet.Rewind()) {} } break; case SFStatementType.COPY_UNLOAD: