Skip to content

Commit

Permalink
query_embed is now sequential and padding has been changed to batchlo…
Browse files Browse the repository at this point in the history
…ngest
  • Loading branch information
joshniemela committed Jan 8, 2024
1 parent e301b28 commit 7e0526a
Showing 1 changed file with 52 additions and 5 deletions.
57 changes: 52 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ impl FlagEmbedding {

let mut tokenizer = tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(max_length),
// TODO: the user should able to choose the padding strategy
strategy: PaddingStrategy::BatchLongest,
pad_token,
pad_id,
..Default::default()
Expand Down Expand Up @@ -474,11 +475,57 @@ impl<S: AsRef<str> + Send + Sync> EmbeddingBase<S> for FlagEmbedding {
self.embed(passages, batch_size)
}

// Method implementation for query embeddings. Prefixed with "query"
// Method implementation for query embeddings. Prefixed with "query" and made sequential for performance
fn query_embed(&self, query: S) -> Result<Embedding> {
let query = format!("query: {}", query.as_ref());
let query_embedding = self.embed(vec![&query], None);
Ok(query_embedding?[0].to_owned())
let text = format!("query: {}", query.as_ref());

// Encode the texts in the batch
let encoding = self.tokenizer.encode(text.as_str(), true).unwrap();

// Extract the encoding length and batch size
let encoding_length = encoding.len();

// Preallocate arrays with the encoding length
let mut ids_array = Vec::with_capacity(encoding_length);
let mut mask_array = Vec::with_capacity(encoding_length);
let mut typeids_array = Vec::with_capacity(encoding_length);

// Not using par_iter because the closure needs to be FnMut
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();

// Extend the preallocated arrays with the current encoding
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));

// Create Array from vectors
let inputs_ids_array = Array::from_shape_vec((1, encoding_length), ids_array)?;

let attention_mask_array = Array::from_shape_vec((1, encoding_length), mask_array)?;

let token_type_ids_array = Array::from_shape_vec((1, encoding_length), typeids_array)?;

// Run the model with inputs
let outputs = self.session.run(ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
"attention_mask" => Value::from_array(attention_mask_array)?,
"token_type_ids" => Value::from_array(token_type_ids_array)?,
]?)?;
// Extract and normalize embeddings
let output_data = outputs["last_hidden_state"].extract_tensor::<f32>()?;
let view = output_data.view();
let shape = view.shape();
let flattened = view.as_slice().unwrap();
let data = get_embeddings(flattened, shape);
let embeddings: Embedding = data
.into_iter()
.map(|mut d| normalize(&mut d))
.next()
.unwrap();

Ok(embeddings)
}
}

Expand Down

0 comments on commit 7e0526a

Please sign in to comment.