diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml index 616cef5..37d39e0 100644 --- a/.github/workflows/builds.yml +++ b/.github/workflows/builds.yml @@ -32,7 +32,7 @@ jobs: POSTGRES_USER: postgres POSTGRES_PASSWORD: pg_pwd OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: docker-compose up -d --build + run: docker-compose -f docker-compose.test.yml up -d --build - name: Docker sanity check run: sleep 20 && nc -vz localhost 8050 - name: Debug diff --git a/README.md b/README.md index f1dbc35..44cb558 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,27 @@ In order to stop the service, run: make stop ``` +### Latency benchmark + +You crave for perfect codde suggestions, but you don't know whether it fits your needs in terms of latency? +In the table below, you will find a latency benchmark for all tested LLMs from Ollama: + +| Model | Ingestion mean (std) | Generation mean (std) | +| ------------------------------------------------------------ | ---------------------- | --------------------- | +| [tinyllama:1.1b-chat-v1-q4_0](https://ollama.com/library/tinyllama:1.1b-chat-v1-q4_0) | 2014.63 tok/s (±12.62) | 227.13 tok/s (±2.26) | +| [dolphin-phi:2.7b-v2.6-q4_0](https://ollama.com/library/dolphin-phi:2.7b-v2.6-q4_0) | 684.07 tok/s (±3.85) | 122.25 toks/s (±0.87) | +| [dolphin-mistral:7b-v2.6](https://ollama.com/library/dolphin-mistral:7b-v2.6) | 291.94 tok/s (±0.4) | 60.56 tok/s (±0.15) | + + +This benchmark was performed over 20 iterations on the same input sequence, on a **laptop** to better reflect performances that can be expected by common users. The hardware setup includes an [Intel(R) Core(TM) i7-12700H](https://ark.intel.com/content/www/us/en/ark/products/132228/intel-core-i7-12700h-processor-24m-cache-up-to-4-70-ghz.html) for the CPU, and a [NVIDIA GeForce RTX 3060](https://www.nvidia.com/fr-fr/geforce/graphics-cards/30-series/rtx-3060-3060ti/) for the laptop GPU. + +You can run this latency benchmark for any Ollama model on your hardware as follows: +```bash +python scripts/evaluate_ollama_latency.py dolphin-mistral:7b-v2.6-dpo-laser-q4_0 --endpoint http://localhost:3000 +``` + +*All script arguments can be checked using `python scripts/evaluate_ollama_latency.py --help`* + ### How is the database organized @@ -88,23 +109,23 @@ The back-end core feature is to interact with the metadata tables. For the servi The project was designed so that everything runs with Docker orchestration (standalone virtual environment), so you won't need to install any additional libraries. -## Configuration +### Configuration In order to run the project, you will need to specific some information, which can be done using a `.env` file. This file will have to hold the following information: +- `POSTGRES_DB`*: a name for the [PostgreSQL](https://www.postgresql.org/) database that will be created +- `POSTGRES_USER`*: a login for the PostgreSQL database +- `POSTGRES_PASSWORD`*: a password for the PostgreSQL database - `SUPERADMIN_GH_PAT`: the GitHub token of the initial admin access (Generate a new token on [GitHub](https://github.com/settings/tokens?type=beta), with no extra permissions = read-only) - `SUPERADMIN_PWD`*: the password of the initial admin access - `GH_OAUTH_ID`: the Client ID of the GitHub Oauth app (Create an OAuth app on [GitHub](https://github.com/settings/applications/new), pointing to your Quack dashboard w/ callback URL) - `GH_OAUTH_SECRET`: the secret of the GitHub Oauth app (Generate a new client secret on the created OAuth app) -- `POSTGRES_DB`*: a name for the [PostgreSQL](https://www.postgresql.org/) database that will be created -- `POSTGRES_USER`*: a login for the PostgreSQL database -- `POSTGRES_PASSWORD`*: a password for the PostgreSQL database -- `OPENAI_API_KEY`: your API key for Open AI (Create new secret key on [OpenAI](https://platform.openai.com/api-keys)) _* marks the values where you can pick what you want._ Optionally, the following information can be added: - `SECRET_KEY`*: if set, tokens can be reused between sessions. All instances sharing the same secret key can use the same token. +- `OLLAMA_MODEL`: the model tag in [Ollama library](https://ollama.com/library) that will be used for the API. - `SENTRY_DSN`: the DSN for your [Sentry](https://sentry.io/) project, which monitors back-end errors and report them back. - `SERVER_NAME`*: the server tag that will be used to report events to Sentry. - `POSTHOG_KEY`: the project API key for PostHog [PostHog](https://eu.posthog.com/settings/project-details). @@ -112,6 +133,9 @@ Optionally, the following information can be added: - `SLACK_CHANNEL`: the Slack channel where your bot will post events (defaults to `#general`, you have to invite the App to your channel). - `SUPPORT_EMAIL`: the email used for support of your API. - `DEBUG`: if set to false, silence debug logs. +- `OPENAI_API_KEY`**: your API key for Open AI (Create new secret key on [OpenAI](https://platform.openai.com/api-keys)) + +_** marks the deprecated values._ So your `.env` file should look like something similar to [`.env.example`](.env.example) The file should be placed in the folder of your `./docker-compose.yml`. diff --git a/docker-compose.ollama.yml b/docker-compose.ollama.yml deleted file mode 100644 index 82ffec3..0000000 --- a/docker-compose.ollama.yml +++ /dev/null @@ -1,91 +0,0 @@ -version: '3.7' - -services: - backend: - build: . - command: uvicorn app.main:app --reload --host 0.0.0.0 --port 8050 --proxy-headers - volumes: - - ./src/:/app/ - ports: - - "8050:8050" - environment: - - SUPERADMIN_GH_PAT=${SUPERADMIN_GH_PAT} - - SUPERADMIN_PWD=${SUPERADMIN_PWD} - - GH_OAUTH_ID=${GH_OAUTH_ID} - - GH_OAUTH_SECRET=${GH_OAUTH_SECRET} - - POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} - - OPENAI_API_KEY=${OPENAI_API_KEY} - - SECRET_KEY=${SECRET_KEY} - - SENTRY_DSN=${SENTRY_DSN} - - SERVER_NAME=${SERVER_NAME} - - POSTHOG_KEY=${POSTHOG_KEY} - - SLACK_API_TOKEN=${SLACK_API_TOKEN} - - SLACK_CHANNEL=${SLACK_CHANNEL} - - OLLAMA_ENDPOINT=http://ollama:11434 - - OLLAMA_MODEL=${OLLAMA_MODEL} - - SUPPORT_EMAIL=${SUPPORT_EMAIL} - - DEBUG=true - depends_on: - db: - condition: service_healthy - ollama: - condition: service_healthy - - ollama: - image: ollama/ollama:latest - command: serve && ollama pull ${OLLAMA_MODEL} - volumes: - - "$HOME/.ollama:/root/.ollama" - expose: - - 11434 - healthcheck: - test: ["CMD-SHELL", "ollama list | grep -q '${OLLAMA_MODEL}'"] - interval: 10s - timeout: 3s - retries: 3 - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] - - # ollama-webui: - # image: ghcr.io/ollama-webui/ollama-webui:main - # container_name: ollama-webui - # depends_on: - # - ollama - # # Uncomment below for WIP: Auth support - # # - ollama-webui-db - # ports: - # - 3000:8080 - # environment: - # - "OLLAMA_API_BASE_URL=http://ollama:11434/api" - # # Uncomment below for WIP: Auth support - # # - "WEBUI_AUTH=TRUE" - # # - "WEBUI_DB_URL=mongodb://root:example@ollama-webui-db:27017/" - # # - "WEBUI_JWT_SECRET_KEY=SECRET_KEY" - # # extra_hosts: - # # - host.docker.internal:host-gateway - # # restart: unless-stopped - - db: - image: postgres:15-alpine - volumes: - - postgres_data:/var/lib/postgresql/data/ - expose: - - 5432 - environment: - - POSTGRES_USER=${POSTGRES_USER} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} - - POSTGRES_DB=${POSTGRES_DB} - healthcheck: - test: ["CMD-SHELL", "sh -c 'pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}'"] - interval: 10s - timeout: 3s - retries: 3 - -volumes: - postgres_data: - ollama: diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 4111ff0..c4b8d1b 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -9,21 +9,24 @@ services: ports: - "8050:8050" environment: + - POSTGRES_URL=postgresql+asyncpg://dummy_login:dummy_pwd@test_db/dummy_db + - OLLAMA_ENDPOINT=http://ollama:11434 + - OLLAMA_MODEL=tinydolphin:1.1b-v2.8-q4_0 - SUPERADMIN_GH_PAT=${SUPERADMIN_GH_PAT} - SUPERADMIN_PWD=superadmin_pwd - GH_OAUTH_ID=${GH_OAUTH_ID} - GH_OAUTH_SECRET=${GH_OAUTH_SECRET} - - POSTGRES_URL=postgresql+asyncpg://dummy_login:dummy_pwd@test_db/dummy_db - - OPENAI_API_KEY=${OPENAI_API_KEY} - DEBUG=true depends_on: test_db: condition: service_healthy + ollama: + condition: service_healthy test_db: image: postgres:15-alpine - ports: - - "5432:5432" + expose: + - 5432 environment: - POSTGRES_USER=dummy_login - POSTGRES_PASSWORD=dummy_pwd @@ -33,3 +36,26 @@ services: interval: 10s timeout: 3s retries: 3 + + ollama: + image: ollama/ollama:0.1.25 + command: serve + volumes: + - "$HOME/.ollama:/root/.ollama" + expose: + - 11434 + healthcheck: + test: ["CMD-SHELL", "ollama pull 'tinydolphin:1.1b-v2.8-q4_0'"] + interval: 5s + timeout: 1m + retries: 3 + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + +volumes: + ollama: diff --git a/docker-compose.yml b/docker-compose.yml index 6ae316e..1b73e03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,23 +9,41 @@ services: ports: - "8050:8050" environment: + - POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} + - OLLAMA_ENDPOINT=http://ollama:11434 + - OLLAMA_MODEL=${OLLAMA_MODEL} + - SECRET_KEY=${SECRET_KEY} - SUPERADMIN_GH_PAT=${SUPERADMIN_GH_PAT} - SUPERADMIN_PWD=${SUPERADMIN_PWD} - GH_OAUTH_ID=${GH_OAUTH_ID} - GH_OAUTH_SECRET=${GH_OAUTH_SECRET} - - POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} - - OPENAI_API_KEY=${OPENAI_API_KEY} - - SECRET_KEY=${SECRET_KEY} - - SENTRY_DSN=${SENTRY_DSN} - - SERVER_NAME=${SERVER_NAME} - - POSTHOG_KEY=${POSTHOG_KEY} - - SLACK_API_TOKEN=${SLACK_API_TOKEN} - - SLACK_CHANNEL=${SLACK_CHANNEL} - SUPPORT_EMAIL=${SUPPORT_EMAIL} - DEBUG=true depends_on: db: condition: service_healthy + ollama: + condition: service_healthy + + ollama: + image: ollama/ollama:0.1.25 + command: serve + volumes: + - "$HOME/.ollama:/root/.ollama" + expose: + - 11434 + healthcheck: + test: ["CMD-SHELL", "ollama pull '${OLLAMA_MODEL}'"] + interval: 5s + timeout: 1m + retries: 3 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] db: image: postgres:15-alpine @@ -71,3 +89,4 @@ services: volumes: postgres_data: + ollama: diff --git a/scripts/evaluate_ollama_latency.py b/scripts/evaluate_ollama_latency.py new file mode 100644 index 0000000..b4627e2 --- /dev/null +++ b/scripts/evaluate_ollama_latency.py @@ -0,0 +1,156 @@ +# Copyright (C) 2024, Quack AI. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import time +from typing import Any, Dict + +import numpy as np +import requests +from tqdm import tqdm + + +def _generate( + endpoint: str, model: str, system: str, prompt: str, temperature: float = 0.0, timeout: int = 60 +) -> Dict[str, Any]: + return requests.post( + f"{endpoint}/api/generate", + json={ + "model": model, + "stream": False, + "options": {"temperature": temperature}, + "system": system, + "prompt": prompt, + }, + timeout=timeout, + ) + + +def _chat_completion( + endpoint: str, model: str, system: str, prompt: str, temperature: float = 0.0, timeout: int = 60 +) -> Dict[str, Any]: + return requests.post( + f"{endpoint}/api/chat", + json={ + "model": model, + "stream": False, + "options": {"temperature": temperature}, + "messages": [{"role": "system", "content": system}, {"role": "user", "content": prompt}], + }, + timeout=timeout, + ) + + +def _format_response(response, system, prompt) -> Dict[str, Any]: + assert response.status_code == 200 + json_response = response.json() + return { + "duration": { + "model": json_response.get("load_duration"), + "input": json_response.get("prompt_eval_duration"), + "output": json_response.get("eval_duration"), + "total": json_response.get("total_duration"), + }, + "tokens": {"input": json_response.get("prompt_eval_count"), "output": json_response.get("eval_count")}, + "chars": { + "input": len(system) + len(prompt), + "output": len(json_response.get("response") or json_response["message"]["content"]), + }, + } + + +def main(args): + print(args) + + # Healthcheck on endpoint & model + assert requests.get(f"{args.endpoint}/api/tags", timeout=2).status_code == 200 + response = requests.post(f"{args.endpoint}/api/pull", json={"name": args.model, "stream": False}, timeout=10) + assert response.status_code == 200 + assert response.json()["status"] == "success" + + # Speed + speed_system = ( + "You are a helpful assistant, you will be given a coding task. Answer correctly, otherwise someone will die." + ) + speed_prompt = "Write a Python function to compute the n-th fibonacci number" + # Warmup + for _ in range(args.warmup): + _generate(args.endpoint, args.model, speed_system, speed_prompt) + + # Run + timings = [] + input_chars, output_chars = [], [] + input_tokens, output_tokens = [], [] + load_duration, input_duration, output_duration, total_duration = [], [], [], [] + for _ in tqdm(range(args.it)): + start_ts = time.perf_counter() + response = _generate(args.endpoint, args.model, speed_system, speed_prompt) + timings.append(time.perf_counter() - start_ts) + inference = _format_response(response, speed_system, speed_prompt) + input_chars.append(inference["chars"]["input"]) + output_chars.append(inference["chars"]["output"]) + input_tokens.append(inference["tokens"]["input"]) + output_tokens.append(inference["tokens"]["output"]) + load_duration.append(inference["duration"]["model"]) + input_duration.append(inference["duration"]["input"]) + output_duration.append(inference["duration"]["output"]) + total_duration.append(inference["duration"]["total"]) + + print(f"{args.model} ({args.it} runs)") + timings = np.array(timings) + load_duration = np.array(load_duration, dtype=int) + input_duration = np.array(input_duration, dtype=int) + output_duration = np.array(output_duration, dtype=int) + total_duration = np.array(total_duration, dtype=int) + print(f"Model load duration: mean {load_duration.mean() / 1e6:.2f}ms, std {load_duration.std() / 1e6:.2f}ms") + # Tokens (np.float64 to handle NaNs) + input_tokens = np.array(input_tokens, dtype=np.float64) + output_tokens = np.array(output_tokens, dtype=np.float64) + input_chars = np.array(input_chars, dtype=np.float64) + output_chars = np.array(output_chars, dtype=np.float64) + print( + f"Input processing: mean {1e9 * input_tokens.sum() / input_duration.sum():.2f} tok/s, std {1e9 * (input_tokens / input_duration).std():.2f} tok/s" + ) + print( + f"Output generation: mean {1e9 * output_tokens.sum() / output_duration.sum():.2f} tok/s, std {1e9 * (output_tokens / output_duration).std():.2f} tok/s" + ) + + # Chars + print( + f"Input processing: mean {1e9 * input_chars.sum() / input_duration.sum():.2f} char/s, std {1e9 * (input_chars / input_duration).std():.2f} char/s" + ) + print( + f"Output generation: mean {1e9 * output_chars.sum() / output_duration.sum():.2f} char/s, std {1e9 * (output_chars / output_duration).std():.2f} char/s" + ) + print(f"Overall latency (ollama): mean {total_duration.mean() / 1e6:.2f}ms, std {total_duration.std() / 1e6:.2f}ms") + print(f"Overall latency (HTTP): mean {1000 * timings.mean():.2f}ms, std {1000 * timings.std():.2f}ms") + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser( + description="Ollama latency evaluation", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Data & model + group = parser.add_argument_group("Data & model") + group.add_argument("model", type=str, help="model to use") + group.add_argument("--endpoint", default="http://localhost:11434/api", type=str, help="Ollama endpoint") + + # Inference params + group = parser.add_argument_group("Inference params") + group.add_argument("--temperature", default=0, type=float, help="Temperature to use for model inference") + + # Inference params + group = parser.add_argument_group("Evaluation") + group.add_argument("--it", type=int, default=20, help="Number of iterations to run") + group.add_argument("--warmup", type=int, default=5, help="Number of iterations for warmup") + + return parser + + +if __name__ == "__main__": + args = get_parser().parse_args() + main(args) diff --git a/scripts/latency.csv b/scripts/latency.csv new file mode 100644 index 0000000..7332694 --- /dev/null +++ b/scripts/latency.csv @@ -0,0 +1,4 @@ +model,hardware,ingestion_mean (tok/s),ingestion_std (tok/s),generation_mean (tok/s),generation_std (tok/s) +dolphin-mistral:7b-v2.6,NVIDIA RTX 3060 (laptop),291.94,0.4,60.56,0.15 +dolphin-phi:2.7b-v2.6-q4_0,NVIDIA RTX 3060 (laptop),684.07,3.85,122.25,0.87 +tinyllama:1.1b-chat-v1-q4_0,NVIDIA RTX 3060 (laptop),2014.63,12.62,227.13,2.26 diff --git a/src/app/api/api_v1/endpoints/code.py b/src/app/api/api_v1/endpoints/code.py new file mode 100644 index 0000000..91f807e --- /dev/null +++ b/src/app/api/api_v1/endpoints/code.py @@ -0,0 +1,29 @@ +# Copyright (C) 2024, Quack AI. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from fastapi import APIRouter, Security, status +from fastapi.responses import StreamingResponse + +from app.api.dependencies import get_current_user +from app.models import User, UserScope +from app.schemas.code import ChatMessage +from app.services.ollama import ollama_client +from app.services.telemetry import telemetry_client + +router = APIRouter() + + +@router.post("/chat", status_code=status.HTTP_200_OK, summary="Chat with our code model") +async def chat( + payload: ChatMessage, + user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +) -> StreamingResponse: + telemetry_client.capture(user.id, event="compute-chat") + # Run analysis + return StreamingResponse( + ollama_client.chat(payload.content).iter_content(chunk_size=8192), + media_type="text/event-stream", + ) diff --git a/src/app/api/api_v1/endpoints/compute.py b/src/app/api/api_v1/endpoints/compute.py deleted file mode 100644 index ad1aceb..0000000 --- a/src/app/api/api_v1/endpoints/compute.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (C) 2023-2024, Quack AI. - -# This program is licensed under the Apache License 2.0. -# See LICENSE or go to for full license details. - -from typing import List, cast - -from fastapi import APIRouter, Depends, Path, Security, status - -from app.api.dependencies import get_current_user, get_guideline_crud, get_repo_crud -from app.crud import GuidelineCRUD, RepositoryCRUD -from app.models import Guideline, User, UserScope -from app.schemas.compute import ComplianceResult, Snippet -from app.services.openai import ExecutionMode, openai_client -from app.services.telemetry import telemetry_client - -router = APIRouter() - - -@router.post( - "/analyze/{repo_id}", - status_code=status.HTTP_200_OK, - summary="Check code against the guidelines of a given repository", -) -async def check_code_against_repo_guidelines( - payload: Snippet, - repo_id: int = Path(..., gt=0), - repos: RepositoryCRUD = Depends(get_repo_crud), - guidelines: GuidelineCRUD = Depends(get_guideline_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), -) -> List[ComplianceResult]: - telemetry_client.capture(user.id, event="compute-analyze", properties={"repo_id": repo_id}) - # Check repo - await repos.get(repo_id, strict=True) - # Fetch guidelines - guideline_list = [elt for elt in await guidelines.fetch_all(("repo_id", repo_id))] - # Run analysis - return openai_client.check_code_against_guidelines( - payload.code, guideline_list, mode=ExecutionMode.MULTI, user_id=str(user.id) - ) - - -@router.post("/check/{guideline_id}", status_code=status.HTTP_200_OK, summary="Check code against a specific guideline") -async def check_code_against_guideline( - payload: Snippet, - guideline_id: int = Path(..., gt=0), - guidelines: GuidelineCRUD = Depends(get_guideline_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), -) -> ComplianceResult: - # Check repo - guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True)) - telemetry_client.capture( - user.id, event="compute-check", properties={"repo_id": guideline.repo_id, "guideline_id": guideline_id} - ) - # Run analysis - return openai_client.check_code(payload.code, guideline, user_id=str(user.id)) diff --git a/src/app/api/api_v1/endpoints/guidelines.py b/src/app/api/api_v1/endpoints/guidelines.py index 8a9e49e..d7fe050 100644 --- a/src/app/api/api_v1/endpoints/guidelines.py +++ b/src/app/api/api_v1/endpoints/guidelines.py @@ -14,17 +14,14 @@ from app.schemas.base import OptionalGHToken from app.schemas.guidelines import ( ContentUpdate, - ExampleRequest, - GuidelineContent, GuidelineCreate, GuidelineCreation, GuidelineEdit, - GuidelineExample, OrderUpdate, - TextContent, ) from app.services.github import gh_client -from app.services.openai import openai_client + +# from app.services.openai import openai_client from app.services.telemetry import telemetry_client router = APIRouter() @@ -115,23 +112,23 @@ async def delete_guideline( await guidelines.delete(guideline_id) -@router.post("/parse", status_code=status.HTTP_200_OK, summary="Extract guidelines from a text corpus") -async def parse_guidelines_from_text( - payload: TextContent, - user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), -) -> List[GuidelineContent]: - telemetry_client.capture(user.id, event="guideline-parse") - # Analyze with LLM - return openai_client.parse_guidelines_from_text(payload.content, user_id=str(user.id)) - # return ollama_client.parse_guidelines_from_text(payload.content) - - -@router.post("/examples", status_code=status.HTTP_200_OK, summary="Request examples for a guideline") -async def generate_examples_for_text( - payload: ExampleRequest, - user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), -) -> GuidelineExample: - telemetry_client.capture(user.id, event="guideline-examples") - # Analyze with LLM - return openai_client.generate_examples_for_instruction(payload.content, payload.language, user_id=str(user.id)) - # return ollama_client.generate_examples_for_instruction(payload.content, payload.language) +# @router.post("/parse", status_code=status.HTTP_200_OK, summary="Extract guidelines from a text corpus") +# async def parse_guidelines_from_text( +# payload: TextContent, +# user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +# ) -> List[GuidelineContent]: +# telemetry_client.capture(user.id, event="guideline-parse") +# # Analyze with LLM +# return openai_client.parse_guidelines_from_text(payload.content, user_id=str(user.id)) +# # return ollama_client.parse_guidelines_from_text(payload.content) + + +# @router.post("/examples", status_code=status.HTTP_200_OK, summary="Request examples for a guideline") +# async def generate_examples_for_text( +# payload: ExampleRequest, +# user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +# ) -> GuidelineExample: +# telemetry_client.capture(user.id, event="guideline-examples") +# # Analyze with LLM +# return openai_client.generate_examples_for_instruction(payload.content, payload.language, user_id=str(user.id)) +# # return ollama_client.generate_examples_for_instruction(payload.content, payload.language) diff --git a/src/app/api/api_v1/endpoints/repos.py b/src/app/api/api_v1/endpoints/repos.py index f0ecb36..4d17263 100644 --- a/src/app/api/api_v1/endpoints/repos.py +++ b/src/app/api/api_v1/endpoints/repos.py @@ -4,9 +4,7 @@ # See LICENSE or go to for full license details. import logging -from base64 import b64decode from datetime import datetime -from functools import partial from typing import List, cast from fastapi import APIRouter, Depends, HTTPException, Path, Security, status @@ -15,13 +13,13 @@ from app.crud import GuidelineCRUD, RepositoryCRUD from app.models import Guideline, Repository, User, UserScope from app.schemas.base import OptionalGHToken -from app.schemas.guidelines import OrderUpdate, ParsedGuideline +from app.schemas.guidelines import OrderUpdate from app.schemas.repos import GuidelineOrder, RepoCreate, RepoCreation, RepoUpdate from app.services.github import gh_client -from app.services.openai import openai_client + +# from app.services.openai import openai_client from app.services.slack import slack_client from app.services.telemetry import telemetry_client -from app.services.utils import execute_in_parallel logger = logging.getLogger("uvicorn.error") router = APIRouter() @@ -167,75 +165,75 @@ async def fetch_guidelines_from_repo( return [elt for elt in await guidelines.fetch_all(("repo_id", repo_id))] -@router.post( - "/{repo_id}/parse", status_code=status.HTTP_200_OK, summary="Extracts the guidelines from a GitHub repository" -) -async def parse_guidelines_from_github( - payload: OptionalGHToken, - repo_id: int = Path(..., gt=0), - repos: RepositoryCRUD = Depends(get_repo_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), -) -> List[ParsedGuideline]: - telemetry_client.capture(user.id, event="repo-parse-guidelines", properties={"repo_id": repo_id}) - # Sanity check - repo = cast(Repository, await repos.get(repo_id, strict=True)) - # Stage all the text sources - sources = [] - # Parse CONTRIBUTING (README if CONTRIBUTING doesn't exist) - contributing = gh_client.get_file(repo.full_name, "CONTRIBUTING.md", payload.github_token) - readme = gh_client.get_readme(repo.full_name, payload.github_token) if contributing is None else None - if contributing is not None: - sources.append((contributing["path"], b64decode(contributing["content"]).decode())) - if readme is not None: - sources.append((readme["path"], b64decode(readme["content"]).decode())) - # Pull request comments (!= review comments/threads) - pull_comments = [ - pull - for pull in gh_client.fetch_pull_comments_from_repo(repo.full_name, token=payload.github_token) - if len(pull["comments"]) > 0 - ] - if len(pull_comments) > 0: - # Keep: body, user/id, reactions/total_count - corpus = "# Pull request comments\n\n\n\n\n\n".join([ - f"PULL REQUEST {pull['pull']['number']} from user {pull['pull']['user_id']}\n\n" - + "\n\n".join(f"[User {comment['user_id']}] {comment['body']}" for comment in pull["comments"]) - for pull in pull_comments - ]) - sources.append(("pull_request_comments", corpus)) - # Review threads - review_comments = [ - pull - for pull in gh_client.fetch_reviews_from_repo(repo.full_name, token=payload.github_token) - if len(pull["threads"]) > 0 - ] - # Ideas: filter on pulls with highest amount of comments recently, add the review output rejection/etc - if len(review_comments) > 0: - # Keep: code, body, user/id, reactions/total_count - corpus = "# Code review history\n\n\n\n\n\n".join([ - f"PULL: {pull['pull']['number']} from user {pull['pull']['user_id']}\n\n" - + "\n\n".join( - f"[Code diff]\n```{thread[0]['code']}\n```\n" - + "\n".join(f"[User {comment['user_id']}] {comment['body']}" for comment in thread) - for thread in pull["threads"] - ) - for pull in review_comments - ]) - sources.append(("review_comments", corpus)) - # If not enough information, raise error - if len(sources) == 0: - raise HTTPException(status.HTTP_404_NOT_FOUND, detail="No useful information is accessible in the repository") - # Process all sources in parallel - responses = execute_in_parallel( - partial(openai_client.parse_guidelines_from_text, user_id=str(user.id)), - (corpus for _, corpus in sources), - num_threads=len(sources), - ) - guidelines = [ - ParsedGuideline(**guideline.dict(), repo_id=repo_id, source=source) - for (source, _), response in zip(sources, responses) - for guideline in response - ] - return guidelines +# @router.post( +# "/{repo_id}/parse", status_code=status.HTTP_200_OK, summary="Extracts the guidelines from a GitHub repository" +# ) +# async def parse_guidelines_from_github( +# payload: OptionalGHToken, +# repo_id: int = Path(..., gt=0), +# repos: RepositoryCRUD = Depends(get_repo_crud), +# user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +# ) -> List[ParsedGuideline]: +# telemetry_client.capture(user.id, event="repo-parse-guidelines", properties={"repo_id": repo_id}) +# # Sanity check +# repo = cast(Repository, await repos.get(repo_id, strict=True)) +# # Stage all the text sources +# sources = [] +# # Parse CONTRIBUTING (README if CONTRIBUTING doesn't exist) +# contributing = gh_client.get_file(repo.full_name, "CONTRIBUTING.md", payload.github_token) +# readme = gh_client.get_readme(repo.full_name, payload.github_token) if contributing is None else None +# if contributing is not None: +# sources.append((contributing["path"], b64decode(contributing["content"]).decode())) +# if readme is not None: +# sources.append((readme["path"], b64decode(readme["content"]).decode())) +# # Pull request comments (!= review comments/threads) +# pull_comments = [ +# pull +# for pull in gh_client.fetch_pull_comments_from_repo(repo.full_name, token=payload.github_token) +# if len(pull["comments"]) > 0 +# ] +# if len(pull_comments) > 0: +# # Keep: body, user/id, reactions/total_count +# corpus = "# Pull request comments\n\n\n\n\n\n".join([ +# f"PULL REQUEST {pull['pull']['number']} from user {pull['pull']['user_id']}\n\n" +# + "\n\n".join(f"[User {comment['user_id']}] {comment['body']}" for comment in pull["comments"]) +# for pull in pull_comments +# ]) +# sources.append(("pull_request_comments", corpus)) +# # Review threads +# review_comments = [ +# pull +# for pull in gh_client.fetch_reviews_from_repo(repo.full_name, token=payload.github_token) +# if len(pull["threads"]) > 0 +# ] +# # Ideas: filter on pulls with highest amount of comments recently, add the review output rejection/etc +# if len(review_comments) > 0: +# # Keep: code, body, user/id, reactions/total_count +# corpus = "# Code review history\n\n\n\n\n\n".join([ +# f"PULL: {pull['pull']['number']} from user {pull['pull']['user_id']}\n\n" +# + "\n\n".join( +# f"[Code diff]\n```{thread[0]['code']}\n```\n" +# + "\n".join(f"[User {comment['user_id']}] {comment['body']}" for comment in thread) +# for thread in pull["threads"] +# ) +# for pull in review_comments +# ]) +# sources.append(("review_comments", corpus)) +# # If not enough information, raise error +# if len(sources) == 0: +# raise HTTPException(status.HTTP_404_NOT_FOUND, detail="No useful information is accessible in the repository") +# # Process all sources in parallel +# responses = execute_in_parallel( +# partial(openai_client.parse_guidelines_from_text, user_id=str(user.id)), +# (corpus for _, corpus in sources), +# num_threads=len(sources), +# ) +# guidelines = [ +# ParsedGuideline(**guideline.dict(), repo_id=repo_id, source=source) +# for (source, _), response in zip(sources, responses) +# for guideline in response +# ] +# return guidelines @router.post("/{repo_id}/waitlist", status_code=status.HTTP_200_OK, summary="Add a GitHub repository to the waitlist") diff --git a/src/app/api/api_v1/router.py b/src/app/api/api_v1/router.py index 78fb5c9..a3c8dd0 100644 --- a/src/app/api/api_v1/router.py +++ b/src/app/api/api_v1/router.py @@ -5,11 +5,11 @@ from fastapi import APIRouter -from app.api.api_v1.endpoints import compute, guidelines, login, repos, users +from app.api.api_v1.endpoints import code, guidelines, login, repos, users api_router = APIRouter() api_router.include_router(login.router, prefix="/login", tags=["login"]) api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(repos.router, prefix="/repos", tags=["repos"]) api_router.include_router(guidelines.router, prefix="/guidelines", tags=["guidelines"]) -api_router.include_router(compute.router, prefix="/compute", tags=["compute"]) +api_router.include_router(code.router, prefix="/code", tags=["code"]) diff --git a/src/app/core/config.py b/src/app/core/config.py index edca884..fdd8b3f 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -10,8 +10,6 @@ from pydantic import BaseSettings, validator -from app.schemas.services import OpenAIModel - __all__ = ["settings"] @@ -45,11 +43,9 @@ def sqlachmey_uri(cls, v: str) -> str: ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 ACCESS_TOKEN_UNLIMITED_MINUTES: int = 60 * 24 * 365 JWT_ENCODING_ALGORITHM: str = "HS256" - # Compute - OPENAI_API_KEY: str = os.environ["OPENAI_API_KEY"] - OPENAI_MODEL: OpenAIModel = OpenAIModel.GPT3_5_TURBO - # OLLAMA_ENDPOINT: str = os.environ["OLLAMA_ENDPOINT"] - # OLLAMA_MODEL: str = os.environ.get("OLLAMA_MODEL", "starling-lm:7b-alpha") + # LLM Compute + OLLAMA_ENDPOINT: str = os.environ["OLLAMA_ENDPOINT"] + OLLAMA_MODEL: str = os.environ.get("OLLAMA_MODEL", "dolphin-mistral:7b-v2.6-dpo-laser-q4_0") # Error monitoring SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN") diff --git a/src/app/schemas/compute.py b/src/app/schemas/code.py similarity index 55% rename from src/app/schemas/compute.py rename to src/app/schemas/code.py index f82f9be..7a9bee4 100644 --- a/src/app/schemas/compute.py +++ b/src/app/schemas/code.py @@ -3,10 +3,11 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +from enum import Enum from pydantic import BaseModel, Field -__all__ = ["ComplianceResult", "Snippet"] +__all__ = ["ChatMessage", "ComplianceResult", "Snippet"] class Snippet(BaseModel): @@ -17,3 +18,14 @@ class ComplianceResult(BaseModel): guideline_id: int = Field(..., gt=0) is_compliant: bool comment: str + + +class ChatRole(str, Enum): + SYSTEM: str = "system" + USER: str = "user" + ASSISTANT: str = "assistant" + + +class ChatMessage(BaseModel): + role: ChatRole = Field(ChatRole.USER, example=ChatRole.USER) + content: str = Field(..., min_length=1) diff --git a/src/app/services/ollama.py b/src/app/services/ollama.py index 5923543..2fd950a 100644 --- a/src/app/services/ollama.py +++ b/src/app/services/ollama.py @@ -11,12 +11,13 @@ import requests from fastapi import HTTPException, status +from app.core.config import settings from app.schemas.guidelines import GuidelineContent, GuidelineExample logger = logging.getLogger("uvicorn.error") ValidationOut = TypeVar("ValidationOut") -# __all__ = ["ollama_client"] +__all__ = ["ollama_client"] EXAMPLE_PROMPT = ( @@ -55,6 +56,12 @@ def validate_example_response(response: str) -> Dict[str, str]: PARSING_PATTERN = r"\{\s*\"title\":\s+\"(?P.*?)\",\s+\"details\":\s+\"(?P<details>.*?)\"\s*\}" +CHAT_PROMPT = ( + "You are an AI programming assistant, developed by the company Quack AI, and you only answer questions related to computer science " + "(refuse to answer for the rest)." +) + + def validate_parsing_response(response: str) -> List[Dict[str, str]]: guideline_list = json.loads(response.strip()) if not isinstance(guideline_list, list) or any( @@ -71,12 +78,12 @@ def __init__(self, endpoint: str, model_name: str, temperature: float = 0.0) -> # Check endpoint response = requests.get(f"{self.endpoint}/api/tags", timeout=2) if response.status_code != 200: - raise HTTPException(status_code=status.HTTP_404, detail="Unavailable endpoint") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Unavailable endpoint") # Pull model logger.info("Loading Ollama model...") response = requests.post(f"{self.endpoint}/api/pull", json={"name": model_name, "stream": False}, timeout=10) if response.status_code != 200 or response.json()["status"] != "success": - raise HTTPException(status_code=status.HTTP_404, detail="Unavailable model") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Unavailable model") self.temperature = temperature self.model_name = model_name logger.info(f"Using Ollama model: {self.model_name}") @@ -97,6 +104,8 @@ def _request( "options": {"temperature": self.temperature}, "system": system_prompt, "prompt": message, + "format": "json", + "keep_alive": "30s", }, timeout=timeout, ) @@ -109,6 +118,35 @@ def _request( logger.info(response.json()["response"].strip()) return validate_fn(response.json()["response"]) + def _chat( + self, + system_prompt: str, + message: str, + timeout: int = 20, + ) -> requests.Response: + return requests.post( + f"{self.endpoint}/api/chat", + json={ + "model": self.model_name, + "stream": True, + "options": {"temperature": self.temperature}, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": message}, + ], + "keep_alive": "30s", + }, + stream=True, + timeout=timeout, + ) + + def chat( + self, + message: str, + **kwargs, + ) -> requests.Response: + return self._chat(CHAT_PROMPT, message, **kwargs) + def parse_guidelines_from_text( self, corpus: str, @@ -158,4 +196,4 @@ def generate_examples_for_instruction( ) -# ollama_client = OllamaClient(settings.OLLAMA_ENDPOINT, settings.OLLAMA_MODEL) +ollama_client = OllamaClient(settings.OLLAMA_ENDPOINT, settings.OLLAMA_MODEL) diff --git a/src/app/services/openai.py b/src/app/services/openai.py index 4b46fd5..1f879fb 100644 --- a/src/app/services/openai.py +++ b/src/app/services/openai.py @@ -14,9 +14,8 @@ from fastapi import HTTPException, status from pydantic import ValidationError -from app.core.config import settings from app.models import Guideline -from app.schemas.compute import ComplianceResult +from app.schemas.code import ComplianceResult from app.schemas.guidelines import GuidelineContent, GuidelineExample from app.schemas.services import ( ArraySchema, @@ -31,7 +30,7 @@ logger = logging.getLogger("uvicorn.error") -__all__ = ["openai_client"] +# __all__ = ["openai_client"] class ExecutionMode(str, Enum): @@ -203,7 +202,7 @@ def __init__( ) -> None: self.headers = self._get_headers(api_key) # Validate model - model_card = requests.get(f"https://api.openai.com/v1/models/{model}", headers=self.headers, timeout=2) + model_card = requests.get(f"https://api.openai.com/v1/models/{model}", headers=self.headers, timeout=5) if model_card.status_code != 200: raise HTTPException(status_code=model_card.status_code, detail=model_card.json()["error"]["message"]) self.model = model @@ -389,4 +388,4 @@ def generate_examples_for_instruction( ) -openai_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL) +# openai_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL) diff --git a/src/tests/endpoints/test_code.py b/src/tests/endpoints/test_code.py new file mode 100644 index 0000000..16c6d57 --- /dev/null +++ b/src/tests/endpoints/test_code.py @@ -0,0 +1,48 @@ +from typing import Any, Dict, Union + +import pytest +import pytest_asyncio +from httpx import AsyncClient +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models import User + +USER_TABLE = [ + {"id": 1, "login": "first_login", "hashed_password": "hashed_first_pwd", "scope": "admin"}, + {"id": 2, "login": "second_login", "hashed_password": "hashed_second_pwd", "scope": "user"}, +] + + +@pytest_asyncio.fixture(scope="function") +async def code_session(async_session: AsyncSession, monkeypatch): + for entry in USER_TABLE: + async_session.add(User(**entry)) + await async_session.commit() + yield async_session + + +@pytest.mark.parametrize( + ("user_idx", "payload", "status_code", "status_detail"), + [ + (None, {"role": "user", "content": "Is Python 3.11 faster than 3.10?"}, 401, "Not authenticated"), + (0, {"role": "alien", "content": "Is Python 3.11 faster than 3.10?"}, 422, None), + (0, {"role": "user", "content": "Is Python 3.11 faster than 3.10?"}, 200, None), + ], +) +@pytest.mark.asyncio() +async def test_chat( + async_client: AsyncClient, + code_session: AsyncSession, + user_idx: Union[int, None], + payload: Dict[str, Any], + status_code: int, + status_detail: Union[str, None], +): + auth = None + if isinstance(user_idx, int): + auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split()) + + response = await async_client.post("/code/chat", json=payload, headers=auth) + assert response.status_code == status_code, print(response.__dict__) + if isinstance(status_detail, str): + assert response.json()["detail"] == status_detail diff --git a/src/tests/endpoints/test_compute.py b/src/tests/endpoints/test_compute.py index 1817e4a..e9d65de 100644 --- a/src/tests/endpoints/test_compute.py +++ b/src/tests/endpoints/test_compute.py @@ -1,129 +1,129 @@ -from typing import Any, Dict, Union +# from typing import Any, Dict, Union -import pytest -import pytest_asyncio -from httpx import AsyncClient -from sqlmodel import text -from sqlmodel.ext.asyncio.session import AsyncSession +# import pytest +# import pytest_asyncio +# from httpx import AsyncClient +# from sqlmodel import text +# from sqlmodel.ext.asyncio.session import AsyncSession -from app.models import Guideline, Repository, User +# from app.models import Guideline, Repository, User -USER_TABLE = [ - {"id": 1, "login": "first_login", "hashed_password": "hashed_first_pwd", "scope": "admin"}, - {"id": 2, "login": "second_login", "hashed_password": "hashed_second_pwd", "scope": "user"}, -] +# USER_TABLE = [ +# {"id": 1, "login": "first_login", "hashed_password": "hashed_first_pwd", "scope": "admin"}, +# {"id": 2, "login": "second_login", "hashed_password": "hashed_second_pwd", "scope": "user"}, +# ] -REPO_TABLE = [ - { - "id": 12345, - "full_name": "quack-ai/dummy-repo", - "installed_by": 1, - "owner_id": 1, - "installed_at": "2023-11-07T15:07:19.226673", - "removed_at": None, - }, - { - "id": 123456, - "full_name": "quack-ai/another-repo", - "installed_by": 2, - "owner_id": 2, - "installed_at": "2023-11-07T15:07:19.226673", - "removed_at": None, - }, -] +# REPO_TABLE = [ +# { +# "id": 12345, +# "full_name": "quack-ai/dummy-repo", +# "installed_by": 1, +# "owner_id": 1, +# "installed_at": "2023-11-07T15:07:19.226673", +# "removed_at": None, +# }, +# { +# "id": 123456, +# "full_name": "quack-ai/another-repo", +# "installed_by": 2, +# "owner_id": 2, +# "installed_at": "2023-11-07T15:07:19.226673", +# "removed_at": None, +# }, +# ] -GUIDELINE_TABLE = [ - { - "id": 1, - "repo_id": 12345, - "title": "Object naming", - "details": "Ensure function and class/instance methods have a meaningful & informative name", - "order": 1, - "created_at": "2023-11-07T15:08:19.226673", - "updated_at": "2023-11-07T15:08:19.226673", - }, - { - "id": 2, - "repo_id": 12345, - "title": "Docstrings", - "details": "All functions and methods need to have a docstring", - "order": 2, - "created_at": "2023-11-07T15:08:20.226673", - "updated_at": "2023-11-07T15:08:20.226673", - }, -] +# GUIDELINE_TABLE = [ +# { +# "id": 1, +# "repo_id": 12345, +# "title": "Object naming", +# "details": "Ensure function and class/instance methods have a meaningful & informative name", +# "order": 1, +# "created_at": "2023-11-07T15:08:19.226673", +# "updated_at": "2023-11-07T15:08:19.226673", +# }, +# { +# "id": 2, +# "repo_id": 12345, +# "title": "Docstrings", +# "details": "All functions and methods need to have a docstring", +# "order": 2, +# "created_at": "2023-11-07T15:08:20.226673", +# "updated_at": "2023-11-07T15:08:20.226673", +# }, +# ] -@pytest_asyncio.fixture(scope="function") -async def compute_session(async_session: AsyncSession, monkeypatch): - for entry in USER_TABLE: - async_session.add(User(**entry)) - await async_session.commit() - for entry in REPO_TABLE: - async_session.add(Repository(**entry)) - await async_session.commit() - for entry in GUIDELINE_TABLE: - async_session.add(Guideline(**entry)) - await async_session.commit() - # Update the guideline index count - await async_session.execute( - text(f"ALTER SEQUENCE guideline_id_seq RESTART WITH {max(entry['id'] for entry in GUIDELINE_TABLE) + 1}") - ) - await async_session.commit() - yield async_session +# @pytest_asyncio.fixture(scope="function") +# async def compute_session(async_session: AsyncSession, monkeypatch): +# for entry in USER_TABLE: +# async_session.add(User(**entry)) +# await async_session.commit() +# for entry in REPO_TABLE: +# async_session.add(Repository(**entry)) +# await async_session.commit() +# for entry in GUIDELINE_TABLE: +# async_session.add(Guideline(**entry)) +# await async_session.commit() +# # Update the guideline index count +# await async_session.execute( +# text(f"ALTER SEQUENCE guideline_id_seq RESTART WITH {max(entry['id'] for entry in GUIDELINE_TABLE) + 1}") +# ) +# await async_session.commit() +# yield async_session -@pytest.mark.parametrize( - ("user_idx", "guideline_id", "payload", "status_code", "status_detail"), - [ - (None, 1, {"code": "def hello_world():\n\tprint('hello')"}, 401, "Not authenticated"), - (0, 1, {"code": ""}, 422, None), - (0, 100, {"code": "def hello_world():\n\tprint('hello')"}, 404, "Table Guideline has no corresponding entry."), - ], -) -@pytest.mark.asyncio() -async def test_check_code_against_guideline( - async_client: AsyncClient, - compute_session: AsyncSession, - user_idx: Union[int, None], - guideline_id: int, - payload: Dict[str, Any], - status_code: int, - status_detail: Union[str, None], -): - auth = None - if isinstance(user_idx, int): - auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split()) +# @pytest.mark.parametrize( +# ("user_idx", "guideline_id", "payload", "status_code", "status_detail"), +# [ +# (None, 1, {"code": "def hello_world():\n\tprint('hello')"}, 401, "Not authenticated"), +# (0, 1, {"code": ""}, 422, None), +# (0, 100, {"code": "def hello_world():\n\tprint('hello')"}, 404, "Table Guideline has no corresponding entry."), +# ], +# ) +# @pytest.mark.asyncio() +# async def test_check_code_against_guideline( +# async_client: AsyncClient, +# compute_session: AsyncSession, +# user_idx: Union[int, None], +# guideline_id: int, +# payload: Dict[str, Any], +# status_code: int, +# status_detail: Union[str, None], +# ): +# auth = None +# if isinstance(user_idx, int): +# auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split()) - response = await async_client.post(f"/compute/check/{guideline_id}", json=payload, headers=auth) - assert response.status_code == status_code, print(response.__dict__) - if isinstance(status_detail, str): - assert response.json()["detail"] == status_detail +# response = await async_client.post(f"/compute/check/{guideline_id}", json=payload, headers=auth) +# assert response.status_code == status_code, print(response.__dict__) +# if isinstance(status_detail, str): +# assert response.json()["detail"] == status_detail -@pytest.mark.parametrize( - ("user_idx", "repo_id", "payload", "status_code", "status_detail"), - [ - (None, 12345, {"code": "def hello_world():\n\tprint('hello')"}, 401, "Not authenticated"), - (0, 12345, {"code": ""}, 422, None), - (0, 100, {"code": "def hello_world():\n\tprint('hello')"}, 404, "Table Repository has no corresponding entry."), - ], -) -@pytest.mark.asyncio() -async def test_check_code_against_repo_guidelines( - async_client: AsyncClient, - compute_session: AsyncSession, - user_idx: Union[int, None], - repo_id: int, - payload: Dict[str, Any], - status_code: int, - status_detail: Union[str, None], -): - auth = None - if isinstance(user_idx, int): - auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split()) +# @pytest.mark.parametrize( +# ("user_idx", "repo_id", "payload", "status_code", "status_detail"), +# [ +# (None, 12345, {"code": "def hello_world():\n\tprint('hello')"}, 401, "Not authenticated"), +# (0, 12345, {"code": ""}, 422, None), +# (0, 100, {"code": "def hello_world():\n\tprint('hello')"}, 404, "Table Repository has no corresponding entry."), +# ], +# ) +# @pytest.mark.asyncio() +# async def test_check_code_against_repo_guidelines( +# async_client: AsyncClient, +# compute_session: AsyncSession, +# user_idx: Union[int, None], +# repo_id: int, +# payload: Dict[str, Any], +# status_code: int, +# status_detail: Union[str, None], +# ): +# auth = None +# if isinstance(user_idx, int): +# auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split()) - response = await async_client.post(f"/compute/analyze/{repo_id}", json=payload, headers=auth) - assert response.status_code == status_code, print(response.__dict__) - if isinstance(status_detail, str): - assert response.json()["detail"] == status_detail +# response = await async_client.post(f"/compute/analyze/{repo_id}", json=payload, headers=auth) +# assert response.status_code == status_code, print(response.__dict__) +# if isinstance(status_detail, str): +# assert response.json()["detail"] == status_detail diff --git a/src/tests/endpoints/test_repos.py b/src/tests/endpoints/test_repos.py index 7284991..c0be0e6 100644 --- a/src/tests/endpoints/test_repos.py +++ b/src/tests/endpoints/test_repos.py @@ -381,27 +381,27 @@ async def test_add_repo_to_waitlist( assert response.json()["detail"] == status_detail -@pytest.mark.parametrize( - ("user_idx", "repo_id", "status_code", "status_detail"), - [ - (None, 12345, 401, "Not authenticated"), - (0, 100, 404, "Table Repository has no corresponding entry."), - ], -) -@pytest.mark.asyncio() -async def test_parse_guidelines_from_github( - async_client: AsyncClient, - guideline_session: AsyncSession, - user_idx: Union[int, None], - repo_id: int, - status_code: int, - status_detail: Union[str, None], -): - auth = None - if isinstance(user_idx, int): - auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split()) - - response = await async_client.post(f"/repos/{repo_id}/parse", json={}, headers=auth) - assert response.status_code == status_code, print(response.json()) - if isinstance(status_detail, str): - assert response.json()["detail"] == status_detail +# @pytest.mark.parametrize( +# ("user_idx", "repo_id", "status_code", "status_detail"), +# [ +# (None, 12345, 401, "Not authenticated"), +# (0, 100, 404, "Table Repository has no corresponding entry."), +# ], +# ) +# @pytest.mark.asyncio() +# async def test_parse_guidelines_from_github( +# async_client: AsyncClient, +# guideline_session: AsyncSession, +# user_idx: Union[int, None], +# repo_id: int, +# status_code: int, +# status_detail: Union[str, None], +# ): +# auth = None +# if isinstance(user_idx, int): +# auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split()) + +# response = await async_client.post(f"/repos/{repo_id}/parse", json={}, headers=auth) +# assert response.status_code == status_code, print(response.json()) +# if isinstance(status_detail, str): +# assert response.json()["detail"] == status_detail