From b23d06c5652084a38c5ea6ecec508d31ae7649a7 Mon Sep 17 00:00:00 2001
From: Jose Venegas Vega <jose.venegasvega@snowflake.com>
Date: Thu, 23 Nov 2023 11:26:40 -0600
Subject: [PATCH] SCT-7633 Adding
 com.snowflake.snowpark.Session.SessionBuilder.GetOrCreate

---
 .../com/snowflake/snowpark_java/SessionBuilder.java    | 10 ++++++++++
 src/main/scala/com/snowflake/snowpark/Session.scala    | 10 ++++++++++
 .../com/snowflake/snowpark_test/JavaSessionSuite.java  |  8 ++++++++
 .../com/snowflake/snowpark_test/SessionSuite.scala     |  5 +++++
 4 files changed, 33 insertions(+)

diff --git a/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java b/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java
index 344266c6..210a426e 100644
--- a/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java
+++ b/src/main/java/com/snowflake/snowpark_java/SessionBuilder.java
@@ -70,4 +70,14 @@ public Session create() {
     this.builder.config("snowpark_enable_closure_cleaner", "never");
     return new Session(builder.create());
   }
+
+  /**
+   * Returns the existing session if already exists or create it if not.
+   *
+   * @return A {@code Session} object
+   * @since 1.10.0
+   */
+  public Session getOrCreate() {
+    return new Session(this.builder.getOrCreate());
+  }
 }
diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala
index ea1e5913..aa8058d0 100644
--- a/src/main/scala/com/snowflake/snowpark/Session.scala
+++ b/src/main/scala/com/snowflake/snowpark/Session.scala
@@ -1455,6 +1455,16 @@ object Session extends Logging {
       createInternal(None)
     }
 
+    /**
+     * Returns the existing session if already exists or create it if not.
+     *
+     * @return A [[Session]]
+     * @since 1.10.0
+     */
+    def getOrCreate: Session = {
+      Session.getActiveSession.getOrElse(create)
+    }
+
     private[snowpark] def createInternal(conn: Option[SnowflakeConnectionV1]): Session = {
       conn match {
         case Some(_) =>
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java
index a0d7de57..068e9f5e 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java
@@ -81,6 +81,14 @@ public void flatten() {
         new Row[] {Row.create("1"), Row.create("2")});
   }
 
+  @Test
+  public void getOrCreate()
+  {
+    String expectedSession = getSession().getSessionInfo();
+    String actualSession = Session.builder().getOrCreate().getSessionInfo();
+    assert(expectedSession.equals(actualSession));
+  }
+
   @Test
   public void getSessionInfo() {
     String result = getSession().getSessionInfo();
diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
index 2803378c..71ed1aac 100644
--- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
@@ -49,6 +49,11 @@ class SessionSuite extends SNTestBase {
     t2.run()
   }
 
+  test("Test for get or create session") {
+    val actualSession = Session.builder.getOrCreate
+    assert(actualSession == session)
+  }
+
   test("Test for invalid configs") {
     val badSessionBuilder = Session.builder
       .configFile(defaultProfile)