Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG: after training the model, i am not able to merge the model and run inference on it #89

Open
kiranshivaraju opened this issue Aug 7, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@kiranshivaraju
Copy link

kiranshivaraju commented Aug 7, 2024

Python Version

Python 3.10.8

Pip Freeze

Python 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:23:14) [GCC 10.4.0]
sh-4.2$ pip freeze
absl-py==2.1.0
aiobotocore @ file:///home/conda/feedstock_root/build_artifacts/aiobotocore_1719300089447/work
aiofiles==22.1.0
aiohttp @ file:///home/conda/feedstock_root/build_artifacts/aiohttp_1713964853148/work
aioitertools @ file:///home/conda/feedstock_root/build_artifacts/aioitertools_1663521246073/work
aiosignal @ file:///home/conda/feedstock_root/build_artifacts/aiosignal_1667935791922/work
aiosqlite==0.20.0
annotated-types==0.7.0
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1717693030552/work
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1692818318753/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1695386546427/work
arrow @ file:///home/conda/feedstock_root/build_artifacts/arrow_1696128962909/work
astroid==3.2.2
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
async-timeout @ file:///home/conda/feedstock_root/build_artifacts/async-timeout_1691763562544/work
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1704011227531/work
autopep8==2.0.4
autovizwidget @ file:///home/conda/feedstock_root/build_artifacts/autovizwidget_1695512335638/work
awscli==1.33.24
Babel==2.15.0
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1705564648255/work
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1696630167146/work
boto3==1.34.142
botocore==1.34.142
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1695989787169/work
cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1720457958366/work/certifi
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1696001684923/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
cloudpickle==2.2.1
colorama==0.4.6
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1672672382195/work
Cython @ file:///home/conda/feedstock_root/build_artifacts/cython_1711833953208/work
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1719378659226/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
dill==0.3.8
docker==7.1.0
docstring-to-markdown==0.15
docstring_parser==0.16
docutils==0.16
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
environment-kernels==1.2.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1718477020893/work/dist
filelock==3.15.4
fire==0.6.0
fqdn @ file:///home/conda/feedstock_root/build_artifacts/fqdn_1638810296540/work/dist
frozenlist @ file:///home/conda/feedstock_root/build_artifacts/frozenlist_1702645481127/work
fsspec @ file:///home/conda/feedstock_root/build_artifacts/fsspec_1719514913127/work
gitdb==4.0.11
GitPython==3.1.43
google-pasta==0.2.0
grpcio==1.65.4
gssapi==1.8.3
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1634280454336/work
hdijupyterutils @ file:///home/conda/feedstock_root/build_artifacts/hdijupyterutils_1695512275947/work
hpack==4.0.0
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1619110129307/work
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1713279365350/work
importlib-metadata==6.11.0
importlib_resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1711040877059/work
ipykernel==5.5.6
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1719582526268/work
ipython_genutils @ file:///home/conda/feedstock_root/build_artifacts/ipython_genutils_1716278396992/work
ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1716897651763/work
isoduration @ file:///home/conda/feedstock_root/build_artifacts/isoduration_1638811571363/work/dist
isort==5.13.2
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1715127149914/work
jmespath @ file:///home/conda/feedstock_root/build_artifacts/jmespath_1655568249366/work
json5==0.9.25
jsonpointer @ file:///home/conda/feedstock_root/build_artifacts/jsonpointer_1718283388110/work
jsonschema==4.21.1
jsonschema-specifications @ file:///tmp/tmpkv1z7p57/src
jupyter @ file:///home/conda/feedstock_root/build_artifacts/jupyter_1696255489086/work
jupyter-console @ file:///home/conda/feedstock_root/build_artifacts/jupyter_console_1678118109161/work
jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/jupyter_events_1710805637316/work
jupyter-lsp==2.2.5
jupyter-server-mathjax==0.2.6
jupyter-server-proxy @ git+https://github.com/jupyterhub/jupyter-server-proxy@2d7dd346bb595106b417476de870a348943f3c70
jupyter-ydoc==0.2.5
jupyter_client==7.4.9
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257277185/work
jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1720529946765/work
jupyter_server_fileid==0.9.2
jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1710262634903/work
jupyter_server_ydoc==0.8.0
jupyterlab==3.6.7
jupyterlab-git==0.41.0
jupyterlab-lsp==4.3.0
jupyterlab_pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1707149102966/work
jupyterlab_server==2.27.2
jupyterlab_widgets @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_widgets_1716891641122/work
krb5==0.5.1
Markdown==3.6
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1706899921127/work
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
mccabe==0.7.0
mistral_common==1.3.3
mistral_inference==1.3.1
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
mpmath==1.3.0
multidict @ file:///home/conda/feedstock_root/build_artifacts/multidict_1707040698785/work
multiprocess==0.70.16
nb_conda @ file:///home/conda/feedstock_root/build_artifacts/nb_conda_1704789357480/work
nb_conda_kernels @ file:///home/conda/feedstock_root/build_artifacts/nb_conda_kernels_1708439411368/work
nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1675369808718/work
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1710317608672/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1674590374792/work
nbdime==3.2.1
nbexamples @ file:///opt/nbexamples
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1712238998817/work
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
networkx==3.3
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1715848908871/work
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1707957777232/work
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1707225380409/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=51131fd8fc130cd168aecaf1bc0ea85f92e8ffebf211772ceb16ac2e7f10d7ca
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.6.20
nvidia-nvtx-cu12==12.1.105
overrides @ file:///home/conda/feedstock_root/build_artifacts/overrides_1706394519472/work
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
pandas==1.5.3
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
pathos==0.3.2
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
pid==3.0.4
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1694617248815/work
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
plotly @ file:///home/conda/feedstock_root/build_artifacts/plotly_1714829923649/work
pluggy==1.5.0
pox==0.3.4
ppft==1.7.6.8
prometheus_client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1707932675456/work
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
protobuf==4.25.3
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1719274566094/work
psycopg2 @ file:///home/conda/feedstock_root/build_artifacts/psycopg2-split_1667025517155/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
py4j==0.10.9.5
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1711811537435/work
pydantic==2.6.1
pydantic_core==2.16.2
pydocstyle==6.3.0
pyflakes==3.2.0
pygal==3.0.4
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
pykerberos @ file:///home/conda/feedstock_root/build_artifacts/pykerberos_1671204518513/work
pylint==3.2.5
PyQt5==5.12.3
PyQt5_sip==4.19.18
PyQtChart==5.12
PyQtWebEngine==5.12.1
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
pyspark==3.3.0
pyspnego @ file:///home/conda/feedstock_root/build_artifacts/pyspnego_1720048480072/work
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1677079630776/work
python-lsp-jsonrpc==1.1.2
python-lsp-server==1.11.0
pytoolconfig==1.3.1
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1695373428874/work
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1715024398995/work
qtconsole @ file:///home/conda/feedstock_root/build_artifacts/qtconsole-base_1714942934316/work
QtPy @ file:///home/conda/feedstock_root/build_artifacts/qtpy_1698112029416/work
referencing @ file:///home/conda/feedstock_root/build_artifacts/referencing_1714619483868/work
regex==2024.7.24
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1717057054362/work
requests-kerberos @ file:///home/conda/feedstock_root/build_artifacts/requests-kerberos_1697118166865/work
rfc3339-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1638811747357/work
rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
rope==1.13.0
rpds-py @ file:///home/conda/feedstock_root/build_artifacts/rpds-py_1720476547861/work
rsa==4.7.2
ruamel.yaml @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml_1707298115475/work
ruamel.yaml.clib @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml.clib_1707314473442/work
s3fs @ file:///home/conda/feedstock_root/build_artifacts/s3fs_1719518013855/work
s3transfer @ file:///home/conda/feedstock_root/build_artifacts/s3transfer_1719300139436/work
safetensors==0.4.4
sagemaker==2.225.0
sagemaker-experiments==0.1.45
sagemaker_nbi_agent @ file:///opt/sagemaker_nbi_agent
sagemaker_pyspark==1.4.5
schema==0.7.7
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1712584999685/work
sentencepiece==0.2.0
simpervisor==1.0.0
simple_parsing==0.1.5
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
smdebug-rulesconfig==1.0.1
smmap==5.0.1
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1708952932303/work
snowballstemmer==2.2.0
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1693929250441/work
sparkmagic @ file:///home/conda/feedstock_root/build_artifacts/sparkmagic_1695511119398/work/sparkmagic
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
sympy==1.13.1
tblib==3.0.0
tenacity @ file:///home/conda/feedstock_root/build_artifacts/tenacity_1720351771156/work
tensorboard==2.17.0
tensorboard-data-server==0.7.2
termcolor==2.4.0
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1710262609923/work
tiktoken==0.7.0
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1713974937325/work
tomli==2.0.1
tomlkit==0.13.0
torch==2.2.0
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1717722796999/work
tqdm==4.66.4
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
triton==2.2.0
types-python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/types-python-dateutil_1710589910274/work
typing-utils @ file:///home/conda/feedstock_root/build_artifacts/typing_utils_1622899189314/work
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
ujson==5.10.0
uri-template @ file:///home/conda/feedstock_root/build_artifacts/uri-template_1688655812972/work/dist
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1719391292974/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
webcolors @ file:///home/conda/feedstock_root/build_artifacts/webcolors_1717667289718/work
webencodings @ file:///home/conda/feedstock_root/build_artifacts/webencodings_1694681268211/work
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1713923384721/work
Werkzeug==3.0.3
widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1716891659446/work
wrapt @ file:///home/conda/feedstock_root/build_artifacts/wrapt_1699532811524/work
xformers==0.0.24
y-py==0.6.2
yarl @ file:///home/conda/feedstock_root/build_artifacts/yarl_1705508292061/work
ypy-websocket==0.8.4
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1718013267051/work
zstandard==0.22.0



### Reproduction Steps
NOTE: i am using ml.g5.4xlarge notebook instance for training mistral v3 7Bmodel
1. fine tune the model as told in the readme file
2. this is my yaml file

data:
instruct_data: "/home/ec2-user/data/train_file.jsonl" # Fill
data: "" # Optionally fill with pretraining data
eval_instruct_data: "/home/ec2-user/data/test_file.jsonl" # Optionally fill

model_id_or_path: "/home/ec2-user/mistral_models/" # Change to downloaded path
lora:
rank: 32

seq_len: 2048
batch_size: 1
max_steps: 300
optim:
lr: 6.e-5
weight_decay: 0.1
pct_start: 0.05

seed: 0
log_freq: 1
eval_freq: 100
no_eval: False
ckpt_freq: 100

save_adapters: True # save only trained LoRA adapters. Set to False to merge LoRA adapter into the base model and save full fine-tuned model

run_dir: "/home/ec2-user/outputs" # Fill


3. i use this command to initiate fine tune process:
 torchrun --nproc-per-node 1 --master_port $RANDOM -m train example/7B.yaml

4. after fine tune is finished, i download mistral-inference, and run this:
mistral-chat /home/ec2-user/mistral_models/ --max_tokens 256 --temparature 1.0 --instruct --lora_path /home/ec2-user/outputs/checkpoints/checkpoint_000300/consolidated/lora.safetensors

This is the error i get:
Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/bin/mistral-chat", line 8, in <module>
    sys.exit(mistral_chat())
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.10/site-packages/mistral_inference/main.py", line 207, in mistral_chat
    fire.Fire(interactive)
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.10/site-packages/mistral_inference/main.py", line 91, in interactive
    model.load_lora(Path(lora_path))
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.10/site-packages/mistral_inference/lora.py", line 101, in load_lora
    self._load_lora_state_dict(state_dict, scaling=scaling)
  File "/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.10/site-packages/mistral_inference/lora.py", line 135, in _load_lora_state_dict
    + (lora_state_dict[name + ".lora_B.weight"] @ lora_state_dict[name + ".lora_A.weight"])
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 21.99 GiB of which 99.38 MiB is free. Including non-PyTorch memory, this process has 21.88 GiB memory in use. Of the allocated memory 21.57 GiB is allocated by PyTorch, and 13.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)




### Expected Behavior

i expect to interact with the fine tune model

### Additional Context

_No response_

### Suggested Solutions

_No response_
@kiranshivaraju kiranshivaraju added the bug Something isn't working label Aug 7, 2024
@kiranshivaraju kiranshivaraju changed the title [BUG: after training the mode, i am not able to merge the model and run inference on it [BUG: after training the model, i am not able to merge the model and run inference on it Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant