Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAlex API support #135

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ export OPENAI_API_KEY="YOUR KEY HERE"
export S2_API_KEY="YOUR KEY HERE"
```

#### OpenAlex API (Literature Search Alternative)

OpenAlex API can be used as an alternative if you do not have a Semantic Scholar API Key.
OpenAlex does not require API key.

```bash
pip install pyalex
export OPENALEX_MAIL_ADDRESS="YOUR EMAIL ADDRESS"
```

And specify `--engine openalex` when you execute the AI Scientist code.

Note that this is experimental for those who do not have a Semantic Scholar API Key.

### Setup NanoGPT

Here, and below, we give instructions for setting up the data and baseline evaluations for each template. You can only run setup steps for templates you are interested in. This is necessary to run on your machine as training times may vary depending on your hardware.
Expand Down
91 changes: 68 additions & 23 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,75 @@ def on_backoff(details):
@backoff.on_exception(
backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
)
def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]:
def search_for_papers(query, result_limit=10, engine="semanticscholar") -> Union[None, List[Dict]]:
if not query:
return None
rsp = requests.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
headers={"X-API-KEY": S2_API_KEY},
params={
"query": query,
"limit": result_limit,
"fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
},
)
print(f"Response Status Code: {rsp.status_code}")
print(
f"Response Content: {rsp.text[:500]}"
) # Print the first 500 characters of the response content
rsp.raise_for_status()
results = rsp.json()
total = results["total"]
time.sleep(1.0)
if not total:
return None
if engine == "semanticscholar":
rsp = requests.get(
"https://api.semanticscholar.org/graph/v1/paper/search",
headers={"X-API-KEY": S2_API_KEY} if S2_API_KEY else {},
params={
"query": query,
"limit": result_limit,
"fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
},
)
print(f"Response Status Code: {rsp.status_code}")
print(
f"Response Content: {rsp.text[:500]}"
) # Print the first 500 characters of the response content
rsp.raise_for_status()
results = rsp.json()
total = results["total"]
time.sleep(1.0)
if not total:
return None

papers = results["data"]
return papers
elif engine == "openalex":
import pyalex
from pyalex import Work, Works
mail = os.environ.get("OPENALEX_MAIL_ADDRESS", None)
if mail is None:
print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better access to OpenAlex API!")
else:
pyalex.config.email = mail

def extract_info_from_work(work: Work, max_abstract_length: int = 1000) -> dict[str, str]:
# "Unknown" is returned when venue is unknown...
venue = "Unknown"
for i, location in enumerate(work["locations"]):
if location["source"] is not None:
venue = location["source"]["display_name"]
if venue != "":
break
title = work["title"]
abstract = work["abstract"]
if abstract is None:
abstract = ""
if len(abstract) > max_abstract_length:
# To avoid context length exceed error.
print(f"[WARNING] {title=}: {len(abstract)=} is too long! Use first {max_abstract_length} chars.")
abstract = abstract[:max_abstract_length]
authors_list = [author["author"]["display_name"] for author in work["authorships"]]
authors = " and ".join(authors_list) if len(authors_list) < 20 else f"{authors_list[0]} et al."
paper = dict(
title=title,
authors=authors,
venue=venue,
year=work["publication_year"],
abstract=abstract,
citationCount=work["cited_by_count"],
)
return paper

works: List[Dict] = Works().search(query).get(per_page=result_limit)
papers: List[Dict[str, str]] = [extract_info_from_work(work) for work in works]
return papers
else:
raise NotImplementedError(f"{engine=} not supported!")

papers = results["data"]
return papers


novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
Expand Down Expand Up @@ -363,6 +407,7 @@ def check_idea_novelty(
client,
model,
max_num_iterations=10,
engine="semanticscholar"
):
with open(osp.join(base_dir, "experiment.py"), "r") as f:
code = f.read()
Expand Down Expand Up @@ -413,7 +458,7 @@ def check_idea_novelty(

## SEARCH FOR PAPERS
query = json_output["Query"]
papers = search_for_papers(query, result_limit=10)
papers = search_for_papers(query, result_limit=10, engine=engine)
if papers is None:
papers_str = "No papers found."

Expand Down
17 changes: 12 additions & 5 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def compile_latex(cwd, pdf_file, timeout=30):


def get_citation_aider_prompt(
client, model, draft, current_round, total_rounds
client, model, draft, current_round, total_rounds, engine="semanticscholar"
) -> Tuple[Optional[str], bool]:
msg_history = []
try:
Expand All @@ -314,7 +314,7 @@ def get_citation_aider_prompt(
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
query = json_output["Query"]
papers = search_for_papers(query)
papers = search_for_papers(query, engine=engine)
except Exception as e:
print(f"Error: {e}")
return None, False
Expand Down Expand Up @@ -398,7 +398,7 @@ def get_citation_aider_prompt(

# PERFORM WRITEUP
def perform_writeup(
idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20
idea, folder_name, coder, cite_client, cite_model, num_cite_rounds=20, engine="semanticscholar"
):
# CURRENTLY ASSUMES LATEX
abstract_prompt = f"""We've provided the `latex/template.tex` file to the project. We will be filling it in section by section.
Expand Down Expand Up @@ -465,7 +465,7 @@ def perform_writeup(
with open(osp.join(folder_name, "latex", "template.tex"), "r") as f:
draft = f.read()
prompt, done = get_citation_aider_prompt(
cite_client, cite_model, draft, _, num_cite_rounds
cite_client, cite_model, draft, _, num_cite_rounds, engine=engine
)
if done:
break
Expand Down Expand Up @@ -542,6 +542,13 @@ def perform_writeup(
],
help="Model to use for AI Scientist.",
)
parser.add_argument(
"--engine",
type=str,
default="semanticscholar",
choices=["semanticscholar", "openalex"],
help="Scholar engine to use.",
)
args = parser.parse_args()
if args.model == "claude-3-5-sonnet-20240620":
import anthropic
Expand Down Expand Up @@ -627,6 +634,6 @@ def perform_writeup(
generate_latex(coder, args.folder, f"{args.folder}/test.pdf")
else:
try:
perform_writeup(idea, folder_name, coder, client, client_model)
perform_writeup(idea, folder_name, coder, client, client_model, engine=args.engine)
except Exception as e:
print(f"Failed to perform writeup: {e}")
26 changes: 18 additions & 8 deletions launch_scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def parse_arguments():
default=50,
help="Number of ideas to generate",
)
parser.add_argument(
"--engine",
type=str,
default="semanticscholar",
choices=["semanticscholar", "openalex"],
help="Scholar engine to use.",
)
return parser.parse_args()


Expand Down Expand Up @@ -230,7 +237,7 @@ def do_idea(
edit_format="diff",
)
try:
perform_writeup(idea, folder_name, coder, client, client_model)
perform_writeup(idea, folder_name, coder, client, client_model, engine=args.engine)
except Exception as e:
print(f"Failed to perform writeup: {e}")
return False
Expand Down Expand Up @@ -373,12 +380,14 @@ def do_idea(
max_num_generations=args.num_ideas,
num_reflections=NUM_REFLECTIONS,
)
ideas = check_idea_novelty(
ideas,
base_dir=base_dir,
client=client,
model=client_model,
)
if not args.skip_novelty_check:
ideas = check_idea_novelty(
ideas,
base_dir=base_dir,
client=client,
model=client_model,
engine=args.engine,
)

with open(osp.join(base_dir, "ideas.json"), "w") as f:
json.dump(ideas, f, indent=4)
Expand Down Expand Up @@ -438,5 +447,6 @@ def do_idea(
print(f"Completed idea: {idea['Name']}, Success: {success}")
except Exception as e:
print(f"Failed to evaluate idea {idea['Name']}: {str(e)}")

import traceback
print(traceback.format_exc())
print("All ideas evaluated.")