Skip to content

Commit

Permalink
Merge pull request #146 from yuluo-yx/1128-yuluo/fix-image-options
Browse files Browse the repository at this point in the history
fix: fixed the bug that the image model parameter was empty when set  in yaml
  • Loading branch information
chickenlj authored Dec 2, 2024
2 parents 0cf7c70 + 40bb2f3 commit 0420bab
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 96 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package com.alibaba.cloud.ai.dashscope.image;

import java.util.List;
import java.util.Objects;

import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.*;

import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.List;

/**
* @author nuocheng.lxm
Expand All @@ -21,6 +28,11 @@ public class DashScopeImageModel implements ImageModel {

private static final Logger logger = LoggerFactory.getLogger(DashScopeImageModel.class);

/**
* The default model used for the image completion requests.
*/
private static final String DEFAULT_MODEL = "wanx-v1";

/**
* Low-level access to the DashScope Image API.
*/
Expand All @@ -29,7 +41,7 @@ public class DashScopeImageModel implements ImageModel {
/**
* The default options used for the image completion requests.
*/
private final DashScopeImageOptions options;
private final DashScopeImageOptions defaultOptions;

/**
* The retry template used to retry the OpenAI Image API calls.
Expand All @@ -44,37 +56,42 @@ public DashScopeImageModel(DashScopeImageApi dashScopeImageApi) {

public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions options,
RetryTemplate retryTemplate) {

Assert.notNull(dashScopeImageApi, "DashScopeImageApi must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");

this.dashScopeImageApi = dashScopeImageApi;
this.options = options;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ImageResponse call(ImagePrompt request) {

String taskId = submitImageGenTask(request);
if (taskId == null) {
return new ImageResponse(List.of());
}

int retryCount = 0;
while (true && retryCount < MAX_RETRY_COUNT) {
while (retryCount < MAX_RETRY_COUNT) {
DashScopeImageApi.DashScopeImageAsyncReponse getResultResponse = getImageGenTask(taskId);
if (getResultResponse != null) {
DashScopeImageApi.DashScopeImageAsyncReponse.DashScopeImageAsyncReponseOutput output = getResultResponse
.output();
String taskStatus = output.taskStatus();
switch (taskStatus) {
case "SUCCEEDED":
case "SUCCEEDED" -> {
return toImageResponse(output);
case "FAILED":
case "UNKNOWN":
}
case "FAILED", "UNKNOWN" -> {
return new ImageResponse(List.of());
}
}
}
try {
Thread.sleep(15000l);
Thread.sleep(15000L);
retryCount++;
}
catch (InterruptedException e) {
Expand All @@ -85,34 +102,42 @@ public ImageResponse call(ImagePrompt request) {
}

public String submitImageGenTask(ImagePrompt request) {
String instructions = request.getInstructions().get(0).getText();
DashScopeImageApi.DashScopeImageRequest imageRequest = null;
if (options != null) {
imageRequest = new DashScopeImageApi.DashScopeImageRequest(options.getModel(),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestInput(instructions,
options.getNegativePrompt(), options.getRefImg()),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestParameter(options.getStyle(),
options.getSize(), options.getN(), options.getSeed(), options.getRefStrength(),
options.getRefMode()));
}
if (request.getOptions() != null) {
DashScopeImageOptions options = toQianFanImageOptions(request.getOptions());
imageRequest = new DashScopeImageApi.DashScopeImageRequest(options.getModel(),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestInput(instructions,
options.getNegativePrompt(), options.getRefImg()),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestParameter(options.getStyle(),
options.getSize(), options.getN(), options.getSeed(), options.getRefStrength(),
options.getRefMode()));
}

DashScopeImageOptions imageOptions = toImageOptions(request.getOptions());
logger.debug("Image options: {}", imageOptions);

DashScopeImageApi.DashScopeImageRequest dashScopeImageRequest = constructImageRequest(request, imageOptions);

ResponseEntity<DashScopeImageApi.DashScopeImageAsyncReponse> submitResponse = dashScopeImageApi
.submitImageGenTask(imageRequest);
.submitImageGenTask(dashScopeImageRequest);

if (submitResponse == null || submitResponse.getBody() == null) {
logger.warn("Submit imageGen error,request: {}", request);
return null;
}

return submitResponse.getBody().output().taskId();
}

/**
* Merge Image options. Notice: Programmatically set options parameters take
* precedence
*/
private DashScopeImageOptions toImageOptions(ImageOptions runtimeOptions) {

// set default image model
var currentOptions = DashScopeImageOptions.builder().withModel(DEFAULT_MODEL).build();

if (Objects.nonNull(runtimeOptions)) {
currentOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
DashScopeImageOptions.class);
}

currentOptions = ModelOptionsUtils.merge(currentOptions, this.defaultOptions, DashScopeImageOptions.class);

return currentOptions;
}

public DashScopeImageApi.DashScopeImageAsyncReponse getImageGenTask(String taskId) {
ResponseEntity<DashScopeImageApi.DashScopeImageAsyncReponse> getImageGenResponse = dashScopeImageApi
.getImageGenTaskResult(taskId);
Expand All @@ -137,52 +162,16 @@ private ImageResponse toImageResponse(
return new ImageResponse(imageGenerationList);
}

private DashScopeImageOptions toQianFanImageOptions(ImageOptions runtimeImageOptions) {
DashScopeImageOptions.Builder builder = DashScopeImageOptions.builder();
if (runtimeImageOptions == null) {
return builder.build();
}
commonImageOptions(runtimeImageOptions, builder);
if (runtimeImageOptions instanceof DashScopeImageOptions dashScopeImageOptions) {
dashScopeSpecificOptions(dashScopeImageOptions, builder);
}
return builder.build();
}
private DashScopeImageApi.DashScopeImageRequest constructImageRequest(ImagePrompt imagePrompt,
DashScopeImageOptions options) {

private void commonImageOptions(ImageOptions runtimeImageOptions, DashScopeImageOptions.Builder builder) {
if (runtimeImageOptions.getN() != null) {
builder.withN(options.getN());
}
if (runtimeImageOptions.getModel() != null) {
builder.withModel(options.getModel());
}
if (runtimeImageOptions.getWidth() != null) {
builder.withWidth(options.getWidth());
}
if (runtimeImageOptions.getHeight() != null) {
builder.withHeight(options.getHeight());
}
}

private void dashScopeSpecificOptions(DashScopeImageOptions options, DashScopeImageOptions.Builder builder) {
if (options.getStyle() != null) {
builder.withStyle(options.getStyle());
}
if (options.getSeed() != null) {
builder.withSeed(options.getSeed());
}
if (options.getRefImg() != null) {
builder.withRefImg(options.getRefImg());
}
if (options.getRefMode() != null) {
builder.withRefMode(options.getRefMode());
}
if (options.getRefStrength() != null) {
builder.withRefStrength(options.getRefStrength());
}
if (options.getNegativePrompt() != null) {
builder.withNegativePrompt(options.getNegativePrompt());
}
return new DashScopeImageApi.DashScopeImageRequest(options.getModel(),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestInput(
imagePrompt.getInstructions().get(0).getText(), options.getNegativePrompt(),
options.getRefImg()),
new DashScopeImageApi.DashScopeImageRequest.DashScopeImageRequestParameter(options.getStyle(),
options.getSize(), options.getN(), options.getSeed(), options.getRefStrength(),
options.getRefMode()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ public void setNegativePrompt(String negativePrompt) {
this.negativePrompt = negativePrompt;
}

@Override
public String toString() {

return "DashScopeImageOptions{" + "model='" + model + '\'' + ", n=" + n + ", width=" + width + ", height="
+ height + ", size='" + size + '\'' + ", style='" + style + '\'' + ", seed=" + seed + ", refImg='"
+ refImg + '\'' + ", refStrength=" + refStrength + ", refMode='" + refMode + '\'' + ", negativePrompt='"
+ negativePrompt + '\'' + '}';
}

public static class Builder {

private final DashScopeImageOptions options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,30 @@ public enum DocumentParser {

TEXT_PARSER("text", "org.springframework.ai.reader.TextReader", "Plain Text Document Reader"),
JSON_PARSER("json", "org.springframework.ai.reader.JsonReader", "Plain Json Document Reader"),
PAGE_PDF_PARSER("pagePdf", "org.springframework.ai.reader.pdf.PagePdfDocumentReader", "Groups the parsed PDF pages into {@link Document}s. You can group one or more pages into a single output document."),
PARAGRAPH_PDF_PARSER("paragraphPdf", "org.springframework.ai.reader.pdf.ParagraphPdfDocumentReader", "Uses the PDF catalog (e.g. TOC) information to split the input PDF into text paragraphs and output a single {@link Document} per paragraph"),
DOCX_PARSER("docx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
DOC_PARSER("doc", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPTX_PARSER("pptx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPT_PARSER("ppt", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLSX_PARSER("xlsx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLS_PARSER("xls", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader", "Parses Microsoft Office file into a {@link Document} using Apache POI library"),
CSV_PARSER("csv", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from CSV documents."),
HTML_PARSER("html", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from HTML documents."),
XML_PARSER("xml", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from XML documents."),
RTF_PARSER("rtf", "org.springframework.ai.reader.tika.TikaDocumentReader", "Uses tika to extract text content from Rich Text Format (RTF) documents.");
PAGE_PDF_PARSER("pagePdf", "org.springframework.ai.reader.pdf.PagePdfDocumentReader",
"Groups the parsed PDF pages into {@link Document}s. You can group one or more pages into a single output document."),
PARAGRAPH_PDF_PARSER("paragraphPdf", "org.springframework.ai.reader.pdf.ParagraphPdfDocumentReader",
"Uses the PDF catalog (e.g. TOC) information to split the input PDF into text paragraphs and output a single {@link Document} per paragraph"),
DOCX_PARSER("docx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
DOC_PARSER("doc", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPTX_PARSER("pptx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
PPT_PARSER("ppt", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLSX_PARSER("xlsx", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
XLS_PARSER("xls", "com.alibaba.cloud.ai.reader.poi.PoiDocumentReader",
"Parses Microsoft Office file into a {@link Document} using Apache POI library"),
CSV_PARSER("csv", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from CSV documents."),
HTML_PARSER("html", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from HTML documents."),
XML_PARSER("xml", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from XML documents."),
RTF_PARSER("rtf", "org.springframework.ai.reader.tika.TikaDocumentReader",
"Uses tika to extract text content from Rich Text Format (RTF) documents.");

private final String parserType;

Expand Down
2 changes: 0 additions & 2 deletions spring-ai-alibaba-examples/chatmodel-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<maven-deploy-plugin.version>3.1.1</maven-deploy-plugin.version>

<!-- Spring AI -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import jakarta.servlet.http.HttpServletResponse;

import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.http.MediaType;
Expand All @@ -46,11 +44,14 @@ public class ImageModelController {
@GetMapping("/image/{input}")
public void image(@PathVariable("input") String input, HttpServletResponse response) {

ImageOptions options = ImageOptionsBuilder.builder()
.withModel("wanx-v1")
.build();
// The options parameter set in this way takes precedence over the parameters in the yaml configuration file.
// The default image model is wanx-v1
// ImageOptions options = ImageOptionsBuilder.builder()
// .withModel("wax-2")
// .build();
// ImagePrompt imagePrompt = new ImagePrompt(input, options);

ImagePrompt imagePrompt = new ImagePrompt(input, options);
ImagePrompt imagePrompt = new ImagePrompt(input);
ImageResponse imageResponse = imageModel.call(imagePrompt);
String imageUrl = imageResponse.getResult().getOutput().getUrl();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ spring:
ai:
dashscope:
api-key: ${AI_DASHSCOPE_API_KEY}
image:
options:
model: wanx-v1

0 comments on commit 0420bab

Please sign in to comment.