diff --git a/README.md b/README.md index aa8fc0e..4b3ad64 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ The default installation dependencies, as defined in the `pyproject.toml`, are s > Users have run this codebase with Python 3.9,3.10 and cuda_12, cuda-11.8 ``` -> pip install . +> pip install -e . ``` Developers should set up `pre-commit` as well with `pre-commit install`. diff --git a/ml_mdm/clis/generate_sample.py b/ml_mdm/clis/generate_sample.py index db44701..975ab31 100644 --- a/ml_mdm/clis/generate_sample.py +++ b/ml_mdm/clis/generate_sample.py @@ -20,7 +20,13 @@ from ml_mdm.config import get_arguments, get_model, get_pipeline from ml_mdm.language_models import factory -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" +) # Note that it is called add_arguments, not add_argument. logging.basicConfig( diff --git a/pyproject.toml b/pyproject.toml index ade6b1d..9fe16f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "imageio[ffmpeg]", "matplotlib", "mlx-data", - "numpy", + "numpy<2", "pytorch-model-summary", "rotary-embedding-torch", "simple-parsing==0.1.5",