From ec641769f35991d561d11b330b29a7abb2802524 Mon Sep 17 00:00:00 2001 From: corochann Date: Tue, 1 Oct 2024 00:04:10 +0900 Subject: [PATCH 1/2] OpenAlex API support --- README.md | 14 +++++ ai_scientist/generate_ideas.py | 92 ++++++++++++++++++++++++--------- ai_scientist/perform_writeup.py | 17 ++++-- launch_scientist.py | 26 +++++++--- 4 files changed, 113 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index fe570d8b..91525acd 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,20 @@ export OPENAI_API_KEY="YOUR KEY HERE" export S2_API_KEY="YOUR KEY HERE" ``` +#### OpenAlex API (Literature Search Alternative) + +OpenAlex API can be used as an alternative if you do not have a Semantic Scholar API Key. +OpenAlex does not require API key. + +```bash +pip install pyalex +export OPENALEX_MAIL_ADDRESS="YOUR EMAIL ADDRESS" +``` + +And specify `--engine openalex` when you execute the AI Scientist code. + +Note that this is experimental for those who do not have a Semantic Scholar API Key. + ### Setup NanoGPT Here, and below, we give instructions for setting up the data and baseline evaluations for each template. You can only run setup steps for templates you are interested in. This is necessary to run on your machine as training times may vary depending on your hardware. diff --git a/ai_scientist/generate_ideas.py b/ai_scientist/generate_ideas.py index a8feedfe..df444dfb 100644 --- a/ai_scientist/generate_ideas.py +++ b/ai_scientist/generate_ideas.py @@ -281,31 +281,76 @@ def on_backoff(details): @backoff.on_exception( backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff ) -def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]: +def search_for_papers(query, result_limit=10, engine="semanticscholar") -> Union[None, List[Dict]]: if not query: return None - rsp = requests.get( - "https://api.semanticscholar.org/graph/v1/paper/search", - headers={"X-API-KEY": S2_API_KEY}, - params={ - "query": query, - "limit": result_limit, - "fields": "title,authors,venue,year,abstract,citationStyles,citationCount", - }, - ) - print(f"Response Status Code: {rsp.status_code}") - print( - f"Response Content: {rsp.text[:500]}" - ) # Print the first 500 characters of the response content - rsp.raise_for_status() - results = rsp.json() - total = results["total"] - time.sleep(1.0) - if not total: - return None + if engine == "semanticscholar": + rsp = requests.get( + "https://api.semanticscholar.org/graph/v1/paper/search", + headers={"X-API-KEY": S2_API_KEY} if S2_API_KEY else {}, + params={ + "query": query, + "limit": result_limit, + "fields": "title,authors,venue,year,abstract,citationStyles,citationCount", + }, + ) + print(f"Response Status Code: {rsp.status_code}") + print( + f"Response Content: {rsp.text[:500]}" + ) # Print the first 500 characters of the response content + rsp.raise_for_status() + results = rsp.json() + total = results["total"] + time.sleep(1.0) + if not total: + return None + + papers = results["data"] + return papers + elif engine == "openalex": + import pyalex + from pyalex import Work, Works + mail = os.environ.get("OPENALEX_MAIL_ADDRESS", None) + if mail is None: + print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better access to OpenAlex API!") + else: + pyalex.config.email = mail + + def extract_info_from_work(work: Work, max_abstract_length: int = 1000) -> dict[str, str]: + # "Unknown" is returned when venue is unknown... + venue = "Unknown" + for i, location in enumerate(work["locations"]): + if location["source"] is not None: + venue = location["source"]["display_name"] + if venue != "": + # print(f"[DEBUG] {i=} {venue=} found!") + break + title = work["title"] + abstract = work["abstract"] + if abstract is None: + abstract = "" + if len(abstract) > max_abstract_length: + # To avoid context length exceed error. + print(f"[WARNING] {title=}: {len(abstract)=} is too long! Use first {max_abstract_length} chars.") + abstract = abstract[:max_abstract_length] + authors_list = [author["author"]["display_name"] for author in work["authorships"]] + authors = " and ".join(authors_list) if len(authors_list) < 20 else f"{authors_list[0]} et al." + paper = dict( + title=title, + authors=authors, + venue=venue, + year=work["publication_year"], + abstract=abstract, + citationCount=work["cited_by_count"], + ) + return paper + + works: List[Dict] = Works().search(query).get(per_page=result_limit) + papers: List[Dict[str, str]] = [extract_info_from_work(work) for work in works] + return papers + else: + raise NotImplementedError(f"{engine=} not supported!") - papers = results["data"] - return papers novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field. @@ -363,6 +408,7 @@ def check_idea_novelty( client, model, max_num_iterations=10, + engine="semanticscholar" ): with open(osp.join(base_dir, "experiment.py"), "r") as f: code = f.read() @@ -413,7 +459,7 @@ def check_idea_novelty( ## SEARCH FOR PAPERS query = json_output["Query"] - papers = search_for_papers(query, result_limit=10) + papers = search_for_papers(query, result_limit=10, engine=engine) if papers is None: papers_str = "No papers found." diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py index c32565e9..c42fb39a 100644 --- a/ai_scientist/perform_writeup.py +++ b/ai_scientist/perform_writeup.py @@ -293,7 +293,7 @@ def compile_latex(cwd, pdf_file, timeout=30): def get_citation_aider_prompt( - client, model, draft, current_round, total_rounds + client, model, draft, current_round, total_rounds, engine="semanticscholar" ) -> Tuple[Optional[str], bool]: msg_history = [] try: @@ -314,7 +314,7 @@ def get_citation_aider_prompt( json_output = extract_json_between_markers(text) assert json_output is not None, "Failed to extract JSON from LLM output" query = json_output["Query"] - papers = search_for_papers(query) + papers = search_for_papers(query, engine=engine) except Exception as e: print(f"Error: {e}") return None, False @@ -398,7 +398,7 @@ def get_citation_aider_prompt( # PERFORM WRITEUP def perform_writeup( - idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20 + idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20, engine="semanticscholar" ): # CURRENTLY ASSUMES LATEX abstract_prompt = f"""We've provided the `latex/template.tex` file to the project. We will be filling it in section by section. @@ -465,7 +465,7 @@ def perform_writeup( with open(osp.join(folder_name, "latex", "template.tex"), "r") as f: draft = f.read() prompt, done = get_citation_aider_prompt( - cite_client, cite_model, draft, _, num_cite_rounds + cite_client, cite_model, draft, _, num_cite_rounds, engine=engine ) if done: break @@ -542,6 +542,13 @@ def perform_writeup( ], help="Model to use for AI Scientist.", ) + parser.add_argument( + "--engine", + type=str, + default="semanticscholar", + choices=["semanticscholar", "openalex"], + help="Scholar engine to use.", + ) args = parser.parse_args() if args.model == "claude-3-5-sonnet-20240620": import anthropic @@ -627,6 +634,6 @@ def perform_writeup( generate_latex(coder, args.folder, f"{args.folder}/test.pdf") else: try: - perform_writeup(idea, folder_name, coder, client, client_model) + perform_writeup(idea, folder_name, coder, client, client_model, engine=args.engine) except Exception as e: print(f"Failed to perform writeup: {e}") diff --git a/launch_scientist.py b/launch_scientist.py index 489c7fc8..313f1f89 100644 --- a/launch_scientist.py +++ b/launch_scientist.py @@ -95,6 +95,13 @@ def parse_arguments(): default=50, help="Number of ideas to generate", ) + parser.add_argument( + "--engine", + type=str, + default="semanticscholar", + choices=["semanticscholar", "openalex"], + help="Scholar engine to use.", + ) return parser.parse_args() @@ -230,7 +237,7 @@ def do_idea( edit_format="diff", ) try: - perform_writeup(idea, folder_name, coder, client, client_model) + perform_writeup(idea, folder_name, coder, client, client_model, engine=args.engine) except Exception as e: print(f"Failed to perform writeup: {e}") return False @@ -373,12 +380,14 @@ def do_idea( max_num_generations=args.num_ideas, num_reflections=NUM_REFLECTIONS, ) - ideas = check_idea_novelty( - ideas, - base_dir=base_dir, - client=client, - model=client_model, - ) + if not args.skip_novelty_check: + ideas = check_idea_novelty( + ideas, + base_dir=base_dir, + client=client, + model=client_model, + engine=args.engine, + ) with open(osp.join(base_dir, "ideas.json"), "w") as f: json.dump(ideas, f, indent=4) @@ -438,5 +447,6 @@ def do_idea( print(f"Completed idea: {idea['Name']}, Success: {success}") except Exception as e: print(f"Failed to evaluate idea {idea['Name']}: {str(e)}") - + import traceback + print(traceback.format_exc()) print("All ideas evaluated.") From 17d5c0896a1e50f026f9861fd5856af94c119303 Mon Sep 17 00:00:00 2001 From: corochann Date: Tue, 1 Oct 2024 00:13:49 +0900 Subject: [PATCH 2/2] OpenAlex API support --- ai_scientist/generate_ideas.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ai_scientist/generate_ideas.py b/ai_scientist/generate_ideas.py index df444dfb..b86bb323 100644 --- a/ai_scientist/generate_ideas.py +++ b/ai_scientist/generate_ideas.py @@ -323,7 +323,6 @@ def extract_info_from_work(work: Work, max_abstract_length: int = 1000) -> dict[ if location["source"] is not None: venue = location["source"]["display_name"] if venue != "": - # print(f"[DEBUG] {i=} {venue=} found!") break title = work["title"] abstract = work["abstract"]