Skip to content

Commit

Permalink
feat : when apikey is demo modify headers
Browse files Browse the repository at this point in the history
  • Loading branch information
rajadilipkolli committed Mar 27, 2024
1 parent 54787f6 commit 14358b4
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 5 deletions.
6 changes: 6 additions & 0 deletions chatmodel-springai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.rest-assured</groupId>
<artifactId>rest-assured</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<dependencyManagement>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.example.ai.config;

import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;

public class CustomClientHttpResponse implements ClientHttpResponse {

private final ClientHttpResponse originalResponse;
private final HttpHeaders headers;
public CustomClientHttpResponse(ClientHttpResponse originalResponse) {
this.originalResponse = originalResponse;
MultiValueMap<String, String> modifiedHeaders = new LinkedMultiValueMap<>(originalResponse.getHeaders());
modifiedHeaders.put(HttpHeaders.CONTENT_TYPE, Collections.singletonList(MediaType.APPLICATION_JSON_VALUE));
this.headers = new HttpHeaders(modifiedHeaders);
}

@Override
public HttpStatusCode getStatusCode() throws IOException {
return originalResponse.getStatusCode();
}

@Override
public String getStatusText() throws IOException {
return originalResponse.getStatusText();
}

@Override
public void close() {

}

@Override
public InputStream getBody() throws IOException {
return originalResponse.getBody();
}

@Override
public HttpHeaders getHeaders() {
return headers;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.example.ai.config;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpRequest;
import org.springframework.http.MediaType;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClient;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
public class LoggingConfig {

private final Logger log = LoggerFactory.getLogger(LoggingConfig.class);

@Bean
RestClient.Builder restClientBuilder() {
return RestClient.builder().requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory()))
.requestInterceptor((request, body, execution) -> {
logRequest(request, body);
ClientHttpResponse response = execution.execute(request, body);
logResponse(response);
return new CustomClientHttpResponse(response);
}).defaultHeaders(httpHeaders -> {
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.setAccept(List.of(MediaType.ALL));
});
}

private void logResponse(ClientHttpResponse response) throws IOException {
log.info("============================response begin==========================================");
log.info("Status code : {}", response.getStatusCode());
log.info("Status text : {}", response.getStatusText());
log.info("Headers : {}", response.getHeaders());
log.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
log.info("=======================response end=================================================");
}

private void logRequest(HttpRequest request, byte[] body) {

log.info("===========================request begin================================================");
log.info("URI : {}", request.getURI());
log.info("Method : {}", request.getMethod());
log.info("Headers : {}", request.getHeaders());
log.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
log.info("==========================request end================================================");

}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.example.ai.controller;

import com.example.ai.model.response.AIChatResponse;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.SystemMessage;
Expand Down Expand Up @@ -31,21 +32,21 @@ Map<String,String> chat(@RequestParam String question) {
}

@GetMapping("/chat-with-prompt")
Map<String,String> chatWithPrompt(@RequestParam String subject) {
AIChatResponse chatWithPrompt(@RequestParam String subject) {
PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
Prompt prompt = promptTemplate.create(Map.of("subject", subject));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
return Map.of( "answer", answer);
return new AIChatResponse(answer);
}

@GetMapping("/chat-with-system-prompt")
Map<String,String> chatWithSystemPrompt(@RequestParam String subject) {
AIChatResponse chatWithSystemPrompt(@RequestParam String subject) {
SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
UserMessage userMessage = new UserMessage("Tell me a joke about " + subject);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
return Map.of( "answer", answer);
return new AIChatResponse(answer);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.example.ai.model.response;

public record AIChatResponse(String answer) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ spring.ai.openai.chat.options.responseFormat=json_object
spring.ai.openai.embedding.enabled=false

##logging
logging.level.org.apache.hc.client5.http=DEBUG
logging.level.org.apache.hc.client5.http=INFO
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.example.ai.controller;

import org.hamcrest.Matchers;
import org.junit.jupiter.api.Test;

import io.restassured.RestAssured;

public class ChatControllerTest {

@Test
void testChat() {
RestAssured.given()
.param("question", "Hello?")
.when()
.get("http://localhost:8080/api/ai/chat")
.then()
.statusCode(200)
.body("question", Matchers.equalTo("Hello?"))
.body("answer", Matchers.equalTo("Hi!"));
}
}

0 comments on commit 14358b4

Please sign in to comment.