From 499868637650dfbeb48bfd3feb5ba3ff865a6c73 Mon Sep 17 00:00:00 2001 From: Prashanth R Date: Mon, 8 Jul 2024 23:04:08 -0700 Subject: [PATCH] Add Data LLM (#269) * Add Data LLM * Various API changes * Drop underscore * Rename * Update imports * PR comments * PR comments * PR comments * PR comments * PR Comments --- llm/README.md | 17 ++ llm/data_gemma/__init__.py | 36 +++ llm/data_gemma/base.py | 180 +++++++++++++ llm/data_gemma/baseline.py | 45 ++++ llm/data_gemma/datacommons.py | 172 +++++++++++++ llm/data_gemma/google_api.py | 125 +++++++++ llm/data_gemma/prompts.py | 468 ++++++++++++++++++++++++++++++++++ llm/data_gemma/rag.py | 130 ++++++++++ llm/data_gemma/rig.py | 190 ++++++++++++++ llm/data_gemma/utils.py | 109 ++++++++ llm/data_gemma/validate.py | 83 ++++++ llm/setup.py | 52 ++++ 12 files changed, 1607 insertions(+) create mode 100644 llm/README.md create mode 100644 llm/data_gemma/__init__.py create mode 100644 llm/data_gemma/base.py create mode 100644 llm/data_gemma/baseline.py create mode 100644 llm/data_gemma/datacommons.py create mode 100644 llm/data_gemma/google_api.py create mode 100644 llm/data_gemma/prompts.py create mode 100644 llm/data_gemma/rag.py create mode 100644 llm/data_gemma/rig.py create mode 100644 llm/data_gemma/utils.py create mode 100644 llm/data_gemma/validate.py create mode 100644 llm/setup.py diff --git a/llm/README.md b/llm/README.md new file mode 100644 index 0000000..a4089a8 --- /dev/null +++ b/llm/README.md @@ -0,0 +1,17 @@ +# Data LLM + +This directory contains code to inference on LLMs integrated with Data Commons. + +It includes a python package called `data_gemma` that can be used for doing +inference with the Gemma 2 (or other) LLMs fine-tuned for integration with Data +Commons. + +## Install `data_gemma` + +``` +pip install git+https://github.com/datacommonsorg/tools#subdirectory=llm +``` + +## Examples + +TODO: Add links diff --git a/llm/data_gemma/__init__.py b/llm/data_gemma/__init__.py new file mode 100644 index 0000000..17c5651 --- /dev/null +++ b/llm/data_gemma/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from data_gemma import base +from data_gemma import baseline +from data_gemma import datacommons +from data_gemma import google_api +from data_gemma import rag +from data_gemma import rig + +# LLM related classes. +LLM = base.LLM +LLMCall = base.LLMCall +GoogleAIStudio = google_api.GoogleAIStudio + +# Data Commons related classes. +DataCommons = datacommons.DataCommons +DataCommonsCall = base.DataCommonsCall + +# Flow related classes. +Flow = base.Flow +FlowResponse = base.FlowResponse +BaselineFlow = baseline.BaselineFlow +RAGFlow = rag.RAGFlow +RIGFlow = rig.RIGFlow diff --git a/llm/data_gemma/base.py b/llm/data_gemma/base.py new file mode 100644 index 0000000..e3aa5f9 --- /dev/null +++ b/llm/data_gemma/base.py @@ -0,0 +1,180 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base Types.""" + +import dataclasses +from typing import Any, Protocol + + +DC = '__DC__' + + +@dataclasses.dataclass(frozen=True) +class Options: + """Common options for all APIs.""" + # Print messages to stdout. + verbose: bool = True + + def vlog(self, msg: str) -> None: + if self.verbose: + print(msg) + + +@dataclasses.dataclass(frozen=True) +class LLMCall: + prompt: str + response: str + duration_secs: float + error: str | None = None + + def debug(self, i: int = 0) -> str: + return ( + f'\n### Prompt {i} ###\n{self.prompt}\n' + f'### Response {i} ###\n{self.response}\n' + f'### LLM Duration {i} {self.duration_secs}s ###\n' + ) + + +@dataclasses.dataclass +class DataCommonsCall: + """A single request and response from Data Commons.""" + + id: int = 0 + query: str = '' + + # For point: val and date is set + val: str = '' + date: str = '' + # For table: table is set + table: str = '' + + unit: str = '' + title: str = '' + src: str = '' + url: str = '' + var: str = '' + score: float = 0.0 + + # The original LLM Value in case of RIG. + llm_val: str = '' + + def footnote(self) -> str: + return ( + f'Per {self.src}, value was {self.val}{self._dunit()} in {self.date}.' + f' See more at {self.url}' + ) + + def debug(self) -> str: + if not self.title: + return '' + if self.table: + return self.answer() + return ( + f'"{self.title}" was {self.val}{self._dunit()} in' + f' {self.date} per {self.src} ({self.var}:{self.score})' + ) + + def answer(self) -> str: + if self.table: + return f'{self.header()}\n{self.table}' + else: + return ( + f'According to {self.src}, "{self.title}" was' + f' {self.val}{self._dunit()} in {self.date}.' + ) + + def header(self) -> str: + if self.table: + if self.unit: + header = f'{self.title} (unit: {self.unit})' + else: + header = f'{self.title}' + return f'{header}, according to {self.src}' + + return self.title + + def val_and_unit(self) -> str: + return f'{self.val}{self._dunit()}' + + def _dunit(self) -> str: + return ' ' + self.unit if self.unit else '' + + +@dataclasses.dataclass(frozen=True) +class FlowResponse: + """A response from Flow.""" + + main_text: str = '' + footnotes: str = '' + tables_str: str = '' + + llm_calls: list[LLMCall] = dataclasses.field(default_factory=list) + dc_calls: list[DataCommonsCall] = dataclasses.field(default_factory=list) + dc_duration_secs: float = 0.0 + + def duration_secs(self) -> float: + return ( + sum([r.duration_secs for r in self.llm_calls]) + self.dc_duration_secs + ) + + def answer(self, include_aux: bool = True) -> str: + """Returns a string representation of the response.""" + + lines = [] + lines.append(self.main_text) + + if include_aux and self.footnotes: + lines.append('\n#### FOOTNOTES ####') + lines.append(self.footnotes) + + if include_aux and self.tables_str: + lines.append('\n#### TABLES ####') + lines.append(self.tables_str) + + return '\n'.join(lines) + + def debug(self) -> str: + """Returns a string representation of the response.""" + + lines = [] + lines.append('\n\n## LLM INTERACTIONS ##\n') + for i, llm_response in enumerate(self.llm_calls): + lines.append(llm_response.debug(i)) + + lines.append('\n\n## DC INTERACTIONS ##\n') + for dc_response in self.dc_calls: + dbg = dc_response.debug() + if dbg: + lines.append(dbg) + lines.append(f'\n\n## DC Duration {self.dc_duration_secs} ##') + lines.append(f'\n\n## Total Duration {self.duration_secs()} ##') + + return '\n'.join(lines) + + def json(self) -> dict[str, Any]: + return dataclasses.asdict(self) + + +class LLM(Protocol): + + def query(self, prompt: str) -> LLMCall: + ... + + +class Flow(Protocol): + """A Flow integrates LLMs with DC in a certain way.""" + + def query(self, query: str) -> FlowResponse: + ... diff --git a/llm/data_gemma/baseline.py b/llm/data_gemma/baseline.py new file mode 100644 index 0000000..5cdf6ba --- /dev/null +++ b/llm/data_gemma/baseline.py @@ -0,0 +1,45 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Basic Flow.""" + +from data_gemma import base + + +class BaselineFlow(base.Flow): + """Baseline Flow.""" + + def __init__( + self, + llm: base.LLM, + verbose: bool = True, + ): + self.llm = llm + self.options = base.Options(verbose=verbose) + + def query( + self, + query: str, + in_context: bool = False, + prompt1: str = '', + prompt2: str = '', + ) -> base.FlowResponse: + self.options.vlog('... [DEFAULT] Calling BASE model') + resp = self.llm.query(query) + return base.FlowResponse( + main_text=resp.response, + llm_calls=[resp], + dc_duration_secs=0, + dc_calls=[], + ) diff --git a/llm/data_gemma/datacommons.py b/llm/data_gemma/datacommons.py new file mode 100644 index 0000000..ff7e4ab --- /dev/null +++ b/llm/data_gemma/datacommons.py @@ -0,0 +1,172 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data Commons.""" + +import concurrent.futures +import csv +import io +from typing import Any, Callable + +import requests + +from data_gemma import base +from data_gemma import utils + + +_BASE_URL = 'https://{env}.datacommons.org/nodejs/query' + +# Do not allow topics, use higher threshold (0.8). +_POINT_PARAMS = 'allCharts=1&mode=toolformer_rig&idx=base_uae_mem' + +# Allow topics, use lower threshold (0.7) +_TABLE_PARAMS = 'allCharts=1&mode=toolformer_rag&client=table&idx=base_uae_mem' + + +class DataCommons: + """Data Commons.""" + + def __init__( + self, + api_key: str, + verbose: bool = True, + num_threads: int = 10, + env: str = 'dev', + session: requests.Session | None = None, + ): + self.options = base.Options(verbose=verbose) + self.num_threads = num_threads + self.env = env + self.api_key = api_key + if not session: + session = requests.Session() + self.session = session + + def point(self, query: str) -> base.DataCommonsCall: + """Calls Data Commons API.""" + + self.options.vlog(f'... calling DC with "{query}"') + response = self._call_api(query, _POINT_PARAMS) + # Get the first LINE chart. + chart = None + for c in response.get('charts', []): + if c.get('type') == 'LINE': + chart = c + break + if not chart: + return base.DataCommonsCall(query=query) + + v = str(chart.get('highlight', {}).get('value', '')) + v = utils.round_float(v) + if not v: + return base.DataCommonsCall(query=query) + + u = chart.get('unit', '') + d = chart.get('highlight', {}).get('date') + s = _src(chart) + t = chart.get('title', '') + + svm = response.get('debug', {}).get('debug', {}).get('sv_matching', {}) + score = svm.get('CosineScore', [-1])[0] + var = svm.get('SV', [''])[0] + url = chart.get('dcUrl', '') + return base.DataCommonsCall( + query=query, + val=v, + unit=u, + title=t, + date=d, + src=s, + url=url, + var=var, + score=score, + ) + + def table(self, query: str) -> base.DataCommonsCall: + """Calls Data Commons API.""" + + self.options.vlog(f'... calling DC for table with "{query}"') + response = self._call_api(query, _TABLE_PARAMS) + # Get the first chart. + charts = response.get('charts') + if not charts: + return base.DataCommonsCall(query=query) + chart = charts[0] + + data_csv = chart.get('data_csv', '') + rows = list(csv.reader(io.StringIO(data_csv))) + if not data_csv or not rows: + return base.DataCommonsCall(query=query) + + u = chart.get('unit', '') + s = _src(chart) + t = chart.get('title', '') + + parts = [] + parts.append(' | '.join(rows[0])) + parts.append('-' * len(parts[-1])) + for row in rows[1:]: + row = [utils.round_float(v) for v in row] + parts.append(' | '.join(row)) + parts.append('\n') + table_str = '\n'.join(parts) + + svm = response.get('debug', {}).get('debug', {}).get('sv_matching', {}) + score = svm.get('CosineScore', [-1])[0] + var = svm.get('SV', [''])[0] + url = chart.get('dcUrl', '') + return base.DataCommonsCall( + query=query, + unit=u, + title=t, + src=s, + table=table_str, + url=url, + var=var, + score=score, + ) + + def calln( + self, queries: list[str], func: Callable[[str], base.DataCommonsCall] + ) -> dict[str, base.DataCommonsCall]: + """Calls Data Commons API in parallel if needed.""" + + if self.num_threads == 1: + results = [func(q) for q in queries] + else: + # TODO: Check why this ~breaks in Colab Borg runtime + with concurrent.futures.ThreadPoolExecutor(self.num_threads) as executor: + futures = [executor.submit(func, query) for query in queries] + results = [f.result() for f in futures] + + q2resp: dict[str, base.DataCommonsCall] = {} + for i, (q, r) in enumerate(zip(queries, results)): + r.id = i + 1 + q2resp[q] = r + return q2resp + + def _call_api(self, query: str, extra_params: str) -> Any: + query = query.strip().replace(' ', '+') + url = _BASE_URL.format(env=self.env) + f'?&q={query}&{extra_params}' + if self.api_key: + url = f'{url}&apikey={self.api_key}' + # print(f'DC: Calling {url}') + return self.session.get(url).json() + + +def _src(chart: dict[str, Any]) -> str: + srcs = chart.get('srcs', [{}]) + if not srcs: + return '' + return srcs[0].get('name', '') diff --git a/llm/data_gemma/google_api.py b/llm/data_gemma/google_api.py new file mode 100644 index 0000000..33acbdc --- /dev/null +++ b/llm/data_gemma/google_api.py @@ -0,0 +1,125 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LLM Interface.""" + +import json +import logging +import time +from typing import Any + +import requests + +from data_gemma import base + + +_REQ_DATA = { + 'contents': [{ + 'parts': [{ + 'text': '', + }], + }], + 'generationConfig': { + 'temperature': 0.1, + }, + 'safetySettings': [ + {'category': 'HARM_CATEGORY_HARASSMENT', 'threshold': 'BLOCK_NONE'}, + {'category': 'HARM_CATEGORY_HATE_SPEECH', 'threshold': 'BLOCK_NONE'}, + { + 'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + 'threshold': 'BLOCK_NONE', + }, + { + 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', + 'threshold': 'BLOCK_NONE', + }, + ], +} + + +class GoogleAIStudio(base.LLM): + """Google AI Studio.""" + + def __init__( + self, + model: str, + api_keys: list[str], + verbose: bool = True, + session: requests.Session | None = None, + ): + self.keys = api_keys + if not session: + session = requests.Session() + self.session: requests.Session = session + self.next_key_idx = 0 + self.options = base.Options(verbose=verbose) + self.model = model + + def query(self, prompt: str) -> base.LLMCall: + req_data = _REQ_DATA.copy() + + # set the params. + req_data['generationConfig']['temperature'] = 0.1 + req_data['contents'][0]['parts'][0]['text'] = prompt + + # Make API request. + req = json.dumps(req_data) + + start = time.time() + self.options.vlog( + f'... calling AIStudio {self.model} "{prompt[:50].strip()}..."' + ) + resp = _call_api(self.session, self.model, self._get_key(), req) + t = round(time.time() - start, 3) + ans = '' + err = '' + if ( + 'candidates' in resp + and resp['candidates'] + and 'content' in resp['candidates'][0] + and 'parts' in resp['candidates'][0]['content'] + and resp['candidates'][0]['content']['parts'] + and 'text' in resp['candidates'][0]['content']['parts'][0] + ): + ans = resp['candidates'][0]['content']['parts'][0]['text'] + elif 'error' not in resp: + err = 'Got empty response' + logging.warning(err) + else: + err = json.dumps(resp) + logging.error('%s', err) + + return base.LLMCall(prompt=prompt, response=ans, duration_secs=t, error=err) + + def _get_key(self): + key = self.keys[self.next_key_idx] + self.next_key_idx += 1 + if self.next_key_idx >= len(self.keys): + self.next_key_idx = 0 + return key + + +_BASE_URL = 'https://generativelanguage.googleapis.com/v1beta/models' +_API_HEADER = {'content-type': 'application/json'} + + +def _call_api( + session: requests.Session, model: str, key: str, req_data: str +) -> Any: + r = session.post( + f'{_BASE_URL}/{model}:generateContent?key={key}', + data=req_data, + headers=_API_HEADER, + ) + return r.json() diff --git a/llm/data_gemma/prompts.py b/llm/data_gemma/prompts.py new file mode 100644 index 0000000..d0f114a --- /dev/null +++ b/llm/data_gemma/prompts.py @@ -0,0 +1,468 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""List of all prompts used in Data Gemma.""" + +RIO_IN_CONTEXT_PROMPT = """ +Your task is to annotate every statistic in the given text with a `__DC__` +query that can retrieve the statistic. The query should be about metrics +on topics like demographics, economy, education, health, and so on that are +associated with geographical places (like USA, California, Miami, etc.). + +Concretely, every occurrence of a statistical value for a metric in a place +should be replaced with `[__DC__("query") --> "stat"]`, where "query" +must include a metric, a place name and optional date. And "stat" is the +statistical value that originally occurred in the text. Do not annotate +values that are dates ("founded in 1760") and ranks ("10th largest by area"). + +The `__DC__()` calls MUST be in place of the statistical value in the text. +And DO NOT modify sentences that have no statistical data. + +Below is an example of an INPUT and the corresponding annotated OUTPUT. + +INPUT: + +Question:- Tell me one statistic about California, San Francisco, Alabama and the US. +Answer:- +California is 1st as the nation's most populous state, with about 39 million people in 2020. +In San Francisco, the diabetes rate is 9.2 cases per 10000 people. +San Francisco and the surrounding San Francisco Bay Area are a global center of economic activity and the arts and sciences. +In 1861, Alabama seceded from the United States to become part of the Confederate States of America. +As of 2022, the United States receives approximately 81% of its energy from fossil fuel and the largest source of the country's energy came from petroleum (35.8%), followed by natural gas (33.4%) and renewable sources (13.3%). + +OUTPUT: + +Question:- Tell me one statistic about California, San Francisco, Alabama and the US. +Answer:- +California is 1st as the nation's most populous state, with about [__DC__("what was the population of california in 2020?") --> "39 million"] people. +In San Francisco, the diabetes rate is [__DC__("what is the prevalence of diabetes in San Francisco?") --> "9.2 cases per 10000 people"]. +San Francisco and the surrounding San Francisco Bay Area are a global center of economic activity and the arts and sciences. +In 1861, Alabama seceded from the United States to become part of the Confederate States of America. +As of 2022, the United States receives approximately [__DC__("what percentage of energy in the United States comes from fossil fuels in 2022?") --> "81%"] of its energy from fossil fuel and the largest source of the country's energy came from petroleum ([__DC__("what percentage of energy in the United States comes from petroleum in 2022?") --> "35.8%"]), followed by natural gas ([__DC__("what percentage of energy in the United States comes from natural gas in 2022?") --> "33.4%"]) and renewable sources ([__DC__("what percentage of energy in the United States comes from renewable sources in 2022?") --> "13.3%"]). + +INPUT: + +Question:- {question} +Answer:- +{answer} + +OUTPUT: + +""" + + +RIG_IN_CONTEXT_PROMPT = """ +Help me prepare a training set for the purpose of fine tuning an LLM with the +intent of looking up recent statistical facts from a knowledge base. + +Specifically, I want you to attempt responding to question normally in a few +paragraphs, using data and numbers when necessary. You MUST annotate the +sub-strings representing the numbers and statistics with questions asked to the +Data Commons API. The responses to those questions from Data Commons should help +fill in the accurate numeric value. + +Every intended occurrence of a numeric fact should be annotated with +`[__DC__("question") --> "answer"]`, where "question" is the question to +the Knowledge Base, and "answer" is a numeric fact with units. + +Below are a couple of examples. + +### Example 1 ### + +``` +Question: Tell me about health outcomes in California + +Answer: +California's health landscape, much like its diverse geography, showcases +contrasting realities. While boasting pockets of excellence, stark disparities +paint a complex picture. Let's delve into the data: + +Strengths: + +* Longevity: Californians enjoy a high life expectancy of [__DC__("What is the life expectancy in California?") --> "81 years"], exceeding the national average. +* Healthy Habits: Low smoking rates ([__DC__("What is the prevalence of smoking in California?") --> "5%"]) and obesity prevalence ([__DC__("What is the prevalence of obesity in California?") --> "23.3%"]) contribute to better health outcomes. +* Safety: Violent crime has significantly decreased ([__DC__("How much has violent crime reduced in California since 1990?") --> "54%"] since 1990), creating safer communities. +* Investment: California spends the most per capita on healthcare nationally, indicating commitment to health. + +Challenges: + +* Inequities: The picture darkens when considering racial and ethnic disparities. + * Black Californians have the lowest life expectancy ([__DC__("What is the life expectancy of black people in California?") --> "75.1 years"]), significantly lower than Asians ([__DC__("What is the life expectancy of asian people in California?") --> "86.3 years"]). + * Communities of color experience higher rates of chronic diseases, premature death, and limited healthcare access. +* Access: Despite high spending, [__DC__("What is the percentage of uninsured population in California?") --> "7.1%"] of residents remain uninsured, exceeding the national average. +* Environment: Air pollution, especially in urban areas, poses a major health risk. + * Droughts and wildfires are becoming increasingly frequent, impacting mental and physical well-being. + +Moving forward: + +* Bridging the healthcare gap: Expanding access to quality care, particularly in underserved communities, is crucial. +* Addressing social determinants: Tackling poverty, education inequality, and poor housing conditions holds the key to lasting change. +* Environmental protection: Combating air pollution and mitigating the effects of climate change are essential for future health. + +California's journey towards a healthier future requires acknowledging and addressing these contrasting realities. +By closing the gap between its strengths and challenges, the Golden State can ensure a brighter future for all its residents, regardless of background or zip code. +``` + +### Example 2 ### + +``` +Question: Tell me about the education system in New York state and the overall school districts. Also, how does it compare with other regions in the US? +Answer: New York State's education system is a complex and diverse landscape, encompassing over [__DC__("How many schools districts are there in New York state?") --> "700"] school districts and serving nearly [__DC__("How many students are served by the schools districts in New York state?") --> "2.5 million"] students. To provide a comprehensive overview, let's delve into key aspects and compare them with national averages: + +Funding: + +* New York ranks among the top states in terms of education spending, with an average expenditure of [__DC__("What is the average education spending per pupil in New York?") --> "$23,000"] per pupil, significantly higher than the national average of [__DC__("What is the average education spending per pupil in the US?") --> "$12,000"]. + +Student Performance: + +* New York students consistently perform above the national average on standardized tests. In 2019, [__DC__("What percentage of New York students are proficient in math in 2019?") --> "62%"] of students were proficient in math, compared to the national average of [__DC__("What percentage of US students are proficient in math in 2019?") --> "37%"]. +* Similarly, [__DC__("What percentage of New York students are proficient in reading?") --> "66%"] of students were proficient in reading, exceeding the national average of [__DC__("What percentage of US students are proficient in reading?") --> "35%"]. + +Graduation Rates: + +* New York's graduation rate has steadily increased over the past decade, reaching [__DC__("What is the graduation rate in New York in 2019?") --> "85%"] in 2019. This surpasses the national average of [__DC__("What is the graduation rate in the US?") --> "84%"]. + +Teacher Quality: + +* New York has a rigorous teacher certification process, ensuring that educators meet high standards. The state also invests in professional development opportunities for teachers, contributing to their effectiveness. + +Challenges: + +* Despite these strengths, New York faces challenges, including persistent achievement gaps between different student groups and a shortage of qualified teachers in certain subjects. + +Comparison with Other Regions: + +* New York's education system compares favorably with other regions in the US. Its funding levels, student performance, and graduation rates are generally higher than the national average. +* However, there is still room for improvement, particularly in addressing equity issues and ensuring that all students have access to high-quality education. + +Overall, New York State's education system is well-funded and produces strong student outcomes. While there are challenges to address, the state's commitment to education and its students is evident. +``` + +### Caveats ### + +AVOID the following bugs in your annotated responses. + +1. Do not annotate dates. For example: + +`In 2019, India saw wet bulb temperatures reach [__DC__("What was the max wet bulb temperature in 2019 in India?") --> "37 degrees Celsius"].` + +2. Do not skip place names from the main text, even if they are included in the "question". For example: + +`Life expectancy at birth has increased significantly in many African countries. For example, in Nigeria, life expectancy has increased from [__DC__("What was the life expectancy in Nigeria in 2000?") --> "46.6 years"] to [__DC__("What is the current life expectancy in Nigeria?") --> "55.4 years"].` + +3. Do not skip dates from the main text. For example: + +`By 2050, an estimated [__DC__("How many people in Europe will be affected by coastal flooding by 2100?") --> "3 million"] people in Europe will be affected by coastal flooding annually.` + +4. Do not repeat stats or other words that appear in the "answer" again in the main text. For example: + +`By 2000, the widowed population in San Francisco had grown to [__DC__("What was the widowed population in San Francisco in 2020?") --> "70,000 people"].` + + +### Answer this question ### + +Question: {sentence} +Answer: +""" + + +RAG_IN_CONTEXT_PROMPT = """ +Given a QUERY below, your task is to come up with a maximum of 25 +STATISTICAL QUESTIONS that help in answering QUERY. + +Here are the only forms of STATISTICAL QUESTIONS you can generate: + +1. "What is $METRIC in $PLACE?" +2. "What is $METRIC in $PLACE $PLACE_TYPE?" +3. "How has $METRIC changed over time in $PLACE $PLACE_TYPE?" + +where: +- $METRIC should a publicly accessible metric on societal topics around + demographics, economy, health, education, environment, etc. Examples are + unemployment rate, life expectancy, etc. +- $PLACE is the name of a place like California, World, Chennai, etc. +- $PLACE_TYPE is an immediate child type within $PLACE, like counties, states, + districts, etc. + +Your response should only include the questions, one per line without any +numbering or bullet! If you cannot come up with statistical questions to ask, +return an empty response. + +NOTE: Do not repeat questions. Limit the number of questions to 25. + +If QUERY asks about multiple concepts (e.g., income and diseases), make sure +the questions cover all the concepts. + +[Start of Examples] + +QUERY: Which grades in the middle school have the lowest enrollment in Palo Alto? +STATISTICAL QUESTIONS: +What is the number of students enrolled in Grade 6 in Palo Alto schools? +What is the number of students enrolled in Grade 7 in Palo Alto schools? +What is the number of students enrolled in Grade 8 in Palo Alto schools? + +QUERY: Which industries have grown the most in California? +STATISTICAL QUESTIONS: +How have jobs in agriculture changed over time in California? +How has GDP of agriculture sector changed over time in California? +How have jobs in information and technology changed over time in California? +How has GDP of information and technology sector changed over time in California? +How have jobs in the government changed over time in California? +How has GDP of the government sector changed over time in California? +How have jobs in healthcare changed over time in California? +How has GDP of healthcare sector changed over time in California? +How have jobs in entertainment changed over time in California? +How has GDP of entertainment sector changed over time in California? +How have jobs in retail trade changed over time in California? +How has GDP of retail trade sector changed over time in California? +How have jobs in manufacturing changed over time in California? +How has GDP of manufacturing sector changed over time in California? +How have jobs in education services changed over time in California? +How has GDP of education services sector changed over time in California? + +QUERY: Which state in the US has the most asian population? +STATISTICAL QUESTIONS: +What is the number of asian people in US states? + +QUERY: Do specific health conditions affect the richer California counties? +STATISTICAL QUESTIONS: +What is the median income among California counties? +What is the median house price among California counties? +What is the prevalence of obesity in California counties? +What is the prevalence of diabetes in California counties? +What is the prevalence of heart disease in California counties? +What is the prevalence of arthritis in California counties? +What is the prevalence of asthma in California counties? +What is the prevalence of chronic kidney disease in California counties? +What is the prevalence of chronic obstructive pulmonary disease in California counties? +What is the prevalence of coronary heart disease in California counties? +What is the prevalence of high blood pressure in California counties? +What is the prevalence of high cholesterol in California counties? +What is the prevalence of stroke in California counties? +What is the prevalence of poor mental health in California counties? +What is the prevalence of poor physical health in California counties? + + +[End of Examples] + +QUERY: {sentence} +STATISTICAL QUESTIONS: +""" + + +RAG_IN_CONTEXT_PROMPT_WITH_VARS = """ +Given a 'Query' below, your task is to come up with a maximum of 25 +'Statistical Questions' that relate to 'Query'. + +Here are the only forms of 'Statistical Questions' you can generate: + +1. What is $METRIC in $PLACE? +2. What is $METRIC in $PLACE $PLACE_TYPE? +3. How has $METRIC changed over time in $PLACE $PLACE_TYPE? + +Where: +- $METRIC should only be from the 'Metrics List' given below. +- $PLACE is the name of a place like California, World, Chennai, etc. +- $PLACE_TYPE is first-level child type within $PLACE, like counties or + districts if $PLACE is a state, states if $PLACE is a country, etc. + +Your response should only include the questions, one per line, without any +numbering or bullets! If you cannot come up with 'Statistical Questions' only +using the 'Metrics List' below, return an empty response. + +NOTE: Do not repeat questions. Limit the number of questions to 25 and +order the questions from most relevant to least relevant. + +If "Query" asks about multiple concepts (e.g., income and diseases), make sure +the questions cover all the concepts. + +[Start of Examples] + +Query: Tell me about life expectancy. +Statistical Questions: +What is the people life expectancy in the world? +How has people life expectancy changed over time in the world countries? + +Query: Which state in the US has the most asian population? +Statistical Questions: +What is the number of asian people in US states? +How has the number of asian people changed over time in US states? + +Query: Which grades in the middle school have the lowest enrollment in Palo Alto? +Statistical Questions: +What is the number of students enrolled in Grade 6 in Palo Alto schools? +What is the number of students enrolled in Grade 7 in Palo Alto schools? +What is the number of students enrolled in Grade 8 in Palo Alto schools? + +QUERY: Do specific health conditions affect the richer California counties? +STATISTICAL QUESTIONS: +What is the median income among California counties? +What is the median house price among California counties? +What is the prevalence of obesity in California counties? +What is the prevalence of diabetes in California counties? +What is the prevalence of heart disease in California counties? +What is the prevalence of arthritis in California counties? +What is the prevalence of asthma in California counties? +What is the prevalence of chronic kidney disease in California counties? +What is the prevalence of chronic obstructive pulmonary disease in California counties? +What is the prevalence of coronary heart disease in California counties? +What is the prevalence of high blood pressure in California counties? +What is the prevalence of high cholesterol in California counties? +What is the prevalence of stroke in California counties? +What is the prevalence of poor mental health in California counties? +What is the prevalence of poor physical health in California counties? + +[End of Examples] + + +[Start of Metrics List] + +Here is a list of possible METRIC values: + +``` +{metrics_list} +``` + +[End of Metrics List] + + +Query: {sentence} +Statistical Questions: +""" + + +RAG_FINE_TUNED_PROMPT = """" +Your role is that of a Question Generator. Given Query below, come up with a +maximum of 25 Statistical Questions that help in answering Query. + +These are the only forms of Statistical Questions you can generate: +1. What is $METRIC in $PLACE? +2. What is $METRIC in $PLACE $PLACE_TYPE? +3. How has $METRIC changed over time in $PLACE $PLACE_TYPE? + +where, +- $METRIC should a metric on societal topics like demographics, economy, health, + education, environment, etc. Examples are unemployment rate and + life expectancy. +- $PLACE is the name of a place like California, World, Chennai, etc. +- $PLACE_TYPE is an immediate child type within $PLACE, like counties, states, + districts, etc. + +Your response should only have questions, one per line, without any numbering +or bullet. + +If you cannot come up with Statistical Questions to ask for a Query, return an +empty response. + +Query: {sentence} +Statistical Questions: +""" + + +RAG_FINAL_ANSWER_PROMPT = """ +Using statistics from the tables below, respond to the query: "{sentence}" + +In your response, when using statistics from a table, please cite the table +by its ID, for example, "Table 1". + +If necessary to answer the query, perform simple calculations on the statistics, +like adding or subtracting statistics, computing growth rates from statistics +over time, etc. + +If you cannot answer the query based on the provided tables, start your response with: +"The tables do not have the relevant information to answer the query." + +``` +{table_str} +``` + +So now, using statistics from the tables above, respond to the query: "{sentence}" +""" + + +DC_QA_VALIDATION = """ +You will be provided with a list of up to 20 question-answer pairs, each +identified by an ID like [[QA1]]. You must return each ID whose answer is +relevant to its question, one per line. If none of the answers are relevant, +return `[[EMPTY]]`. + +Here is an example INPUT and OUTPUT: + +## INPUT ## +[[QA1]]: + Question: "What is the average education spending per pupil in New York?" + Answer: "% Govt Expenditure on Education in United States" +[[QA2]] + Question: "What is the Gini coefficient in Chile?" + Answer: "Gini Index of Economic Activity of a Population in Chile" +[[QA3]] + Question: "How many people work in health care jobs in Nevada?" + Answer: "Population of Health Care Workers in Nevada" + +## OUTPUT ## +[[QA2]] +[[QA3]] + + +## INPUT ## +{input} + +## OUTPUT ## +""" + + +LLM_JUDGE_PROMPT = """ +[System] + +Please act as an impartial judge and evaluate the quality of the response +provided by an AI Assistant that annotates statistical numbers with questions. + +In the text, you will find patterns of the form: [__DC__("QUERY") --> "ANS"]. +Where, `ANS` is a statistical value, and `QUERY` is a query that can be answered +with `ANS`. + +Your evaluation should consider whether the `__DC__` annotations follow these +constraints: + +(C1) `QUERY` can refer to a very wide variety of measures related to demographics, + economy, education, health, etc. However, it should accurately describe the + statistic involved. `QUERY` must include a place name of a city, state, + country, continent, etc, or words that represent the world (like global). + +(C2) `ANS` must not be empty or have the the word "stat". `ANS` must have a + numeric value, may include percentages, and may additionally have + non-numeric letters for units, currency symbols, etc. + +Provide a classification of the answer quality like "[[]]", +where the can be: +- GOOD: The annotations in the answer adhere to the above rules. +- BAD: If some annotations do not follow the above rules. + +NOTE: You do not need to judge the correctness of the `ANS` value. Duplicate +annotations are fine. + +For example, if the answer is GOOD, the first line of response should be +"[[GOOD]]". + +Then, list only the bad `[__DC__("QUERY") --> "ANS"]` values, one per line, +and concisely point out what is wrong. + +You don't have to provide a revised version of the answer. + +[Start of Assistant's Answer] + +{answer} + +[End of Assistant's Answer] +""" diff --git a/llm/data_gemma/rag.py b/llm/data_gemma/rag.py new file mode 100644 index 0000000..0f2c42a --- /dev/null +++ b/llm/data_gemma/rag.py @@ -0,0 +1,130 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAG Flow.""" + +import logging +import time + +from data_gemma import base +from data_gemma import datacommons +from data_gemma import prompts +from data_gemma import validate + +_MAX_QUESTIONS = 25 + + +class RAGFlow(base.Flow): + """Retrieval Augmented Generation.""" + + def __init__( + self, + llm_question: base.LLM, + llm_answer: base.LLM, + data_fetcher: datacommons.DataCommons, + verbose: bool = True, + in_context: bool = False, + validate_dc_responses: bool = False, + metrics_list: str = '', + ): + self.llm_question = llm_question + self.llm_answer = llm_answer + self.data_fetcher = data_fetcher + self.options = base.Options(verbose=verbose) + self.in_context = in_context + self.validate_dc_responses = validate_dc_responses + self.metrics_list = metrics_list + + def query( + self, + query: str, + ) -> base.FlowResponse: + + # + # First call FT or V LLM model to get questions for Retrieval + # + if self.in_context: + if self.metrics_list: + prompt = prompts.RAG_IN_CONTEXT_PROMPT_WITH_VARS + self.options.vlog( + '... [RAG] Calling UNTUNED model for DC ' + 'questions with all DC vars in prompt' + ) + ques_resp = self.llm_question.query( + prompt.format(metrics_list=self.metrics_list, sentence=query) + ) + else: + prompt = prompts.RAG_IN_CONTEXT_PROMPT + self.options.vlog('... [RAG] Calling UNTUNED model for DC questions') + ques_resp = self.llm_question.query(prompt.format(sentence=query)) + else: + prompt = prompts.RAG_FINE_TUNED_PROMPT + self.options.vlog('... [RAG] Calling FINETUNED model for DC questions') + ques_resp = self.llm_question.query(prompt.format(sentence=query)) + llm_calls = [ques_resp] + if not ques_resp.response: + return base.FlowResponse(llm_calls=llm_calls) + + questions = [q.strip() for q in ques_resp.response.split('\n') if q.strip()] + questions = list(set(questions))[:_MAX_QUESTIONS] + + self.options.vlog('... [RAG] Making DC Calls') + start = time.time() + try: + q2resp = self.data_fetcher.calln(questions, self.data_fetcher.table) + except Exception as e: + logging.warning(e) + q2resp = {} + pass + dc_duration = time.time() - start + + if self.validate_dc_responses: + q2resp = validate.run_validation( + q2resp, self.llm_answer, self.options, llm_calls + ) + + table_parts: list[str] = [] + table_titles = set() + dc_calls = [] + for resp in q2resp.values(): + tidx = len(dc_calls) + 1 + if resp.table and resp.title not in table_titles: + table_parts.append(f'Table {tidx}: {resp.answer()}') + table_titles.add(resp.title) + resp.id = tidx + dc_calls.append(resp) + if table_parts: + prompt = prompts.RAG_FINAL_ANSWER_PROMPT + tables_str = '\n'.join(table_parts) + final_prompt = prompt.format(sentence=query, table_str=tables_str) + else: + self.options.vlog('... [RAG] No stats found!') + final_prompt = query + tables_str = '' + + self.options.vlog('... [RAG] Calling UNTUNED model for final response') + ans_resp = self.llm_answer.query(final_prompt) + llm_calls.append(ans_resp) + if not ans_resp.response: + return base.FlowResponse( + llm_calls=llm_calls, dc_duration_secs=dc_duration + ) + + return base.FlowResponse( + main_text=ans_resp.response, + tables_str=tables_str, + llm_calls=llm_calls, + dc_duration_secs=dc_duration, + dc_calls=dc_calls, + ) diff --git a/llm/data_gemma/rig.py b/llm/data_gemma/rig.py new file mode 100644 index 0000000..6bc9dae --- /dev/null +++ b/llm/data_gemma/rig.py @@ -0,0 +1,190 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RIG Flow.""" + +import copy +import logging +import re +import time + +from data_gemma import base +from data_gemma import datacommons +from data_gemma import prompts +from data_gemma import validate + + +_DC_PATTERN = r'\[__DC__\("([^"]+)"\) --> "([^"]*)"\]?' + +# 5% threshold +_DIFF_THRESHOLD = 0.05 + + +class RIGFlow(base.Flow): + """Retrieval Interleaved Answering.""" + + def __init__( + self, + llm: base.LLM, + data_fetcher: datacommons.DataCommons, + verbose: bool = True, + in_context: bool = False, + validate_dc_responses: bool = False, + ): + self.llm = llm + self.data_fetcher = data_fetcher + self.options = base.Options(verbose=verbose) + self.in_context = in_context + self.validate_dc_responses = validate_dc_responses + + def query( + self, + query: str, + ) -> base.FlowResponse: + + if self.in_context: + self.options.vlog('... [RIG] Calling UNTUNED Model') + prompt = prompts.RIG_IN_CONTEXT_PROMPT + llm_resp = self.llm.query(prompt.format(sentence=query)) + else: + self.options.vlog('... [RIG] Calling FINETUNED Model') + llm_resp = self.llm.query(query) + if not llm_resp.response: + logging.error('FAILED: %s', query) + return base.FlowResponse(llm_calls=[llm_resp]) + + # Make DC calls. + llm_text = llm_resp.response + q2llmval, q2resp, dc_duration = self._call_dc(llm_text) + llm_calls = [llm_resp] + + # Sanity check DC call and response using LLM, and keep only the "good" + # ones. + if self.validate_dc_responses: + q2resp = validate.run_validation( + q2resp, self.llm, self.options, llm_calls + ) + + self.options.vlog('... [RIG] Calling DC Evaluate') + llm_text, footnotes, dc_calls = self._evaluate(llm_text, q2llmval, q2resp) + + return base.FlowResponse( + main_text=llm_text, + footnotes='\n'.join(footnotes), + llm_calls=llm_calls, + dc_duration_secs=dc_duration, + dc_calls=dc_calls, + ) + + def _call_dc( + self, llm_text: str + ) -> tuple[dict[str, list[str]], dict[str, base.DataCommonsCall], float]: + """Calls DC.""" + + start = time.time() + + q2llmval: dict[str, list[str]] = {} + for match in re.findall(_DC_PATTERN, llm_text): + q2llmval.setdefault(match[0], []).append(match[1]) + + try: + q2resp = self.data_fetcher.calln( + list(q2llmval.keys()), self.data_fetcher.point + ) + except Exception as e: + logging.warning(e) + q2resp = {} + pass + + return q2llmval, q2resp, time.time() - start + + def _evaluate( + self, + text: str, + q2llmval: dict[str, list[str]], + q2resp: dict[str, base.DataCommonsCall], + ) -> tuple[str, list[str], list[base.DataCommonsCall]]: + """Evaluates a text contained DC Calls.""" + + def _rtag(txt: str, r: base.DataCommonsCall) -> str: + return f'[{base.DC}#{r.id}({txt})]' + + dc_calls = [] + footnote_map = {} + for q, orig_resp in q2resp.items(): + llm_vals = q2llmval[q] + + for llmval in llm_vals: + resp = copy.deepcopy(orig_resp) + + resp.id = len(dc_calls) + 1 + resp.llm_val = llmval + dcval = resp.val_and_unit() + + idx = -1 + if dcval: + idx = len(footnote_map) + 1 + if q not in footnote_map: + footnote_map[q] = (idx, f'[{idx}] - {resp.footnote()}') + else: + idx = footnote_map[q][0] + + orig = f'[__DC__("{q}") --> "{llmval}"]' + if not llmval: + # If LLM answer was empty! + if dcval: + new = f'{dcval} [{idx}] ||' + else: + new = '--- || ---' + text = text.replace(orig, _rtag(new, resp), 1) + elif dcval: + if _flag_value(resp.val, llmval): + new = f'{dcval} [{idx}]* || {llmval}' + else: + new = f'{dcval} [{idx}] || {llmval}' + text = text.replace(orig, _rtag(new, resp), 1) + else: + new = f'|| {llmval}' + text = text.replace(orig, _rtag(new, resp), 1) + + dc_calls.append(resp) + + footnotes = [ + v[1] for v in sorted(footnote_map.values(), key=lambda x: x[0]) + ] + + return text, footnotes, dc_calls + + +def _clean_float(text: str) -> float: + return float(re.sub(r'[^0-9.]', '', text)) + + +def _flag_value(dcv: str, llmv: str) -> bool: + """Compares dc and llm values and flags beyond a threshold.""" + try: + for t, v in [ + (' million', 1000000), + (' billion', 1000000000), + (' trillion', 1000000000000), + ]: + if t in llmv: + llmv = str(_clean_float(llmv.replace(' million', '')) * v) + break + llmv = _clean_float(llmv) + dcv = float(dcv) + pct_diff = ((dcv - llmv) / llmv) if llmv != 0 else 1.0 + except: + return False + return pct_diff > _DIFF_THRESHOLD or pct_diff < -_DIFF_THRESHOLD diff --git a/llm/data_gemma/utils.py b/llm/data_gemma/utils.py new file mode 100644 index 0000000..b8780d0 --- /dev/null +++ b/llm/data_gemma/utils.py @@ -0,0 +1,109 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils.""" + +import csv +import os +import textwrap + + +# Use a larger field size limit since we can have longer text in training +# data CSVs. +_LARGE_FIELD_SIZE = 1048576 + + +def get_header(in_file): + with open(in_file, 'r') as f: + csvr = csv.reader(f) + header = next(csvr) + return header + + +def round_float(v: str) -> str: + try: + v = int(v) + return str(v) + except Exception: + try: + v = float(v) + return str(round(v, 4)) + except Exception: + return v + + +# +# Returns IDs from links_file that match the given statuses. +# +def get_matched_ids( + links_file: str, statuses: set[str], id_col: str, status_col: str +) -> set[str]: + if not links_file or not statuses: + return set() + matched_ids = set() + with open(links_file, 'r') as f: + for row in csv.DictReader(f): + s = row.get(status_col, '') + if s in statuses: + matched_ids.add(row[id_col]) + return matched_ids + + +def load_csv( + csv_file: str, id_column: str, aux_id_column: str = '' +) -> dict[str, dict[str, str]]: + """Loads an ID keyed csv file.""" + + csv.field_size_limit(_LARGE_FIELD_SIZE) + results = {} + if os.path.exists(csv_file): + with open(csv_file, 'r') as f: + results = {} + for row in csv.DictReader(f): + k = row[id_column].strip() + if aux_id_column: + k = f'{k}/{row[aux_id_column].strip()}' + if k: + results[k] = row + return results + + +def checkpoint_csv( + csv_file: str, key2row: dict[str, dict[str, str]], header: list[str] +) -> None: + """Checkpoint an ID keyed csv file.""" + with open(csv_file, 'w', newline='') as f: + csvw = csv.DictWriter(f, fieldnames=header) + csvw.writeheader() + csvw.writerows([key2row[k] for k in sorted(key2row.keys())]) + + +def clean_rig_in_context_response(text: str) -> str: + parts = text.split('Answer:-', 1) + if len(parts) > 1: + return parts[1].strip() + return parts[0].strip() + + +def narrow_print(text: str) -> str: + wrapper = textwrap.TextWrapper( + width=80, break_long_words=False, break_on_hyphens=False + ) + parts = [] + for line in text.split('\n'): + if not line: + parts.append('') + else: + parts.append(wrapper.fill(line)) + return '\n'.join(parts) diff --git a/llm/data_gemma/validate.py b/llm/data_gemma/validate.py new file mode 100644 index 0000000..da81230 --- /dev/null +++ b/llm/data_gemma/validate.py @@ -0,0 +1,83 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validation Flow.""" + +import logging + +from data_gemma import base +from data_gemma import prompts + + +def run_validation( + q2resp: dict[str, base.DataCommonsCall], + llm: base.LLM, + options: base.Options, + llm_calls: list[base.LLMCall], +) -> dict[str, base.DataCommonsCall]: + """Runs DC QA validation.""" + queries, input_text = _dc_qa_validation_input( + {q: r.title for q, r in q2resp.items()} + ) + if queries: + llm_resp2 = llm.query(prompts.DC_QA_VALIDATION.format(input=input_text)) + options.vlog(f'... [Validate] {input_text}\n{llm_resp2.response}') + if not llm_resp2.response: + logging.error('FAILED: %s', input_text) + queries = [] + else: + llm_calls.append(llm_resp2) + try: + onum = len(queries) + queries = _dc_qa_validation_check(llm_resp2.response, queries) + if len(queries) < onum: + options.vlog( + f'... [Validate] Dropped answers: {onum} --> {len(queries)}' + ) + except: + logging.error('FAILED: %s', llm_resp2.response) + queries = [] + else: + options.vlog('... [Validate] empty queries!') + + queries = set(queries) + q2resp = { + q: r if q in queries else base.DataCommonsCall(query=q) + for q, r in q2resp.items() + } + return q2resp + + +def _dc_qa_validation_input(q2a: dict[str, str]) -> tuple[list[str], str]: + """Returns a list of questions and a prompt for DC QA validation.""" + parts = [] + queries = [] + i = 1 + for q, a in q2a.items(): + if not a.strip(): + continue + queries.append(q) + parts.append(f'[[QA{i}]]:\n Question: {q}\n Answer: {a}') + i += 1 + return queries, '\n'.join(parts) + + +def _dc_qa_validation_check(llm_resp: str, queries: list[str]) -> list[str]: + """Checks the DC QA validation response.""" + out_queries = [] + for qaid in llm_resp.strip().split('\n'): + if qaid.startswith('[[QA') and qaid.endswith(']]'): + idx = int(qaid.replace('[[QA', '').replace(']]', '')) - 1 + out_queries.append(queries[idx]) + return out_queries diff --git a/llm/setup.py b/llm/setup.py new file mode 100644 index 0000000..38587e5 --- /dev/null +++ b/llm/setup.py @@ -0,0 +1,52 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Build and distribute the data_gemma package.""" +import os +from setuptools import setup + + +# Package metadata. +NAME = 'data_gemma' +DESCRIPTION = 'A library to integrate with Data Gemma models and Data Commons.' +URL = 'https://github.com/datacommonsorg/api-python' +EMAIL = 'support@datacommons.org' +AUTHOR = 'datacommons.org' +REQUIRES_PYTHON = '>=3.10' +VERSION = '0.0.1' +REQUIRED = ['requests'] +PACKAGES = ['data_gemma'] + +setup( + name=NAME, + version=VERSION, + description=DESCRIPTION, + author=AUTHOR, + author_email=EMAIL, + maintainer=AUTHOR, + maintainer_email=EMAIL, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=PACKAGES, + install_requires=REQUIRED, + include_package_data=True, + license='Apache 2.0', + classifiers=[ + 'Intended Audience :: Developers', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: Implementation :: CPython', + 'Topic :: Software Development', + ], +)