diff --git a/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java b/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java index 344266c6..210a426e 100644 --- a/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java +++ b/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java @@ -70,4 +70,14 @@ public Session create() { this.builder.config("snowpark_enable_closure_cleaner", "never"); return new Session(builder.create()); } + + /** + * Returns the existing session if already exists or create it if not. + * + * @return A {@code Session} object + * @since 1.10.0 + */ + public Session getOrCreate() { + return new Session(this.builder.getOrCreate()); + } } diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index ea1e5913..aa8058d0 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -1455,6 +1455,16 @@ object Session extends Logging { createInternal(None) } + /** + * Returns the existing session if already exists or create it if not. + * + * @return A [[Session]] + * @since 1.10.0 + */ + def getOrCreate: Session = { + Session.getActiveSession.getOrElse(create) + } + private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = { conn match { case Some(_) => diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java index a0d7de57..068e9f5e 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java @@ -81,6 +81,14 @@ public void flatten() { new Row[] {Row.create("1"), Row.create("2")}); } + @Test + public void getOrCreate() + { + String expectedSession = getSession().getSessionInfo(); + String actualSession = Session.builder().getOrCreate().getSessionInfo(); + assert(expectedSession.equals(actualSession)); + } + @Test public void getSessionInfo() { String result = getSession().getSessionInfo(); diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala index 2803378c..71ed1aac 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala @@ -49,6 +49,11 @@ class SessionSuite extends SNTestBase { t2.run() } + test("Test for get or create session") { + val actualSession = Session.builder.getOrCreate + assert(actualSession == session) + } + test("Test for invalid configs") { val badSessionBuilder = Session.builder .configFile(defaultProfile)