diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs index 58bf90b46..6e92c0aac 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs @@ -60,6 +60,51 @@ public void TestExecAsyncAPI() } } + [Test] + public void TestExecAsyncAPIParallel() + { + SnowflakeDbConnectionPool.ClearAllPools(); + using (DbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + + Task connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + Assert.AreEqual(ConnectionState.Open, conn.State); + + Task[] taskArray = new Task[5]; + for (int i = 0; i < taskArray.Length; i++) + { + taskArray[i] = Task.Factory.StartNew(() => + { + using (DbCommand cmd = conn.CreateCommand()) + { + long queryResult = 0; + cmd.CommandText = "select count(seq4()) from table(generator(timelimit => 3)) v"; + Task execution = cmd.ExecuteReaderAsync(); + Task readCallback = execution.ContinueWith((t) => + { + using (DbDataReader reader = t.Result) + { + Assert.IsTrue(reader.Read()); + queryResult = reader.GetInt64(0); + Assert.IsFalse(reader.Read()); + } + }); + // query is not finished yet, result is still 0; + Assert.AreEqual(0, queryResult); + // block till query finished + readCallback.Wait(); + // queryResult should be updated by callback + Assert.AreNotEqual(0, queryResult); + } + }); + } + Task.WaitAll(taskArray); + conn.Close(); + } + } + [Test] public void TestCancelExecuteAsync() { diff --git a/Snowflake.Data.Tests/UnitTests/QueryContextCacheTest.cs b/Snowflake.Data.Tests/UnitTests/QueryContextCacheTest.cs index 407ef9c0a..a83293335 100644 --- a/Snowflake.Data.Tests/UnitTests/QueryContextCacheTest.cs +++ b/Snowflake.Data.Tests/UnitTests/QueryContextCacheTest.cs @@ -133,6 +133,25 @@ public void TestMoreThanCapacity() AssertCacheData(); } + [Test] + public void TestChangingCapacity() + { + InitCacheWithData(); + + // Add one more element at the end + int i = MaxCapacity; + _qcc.SetCapacity(MaxCapacity + 1); + _qcc.Merge(BaseId + i, BaseReadTimestamp + i, BasePriority + i, Context); + _qcc.SyncPriorityMap(); + _qcc.CheckCacheCapacity(); + Assert.IsTrue(_qcc.GetSize() == MaxCapacity + 1); + + // reduce the capacity back + _qcc.SetCapacity(MaxCapacity); + // Compare elements + AssertCacheData(); + } + [Test] public void TestUpdateTimestamp() { diff --git a/Snowflake.Data/Core/QueryContextCache.cs b/Snowflake.Data/Core/QueryContextCache.cs index ac2ba6d01..74d041d2b 100644 --- a/Snowflake.Data/Core/QueryContextCache.cs +++ b/Snowflake.Data/Core/QueryContextCache.cs @@ -69,6 +69,7 @@ public int Compare(QueryContextElement x, QueryContextElement y) internal class QueryContextCache { + private readonly object _qccLock; private int _capacity; // Capacity of the cache private Dictionary _idMap; // Map for id and QCC private Dictionary _priorityMap; // Map for priority and QCC @@ -78,6 +79,7 @@ internal class QueryContextCache public QueryContextCache(int capacity) { + _qccLock = new object(); _capacity = capacity; _idMap = new Dictionary(); _priorityMap = new Dictionary(); @@ -192,11 +194,16 @@ public void SetCapacity(int cap) // check without locking first for performance reason if (_capacity == cap) return; + lock (_qccLock) + { + if (_capacity == cap) + return; - _logger.Debug($"set capacity from {_capacity} to {cap}"); - _capacity = cap; - CheckCacheCapacity(); - LogCacheEntries(); + _logger.Debug($"set capacity from {_capacity} to {cap}"); + _capacity = cap; + CheckCacheCapacity(); + LogCacheEntries(); + } } /** @@ -221,26 +228,29 @@ public int GetSize() */ public void Update(ResponseQueryContext queryContext) { - // Log existing cache entries - LogCacheEntries(); - - if (queryContext == null || queryContext.Entries == null) - { - // Clear the cache - ClearCache(); - return; - } - foreach (ResponseQueryContextElement entry in queryContext.Entries) + lock(_qccLock) { - Merge(entry.Id, entry.ReadTimestamp, entry.Priority, entry.Context); - } + // Log existing cache entries + LogCacheEntries(); - SyncPriorityMap(); + if (queryContext == null || queryContext.Entries == null) + { + // Clear the cache + ClearCache(); + return; + } + foreach (ResponseQueryContextElement entry in queryContext.Entries) + { + Merge(entry.Id, entry.ReadTimestamp, entry.Priority, entry.Context); + } + + SyncPriorityMap(); - // After merging all entries, truncate to capacity - CheckCacheCapacity(); - // Log existing cache entries - LogCacheEntries(); + // After merging all entries, truncate to capacity + CheckCacheCapacity(); + // Log existing cache entries + LogCacheEntries(); + } } /** @@ -251,10 +261,13 @@ public RequestQueryContext GetQueryContextRequest() { RequestQueryContext reqQCC = new RequestQueryContext(); reqQCC.Entries = new List(); - foreach (QueryContextElement elem in _cacheSet) + lock(_qccLock) { - RequestQueryContextElement reqElem = new RequestQueryContextElement(elem); - reqQCC.Entries.Add(reqElem); + foreach (QueryContextElement elem in _cacheSet) + { + RequestQueryContextElement reqElem = new RequestQueryContextElement(elem); + reqQCC.Entries.Add(reqElem); + } } return reqQCC; @@ -268,10 +281,13 @@ public ResponseQueryContext GetQueryContextResponse() { ResponseQueryContext rspQCC = new ResponseQueryContext(); rspQCC.Entries = new List(); - foreach (QueryContextElement elem in _cacheSet) + lock (_qccLock) { - ResponseQueryContextElement rspElem = new ResponseQueryContextElement(elem); - rspQCC.Entries.Add(rspElem); + foreach (QueryContextElement elem in _cacheSet) + { + ResponseQueryContextElement rspElem = new ResponseQueryContextElement(elem); + rspQCC.Entries.Add(rspElem); + } } return rspQCC; @@ -321,12 +337,13 @@ private void ReplaceQCE(QueryContextElement oldQCE, QueryContextElement newQCE) /** Debugging purpose, log the all entries in the cache. */ private void LogCacheEntries() { -#if DEBUG - foreach (QueryContextElement elem in _cacheSet) + if (_logger.IsDebugEnabled()) { - _logger.Debug($"Cache Entry: id: {elem.Id} readTimestamp: {elem.ReadTimestamp} priority: {elem.Priority}"); + foreach (QueryContextElement elem in _cacheSet) + { + _logger.Debug($"Cache Entry: id: {elem.Id} readTimestamp: {elem.ReadTimestamp} priority: {elem.Priority}"); + } } -#endif } } }