Skip to content

Commit

Permalink
refactor(2671): Moved retry logic from infer_type_name to wizard
Browse files Browse the repository at this point in the history
  • Loading branch information
harshtech123 committed Aug 13, 2024
1 parent c7c0b23 commit 71b6927
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 97 deletions.
53 changes: 11 additions & 42 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ impl InferTypeName {
pub fn new(secret: Option<String>) -> InferTypeName {
Self { secret }
}

pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let secret = self.secret.as_ref().map(|s| s.to_owned());

Expand All @@ -91,59 +92,27 @@ impl InferTypeName {
.filter(|(type_name, _)| !config.is_root_operation_type(type_name))
.collect::<Vec<_>>();

let total = types_to_be_processed.len();
for (i, (type_name, type_)) in types_to_be_processed.into_iter().enumerate() {
// convert type to sdl format.
for (type_name, type_) in types_to_be_processed.into_iter() {
let question = Question {
fields: type_
.fields
.iter()
.map(|(k, v)| (k.clone(), v.type_of.clone()))
@@ -102,48 +98,19 @@ impl InferTypeName {
.collect(),
};

let mut delay = 3;
loop {
let answer = wizard.ask(question.clone()).await;
match answer {
Ok(answer) => {
let name = &answer.suggestions.join(", ");
for name in answer.suggestions {
if config.types.contains_key(&name)
|| new_name_mappings.contains_key(&name)
{
continue;
}
match wizard.ask(question).await {
Ok(answer) => {
for name in answer.suggestions {
if !config.types.contains_key(&name)
&& !new_name_mappings.contains_key(&name) {
new_name_mappings.insert(name, type_name.to_owned());
break;
}
tracing::info!(
"Suggestions for {}: [{}] - {}/{}",
type_name,
name,
i + 1,
total
);

// TODO: case where suggested names are already used, then extend the base
// question with `suggest different names, we have already used following
// names: [names list]`
break;
}
Err(e) => {
// TODO: log errors after certain number of retries.
if let Error::GenAI(_) = e {
// TODO: retry only when it's required.
tracing::warn!(
"Unable to retrieve a name for the type '{}'. Retrying in {}s",
type_name,
delay
);
tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
delay *= std::cmp::min(delay * 2, 60);
}
}
}
Err(e) => {
tracing::error!("Failed to retrieve a name for the type '{}': {:?}", type_name, e);
}
}
}

Expand Down
40 changes: 29 additions & 11 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use derive_setters::Setters;
use tokio_retry::strategy::{ExponentialBackoff, jitter};
use tokio_retry::Retry;
use genai::adapter::AdapterKind;
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::resolver::AuthResolver;
use genai::Client;

use super::Error;
use super::Result;
use crate::cli::llm::model::Model;
use derive_setters::Setters;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
Expand Down Expand Up @@ -39,15 +41,31 @@ impl<Q, A> Wizard<Q, A> {
}
}

pub async fn ask(&self, q: Q) -> Result<A>
pub async fn ask_with_retry(&self, q: Q) -> Result<A>
where
Q: TryInto<ChatRequest, Error = super::Error>,
A: TryFrom<ChatResponse, Error = super::Error>,
Q: TryInto<ChatRequest, Error = Error>,
A: TryFrom<ChatResponse, Error = Error>,
{
let response = self
.client
.exec_chat(self.model.as_str(), q.try_into()?, None)
.await?;
A::try_from(response)
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(5);

let retry_future = Retry::spawn(strategy, || async {
let response = self
.client
.exec_chat(self.model.as_str(), q.clone().try_into()?, None)
.await;

match response {
Ok(res) => {
if res.status_code() == 429 {
Err(Error::GenAI("API rate limit exceeded".into()))
} else {
A::try_from(res).map_err(|e| Error::GenAI(e.to_string()))
}
},
Err(e) => Err(Error::GenAI(e.to_string())),
}
});

retry_future.await
}
}
}
6 changes: 3 additions & 3 deletions tailcall-fixtures/fixtures/configs/yaml-recursive-input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ types:
type: Bar
graphql:
args:
- key: baz
value: '{{.args.baz}}'
- key: baz
value: "{{.args.baz}}"
baseURL: http://localhost
name: bars
Foo:
fields:
name:
type: String
type: String
81 changes: 40 additions & 41 deletions tailcall-wasm/example/browser/index.html
Original file line number Diff line number Diff line change
@@ -1,48 +1,47 @@
<!doctype html>
<html lang="en-US">
<head>
<meta charset="utf-8" />
<title>hello-wasm example</title>
</head>
<body>
<head>
<meta charset="utf-8" />
<title>hello-wasm example</title>
</head>
<body>
<div id="content">
<!-- Adding input field and button -->
<label for="queryInput"></label><input type="text" id="queryInput" placeholder="Enter your query here" />
<button id="btn">Run Query</button>
<p id="result"></p>
</div>

<div id="content">
<!-- Adding input field and button -->
<label for="queryInput"></label><input type="text" id="queryInput" placeholder="Enter your query here" />
<button id="btn">Run Query</button>
<p id="result"></p>
</div>
<script type="module">
import init, {TailcallBuilder} from "../../browser/pkg/tailcall_wasm.js"
await init()

<script type="module">
import init, { TailcallBuilder } from "../../browser/pkg/tailcall_wasm.js";
await init();
let executor // Making executor accessible

let executor; // Making executor accessible
async function setup() {
try {
const urlParams = new URLSearchParams(window.location.search)
let schemaUrl = urlParams.get("config")

async function setup() {
try {
const urlParams = new URLSearchParams(window.location.search);
let schemaUrl = urlParams.get("config");

let builder = new TailcallBuilder();
builder = await builder.with_config(schemaUrl);
executor = await builder.build();
let btn = document.getElementById("btn");
btn.addEventListener("click", runQuery);
} catch (error) {
alert("error: " + error);
}
}
async function runQuery() {
let query = document.getElementById("queryInput").value;
try {
document.getElementById("result").textContent = await executor.execute(query);
} catch (error) {
console.error("Error executing query: " + error);
document.getElementById("result").textContent = "Error: " + error;
}
}
setup();
</script>
</body>
let builder = new TailcallBuilder()
builder = await builder.with_config(schemaUrl)
executor = await builder.build()
let btn = document.getElementById("btn")
btn.addEventListener("click", runQuery)
} catch (error) {
alert("error: " + error)
}
}
async function runQuery() {
let query = document.getElementById("queryInput").value
try {
document.getElementById("result").textContent = await executor.execute(query)
} catch (error) {
console.error("Error executing query: " + error)
document.getElementById("result").textContent = "Error: " + error
}
}
setup()
</script>
</body>
</html>

0 comments on commit 71b6927

Please sign in to comment.