Skip to content

Commit

Permalink
Fixing examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Aug 1, 2023
1 parent 2bf7a63 commit def0e5a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions candle-book/src/inference/hub.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@ Now that we have our weights, we can use them in our bert architecture:
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
# use candle::Device;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
#
# let weights = repo.get("model.safetensors").unwrap();
use candle::{Device, Tensor, DType};
use candle_nn::Linear;

let weights = candle::safetensors::load(weights, &Device::Cpu);
let weights = candle::safetensors::load(weights, &Device::Cpu).unwrap();

let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();

let linear = Linear::new(weight, Some(bias));
let linear = Linear::new(weight.clone(), Some(bias.clone()));

let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids);
Expand Down

0 comments on commit def0e5a

Please sign in to comment.