diff --git a/spring-ai-core/pom.xml b/spring-ai-core/pom.xml index e278f72cc3..d0aa33b6c7 100644 --- a/spring-ai-core/pom.xml +++ b/spring-ai-core/pom.xml @@ -54,18 +54,6 @@ ${jsonschema.version} - - org.springframework.cloud - spring-cloud-function-context - ${spring-cloud-function-context.version} - - - org.springframework.boot - spring-boot-autoconfigure - - - - org.antlr @@ -138,6 +126,13 @@ ${jackson.version} + + org.jetbrains.kotlin + kotlin-stdlib + ${kotlin.version} + true + + org.springframework.boot @@ -146,16 +141,16 @@ - org.jetbrains.kotlin - kotlin-stdlib - ${kotlin.version} + com.fasterxml.jackson.module + jackson-module-kotlin + ${jackson.version} test - com.fasterxml.jackson.module - jackson-module-kotlin - ${jackson.version} + io.mockk + mockk-jvm + 1.13.13 test diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index ecbb9a4c1c..762d33969b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -16,20 +16,21 @@ package org.springframework.ai.model.function; -import java.lang.reflect.Type; import java.util.function.BiFunction; import java.util.function.Function; import com.fasterxml.jackson.annotation.JsonClassDescription; +import kotlin.jvm.functions.Function1; +import kotlin.jvm.functions.Function2; import org.springframework.ai.chat.model.ToolContext; import org.springframework.beans.BeansException; -import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; -import org.springframework.cloud.function.context.config.FunctionContextUtils; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Description; import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.KotlinDetector; +import org.springframework.core.ResolvableType; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; @@ -49,6 +50,7 @@ * * @author Christian Tzolov * @author Christopher Smith + * @author Sebastien Deleuze */ public class FunctionCallbackContext implements ApplicationContextAware { @@ -65,26 +67,13 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext this.applicationContext = (GenericApplicationContext) applicationContext; } - @SuppressWarnings({ "rawtypes", "unchecked" }) + @SuppressWarnings({ "unchecked" }) public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) { - Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName); + ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName); + ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0); - if (beanType == null) { - throw new IllegalArgumentException( - "Functional bean with name: " + beanName + " does not exist in the context."); - } - - if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType)) - && !BiFunction.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) { - throw new IllegalArgumentException( - "Function call Bean must be of type Function or BiFunction. Found: " + beanType.getTypeName()); - } - - Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0); - - Class functionInputClass = FunctionTypeUtils.getRawType(functionInputType); - String functionName = beanName; + Class functionInputClass = functionInputType.toClass(); String functionDescription = defaultDescription; if (!StringUtils.hasText(functionDescription)) { @@ -114,24 +103,42 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable Object bean = this.applicationContext.getBean(beanName); + if (KotlinDetector.isKotlinPresent()) { + if (KotlinDelegate.isKotlinFunction(functionType.toClass())) { + return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinFunction(bean)) + .withName(beanName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputClass) + .build(); + } + else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { + return FunctionCallbackWrapper.builder(KotlinDelegate.wrapKotlinBiFunction(bean)) + .withName(beanName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputClass) + .build(); + } + } if (bean instanceof Function function) { return FunctionCallbackWrapper.builder(function) - .withName(functionName) + .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) .withInputType(functionInputClass) .build(); } - else if (bean instanceof BiFunction biFunction) { - return FunctionCallbackWrapper.builder((BiFunction) biFunction) - .withName(functionName) + else if (bean instanceof BiFunction) { + return FunctionCallbackWrapper.builder((BiFunction) bean) + .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) .withInputType(functionInputClass) .build(); } else { - throw new IllegalArgumentException("Bean must be of type Function"); + throw new IllegalStateException(); } } @@ -141,4 +148,26 @@ public enum SchemaType { } + private static class KotlinDelegate { + + public static boolean isKotlinFunction(Class clazz) { + return Function1.class.isAssignableFrom(clazz); + } + + @SuppressWarnings("unchecked") + public static Function wrapKotlinFunction(Object function) { + return t -> ((Function1) function).invoke(t); + } + + public static boolean isKotlinBiFunction(Class clazz) { + return Function2.class.isAssignableFrom(clazz); + } + + @SuppressWarnings("unchecked") + public static BiFunction wrapKotlinBiFunction(Object function) { + return (t, u) -> ((Function2) function).invoke(t, u); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index 8ff8584c4b..feaafd8013 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -16,21 +16,31 @@ package org.springframework.ai.model.function; -import java.lang.reflect.GenericArrayType; -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; import java.util.function.BiFunction; import java.util.function.Function; -import net.jodah.typetools.TypeResolver; +import kotlin.jvm.functions.Function1; +import kotlin.jvm.functions.Function2; -import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.KotlinDetector; +import org.springframework.core.ResolvableType; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; /** * A utility class that provides methods for resolving types and classes related to * functions. * * @author Christian Tzolov + * @author Sebastien Dekeuze */ public abstract class TypeResolverHelper { @@ -68,12 +78,9 @@ public static Class getFunctionOutputClass(Class> fu * @return The class of the specified function argument. */ public static Class getFunctionArgumentClass(Class> functionClass, int argumentIndex) { - Type type = TypeResolver.reify(Function.class, functionClass); - - var argumentType = type instanceof ParameterizedType - ? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class; - - return toRawClass(argumentType); + ResolvableType resolvableType = ResolvableType.forClass(functionClass).as(Function.class); + return (resolvableType == ResolvableType.NONE ? Object.class + : resolvableType.getGeneric(argumentIndex).toClass()); } /** @@ -84,80 +91,139 @@ public static Class getFunctionArgumentClass(Class> */ public static Class getBiFunctionArgumentClass(Class> biFunctionClass, int argumentIndex) { - Type type = TypeResolver.reify(BiFunction.class, biFunctionClass); - - Type argumentType = type instanceof ParameterizedType - ? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class; - - return toRawClass(argumentType); + ResolvableType resolvableType = ResolvableType.forClass(biFunctionClass).as(BiFunction.class); + return (resolvableType == ResolvableType.NONE ? Object.class + : resolvableType.getGeneric(argumentIndex).toClass()); } /** - * Returns the input type of a given function class. - * @param functionClass The class of the function. - * @return The input type of the function. + * Resolve bean type, either directly with {@link BeanDefinition#getResolvableType()} + * or by resolving the factory method (duplicating + * {@code ConstructorResolver#resolveFactoryMethodIfPossible} logic as it is not + * public). + * @param applicationContext The application context. + * @param beanName The name of the bean to find a definition for. + * @return The resolved type. + * @throws IllegalArgumentException if the type of the bean definition is not + * resolvable. */ - public static Type getFunctionInputType(Class> functionClass) { - return getFunctionArgumentType(functionClass, 0); + public static ResolvableType resolveBeanType(GenericApplicationContext applicationContext, String beanName) { + BeanDefinition beanDefinition; + try { + beanDefinition = applicationContext.getBeanDefinition(beanName); + } + catch (NoSuchBeanDefinitionException ex) { + throw new IllegalArgumentException( + "Functional bean with name " + beanName + " does not exist in the context."); + } + ResolvableType functionType = beanDefinition.getResolvableType(); + Class resolvableClass = functionType.resolve(); + if (resolvableClass != null) { + return functionType; + } + if (beanDefinition instanceof RootBeanDefinition rootBeanDefinition) { + Class factoryClass; + boolean isStatic; + if (rootBeanDefinition.getFactoryBeanName() != null) { + factoryClass = applicationContext.getBeanFactory().getType(rootBeanDefinition.getFactoryBeanName()); + isStatic = false; + } + else { + factoryClass = rootBeanDefinition.getBeanClass(); + isStatic = true; + } + Assert.state(factoryClass != null, "Unresolvable factory class"); + factoryClass = ClassUtils.getUserClass(factoryClass); + + Method[] candidates = getCandidateMethods(factoryClass, rootBeanDefinition); + Method uniqueCandidate = null; + for (Method candidate : candidates) { + if ((!isStatic || isStaticCandidate(candidate, factoryClass)) + && rootBeanDefinition.isFactoryMethod(candidate)) { + if (uniqueCandidate == null) { + uniqueCandidate = candidate; + } + else if (isParamMismatch(uniqueCandidate, candidate)) { + uniqueCandidate = null; + break; + } + } + } + rootBeanDefinition.setResolvedFactoryMethod(uniqueCandidate); + return rootBeanDefinition.getResolvableType(); + } + throw new IllegalArgumentException("Impossible to resolve the type of bean " + beanName); } - /** - * Retrieves the output type of a given function class. - * @param functionClass The function class. - * @return The output type of the function. - */ - public static Type getFunctionOutputType(Class> functionClass) { - return getFunctionArgumentType(functionClass, 1); + static private Method[] getCandidateMethods(Class factoryClass, RootBeanDefinition mbd) { + return (mbd.isNonPublicAccessAllowed() ? ReflectionUtils.getUniqueDeclaredMethods(factoryClass) + : factoryClass.getMethods()); } - /** - * Retrieves the type of a specific argument in a given function class. - * @param functionClass The function class. - * @param argumentIndex The index of the argument whose type should be retrieved. - * @return The type of the specified function argument. - */ - public static Type getFunctionArgumentType(Class> functionClass, int argumentIndex) { - Type functionType = TypeResolver.reify(Function.class, functionClass); - return getFunctionArgumentType(functionType, argumentIndex); + static private boolean isStaticCandidate(Method method, Class factoryClass) { + return (Modifier.isStatic(method.getModifiers()) && method.getDeclaringClass() == factoryClass); + } + + static private boolean isParamMismatch(Method uniqueCandidate, Method candidate) { + int uniqueCandidateParameterCount = uniqueCandidate.getParameterCount(); + int candidateParameterCount = candidate.getParameterCount(); + return (uniqueCandidateParameterCount != candidateParameterCount + || !Arrays.equals(uniqueCandidate.getParameterTypes(), candidate.getParameterTypes())); } /** - * Retrieves the type of a specific argument in a given function type. + * Retrieves the type of a specific argument in a given function class. * @param functionType The function type. * @param argumentIndex The index of the argument whose type should be retrieved. * @return The type of the specified function argument. + * @throws IllegalArgumentException if functionType is not a supported type */ - public static Type getFunctionArgumentType(Type functionType, int argumentIndex) { - - // Resolves: https://github.com/spring-projects/spring-ai/issues/726 - if (!(functionType instanceof ParameterizedType)) { - Class functionalClass = FunctionTypeUtils.getRawType(functionType); - // Resolves: https://github.com/spring-projects/spring-ai/issues/1576 - if (BiFunction.class.isAssignableFrom(functionalClass)) { - functionType = TypeResolver.reify(BiFunction.class, (Class>) functionalClass); + public static ResolvableType getFunctionArgumentType(ResolvableType functionType, int argumentIndex) { + + Class resolvableClass = functionType.toClass(); + ResolvableType functionArgumentResolvableType = ResolvableType.NONE; + + if (Function.class.isAssignableFrom(resolvableClass)) { + functionArgumentResolvableType = functionType.as(Function.class); + } + else if (BiFunction.class.isAssignableFrom(resolvableClass)) { + functionArgumentResolvableType = functionType.as(BiFunction.class); + } + else if (KotlinDetector.isKotlinPresent()) { + if (KotlinDelegate.isKotlinFunction(resolvableClass)) { + functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType); } - else { - functionType = FunctionTypeUtils.discoverFunctionTypeFromClass(functionalClass); + else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) { + functionArgumentResolvableType = KotlinDelegate.adaptToKotlinBiFunctionType(functionType); } } - var argumentType = functionType instanceof ParameterizedType - ? ((ParameterizedType) functionType).getActualTypeArguments()[argumentIndex] : Object.class; + if (functionArgumentResolvableType == ResolvableType.NONE) { + throw new IllegalArgumentException( + "Type must be a Function, BiFunction, Function1 or Function2. Found: " + functionType); + } - return argumentType; + return functionArgumentResolvableType.getGeneric(argumentIndex); } - /** - * Effectively converts {@link Type} which could be {@link ParameterizedType} to raw - * Class (no generics). - * @param type actual {@link Type} instance - * @return instance of {@link Class} as raw representation of the provided - * {@link Type} - */ - public static Class toRawClass(Type type) { - return type != null - ? TypeResolver.resolveRawClass(type instanceof GenericArrayType ? type : TypeResolver.reify(type), null) - : null; + private static class KotlinDelegate { + + public static boolean isKotlinFunction(Class clazz) { + return Function1.class.isAssignableFrom(clazz); + } + + public static ResolvableType adaptToKotlinFunctionType(ResolvableType resolvableType) { + return resolvableType.as(Function1.class); + } + + public static boolean isKotlinBiFunction(Class clazz) { + return Function2.class.isAssignableFrom(clazz); + } + + public static ResolvableType adaptToKotlinBiFunctionType(ResolvableType resolvableType) { + return resolvableType.as(Function2.class); + } + } } diff --git a/spring-ai-core/src/main/kotlin/org/springframework/ai/model/function/FunctionCallbackWrapperExtensions.kt b/spring-ai-core/src/main/kotlin/org/springframework/ai/model/function/FunctionCallbackWrapperExtensions.kt new file mode 100644 index 0000000000..d07dcbc2d2 --- /dev/null +++ b/spring-ai-core/src/main/kotlin/org/springframework/ai/model/function/FunctionCallbackWrapperExtensions.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.function + +/** + * Extension for [FunctionCallbackWrapper.Builder.withInputType] providing a `withInputType()` + * variant. + * + * @author Sebastien Deleuze + */ +inline fun FunctionCallbackWrapper.Builder<*, *>.withInputType() = + withInputType(T::class.java) diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java index fb532d9ce3..4b6d4ad41d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java @@ -16,18 +16,18 @@ package org.springframework.ai.model.function; -import java.lang.reflect.Type; import java.util.function.Function; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.cloud.function.context.config.FunctionContextUtils; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.ResolvableType; import static org.assertj.core.api.Assertions.assertThat; @@ -38,15 +38,14 @@ class TypeResolverHelperIT { GenericApplicationContext applicationContext; @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction" }) - void beanInputTypeResolutionTest(String beanName) { + @ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction", + "scannedStandaloneWeatherFunction" }) + void beanInputTypeResolutionWithResolvableType(String beanName) { assertThat(this.applicationContext).isNotNull(); - Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName); - assertThat(beanType).isNotNull(); - Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0); - assertThat(functionInputType).isNotNull(); - assertThat(functionInputType.getTypeName()).isEqualTo(WeatherRequest.class.getName()); - + ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName); + Class functionInputClass = TypeResolverHelper.getFunctionArgumentType(functionType, 0).getRawClass(); + assertThat(functionInputClass).isNotNull(); + assertThat(functionInputClass.getTypeName()).isEqualTo(WeatherRequest.class.getName()); } public record WeatherRequest(String city) { @@ -70,7 +69,8 @@ public WeatherResponse apply(WeatherRequest weatherRequest) { } - @SpringBootConfiguration + @Configuration + @ComponentScan("org.springframework.ai.model.function.config") public static class TypeResolverHelperConfiguration { @Bean diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/config/TypeResolverHelperConfiguration.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/config/TypeResolverHelperConfiguration.java new file mode 100644 index 0000000000..588db50b51 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/config/TypeResolverHelperConfiguration.java @@ -0,0 +1,31 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.function.config; + +import org.springframework.ai.model.function.StandaloneWeatherFunction; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class TypeResolverHelperConfiguration { + + @Bean + StandaloneWeatherFunction scannedStandaloneWeatherFunction() { + return new StandaloneWeatherFunction(); + } + +} diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/FunctionCallbackWrapperExtensionsTests.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/FunctionCallbackWrapperExtensionsTests.kt new file mode 100644 index 0000000000..531b3cc879 --- /dev/null +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/FunctionCallbackWrapperExtensionsTests.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.function + +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import org.junit.jupiter.api.Test + +class FunctionCallbackWrapperExtensionsTests { + + private val builder = mockk>() + + @Test + fun withInputType() { + every { builder.withInputType(any>()) } returns builder + builder.withInputType() + verify { builder.withInputType(WeatherRequest::class.java) } + } +} diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/StandaloneWeatherKotlinFunction.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/StandaloneWeatherKotlinFunction.kt new file mode 100644 index 0000000000..97ccdbe5c4 --- /dev/null +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/StandaloneWeatherKotlinFunction.kt @@ -0,0 +1,8 @@ +package org.springframework.ai.model.function + +class StandaloneWeatherKotlinFunction : Function1 { + + override fun invoke(weatherRequest: WeatherRequest): WeatherResponse { + return WeatherResponse(42.0f) + } +} diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/TypeResolverHelperKotlinIT.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/TypeResolverHelperKotlinIT.kt new file mode 100644 index 0000000000..8ced2732fc --- /dev/null +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/TypeResolverHelperKotlinIT.kt @@ -0,0 +1,80 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.function + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.ComponentScan +import org.springframework.context.annotation.Configuration +import org.springframework.context.support.GenericApplicationContext + +@SpringBootTest +class TypeResolverHelperKotlinIT { + + @Autowired + lateinit var applicationContext: GenericApplicationContext + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = ["weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction", "scannedStandaloneWeatherFunction"]) + fun beanInputTypeResolutionTest(beanName: String) { + assertThat(this.applicationContext).isNotNull() + val functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName); + val functionInputClass = TypeResolverHelper.getFunctionArgumentType(functionType, 0).rawClass; + assertThat(functionInputClass).isNotNull(); + assertThat(functionInputClass.typeName).isEqualTo(WeatherRequest::class.java.getName()); + } + + class Outer { + + class InnerWeatherFunction : Function1 { + + override fun invoke(weatherRequest: WeatherRequest): WeatherResponse { + return WeatherResponse(42.0f) + } + } + } + + @Configuration + @ComponentScan("org.springframework.ai.model.function.kotlinconfig") + open class TypeResolverHelperConfiguration { + + @Bean + open fun weatherClassDefinition(): Outer.InnerWeatherFunction { + return Outer.InnerWeatherFunction(); + } + + @Bean + open fun weatherFunctionDefinition(): Function1 { + return Outer.InnerWeatherFunction(); + } + + @Bean + open fun standaloneWeatherFunction(): StandaloneWeatherKotlinFunction { + return StandaloneWeatherKotlinFunction(); + } + + } + +} + +data class WeatherRequest(val city: String) + +data class WeatherResponse(val temperatureInCelsius: Float) diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt new file mode 100644 index 0000000000..5724a0d2c6 --- /dev/null +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/function/kotlinconfig/TypeResolverHelperKotlinConfiguration.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.function.kotlinconfig + +import org.springframework.ai.model.function.StandaloneWeatherKotlinFunction +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration + +@Configuration +open class TypeResolverHelperKotlinConfiguration { + + @Bean + open fun scannedStandaloneWeatherFunction(): StandaloneWeatherKotlinFunction { + return StandaloneWeatherKotlinFunction() + } +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index 8ee2e8d4f8..d23dea9ade 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -270,6 +270,11 @@ NOTE: Adhere to the OpenAI link:https://platform.openai.com/docs/guides/structur You can leverage existing xref::api/structured-output-converter.adoc#_bean_output_converter[BeanOutputConverter] utilities to automatically generate the JSON Schema from your domain objects and later convert the structured response into domain-specific instances: +-- +[tabs] +====== +Java:: ++ [source,java] ---- record MathReasoning( @@ -301,8 +306,41 @@ String content = this.response.getResult().getOutput().getContent(); MathReasoning mathReasoning = this.outputConverter.convert(this.content); ---- +Kotlin:: ++ +[source,kotlin] +---- +data class MathReasoning( + @get:JsonProperty(required = true, value = "steps") val steps: Steps, + @get:JsonProperty(required = true, value = "final_answer") val finalAnswer: String) { + + data class Steps(@get:JsonProperty(required = true, value = "items") val items: Array) { + + data class Items( + @get:JsonProperty(required = true, value = "explanation") val explanation: String, + @get:JsonProperty(required = true, value = "output") val output: String) + } +} + +val outputConverter = BeanOutputConverter(MathReasoning::class.java) + +val jsonSchema = outputConverter.jsonSchema; + +val prompt = Prompt("how can I solve 8x + 7 = -23", + OpenAiChatOptions.builder() + .withModel(ChatModel.GPT_4_O_MINI) + .withResponseFormat(ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema)) + .build()) + +val response = openAiChatModel.call(prompt) +val content = response.getResult().getOutput().getContent() + +val mathReasoning = outputConverter.convert(content) +---- +====== +-- -NOTE: Ensure you use the `@JsonProperty(required = true,...)` annotation. +NOTE: Ensure you use the `@JsonProperty(required = true,...)` annotation (`@get:JsonProperty(required = true,...)` with Kotlin in order to generate the annotation on the related getters, see link:https://kotlinlang.org/docs/annotations.html#annotation-use-site-targets[related documentation]). This is crucial for generating a schema that accurately marks fields as `required`. Although this is optional for JSON Schema, OpenAI link:https://platform.openai.com/docs/guides/structured-outputs/all-fields-must-be-required[mandates] it for the structured response to function correctly. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index db7d51693c..21d6b7923a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -56,8 +56,13 @@ When the model needs to answer a question such as `"What’s the weather like in Our function calls some SaaS-based weather service API and returns the weather response back to the model to complete the conversation. In this example, we will use a simple implementation named `MockWeatherService` that hard-codes the temperature for various locations. -The following `MockWeatherService.java` represents the weather service API: +The following `MockWeatherService` class represents the weather service API: +-- +[tabs] +====== +Java:: ++ [source,java] ---- public class MockWeatherService implements Function { @@ -71,6 +76,20 @@ public class MockWeatherService implements Function { } } ---- +Kotlin:: ++ +[source,kotlin] +---- +class MockWeatherService : Function1 { + override fun invoke(request: Request) = Response(30.0, Unit.C) +} + +enum class Unit { C, F } +data class Request(val location: String, val unit: Unit) {} +data class Response(val temp: Double, val unit: Unit) {} +---- +====== +-- === Registering Functions as Beans @@ -78,13 +97,18 @@ Spring AI provides multiple ways to register custom functions as beans in the Sp We start by describing the most POJO-friendly options. -==== Plain Java Functions +==== Plain Functions In this approach, you define a `@Bean` in your application context as you would any other Spring managed object. Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallbackWrapper` that adds the logic for it being invoked via the AI model. The name of the `@Bean` is used function name. +-- +[tabs] +====== +Java:: ++ [source,java] ---- @Configuration @@ -98,24 +122,63 @@ static class Config { } ---- +Kotlin:: ++ +[source,kotlin] +---- +@Configuration +class Config { + + @Bean + @Description("Get the weather in location") // function description + fun currentWeather(): (Request) -> Response = MockWeatherService() + +} +---- +====== +-- The `@Description` annotation is optional and provides a function description that helps the model understand when to call the function. It is an important property to set to help the AI model determine what client side function to invoke. Another option for providing the description of the function is to use the `@JsonClassDescription` annotation on the `MockWeatherService.Request`: +-- +[tabs] +====== +Java:: ++ [source,java] ---- @Configuration static class Config { + @Bean public Function currentWeather() { // bean name as function name return new MockWeatherService(); } } -@JsonClassDescription("Get the weather in location") // // function description +@JsonClassDescription("Get the weather in location") // function description public record Request(String location, Unit unit) {} ---- +Kotlin:: ++ +[source,kotlin] +---- +@Configuration +class Config { + + @Bean + fun currentWeather(): (Request) -> Response { // bean name as function name + return MockWeatherService() + } +} + +@JsonClassDescription("Get the weather in location") // function description +data class Request(val location: String, val unit: Unit) +---- +====== +-- It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. @@ -123,6 +186,11 @@ It is a best practice to annotate the request object with information such that Another way to register a function is to create a `FunctionCallbackWrapper` like this: +-- +[tabs] +====== +Java:: ++ [source,java] ---- @Configuration @@ -138,6 +206,30 @@ static class Config { } } ---- +Kotlin:: ++ +[source,kotlin] +---- +import org.springframework.ai.model.function.withInputType + +@Configuration +class Config { + + @Bean + fun weatherFunctionInfo(): FunctionCallback { + + return FunctionCallbackWrapper.builder(MockWeatherService()) + .withName("CurrentWeather") // (1) function name + .withDescription("Get the weather in location") // (2) function description + // (3) Required due to Kotlin SAM conversion beeing an opaque lambda + .withInputType() + .build(); + } +} + +---- +====== +-- It wraps the 3rd party `MockWeatherService` function and registers it as a `CurrentWeather` function with the `ChatClient`. It also provides a description (2) and an optional response converter to convert the response into a text as expected by the model. diff --git a/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackContextKotlinIT.kt b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackContextKotlinIT.kt new file mode 100644 index 0000000000..8866b29488 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackContextKotlinIT.kt @@ -0,0 +1,117 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.springframework.ai.autoconfigure.ollama.tool + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.DisabledIf +import org.slf4j.LoggerFactory +import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT +import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration +import org.springframework.ai.chat.messages.UserMessage +import org.springframework.ai.chat.prompt.Prompt +import org.springframework.ai.model.function.FunctionCallingOptions +import org.springframework.ai.ollama.OllamaChatModel +import org.springframework.ai.ollama.api.OllamaOptions +import org.springframework.boot.autoconfigure.AutoConfigurations +import org.springframework.boot.test.context.runner.ApplicationContextRunner +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.context.annotation.Description +import org.testcontainers.junit.jupiter.Testcontainers + +@Testcontainers +@DisabledIf("isDisabled") +class FunctionCallbackContextKotlinIT : BaseOllamaIT() { + + private val logger = LoggerFactory.getLogger(FunctionCallbackContextKotlinIT::class.java) + + private val MODEL_NAME = "qwen2.5:3b" + + val contextRunner = buildConnectionWithModel(MODEL_NAME).let { baseUrl -> + ApplicationContextRunner().withPropertyValues( + "spring.ai.ollama.baseUrl=$baseUrl", + "spring.ai.ollama.chat.options.model=$MODEL_NAME", + "spring.ai.ollama.chat.options.temperature=0.5", + "spring.ai.ollama.chat.options.topK=10" + ) + .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration::class.java)) + .withUserConfiguration(Config::class.java) + } + + @Test + fun functionCallTest() { + this.contextRunner.run {context -> + + val chatModel = context.getBean(OllamaChatModel::class.java) + + val userMessage = UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") + + val response = chatModel + .call(Prompt(listOf(userMessage), OllamaOptions.builder().withFunction("weatherInfo").build())) + + logger.info("Response: " + response) + + assertThat(response.getResult().output.content).contains("30", "10", "15") + } + } + + @Test + fun functionCallWithPortableFunctionCallingOptions() { + this.contextRunner.run { context -> + + val chatModel = context.getBean(OllamaChatModel::class.java) + + // Test weatherFunction + val userMessage = UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") + + val functionOptions = FunctionCallingOptions.builder() + .withFunction("weatherInfo") + .build() + + val response = chatModel.call(Prompt(listOf(userMessage), functionOptions)); + + logger.info("Response: " + response.getResult().getOutput().getContent()); + + assertThat(response.getResult().output.content).contains("30", "10", "15"); + } + } + + @Configuration + open class Config { + + @Bean + @Description("Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") + open fun weatherInfo(): (KotlinRequest) -> KotlinResponse = { request -> + var temperature = 10.0 + if (request.location.contains("Paris")) { + temperature = 15.0 + } + else if (request.location.contains("Tokyo")) { + temperature = 10.0 + } + else if (request.location.contains("San Francisco")) { + temperature = 30.0 + } + KotlinResponse(temperature, 15.0, 20.0, 2.0, 53, 45, Unit.C); + } + + } +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperKotlinIT.kt b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperKotlinIT.kt new file mode 100644 index 0000000000..4af4d38c7a --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackWrapperKotlinIT.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.springframework.ai.autoconfigure.ollama.tool + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.condition.DisabledIf +import org.slf4j.LoggerFactory +import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT +import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration +import org.springframework.ai.chat.messages.UserMessage +import org.springframework.ai.chat.prompt.Prompt +import org.springframework.ai.model.function.FunctionCallback +import org.springframework.ai.model.function.FunctionCallbackWrapper +import org.springframework.ai.model.function.FunctionCallingOptions +import org.springframework.ai.ollama.OllamaChatModel +import org.springframework.ai.ollama.api.OllamaOptions +import org.springframework.boot.autoconfigure.AutoConfigurations +import org.springframework.boot.test.context.runner.ApplicationContextRunner +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.testcontainers.junit.jupiter.Testcontainers + +@Testcontainers +@DisabledIf("isDisabled") +class FunctionCallbackWrapperKotlinIT : BaseOllamaIT() { + + private val logger = LoggerFactory.getLogger(FunctionCallbackWrapperKotlinIT::class.java) + + private val MODEL_NAME = "qwen2.5:3b" + + val contextRunner = buildConnectionWithModel(MODEL_NAME).let { baseUrl -> + ApplicationContextRunner().withPropertyValues( + "spring.ai.ollama.baseUrl=$baseUrl", + "spring.ai.ollama.chat.options.model=$MODEL_NAME", + "spring.ai.ollama.chat.options.temperature=0.5", + "spring.ai.ollama.chat.options.topK=10" + ) + .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration::class.java)) + .withUserConfiguration(Config::class.java) + } + + @Test + fun functionCallTest() { + this.contextRunner.run {context -> + + val chatModel = context.getBean(OllamaChatModel::class.java) + + val userMessage = UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") + + val response = chatModel + .call(Prompt(listOf(userMessage), OllamaOptions.builder().withFunction("WeatherInfo").build())) + + logger.info("Response: " + response) + + assertThat(response.getResult().output.content).contains("30", "10", "15") + } + } + + @Test + fun functionCallWithPortableFunctionCallingOptions() { + this.contextRunner.run { context -> + + val chatModel = context.getBean(OllamaChatModel::class.java) + + // Test weatherFunction + val userMessage = UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") + + val functionOptions = FunctionCallingOptions.builder() + .withFunction("WeatherInfo") + .build() + + val response = chatModel.call(Prompt(listOf(userMessage), functionOptions)); + + logger.info("Response: " + response.getResult().getOutput().getContent()); + + assertThat(response.getResult().output.content).contains("30", "10", "15"); + } + } + + @Configuration + open class Config { + + @Bean + open fun weatherFunctionInfo(): FunctionCallback { + + return FunctionCallbackWrapper.builder(MockKotlinWeatherService()) + .withName("WeatherInfo") + .withInputType(KotlinRequest::class.java) + .withDescription( + "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") + .build(); + } + + } +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/MockKotlinWeatherService.kt b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/MockKotlinWeatherService.kt new file mode 100644 index 0000000000..89795b4d12 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/kotlin/org/springframework/ai/autoconfigure/ollama/tool/MockKotlinWeatherService.kt @@ -0,0 +1,80 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama.tool + +import com.fasterxml.jackson.annotation.JsonClassDescription +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.annotation.JsonPropertyDescription + +class MockKotlinWeatherService : Function1 { + + override fun invoke(kotlinRequest: KotlinRequest): KotlinResponse { + var temperature = 10.0 + if (kotlinRequest.location.contains("Paris")) { + temperature = 15.0 + } + else if (kotlinRequest.location.contains("Tokyo")) { + temperature = 10.0 + } + else if (kotlinRequest.location.contains("San Francisco")) { + temperature = 30.0 + } + + return KotlinResponse(temperature, 15.0, 20.0, 2.0, 53, 45, Unit.C); + } +} + +/** + * Temperature units. + */ +enum class Unit(val unitName: String) { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); +} + +/** + * Weather Function request. + */ +@JsonInclude(Include.NON_NULL) +@JsonClassDescription("Weather API request") +data class KotlinRequest(@get:JsonProperty(required = true, value = "location") @get:JsonPropertyDescription("The city and state e.g. San Francisco, CA") val location: String, + @get:JsonProperty(required = true, value = "lat") @get:JsonPropertyDescription("The city latitude") val lat: Double, + @get:JsonProperty(required = true, value = "lon") @get:JsonPropertyDescription("The city longitude") val lon: Double, + @get:JsonProperty(required = true, value = "unit") @get:JsonPropertyDescription("Temperature unit") val unit: Unit) { + +} + +/** + * Weather Function response. + */ +data class KotlinResponse(val temp: Double, + val feels_like: Double, + val temp_min: Double, + val temp_max: Double, + val pressure: Int, + val humidity: Int, + val unit: Unit +)