Skip to content

Commit

Permalink
fix dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Dec 6, 2023
1 parent 424f8b6 commit 4b40f34
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"transformers>=4.35.2",
"typer",
"scipy",
# This works, while installing from pytorch and cuda from conda does not",
Expand All @@ -42,17 +41,18 @@ dev = [
"nbsphinx", # Used to integrate Python notebooks into Sphinx documentation
"ipython", # Also used in building notebooks into Sphinx
"matplotlib", # Used in sample notebook intro_notebook.ipynb
"numpy", # Used in sample notebook intro_notebook.ipynb
"ipykernel",
]
train = [
"jsonlines",
"mlflow",
"azureml-mlflow",
]
azure = [
"azureml-core",
"jsonlines",
"azureml-mlflow",
]

train_cuda = [
cuda = [
"transformers>=4.35.2",
"bitsandbytes>=0.41.2.post2",
"accelerate>=0.24.1",
"xformers",
Expand Down
8 changes: 5 additions & 3 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from timeit import default_timer as timer
from typing import List

import jsonlines
import mlflow
import torch
import typer

Expand All @@ -20,6 +18,11 @@

@app.command()
def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> List[str]:
import jsonlines
import mlflow

mlflow.autolog()

run = mlflow.active_run()

sys_prompt = SYS[sys_id]
Expand Down Expand Up @@ -82,5 +85,4 @@ def import_model(model_name: str) -> None:
if __name__ == "__main__":
logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}")

mlflow.autolog()
app()

0 comments on commit 4b40f34

Please sign in to comment.