diff --git a/lit_llama/utils.py b/lit_llama/utils.py index 1d26d2c6..d9eddf1b 100644 --- a/lit_llama/utils.py +++ b/lit_llama/utils.py @@ -520,23 +520,7 @@ def get_packages(pkgs): def check_python_packages(): - d = { - 'torch': '2.0.0', - 'lightning': '2.1.0.dev0', - } - - versions = get_packages(d.keys()) - - wrong_package_triggered = False - for (pkg_name, suggested_ver), actual_ver in zip(d.items(), versions): - if actual_ver == 'N/A': - continue - actual_ver, suggested_ver = version_parse(actual_ver), version_parse(suggested_ver) - if actual_ver < suggested_ver: - print(f'[FAIL] {pkg_name} {actual_ver}, please upgrade to {suggested_ver}') - wrong_package_triggered = True - else: - print(f'[OK] {pkg_name} {actual_ver}') - - if wrong_package_triggered: + torch_ = RequirementCache('torch>=2.0.0') + lit_ = RequirementCache('lightning>=2.1.0') + if not bool(torch_) or not bool(lit_): raise ImportError("Wrong package version(s) installed.") \ No newline at end of file