Skip to content

Commit

Permalink
SCT-7633 Adding com.snowflake.snowpark.Session.SessionBuilder.getOrCr…
Browse files Browse the repository at this point in the history
…eate (#67)

* SCT-7633 Adding com.snowflake.snowpark.Session.SessionBuilder.GetOrCreate

* SCT-7633 Adding com.snowflake.snowpark.Session.SessionBuilder.GetOrCreate

* Updating java GetOrCreate test variable names

* Updating getOrCreate for java

* Solving scala failing test
  • Loading branch information
sfc-gh-jvenegasvega-1 authored Nov 28, 2023
1 parent 729d50f commit 8a8406c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/SessionBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,14 @@ public Session create() {
this.builder.config("snowpark_enable_closure_cleaner", "never");
return new Session(builder.create());
}

/**
* Returns the existing session if already exists or create it if not.
*
* @return A {@code Session} object
* @since 1.10.0
*/
public Session getOrCreate() {
return new Session(this.builder.getOrCreate());
}
}
10 changes: 10 additions & 0 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,16 @@ object Session extends Logging {
createInternal(None)
}

/**
* Returns the existing session if already exists or create it if not.
*
* @return A [[Session]]
* @since 1.10.0
*/
def getOrCreate: Session = {
Session.getActiveSession.getOrElse(create)
}

private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = {
conn match {
case Some(_) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ public void flatten() {
new Row[] {Row.create("1"), Row.create("2")});
}

@Test
public void getOrCreate()
{
String expectedSessionInfo = getSession().getSessionInfo();
String actualSessionInfo = Session.builder().getOrCreate().getSessionInfo();
assert(actualSessionInfo.equals(expectedSessionInfo));
}

@Test
public void getSessionInfo() {
String result = getSession().getSessionInfo();
Expand Down
6 changes: 6 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class SessionSuite extends SNTestBase {
t2.run()
}

test("Test for get or create session") {
val session1 = Session.builder.getOrCreate
val session2 = Session.builder.getOrCreate
assert(session1 == session2)
}

test("Test for invalid configs") {
val badSessionBuilder = Session.builder
.configFile(defaultProfile)
Expand Down

0 comments on commit 8a8406c

Please sign in to comment.