Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed the bug that the image model parameter was empty when set in yaml #146

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package com.alibaba.cloud.ai.dashscope.image;

import java.util.List;
import java.util.Objects;

import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.*;

import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.List;

/**
* @author nuocheng.lxm
Expand All @@ -21,6 +28,11 @@ public class DashScopeImageModel implements ImageModel {

private static final Logger logger = LoggerFactory.getLogger(DashScopeImageModel.class);

/**
* The default model used for the image completion requests.
*/
private static final String DEFAULT_MODEL = "wanx-v1";

/**
* Low-level access to the DashScope Image API.
*/
Expand All @@ -29,7 +41,7 @@ public class DashScopeImageModel implements ImageModel {
/**
* The default options used for the image completion requests.
*/
private final DashScopeImageOptions options;
private final DashScopeImageOptions defaultOptions;

/**
* The retry template used to retry the OpenAI Image API calls.
Expand All @@ -44,37 +56,42 @@ public DashScopeImageModel(DashScopeImageApi dashScopeImageApi) {

public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions options,
RetryTemplate retryTemplate) {

Assert.notNull(dashScopeImageApi, "DashScopeImageApi must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");

this.dashScopeImageApi = dashScopeImageApi;
this.options = options;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ImageResponse call(ImagePrompt request) {

String taskId = submitImageGenTask(request);
if (taskId == null) {
return new ImageResponse(List.of());
}

int retryCount = 0;
while (true && retryCount < MAX_RETRY_COUNT) {
while (retryCount < MAX_RETRY_COUNT) {
DashScopeImageApi.DashScopeImageAsyncReponse getResultResponse = getImageGenTask(taskId);
if (getResultResponse != null) {
DashScopeImageApi.DashScopeImageAsyncReponse.DashScopeImageAsyncReponseOutput output = getResultResponse
.output();
String taskStatus = output.taskStatus();
switch (taskStatus) {
case "SUCCEEDED":
case "SUCCEEDED" -> {
return toImageResponse(output);
case "FAILED":
case "UNKNOWN":
}
case "FAILED", "UNKNOWN" -> {
return new ImageResponse(List.of());
}
}
}
try {
Thread.sleep(15000l);
Thread.sleep(15000L);
retryCount++;
}
catch (InterruptedException e) {
Expand All @@ -85,34 +102,42 @@ public ImageResponse call(ImagePrompt request) {
}

public String submitImageGenTask(ImagePrompt request) {
String instructions = request.getInstructions().get(0).getText();
DashScopeImageApi.DashScopeImageRequest imageRequest = null;
if (options != null) {
imageRequest = new DashScopeImageApi.DashScopeImageRequest(options.getModel(),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestInput(instructions,
options.getNegativePrompt(), options.getRefImg()),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestParameter(options.getStyle(),
options.getSize(), options.getN(), options.getSeed(), options.getRefStrength(),
options.getRefMode()));
}
if (request.getOptions() != null) {
DashScopeImageOptions options = toQianFanImageOptions(request.getOptions());
imageRequest = new DashScopeImageApi.DashScopeImageRequest(options.getModel(),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestInput(instructions,
options.getNegativePrompt(), options.getRefImg()),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestParameter(options.getStyle(),
options.getSize(), options.getN(), options.getSeed(), options.getRefStrength(),
options.getRefMode()));
}

DashScopeImageOptions imageOptions = toImageOptions(request.getOptions());
logger.debug("Image options: {}", imageOptions);

DashScopeImageApi.DashScopeImageRequest dashScopeImageRequest = constructImageRequest(request, imageOptions);

ResponseEntity<DashScopeImageApi.DashScopeImageAsyncReponse> submitResponse = dashScopeImageApi
.submitImageGenTask(imageRequest);
.submitImageGenTask(dashScopeImageRequest);

if (submitResponse == null || submitResponse.getBody() == null) {
logger.warn("Submit imageGen error,request: {}", request);
return null;
}

return submitResponse.getBody().output().taskId();
}

/**
* Merge Image options. Notice: Programmatically set options parameters take
* precedence
*/
private DashScopeImageOptions toImageOptions(ImageOptions runtimeOptions) {

// set default image model
var currentOptions = DashScopeImageOptions.builder().withModel(DEFAULT_MODEL).build();

if (Objects.nonNull(runtimeOptions)) {
currentOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
DashScopeImageOptions.class);
}

currentOptions = ModelOptionsUtils.merge(currentOptions, this.defaultOptions, DashScopeImageOptions.class);

return currentOptions;
}

public DashScopeImageApi.DashScopeImageAsyncReponse getImageGenTask(String taskId) {
ResponseEntity<DashScopeImageApi.DashScopeImageAsyncReponse> getImageGenResponse = dashScopeImageApi
.getImageGenTaskResult(taskId);
Expand All @@ -137,52 +162,16 @@ private ImageResponse toImageResponse(
return new ImageResponse(imageGenerationList);
}

private DashScopeImageOptions toQianFanImageOptions(ImageOptions runtimeImageOptions) {
DashScopeImageOptions.Builder builder = DashScopeImageOptions.builder();
if (runtimeImageOptions == null) {
return builder.build();
}
commonImageOptions(runtimeImageOptions, builder);
if (runtimeImageOptions instanceof DashScopeImageOptions dashScopeImageOptions) {
dashScopeSpecificOptions(dashScopeImageOptions, builder);
}
return builder.build();
}
private DashScopeImageApi.DashScopeImageRequest constructImageRequest(ImagePrompt imagePrompt,
DashScopeImageOptions options) {

private void commonImageOptions(ImageOptions runtimeImageOptions, DashScopeImageOptions.Builder builder) {
if (runtimeImageOptions.getN() != null) {
builder.withN(options.getN());
}
if (runtimeImageOptions.getModel() != null) {
builder.withModel(options.getModel());
}
if (runtimeImageOptions.getWidth() != null) {
builder.withWidth(options.getWidth());
}
if (runtimeImageOptions.getHeight() != null) {
builder.withHeight(options.getHeight());
}
}

private void dashScopeSpecificOptions(DashScopeImageOptions options, DashScopeImageOptions.Builder builder) {
if (options.getStyle() != null) {
builder.withStyle(options.getStyle());
}
if (options.getSeed() != null) {
builder.withSeed(options.getSeed());
}
if (options.getRefImg() != null) {
builder.withRefImg(options.getRefImg());
}
if (options.getRefMode() != null) {
builder.withRefMode(options.getRefMode());
}
if (options.getRefStrength() != null) {
builder.withRefStrength(options.getRefStrength());
}
if (options.getNegativePrompt() != null) {
builder.withNegativePrompt(options.getNegativePrompt());
}
return new DashScopeImageApi.DashScopeImageRequest(options.getModel(),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestInput(
imagePrompt.getInstructions().get(0).getText(), options.getNegativePrompt(),
options.getRefImg()),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestParameter(options.getStyle(),
options.getSize(), options.getN(), options.getSeed(), options.getRefStrength(),
options.getRefMode()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ public void setNegativePrompt(String negativePrompt) {
this.negativePrompt = negativePrompt;
}

@Override
public String toString() {

return "DashScopeImageOptions{" + "model='" + model + '\'' + ", n=" + n + ", width=" + width + ", height="
+ height + ", size='" + size + '\'' + ", style='" + style + '\'' + ", seed=" + seed + ", refImg='"
+ refImg + '\'' + ", refStrength=" + refStrength + ", refMode='" + refMode + '\'' + ", negativePrompt='"
+ negativePrompt + '\'' + '}';
}

public static class Builder {

private final DashScopeImageOptions options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,30 @@ public enum DocumentParser {

TEXT_PARSER("text", "org.springframework.ai.reader.TextReader", "Plain Text Document Reader"),
JSON_PARSER("json", "org.springframework.ai.reader.JsonReader", "Plain Json Document Reader"),
PAGE_PDF_PARSER("pagePdf", "org.springframework.ai.reader.pdf.PagePdfDocumentReader", "Groups the parsed PDF pages into {@link Document}s. You can group one or more pages into a single output document."),
PARAGRAPH_PDF_PARSER("paragraphPdf", "org.springframework.ai.reader.pdf.ParagraphPdfDocumentReader", "Uses the PDF catalog (e.g. TOC) information to split the input PDF into text paragraphs and output a single {@link Document} per paragraph"),
DOCX_PARSER("docx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
DOC_PARSER("doc", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPTX_PARSER("pptx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPT_PARSER("ppt", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLSX_PARSER("xlsx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLS_PARSER("xls", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
CSV_PARSER("csv", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from CSV documents."),
HTML_PARSER("html", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from HTML documents."),
XML_PARSER("xml", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from XML documents."),
RTF_PARSER("rtf", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from Rich Text Format (RTF) documents.");
PAGE_PDF_PARSER("pagePdf", "org.springframework.ai.reader.pdf.PagePdfDocumentReader",
"Groups the parsed PDF pages into {@link Document}s. You can group one or more pages into a single output document."),
PARAGRAPH_PDF_PARSER("paragraphPdf", "org.springframework.ai.reader.pdf.ParagraphPdfDocumentReader",
"Uses the PDF catalog (e.g. TOC) information to split the input PDF into text paragraphs and output a single {@link Document} per paragraph"),
DOCX_PARSER("docx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
DOC_PARSER("doc", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPTX_PARSER("pptx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPT_PARSER("ppt", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLSX_PARSER("xlsx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLS_PARSER("xls", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
CSV_PARSER("csv", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from CSV documents."),
HTML_PARSER("html", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from HTML documents."),
XML_PARSER("xml", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from XML documents."),
RTF_PARSER("rtf", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from Rich Text Format (RTF) documents.");

private final String parserType;

Expand Down
2 changes: 0 additions & 2 deletions spring-ai-alibaba-examples/chatmodel-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<maven-deploy-plugin.version>3.1.1</maven-deploy-plugin.version>

<!-- Spring AI -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import jakarta.servlet.http.HttpServletResponse;

import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.http.MediaType;
Expand All @@ -46,11 +44,14 @@ public class ImageModelController {
@GetMapping("/image/{input}")
public void image(@PathVariable("input") String input, HttpServletResponse response) {

ImageOptions options = ImageOptionsBuilder.builder()
.withModel("wanx-v1")
.build();
// The options parameter set in this way takes precedence over the parameters in the yaml configuration file.
// The default image model is wanx-v1
// ImageOptions options = ImageOptionsBuilder.builder()
// .withModel("wax-2")
// .build();
// ImagePrompt imagePrompt = new ImagePrompt(input, options);

ImagePrompt imagePrompt = new ImagePrompt(input, options);
ImagePrompt imagePrompt = new ImagePrompt(input);
ImageResponse imageResponse = imageModel.call(imagePrompt);
String imageUrl = imageResponse.getResult().getOutput().getUrl();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ spring:
ai:
dashscope:
api-key: ${AI_DASHSCOPE_API_KEY}
image:
options:
model: wanx-v1