diff --git a/.github/workflows/test-webhook.yml b/.github/workflows/webhook-checks.yml similarity index 75% rename from .github/workflows/test-webhook.yml rename to .github/workflows/webhook-checks.yml index b34f47c..c10a871 100644 --- a/.github/workflows/test-webhook.yml +++ b/.github/workflows/webhook-checks.yml @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: webhook tests +name: webhook on: push: branches: - main paths: - - .github/workflows/test-webhook.yml + - .github/workflows/webhook-checks.yml - webhook/** - '*.tf' pull_request: branches: - main paths: - - .github/workflows/test-webhook.yml + - .github/workflows/webhook-checks.yml - webhook/** - '*.tf' workflow_dispatch: # Manual runs @@ -34,7 +34,38 @@ on: - cron: '0 9 * * *' jobs: - test-webhook: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: webhook/env + key: ${{ runner.os }}-env-${{ hashFiles('webhook/requirements.txt', 'webhook/requirements-test.txt') }} + + - name: Install dependencies + working-directory: webhook + run: | + python -m venv env + source env/bin/activate + pip install --upgrade pip + pip install ruff + pip check + + - name: Run linter + working-directory: webhook + run: | + source env/bin/activate + ruff check . + + + test: runs-on: ubuntu-latest strategy: @@ -87,6 +118,12 @@ jobs: pip install -r requirements.txt -r requirements-test.txt pip check + - name: Type check + working-directory: webhook + run: | + source env/bin/activate + mypy --non-interactive --install-types . + - name: Run tests working-directory: webhook run: | diff --git a/main.tf b/main.tf index b28c014..3ee5712 100644 --- a/main.tf +++ b/main.tf @@ -89,15 +89,13 @@ resource "google_cloudfunctions2_function" "webhook" { timeout_seconds = 300 # 5 minutes service_account_email = google_service_account.webhook.email environment_variables = { - PROJECT_ID = module.project_services.project_id - VERTEXAI_LOCATION = var.region - OUTPUT_BUCKET = google_storage_bucket.main.name - DOCAI_PROCESSOR = google_document_ai_processor.ocr.id - DOCAI_LOCATION = google_document_ai_processor.ocr.location - DATABASE = google_firestore_database.main.name - INDEX_ID = google_vertex_ai_index.docs.id - INDEX_ENDPOINT_ID = google_vertex_ai_index_endpoint.docs.id - INDEX_CONTENTS_PATH = google_vertex_ai_index.docs.metadata[0].contents_delta_uri + PROJECT_ID = module.project_services.project_id + VERTEXAI_LOCATION = var.region + OUTPUT_BUCKET = google_storage_bucket.main.name + DOCAI_PROCESSOR = google_document_ai_processor.ocr.id + DOCAI_LOCATION = google_document_ai_processor.ocr.location + DATABASE = google_firestore_database.main.name + INDEX_ID = google_vertex_ai_index.docs.id } } } diff --git a/webhook/main.py b/webhook/main.py index 6698b08..434f79f 100644 --- a/webhook/main.py +++ b/webhook/main.py @@ -21,18 +21,18 @@ from datetime import datetime import functions_framework -import vertexai +import vertexai # type: ignore from cloudevents.http import CloudEvent from google.api_core.client_options import ClientOptions from google.cloud import aiplatform from google.cloud import documentai -from google.cloud import firestore -from google.cloud import storage +from google.cloud import firestore # type: ignore +from google.cloud import storage # type: ignore from google.cloud.aiplatform_v1.types import IndexDatapoint from retry import retry -from timeout import timeout, TimeoutException -from vertexai.language_models import TextEmbeddingModel -from vertexai.preview.generative_models import GenerativeModel +from timeout import timeout, TimeoutException # type: ignore +from vertexai.language_models import TextEmbeddingModel # type: ignore +from vertexai.preview.generative_models import GenerativeModel # type: ignore DEPLOYED_INDEX_ID = "deployed_index" DOCAI_LOCATION = os.environ.get("DOCAI_LOCATION", "us") @@ -69,6 +69,12 @@ @timeout(duration=5) def deploy_index(index_id: str, index_endpoint_id: str) -> None: + """Deploy a Vector Search index to an endpoint. + + Args: + index_id: ID of the Vector Search index. + index_endpoint_id: ID of the Vector Search index endpoint. + """ index = aiplatform.MatchingEngineIndex(index_id) endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_id) if not any(index.id == DEPLOYED_INDEX_ID for index in endpoint.deployed_indexes): @@ -137,9 +143,8 @@ def process_document( docai_processor_id: ID of the Document AI processor. database: Name of the Firestore database. output_bucket: Name of the output bucket. - index_contents_path: Path to the index contents file. + index_id: ID of the Vector Search index. """ - db = firestore.Client(database=database) doc = db.document("documents", filename.replace("/", "-")) event_entry = { @@ -174,8 +179,7 @@ def process_document( print(f"🔍 {event_id}: Generating Q&As with model ({len(pages)} pages)") with multiprocessing.Pool(len(pages)) as pool: event_pages = [ - {"filename": filename, "page_number": i, "text": page} - for i, page in enumerate(pages) + {"filename": filename, "page_number": i, "text": page} for i, page in enumerate(pages) ] results = pool.map(process_page, event_pages) entries = list(itertools.chain.from_iterable(results)) @@ -239,9 +243,7 @@ def get_document_text( """ # You must set the `api_endpoint` if you use a location other than "us". client = documentai.DocumentProcessorServiceClient( - client_options=ClientOptions( - api_endpoint=f"{DOCAI_LOCATION}-documentai.googleapis.com" - ) + client_options=ClientOptions(api_endpoint=f"{DOCAI_LOCATION}-documentai.googleapis.com") ) response = client.process_document( request=documentai.ProcessRequest( @@ -261,8 +263,7 @@ def get_document_text( ] return [ "\n".join( - response.document.text[start_index:end_index] - for start_index, end_index in segments + response.document.text[start_index:end_index] for start_index, end_index in segments ) for segments in page_segments ] @@ -297,7 +298,7 @@ def index_pages(index_id: str, filename: str, pages: list[str]) -> None: @retry(tries=3) def generate_questions(text: str) -> list[dict[str, str]]: - """Extract questions & answers using a large language model (LLM) + """Extract questions & answers using a large language model (LLM). For more information, see: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models diff --git a/webhook/pyproject.toml b/webhook/pyproject.toml new file mode 100644 index 0000000..f1d255d --- /dev/null +++ b/webhook/pyproject.toml @@ -0,0 +1,10 @@ +[tool.mypy] +exclude = "env/" + +[tool.ruff] +line-length = 100 +select = ["F", "E", "W", "D", "N", "C", "B", "Q", "A"] +ignore = ["D100", "D211", "D213"] + +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/webhook/requirements-test.txt b/webhook/requirements-test.txt index 96b2b55..44c275c 100644 --- a/webhook/requirements-test.txt +++ b/webhook/requirements-test.txt @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +mypy==1.8.0 pytest==7.4.4 diff --git a/webhook/test_e2e.py b/webhook/test_e2e.py index a981faa..b69153f 100644 --- a/webhook/test_e2e.py +++ b/webhook/test_e2e.py @@ -21,8 +21,8 @@ from uuid import uuid4 import pytest -from google.cloud import storage -from google.cloud import firestore +from google.cloud import storage # type: ignore +from google.cloud import firestore # type: ignore from main import process_document @@ -30,12 +30,26 @@ def run_cmd(*cmd: str, **kwargs: Any) -> subprocess.CompletedProcess: + """Run a command in a subprocess. + + Args: + *cmd: The command to run. + **kwargs: Additional keyword arguments to pass to subprocess.run(). + + Returns: + The completed subprocess. + """ print(f">> {cmd}") return subprocess.run(cmd, check=True, **kwargs) @pytest.fixture(scope="session") -def outputs() -> Iterator[dict[str, str]]: +def terraform_outputs() -> Iterator[dict[str, str]]: + """Yield the Terraform outputs. + + Yields: + The Terraform outputs as a dictionary. + """ print("---") print(f"{PROJECT_ID=}") if not os.environ.get("TEST_SKIP_TERRAFORM"): @@ -53,10 +67,7 @@ def outputs() -> Iterator[dict[str, str]]: "-target=google_firestore_database.main", ) p = run_cmd("terraform", "-chdir=..", "output", "-json", stdout=subprocess.PIPE) - outputs = { - name: value["value"] - for name, value in json.loads(p.stdout.decode("utf-8")).items() - } + outputs = {name: value["value"] for name, value in json.loads(p.stdout.decode("utf-8")).items()} print(f"{outputs=}") yield outputs if not os.environ.get("TEST_SKIP_TERRAFORM"): @@ -70,24 +81,29 @@ def outputs() -> Iterator[dict[str, str]]: ) -def test_end_to_end(outputs: dict[str, str]) -> None: +def test_end_to_end(terraform_outputs: dict[str, str]) -> None: + """End-to-end test. + + Args: + terraform_outputs: The Terraform outputs. + """ print(">> process_document") process_document( - event_id=f"webhook-test-{outputs['unique_id']}-{uuid4().hex[:4]}", + event_id=f"webhook-test-{terraform_outputs['unique_id']}-{uuid4().hex[:4]}", input_bucket="arxiv-dataset", filename="arxiv/cmp-lg/pdf/9410/9410009v1.pdf", mime_type="application/pdf", time_uploaded=datetime.datetime.now(), - docai_processor_id=outputs["documentai_processor_id"], - output_bucket=outputs["bucket_main_name"], - database=outputs["firestore_database_name"], + docai_processor_id=terraform_outputs["documentai_processor_id"], + output_bucket=terraform_outputs["bucket_main_name"], + database=terraform_outputs["firestore_database_name"], index_id="7217902410209951744", ) # Make sure we have a non-empty dataset. print(">> Checking output bucket") storage_client = storage.Client() - output_bucket = outputs["bucket_main_name"] + output_bucket = terraform_outputs["bucket_main_name"] with storage_client.get_bucket(output_bucket).blob("dataset.jsonl").open("r") as f: lines = [line.strip() for line in f] print(f"dataset {len(lines)=}") @@ -95,7 +111,7 @@ def test_end_to_end(outputs: dict[str, str]) -> None: # Make sure the Firestore database is populated. print(">> Checking Firestore database") - db = firestore.Client(database=outputs["firestore_database_name"]) + db = firestore.Client(database=terraform_outputs["firestore_database_name"]) entries = list(db.collection("dataset").stream()) print(f"database {len(entries)=}") assert len(entries) == len(lines), "database entries do not match the dataset"