diff --git a/crates/aim-downloader/src/hash.rs b/crates/aim-downloader/src/hash.rs index d7544705d626..8a2470828531 100644 --- a/crates/aim-downloader/src/hash.rs +++ b/crates/aim-downloader/src/hash.rs @@ -9,7 +9,7 @@ impl HashChecker { pub fn check(filename: &str, expected_hash: &str) -> Result<(), ValidateError> { let mut result = Ok(()); if filename != "stdout" && (!expected_hash.is_empty()) { - let actual_hash = HashChecker::sha256sum(filename); + let actual_hash = HashChecker::sha256sum(filename)?; if actual_hash != expected_hash { result = Err(ValidateError::Sha256Mismatch); } @@ -22,15 +22,21 @@ impl HashChecker { result } - fn sha256sum(filename: &str) -> String { + fn sha256sum(filename: &str) -> Result { let mut hasher = Sha256::new(); - let mut file = fs::File::open(filename).unwrap(); - - io::copy(&mut file, &mut hasher).unwrap(); + let mut file = fs::File::open(filename).map_err(|e| { + println!("Can not open {filename}:\n {e}"); + ValidateError::Sha256Mismatch + })?; + + io::copy(&mut file, &mut hasher).map_err(|e| { + println!("Can not read {filename}:\n {e}"); + ValidateError::Sha256Mismatch + })?; let computed_hash = hasher.finalize(); drop(file); - format!("{computed_hash:x}") + Ok(format!("{computed_hash:x}")) } } @@ -61,7 +67,7 @@ mod tests { fn test_sha256sum_api() { let expected = "21d7847124bfb9d9a9d44af6f00d8003006c44b9ef9ba458b5d4d3fc5f81bde5"; - let actual = HashChecker::sha256sum("LICENCE.md"); + let actual = HashChecker::sha256sum("LICENCE.md").unwrap(); assert_eq!(actual, expected); } diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 753d93496739..f0d5c59efa96 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -85,34 +85,45 @@ async fn download_model_impl( ); } - if !prefer_local_file { - info!("Checking model integrity.."); - - let mut sha256_matched = true; - for (index, url) in urls.iter().enumerate() { - if HashChecker::check( - partitioned_file_name!(index + 1, urls.len()).as_str(), - &url.1, - ) - .is_err() - { - sha256_matched = false; - break; - } + let mut model_existed = true; + for (index, _) in urls.iter().enumerate() { + if fs::metadata( + registry + .get_model_store_dir(name) + .join(partitioned_file_name!(index, urls.len())), + ) + .is_err() + { + model_existed = false; + break; } + } - if sha256_matched { - return Ok(()); - } + if model_existed && prefer_local_file { + return Ok(()); + } - warn!( - "Checksum doesn't match for <{}/{}>, re-downloading...", - registry.name, name - ); + info!("Checking model integrity.."); - fs::remove_dir_all(registry.get_model_dir(name))?; + let mut sha256_matched = true; + for (index, url) in urls.iter().enumerate() { + if HashChecker::check(partitioned_file_name!(index, urls.len()).as_str(), &url.1).is_err() { + sha256_matched = false; + break; + } + } + + if sha256_matched { + return Ok(()); } + warn!( + "Checksum doesn't match for <{}/{}>, re-downloading...", + registry.name, name + ); + + fs::remove_dir_all(registry.get_model_dir(name))?; + // prepare for download let dir = registry.get_model_store_dir(name); fs::create_dir_all(dir)?; @@ -123,7 +134,7 @@ async fn download_model_impl( .get_model_store_dir(name) .to_string_lossy() .into_owned(); - let filename: String = partitioned_file_name!(index + 1, urls.len()); + let filename: String = partitioned_file_name!(index, urls.len()); let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); Retry::spawn(strategy, move || {