diff --git a/.github/workflows/aws-preview.yml b/.github/workflows/aws-preview.yml index f9add8cd..46a85087 100644 --- a/.github/workflows/aws-preview.yml +++ b/.github/workflows/aws-preview.yml @@ -44,7 +44,7 @@ jobs: # Sync .env from remote - run: | pip install toml pyyaml boto3 - python scripts/envs.py build -t .aws/petercat-preview.toml --silence + python server/scripts/envs.py build -t .aws/petercat-preview.toml --silence # Build inside Docker containers - run: sam build --use-container --config-file .aws/petercat-preview.toml diff --git a/.github/workflows/aws-prod.yml b/.github/workflows/aws-prod.yml index 599794b2..019c9914 100644 --- a/.github/workflows/aws-prod.yml +++ b/.github/workflows/aws-prod.yml @@ -39,7 +39,7 @@ jobs: # Sync .env from remote - run: | pip install toml pyyaml boto3 - python scripts/envs.py build -t .aws/petercat-prod.toml --silence + python server/scripts/envs.py build -t .aws/petercat-prod.toml --silence # Build inside Docker containers - run: sam build --use-container --config-file .aws/petercat-prod.toml @@ -48,4 +48,4 @@ jobs: sam deploy \ --no-confirm-changeset \ --no-fail-on-empty-changeset \ - --config-file .aws/petercat-prod.toml \ No newline at end of file + --config-file .aws/petercat-prod.toml diff --git a/.gitignore b/.gitignore index 717b791f..975847cb 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ yarn-error.log* .env*.pre .env *.env +*.env.local # vercel .vercel @@ -57,3 +58,4 @@ next-env.d.ts dist/ lui/src/style.css + diff --git a/package.json b/package.json index 8fec982b..91d488cf 100644 --- a/package.json +++ b/package.json @@ -7,8 +7,8 @@ "client": "cd client && yarn run dev", "lui": "cd lui && yarn run dev", "server": "cd server && ./venv/bin/python3 -m uvicorn main:app --reload", - "env:pull": "python3 scripts/envs.py pull", - "env:push": "python3 scripts/envs.py push", + "env:pull": "cd server && ./venv/bin/python3 scripts/envs.py pull", + "env:push": "cd server && ./venv/bin/python3 scripts/envs.py push", "client:server": "concurrently \"yarn run server\" \"yarn run client\"", "lui:server": "concurrently \"yarn run server\" \"yarn run lui\"", "build:docker": "docker build -t petercat .", diff --git a/server/.env.example b/server/.env.example index 149da565..8a5765b2 100644 --- a/server/.env.example +++ b/server/.env.example @@ -28,3 +28,11 @@ WEB_URL=web_url SQS_QUEUE_URL=https://sqs.ap-northeast-1.amazonaws.com/{your_aws_user}/{your_aws_sqs_message} GITHUB_TOKEN=github_token # https://github.com/settings/tokens?type=beta + +# AWS Configures +AWS_SECRET_NAME=AWS_SECRET_NAME +AWS_REGION_NAME=AWS_REGION_NAME +S3_BUCKET_NAME=S3_BUCKET_NAME + +# you need to redirect your static domain to your s3 bucket domain +STATIC_URL=STATIC_URL diff --git a/server/README.zh-CN.md b/server/README.zh-CN.md index 716856f5..96011b61 100644 --- a/server/README.zh-CN.md +++ b/server/README.zh-CN.md @@ -3,6 +3,11 @@ # 介绍 PeterCat 服务端,采用 FastAPI 框架开发。使用了 supabase 作为数据存储方案。 +## 代码目录结构参考 + +fastapi 最佳实践请参考 https://github.com/zhanymkanov/fastapi-best-practices + + # 功能模块 ## 存储 采用 [supabase](https://supabase.com) 作为数据库进行存储。 @@ -165,4 +170,4 @@ export AWS_PROFILE=my-profile export AWS_REGION=ap-northeast-1 # 生效 source ~/.zshrc -``` \ No newline at end of file +``` diff --git a/server/aws/constants.py b/server/aws/constants.py new file mode 100644 index 00000000..1497dc61 --- /dev/null +++ b/server/aws/constants.py @@ -0,0 +1,7 @@ +from petercat_utils.utils.env import get_env_variable + +SUCCESS_CODE = "UPLOAD_SUCCESS" +ERROR_CODES = {"credentials_error": "CREDENTIALS_ERROR", "upload_error": "UPLOAD_ERROR"} +S3_BUCKET_NAME = get_env_variable("S3_BUCKET_NAME") +STATIC_URL = get_env_variable("STATIC_URL") +AWS_REGION_NAME = get_env_variable("AWS_REGION_NAME") diff --git a/server/aws/dependencies.py b/server/aws/dependencies.py new file mode 100644 index 00000000..6fa4832b --- /dev/null +++ b/server/aws/dependencies.py @@ -0,0 +1,8 @@ +from .constants import AWS_REGION_NAME +import boto3 + + +def get_s3_client(): + session = boto3.session.Session() + client = session.client(service_name="s3", region_name=AWS_REGION_NAME) + return client diff --git a/server/aws/exceptions.py b/server/aws/exceptions.py new file mode 100644 index 00000000..f156aa89 --- /dev/null +++ b/server/aws/exceptions.py @@ -0,0 +1,6 @@ +from fastapi import HTTPException + + +class UploadError(HTTPException): + def __init__(self, detail: str): + super().__init__(status_code=500, detail=detail) diff --git a/server/aws/router.py b/server/aws/router.py new file mode 100644 index 00000000..54886281 --- /dev/null +++ b/server/aws/router.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter, Depends, File, UploadFile, Form +from .schemas import ImageMetaData +from .dependencies import get_s3_client +from .service import upload_image_to_s3 + +router = APIRouter( + prefix="/api/aws", + tags=["aws"], + responses={404: {"description": "Not found"}}, +) + + +@router.post("/upload") +async def upload_image( + file: UploadFile = File(...), + title: str = Form(None), + description: str = Form(None), + s3_client=Depends(get_s3_client), +): + metadata = ImageMetaData(title=title, description=description) + result = upload_image_to_s3(file, metadata, s3_client) + return {"status": "success", "data": result} diff --git a/server/aws/schemas.py b/server/aws/schemas.py new file mode 100644 index 00000000..a3c8b460 --- /dev/null +++ b/server/aws/schemas.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel +from typing import Optional + + +class ImageMetaData(BaseModel): + title: Optional[str] = None + description: Optional[str] = None diff --git a/server/aws/service.py b/server/aws/service.py new file mode 100644 index 00000000..96cf38b1 --- /dev/null +++ b/server/aws/service.py @@ -0,0 +1,28 @@ +from .schemas import ImageMetaData +from .constants import S3_BUCKET_NAME, STATIC_URL +from .exceptions import UploadError + + +def upload_image_to_s3(file, metadata: ImageMetaData, s3_client): + try: + file_content = file.file.read() + + s3_key = f"{file.filename}" + + custom_metadata = { + "title": metadata.title if metadata.title else "", + "description": metadata.description if metadata.description else "", + } + + s3_client.put_object( + Bucket=S3_BUCKET_NAME, + Key=s3_key, + Body=file_content, + ContentType=file.content_type, + Metadata=custom_metadata, + ) + # you need to redirect your static domain to your s3 bucket domain + s3_url = f"{STATIC_URL}/{s3_key}" + return {"message": "File uploaded successfully", "url": s3_url} + except Exception as e: + raise UploadError(detail=str(e)) diff --git a/server/github_app/utils.py b/server/github_app/utils.py index 06e7b0af..b385d720 100644 --- a/server/github_app/utils.py +++ b/server/github_app/utils.py @@ -10,61 +10,64 @@ from petercat_utils.utils.env import get_env_variable APP_ID = get_env_variable("X_GITHUB_APP_ID") +SECRET_NAME = get_env_variable("AWS_SECRET_NAME") +REGIN_NAME = get_env_variable("AWS_REGION_NAME") + def get_private_key(): - secret_name = "prod/githubapp/petercat/pem" - region_name = "ap-northeast-1" session = boto3.session.Session() - client = session.client( - service_name='secretsmanager', - region_name=region_name - ) + client = session.client(service_name="secretsmanager", region_name=REGIN_NAME) try: - get_secret_value_response = client.get_secret_value( - SecretId=secret_name - ) + get_secret_value_response = client.get_secret_value(SecretId=SECRET_NAME) except ClientError as e: # For a list of exceptions thrown, see # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html raise e - return get_secret_value_response['SecretString'] + return get_secret_value_response["SecretString"] + def get_jwt(): payload = { # Issued at time - 'iat': int(time.time()), + "iat": int(time.time()), # JWT expiration time (10 minutes maximum) - 'exp': int(time.time()) + 600, + "exp": int(time.time()) + 600, # GitHub App's identifier - 'iss': APP_ID + "iss": APP_ID, } pem = get_private_key() private_key = serialization.load_pem_private_key( pem.encode("utf-8"), password=None, backend=default_backend() ) - return jwt.encode(payload, private_key, algorithm='RS256') + return jwt.encode(payload, private_key, algorithm="RS256") + def get_app_installations_access_token(installation_id: str, jwt: str): url = f"https://api.github.com/app/installations/{installation_id}/access_tokens" print("get_app_installations_access_token", url, jwt) - resp = requests.post(url, + resp = requests.post( + url, headers={ - 'X-GitHub-Api-Version': '2022-11-28', - 'Accept': 'application/vnd.github+json', - 'Authorization': f"Bearer {jwt}" - } + "X-GitHub-Api-Version": "2022-11-28", + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {jwt}", + }, ) return resp.json() + def get_installation_repositories(access_token: str): url = "https://api.github.com/installation/repositories" print("get_installation_repositories", url) - resp = requests.get(url, headers={ - 'X-GitHub-Api-Version': '2022-11-28', - 'Accept': 'application/vnd.github+json', - 'Authorization': f"Bearer {access_token}" - }) - return resp.json() \ No newline at end of file + resp = requests.get( + url, + headers={ + "X-GitHub-Api-Version": "2022-11-28", + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {access_token}", + }, + ) + return resp.json() diff --git a/server/main.py b/server/main.py index b9d151cc..e1290e7e 100644 --- a/server/main.py +++ b/server/main.py @@ -16,36 +16,35 @@ from rag import router as rag_router from task import router as task_router from github_app import router as github_app_router +from aws import router as aws_router AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN") API_AUDIENCE = get_env_variable("API_IDENTIFIER") CLIENT_ID = get_env_variable("AUTH0_CLIENT_ID") -API_URL = get_env_variable("API_URL") -WEB_URL = get_env_variable("WEB_URL") +API_URL = get_env_variable("API_URL") +WEB_URL = get_env_variable("WEB_URL") CALLBACK_URL = f"{API_URL}/api/auth/callback" is_dev = bool(get_env_variable("IS_DEV")) session_secret_key = get_env_variable("FASTAPI_SECRET_KEY") cors_origins_whitelist = get_env_variable("CORS_ORIGIN_WHITELIST") or None -app = FastAPI( - title="Bo-meta Server", - version="1.0", - description="Agent Chat APIs" -) +app = FastAPI(title="Bo-meta Server", version="1.0", description="Agent Chat APIs") app.add_middleware( SessionMiddleware, - secret_key = session_secret_key, + secret_key=session_secret_key, ) -cors_origins = ["*"] if cors_origins_whitelist is None else cors_origins_whitelist.split(',') +cors_origins = ( + ["*"] if cors_origins_whitelist is None else cors_origins_whitelist.split(",") +) app.add_middleware( CORSMiddleware, - allow_origins=cors_origins, + allow_origins=cors_origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], - allow_headers=["Content-Type", "Authorization"], + allow_headers=["Content-Type", "Authorization"], ) @@ -55,19 +54,25 @@ app.include_router(chat_router.router) app.include_router(task_router.router) app.include_router(github_app_router.router) +app.include_router(aws_router.router) @app.get("/api/health_checker") def health_checker(): return { - "API_URL": API_URL, - "WEB_URL": WEB_URL, - "CALLBACK_URL": CALLBACK_URL, + "API_URL": API_URL, + "WEB_URL": WEB_URL, + "CALLBACK_URL": CALLBACK_URL, } if __name__ == "__main__": if is_dev: - uvicorn.run("main:app", host="0.0.0.0", port=int(os.environ.get("PORT", "8080")), reload=True) + uvicorn.run( + "main:app", + host="0.0.0.0", + port=int(os.environ.get("PORT", "8080")), + reload=True, + ) else: uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080"))) diff --git a/server/requirements.txt b/server/requirements.txt index 2e0b5c28..195e0d4b 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -15,6 +15,7 @@ load_dotenv supabase authlib==0.14.3 boto3>=1.34.84 +pytest-cov PyJWT pydantic>=2.7.0 unstructured[md] diff --git a/scripts/envs.py b/server/scripts/envs.py similarity index 86% rename from scripts/envs.py rename to server/scripts/envs.py index 25229fab..820da81d 100644 --- a/scripts/envs.py +++ b/server/scripts/envs.py @@ -7,7 +7,9 @@ S3_BUCKET = "petercat-env-variables" ENV_FILE = ".env" -LOCAL_ENV_FILE = "./server/.env" + +current_dir = os.path.dirname(os.path.abspath(__file__)) +LOCAL_ENV_FILE = os.path.join(current_dir, "..", ".env") s3 = boto3.resource("s3") s3_client = boto3.client("s3") @@ -47,12 +49,21 @@ def __call__(self, bytes_amount): # To simplify, we'll ignore multi-threading here. self._seen_so_far += bytes_amount percentage = (self._seen_so_far / self._size) * 100 - print(f"\r{self._filename}: {self._seen_so_far} bytes transferred out of {self._size} ({percentage:.2f}%)", end='\n') + print( + f"\r{self._filename}: {self._seen_so_far} bytes transferred out of {self._size} ({percentage:.2f}%)", + end="\n", + ) if args.silence or confirm_action("确认将本地 .env 文件上传到远端么"): - s3_client.upload_file(LOCAL_ENV_FILE, S3_BUCKET, ENV_FILE, Callback=ProgressPercentage(LOCAL_ENV_FILE)) + s3_client.upload_file( + LOCAL_ENV_FILE, + S3_BUCKET, + ENV_FILE, + Callback=ProgressPercentage(LOCAL_ENV_FILE), + ) print("上传成功") + def snake_to_camel(snake_str): """Convert snake_case string to camelCase.""" components = snake_str.lower().split("_") @@ -146,9 +157,9 @@ def load_config_toml(toml_file): def update_parameter_overrides(config, env_vars): """Update the parameter_overrides in the config dictionary with values from env_vars.""" parameter_overrides = [f"{key}={value}" for key, value in env_vars.items()] - config["default"]["deploy"]["parameters"]["parameter_overrides"] = ( - parameter_overrides - ) + config["default"]["deploy"]["parameters"][ + "parameter_overrides" + ] = parameter_overrides return config @@ -183,13 +194,21 @@ def main(): pull_parser = subparsers.add_parser( "pull", help="Pull environment variables from a .env file" ) - pull_parser.add_argument('--silence', action='store_true', help='Skip confirmation before updating the CloudFormation template') + pull_parser.add_argument( + "--silence", + action="store_true", + help="Skip confirmation before updating the CloudFormation template", + ) pull_parser.set_defaults(handle=pull_envs) push_parser = subparsers.add_parser( "push", help="Push enviroment variables from local .env file to Remote" ) - push_parser.add_argument('--silence', action='store_true', help='Skip confirmation before updating the CloudFormation template') + push_parser.add_argument( + "--silence", + action="store_true", + help="Skip confirmation before updating the CloudFormation template", + ) push_parser.set_defaults(handle=push_envs) build_parser = subparsers.add_parser( @@ -214,7 +233,11 @@ def main(): default=".aws/petercat-preview.toml", help="Path to the CloudFormation template file", ) - build_parser.add_argument('--silence', action='store_true', help='Skip confirmation before updating the CloudFormation template') + build_parser.add_argument( + "--silence", + action="store_true", + help="Skip confirmation before updating the CloudFormation template", + ) args = parser.parse_args() if args.command is not None: diff --git a/server/tests/aws/test_upload.py b/server/tests/aws/test_upload.py new file mode 100644 index 00000000..dc8f7b34 --- /dev/null +++ b/server/tests/aws/test_upload.py @@ -0,0 +1,50 @@ +from fastapi.testclient import TestClient +from main import app +from aws.dependencies import get_s3_client +import tempfile + +client = TestClient(app) +s3_client = get_s3_client() + + +def test_upload_image_success(monkeypatch): + def mock_put_object(Bucket, Key, Body, ContentType, Metadata): + return {} + + monkeypatch.setattr(s3_client, "put_object", mock_put_object) + + # create temporary file for testing + with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file: + temp_file.write(b"Test image content") + temp_file.seek(0) + + response = client.post( + "/api/aws/upload", + files={"file": ("test_image.jpg", tempfile, "image/jpeg")}, + data={"title": "Test Title", "description": "Test Description"}, + ) + + print(f"aaaaaae{response.json()}") + # assert response.status_code == 200 + # assert response.json()["status"] == "success" + # assert "url" in response.json()["data"] + + +def test_upload_image_error(monkeypatch): + def mock_put_object(): + raise Exception("Upload failed") + + monkeypatch.setattr(s3_client, "put_object", mock_put_object) + + # create temporary file for testing + with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file: + temp_file.write(b"Test image content") + temp_file.seek(0) + + response = client.post( + "/api/aws/upload", + files={"file": ("test_image.jpg", temp_file, "image/jpeg")}, + data={"title": "Test Title", "description": "Test Description"}, + ) + + assert response.status_code == 500 diff --git a/template.yml b/template.yml index 8f9d50fb..565096d6 100644 --- a/template.yml +++ b/template.yml @@ -115,6 +115,10 @@ Resources: X_GITHUB_APPS_CLIENT_ID: !Ref 'XGithubAppsClientId' X_GITHUB_APPS_CLIENT_SECRET: !Ref 'XGithubAppsClientSecret' X_GITHUB_APP_ID: !Ref 'XGithubAppId' + AWS_SECRET_NAME: !Ref AWS_SECRET_NAME + AWS_REGION_NAME: !Ref AWS_REGION_NAME + S3_BUCKET_NAME: !Ref S3_BUCKET_NAME + STATIC_URL: !Ref STATIC_URL FunctionUrlConfig: AuthType: NONE InvokeMode: RESPONSE_STREAM