Skip to content

Commit

Permalink
Merge pull request #52 from snowflakedb/mock-session-renew
Browse files Browse the repository at this point in the history
Add mock tests to test session renew
  • Loading branch information
howryu authored Jul 19, 2018
2 parents 1437de6 + 5f76f73 commit 470c1d3
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 5 deletions.
120 changes: 120 additions & 0 deletions Snowflake.Data.Tests/Mock/MockRestSessionExpired.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using System.Threading;
using System.Net.Http;

namespace Snowflake.Data.Tests.Mock
{
using Snowflake.Data.Core;

class MockRestSessionExpired : IRestRequest
{
static private readonly String EXPIRED_SESSION_TOKEN="session_expired_token";

static private readonly String TOKEN_FMT = "Snowflake Token=\"{0}\"";

static private readonly int SESSION_EXPIRED_CODE = 390112;

public MockRestSessionExpired() { }

public Task<T> PostAsync<T>(SFRestRequest postRequest, CancellationToken cancellationToken)
{
if (postRequest.jsonBody is AuthnRequest)
{
AuthnResponse authnResponse = new AuthnResponse
{
data = new AuthnResponseData()
{
token = EXPIRED_SESSION_TOKEN,
masterToken = "master_token",
authResponseSessionInfo = new SessionInfo(),
nameValueParameter = new List<NameValueParameter>()
},
success = true
};

// login request return success
return Task.FromResult<T>((T)(object)authnResponse);
}
else if (postRequest.jsonBody is QueryRequest)
{
if (postRequest.authorizationToken.Equals(String.Format(TOKEN_FMT, EXPIRED_SESSION_TOKEN)))
{
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = false,
code = SESSION_EXPIRED_CODE
};
return Task.FromResult<T>((T)(object)queryExecResponse);
}
else if (postRequest.authorizationToken.Equals(String.Format(TOKEN_FMT, "new_session_token")))
{
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = true,
data = new QueryExecResponseData
{
rowSet = new string[,] { { "1" } },
rowType = new List<ExecResponseRowType>()
{
new ExecResponseRowType
{
name = "colone",
type = "FIXED"
}
},
parameters = new List<NameValueParameter>()
}
};
return Task.FromResult<T>((T)(object)queryExecResponse);
}
else
{
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = false,
code = 1
};
return Task.FromResult<T>((T)(object)queryExecResponse);
}
}
else if (postRequest.jsonBody is RenewSessionRequest)
{
return Task.FromResult<T>((T)(object)new RenewSessionResponse
{
success = true,
data = new RenewSessionResponseData()
{
sessionToken = "new_session_token"
}
});
}
else
{
return Task.FromResult<T>((T)(object)null);
}
}

public T Post<T>(SFRestRequest postRequest)
{
return Task.Run(async () => await PostAsync<T>(postRequest, CancellationToken.None)).Result;
}

public T Get<T>(SFRestRequest request)
{
return Task.Run(async () => await GetAsync<T>(request, CancellationToken.None)).Result;
}

public Task<T> GetAsync<T>(SFRestRequest request, CancellationToken cancellationToken)
{
return Task.FromResult<T>((T)(object)null);
}

public Task<HttpResponseMessage> GetAsync(S3DownloadRequest request, CancellationToken cancellationToken)
{
return Task.FromResult<HttpResponseMessage>(null);
}
}
}
139 changes: 139 additions & 0 deletions Snowflake.Data.Tests/Mock/MockRestSessionExpiredInQueryExec.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Http;

namespace Snowflake.Data.Tests.Mock
{
using Snowflake.Data.Core;

class MockRestSessionExpiredInQueryExec : IRestRequest
{
static private readonly int QUERY_IN_EXEC_CODE = 333333;

static private readonly int SESSION_EXPIRED_CODE = 390112;

private int getResultCallCount = 0;

public MockRestSessionExpiredInQueryExec() { }

public Task<T> PostAsync<T>(SFRestRequest postRequest, CancellationToken cancellationToken)
{
if (postRequest.jsonBody is AuthnRequest)
{
AuthnResponse authnResponse = new AuthnResponse
{
data = new AuthnResponseData()
{
token = "session_token",
masterToken = "master_token",
authResponseSessionInfo = new SessionInfo(),
nameValueParameter = new List<NameValueParameter>()
},
success = true
};

// login request return success
return Task.FromResult<T>((T)(object)authnResponse);
}
else if (postRequest.jsonBody is QueryRequest)
{
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = false,
code = QUERY_IN_EXEC_CODE
};
return Task.FromResult<T>((T)(object)queryExecResponse);

}
else if (postRequest.jsonBody is RenewSessionRequest)
{
return Task.FromResult<T>((T)(object)new RenewSessionResponse
{
success = true,
data = new RenewSessionResponseData()
{
sessionToken = "new_session_token"
}
});
}
else
{
return Task.FromResult<T>((T)(object)null);
}
}

public T Post<T>(SFRestRequest postRequest)
{
return Task.Run(async () => await PostAsync<T>(postRequest, CancellationToken.None)).Result;
}

public T Get<T>(SFRestRequest request)
{
return Task.Run(async () => await GetAsync<T>(request, CancellationToken.None)).Result;
}

public Task<T> GetAsync<T>(SFRestRequest request, CancellationToken cancellationToken)
{
if (getResultCallCount == 0)
{
getResultCallCount++;
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = false,
code = QUERY_IN_EXEC_CODE
};
return Task.FromResult<T>((T)(object)queryExecResponse);
}
else if (getResultCallCount == 1)
{
getResultCallCount++;
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = false,
code = SESSION_EXPIRED_CODE
};
return Task.FromResult<T>((T)(object)queryExecResponse);
}
else if (getResultCallCount == 2 &&
request.authorizationToken.Equals("Snowflake Token=\"new_session_token\""))
{
getResultCallCount++;
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = true,
data = new QueryExecResponseData
{
rowSet = new string[,] { { "1" } },
rowType = new List<ExecResponseRowType>()
{
new ExecResponseRowType
{
name = "colone",
type = "FIXED"
}
},
parameters = new List<NameValueParameter>()
}
};
return Task.FromResult<T>((T)(object)queryExecResponse);
}
else
{
QueryExecResponse queryExecResponse = new QueryExecResponse
{
success = false,
code = 1
};
return Task.FromResult<T>((T)(object)queryExecResponse);
}
}

public Task<HttpResponseMessage> GetAsync(S3DownloadRequest request, CancellationToken cancellationToken)
{
return Task.FromResult<HttpResponseMessage>(null);
}
}
}

41 changes: 41 additions & 0 deletions Snowflake.Data.Tests/SFStatementTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2012-2017 Snowflake Computing Inc. All rights reserved.
*/

namespace Snowflake.Data.Tests
{
using Snowflake.Data.Core;
using NUnit.Framework;
using System;

/**
* Mock rest request to test session renew
*/
[TestFixture]
class SFStatementTest
{
[Test]
public void TestSessionRenew()
{
Mock.MockRestSessionExpired rest = new Mock.MockRestSessionExpired();
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, rest);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession, rest);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
}

[Test]
public void TestSessionRenewDuringQueyrExec()
{
Mock.MockRestSessionExpiredInQueryExec rest = new Mock.MockRestSessionExpiredInQueryExec();
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, rest);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession, rest);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
}
}
}
5 changes: 5 additions & 0 deletions Snowflake.Data.Tests/Snowflake.Data.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
<Folder Include="Properties\" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net46'">
<Reference Include="System.Net.Http" />
<Reference Include="System.Web" />
</ItemGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DebugType>full</DebugType>
<DebugSymbols>True</DebugSymbols>
Expand Down
9 changes: 7 additions & 2 deletions Snowflake.Data/Core/SFSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ static SFSession()
/// Constructor
/// </summary>
/// <param name="connectionString">A string in the form of "key1=value1;key2=value2"</param>
internal SFSession(String connectionString, SecureString password)
internal SFSession(String connectionString, SecureString password) :
this(connectionString, password, RestRequestImpl.Instance)
{
restRequest = RestRequestImpl.Instance;
}

internal SFSession(String connectionString, SecureString password, IRestRequest restRequest)
{
this.restRequest = restRequest;
properties = SFSessionProperties.parseConnectionString(connectionString, password);

parameterMap = new Dictionary<string, string>();
Expand Down
7 changes: 5 additions & 2 deletions Snowflake.Data/Core/SFStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ class SFStatement
// Cancel callback will be registered under token issued by this source.
private CancellationTokenSource _linkedCancellationTokenSouce;

internal SFStatement(SFSession session)
internal SFStatement(SFSession session, IRestRequest rest)
{
SfSession = session;
_restRequest = RestRequestImpl.Instance;
_restRequest = rest;
}

internal SFStatement(SFSession session) : this(session, RestRequestImpl.Instance)
{ }

private void AssignQueryRequestId()
{
lock (_requestIdLock)
Expand Down
2 changes: 2 additions & 0 deletions Snowflake.Data/Core/Snowflake.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Snowflake.Data.Tests")]
1 change: 0 additions & 1 deletion Snowflake.Data/Snowflake.Data.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

<ItemGroup>
<Folder Include="Properties\" />
<Folder Include="Core\" />
<Folder Include="Client\" />
</ItemGroup>
</Project>

0 comments on commit 470c1d3

Please sign in to comment.