Skip to content

Commit

Permalink
Merge branch 'main' into snow-966360
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed Nov 28, 2023
2 parents 6e693fe + 8a8406c commit ddfe1f2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/SessionBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,14 @@ public Session create() {
this.builder.config("snowpark_enable_closure_cleaner", "never");
return new Session(builder.create());
}

/**
* Returns the existing session if already exists or create it if not.
*
* @return A {@code Session} object
* @since 1.10.0
*/
public Session getOrCreate() {
return new Session(this.builder.getOrCreate());
}
}
10 changes: 10 additions & 0 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,16 @@ object Session extends Logging {
createInternal(None)
}

/**
* Returns the existing session if already exists or create it if not.
*
* @return A [[Session]]
* @since 1.10.0
*/
def getOrCreate: Session = {
Session.getActiveSession.getOrElse(create)
}

private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = {
conn match {
case Some(_) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ public void flatten() {
new Row[] {Row.create("1"), Row.create("2")});
}

@Test
public void getOrCreate()
{
String expectedSessionInfo = getSession().getSessionInfo();
String actualSessionInfo = Session.builder().getOrCreate().getSessionInfo();
assert(actualSessionInfo.equals(expectedSessionInfo));
}

@Test
public void getSessionInfo() {
String result = getSession().getSessionInfo();
Expand Down
6 changes: 6 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class SessionSuite extends SNTestBase {
t2.run()
}

test("Test for get or create session") {
val session1 = Session.builder.getOrCreate
val session2 = Session.builder.getOrCreate
assert(session1 == session2)
}

test("Test for invalid configs") {
val badSessionBuilder = Session.builder
.configFile(defaultProfile)
Expand Down

0 comments on commit ddfe1f2

Please sign in to comment.