Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Kotlin functions #1666

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions spring-ai-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,6 @@
<version>${jsonschema.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-function-context</artifactId>
<version>${spring-cloud-function-context.version}</version>
<exclusions>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-autoconfigure</artifactId>
</exclusion>
</exclusions>
</dependency>

<!-- production dependencies -->
<dependency>
<groupId>org.antlr</groupId>
Expand Down Expand Up @@ -138,6 +126,13 @@
<version>${jackson.version}</version>
</dependency>

<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib</artifactId>
<version>${kotlin.version}</version>
<optional>true</optional>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.springframework.boot</groupId>
Expand All @@ -146,16 +141,16 @@
</dependency>

<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib</artifactId>
<version>${kotlin.version}</version>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-kotlin</artifactId>
<version>${jackson.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-kotlin</artifactId>
<version>${jackson.version}</version>
<groupId>io.mockk</groupId>
<artifactId>mockk-jvm</artifactId>
<version>1.13.13</version>
<scope>test</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

package org.springframework.ai.model.function;

import java.lang.reflect.Type;
import java.util.function.BiFunction;
import java.util.function.Function;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.beans.BeansException;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.context.config.FunctionContextUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Description;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.KotlinDetector;
import org.springframework.core.ResolvableType;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
Expand All @@ -49,6 +50,7 @@
*
* @author Christian Tzolov
* @author Christopher Smith
* @author Sebastien Deleuze
*/
public class FunctionCallbackContext implements ApplicationContextAware {

Expand All @@ -65,26 +67,13 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
this.applicationContext = (GenericApplicationContext) applicationContext;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
@SuppressWarnings({ "unchecked" })
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {

Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0);

if (beanType == null) {
throw new IllegalArgumentException(
"Functional bean with name: " + beanName + " does not exist in the context.");
}

if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))
&& !BiFunction.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
throw new IllegalArgumentException(
"Function call Bean must be of type Function or BiFunction. Found: " + beanType.getTypeName());
}

Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);

Class<?> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
String functionName = beanName;
Class<?> functionInputClass = functionInputType.toClass();
String functionDescription = defaultDescription;

if (!StringUtils.hasText(functionDescription)) {
Expand Down Expand Up @@ -114,24 +103,42 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable

Object bean = this.applicationContext.getBean(beanName);

if (KotlinDetector.isKotlinPresent()) {
if (KotlinDelegate.isKotlinFunction(functionType.toClass())) {
return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinFunction(bean))
.withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinBiFunction(bean))
.withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
}
if (bean instanceof Function<?, ?> function) {
return FunctionCallbackWrapper.builder(function)
.withName(functionName)
.withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
else if (bean instanceof BiFunction<?, ?, ?> biFunction) {
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) biFunction)
.withName(functionName)
else if (bean instanceof BiFunction<?, ?, ?>) {
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) bean)
.withName(beanName)
.withSchemaType(this.schemaType)
.withDescription(functionDescription)
.withInputType(functionInputClass)
.build();
}
else {
throw new IllegalArgumentException("Bean must be of type Function");
throw new IllegalStateException();
}
}

Expand All @@ -141,4 +148,26 @@ public enum SchemaType {

}

private static class KotlinDelegate {

public static boolean isKotlinFunction(Class<?> clazz) {
return Function1.class.isAssignableFrom(clazz);
}

@SuppressWarnings("unchecked")
public static Function<?, ?> wrapKotlinFunction(Object function) {
return t -> ((Function1<Object, Object>) function).invoke(t);
}

public static boolean isKotlinBiFunction(Class<?> clazz) {
return Function2.class.isAssignableFrom(clazz);
}

@SuppressWarnings("unchecked")
public static BiFunction<?, ToolContext, ?> wrapKotlinBiFunction(Object function) {
return (t, u) -> ((Function2<Object, ToolContext, Object>) function).invoke(t, u);
}

}

}
Loading