From 8a8406cddaa01711a6d5f8bfe37bd39fa1829572 Mon Sep 17 00:00:00 2001 From: Jose Venegas <126916083+sfc-gh-jvaenegasvega@users.noreply.github.com> Date: Tue, 28 Nov 2023 11:28:37 -0600 Subject: [PATCH] SCT-7633 Adding com.snowflake.snowpark.Session.SessionBuilder.getOrCreate (#67) * SCT-7633 Adding com.snowflake.snowpark.Session.SessionBuilder.GetOrCreate * SCT-7633 Adding com.snowflake.snowpark.Session.SessionBuilder.GetOrCreate * Updating java GetOrCreate test variable names * Updating getOrCreate for java * Solving scala failing test --- .../com/snowflake/snowpark_java/SessionBuilder.java | 10 ++++++++++ src/main/scala/com/snowflake/snowpark/Session.scala | 10 ++++++++++ .../com/snowflake/snowpark_test/JavaSessionSuite.java | 8 ++++++++ .../com/snowflake/snowpark_test/SessionSuite.scala | 6 ++++++ 4 files changed, 34 insertions(+) diff --git a/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java b/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java index 344266c6..210a426e 100644 --- a/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java +++ b/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java @@ -70,4 +70,14 @@ public Session create() { this.builder.config("snowpark_enable_closure_cleaner", "never"); return new Session(builder.create()); } + + /** + * Returns the existing session if already exists or create it if not. + * + * @return A {@code Session} object + * @since 1.10.0 + */ + public Session getOrCreate() { + return new Session(this.builder.getOrCreate()); + } } diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index ea1e5913..aa8058d0 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -1455,6 +1455,16 @@ object Session extends Logging { createInternal(None) } + /** + * Returns the existing session if already exists or create it if not. + * + * @return A [[Session]] + * @since 1.10.0 + */ + def getOrCreate: Session = { + Session.getActiveSession.getOrElse(create) + } + private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = { conn match { case Some(_) => diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java index a0d7de57..75bda890 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java @@ -81,6 +81,14 @@ public void flatten() { new Row[] {Row.create("1"), Row.create("2")}); } + @Test + public void getOrCreate() + { + String expectedSessionInfo = getSession().getSessionInfo(); + String actualSessionInfo = Session.builder().getOrCreate().getSessionInfo(); + assert(actualSessionInfo.equals(expectedSessionInfo)); + } + @Test public void getSessionInfo() { String result = getSession().getSessionInfo(); diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala index 2803378c..6ab7d5db 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala @@ -49,6 +49,12 @@ class SessionSuite extends SNTestBase { t2.run() } + test("Test for get or create session") { + val session1 = Session.builder.getOrCreate + val session2 = Session.builder.getOrCreate + assert(session1 == session2) + } + test("Test for invalid configs") { val badSessionBuilder = Session.builder .configFile(defaultProfile)