Skip to content

Commit

Permalink
SCT-7633 Adding com.snowflake.snowpark.Session.SessionBuilder.GetOrCr…
Browse files Browse the repository at this point in the history
…eate
  • Loading branch information
sfc-gh-jvenegasvega-1 committed Nov 27, 2023
1 parent 729d50f commit b23d06c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/SessionBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,14 @@ public Session create() {
this.builder.config("snowpark_enable_closure_cleaner", "never");
return new Session(builder.create());
}

/**
* Returns the existing session if already exists or create it if not.
*
* @return A {@code Session} object
* @since 1.10.0
*/
public Session getOrCreate() {
return new Session(this.builder.getOrCreate());
}
}
10 changes: 10 additions & 0 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,16 @@ object Session extends Logging {
createInternal(None)
}

/**
* Returns the existing session if already exists or create it if not.
*
* @return A [[Session]]
* @since 1.10.0
*/
def getOrCreate: Session = {
Session.getActiveSession.getOrElse(create)
}

private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = {
conn match {
case Some(_) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ public void flatten() {
new Row[] {Row.create("1"), Row.create("2")});
}

@Test
public void getOrCreate()
{
String expectedSession = getSession().getSessionInfo();
String actualSession = Session.builder().getOrCreate().getSessionInfo();
assert(expectedSession.equals(actualSession));
}

@Test
public void getSessionInfo() {
String result = getSession().getSessionInfo();
Expand Down
5 changes: 5 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class SessionSuite extends SNTestBase {
t2.run()
}

test("Test for get or create session") {
val actualSession = Session.builder.getOrCreate
assert(actualSession == session)
}

test("Test for invalid configs") {
val badSessionBuilder = Session.builder
.configFile(defaultProfile)
Expand Down

0 comments on commit b23d06c

Please sign in to comment.