diff --git a/spring-ai-core/pom.xml b/spring-ai-core/pom.xml
index e278f72cc3..d0aa33b6c7 100644
--- a/spring-ai-core/pom.xml
+++ b/spring-ai-core/pom.xml
@@ -54,18 +54,6 @@
${jsonschema.version}
-
- org.springframework.cloud
- spring-cloud-function-context
- ${spring-cloud-function-context.version}
-
-
- org.springframework.boot
- spring-boot-autoconfigure
-
-
-
-
org.antlr
@@ -138,6 +126,13 @@
${jackson.version}
+
+ org.jetbrains.kotlin
+ kotlin-stdlib
+ ${kotlin.version}
+ true
+
+
org.springframework.boot
@@ -146,16 +141,16 @@
- org.jetbrains.kotlin
- kotlin-stdlib
- ${kotlin.version}
+ com.fasterxml.jackson.module
+ jackson-module-kotlin
+ ${jackson.version}
test
- com.fasterxml.jackson.module
- jackson-module-kotlin
- ${jackson.version}
+ io.mockk
+ mockk-jvm
+ 1.13.13
test
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java
index ecbb9a4c1c..762d33969b 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java
@@ -16,20 +16,21 @@
package org.springframework.ai.model.function;
-import java.lang.reflect.Type;
import java.util.function.BiFunction;
import java.util.function.Function;
import com.fasterxml.jackson.annotation.JsonClassDescription;
+import kotlin.jvm.functions.Function1;
+import kotlin.jvm.functions.Function2;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.beans.BeansException;
-import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
-import org.springframework.cloud.function.context.config.FunctionContextUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Description;
import org.springframework.context.support.GenericApplicationContext;
+import org.springframework.core.KotlinDetector;
+import org.springframework.core.ResolvableType;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
@@ -49,6 +50,7 @@
*
* @author Christian Tzolov
* @author Christopher Smith
+ * @author Sebastien Deleuze
*/
public class FunctionCallbackContext implements ApplicationContextAware {
@@ -65,26 +67,13 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
this.applicationContext = (GenericApplicationContext) applicationContext;
}
- @SuppressWarnings({ "rawtypes", "unchecked" })
+ @SuppressWarnings({ "unchecked" })
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
- Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
+ ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
+ ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0);
- if (beanType == null) {
- throw new IllegalArgumentException(
- "Functional bean with name: " + beanName + " does not exist in the context.");
- }
-
- if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))
- && !BiFunction.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
- throw new IllegalArgumentException(
- "Function call Bean must be of type Function or BiFunction. Found: " + beanType.getTypeName());
- }
-
- Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
-
- Class> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
- String functionName = beanName;
+ Class> functionInputClass = functionInputType.toClass();
String functionDescription = defaultDescription;
if (!StringUtils.hasText(functionDescription)) {
@@ -114,24 +103,42 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
Object bean = this.applicationContext.getBean(beanName);
+ if (KotlinDetector.isKotlinPresent()) {
+ if (KotlinDelegate.isKotlinFunction(functionType.toClass())) {
+ return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinFunction(bean))
+ .withName(beanName)
+ .withSchemaType(this.schemaType)
+ .withDescription(functionDescription)
+ .withInputType(functionInputClass)
+ .build();
+ }
+ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
+ return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinBiFunction(bean))
+ .withName(beanName)
+ .withSchemaType(this.schemaType)
+ .withDescription(functionDescription)
+ .withInputType(functionInputClass)
+ .build();
+ }
+ }
if (bean instanceof Function, ?> function) {
return FunctionCallbackWrapper.builder(function)
- .withName(functionName)
+ .withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
- else if (bean instanceof BiFunction, ?, ?> biFunction) {
- return FunctionCallbackWrapper.builder((BiFunction, ToolContext, ?>) biFunction)
- .withName(functionName)
+ else if (bean instanceof BiFunction, ?, ?>) {
+ return FunctionCallbackWrapper.builder((BiFunction, ToolContext, ?>) bean)
+ .withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
else {
- throw new IllegalArgumentException("Bean must be of type Function");
+ throw new IllegalStateException();
}
}
@@ -141,4 +148,26 @@ public enum SchemaType {
}
+ private static class KotlinDelegate {
+
+ public static boolean isKotlinFunction(Class> clazz) {
+ return Function1.class.isAssignableFrom(clazz);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static Function, ?> wrapKotlinFunction(Object function) {
+ return t -> ((Function1