Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support large vocab sizes (i.e. llama3) for DraftRetriever datastore
  • Loading branch information
scandukuri committed Nov 13, 2024
commit 05f57b0026960cbd87e5ed7a40a554151a0c22a7
25 changes: 15 additions & 10 deletions DraftRetriever/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch;
// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch;
// The code for drafft buffer is adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py#L31-L124
use ahash::AHashSet;
use byteorder::{ReadBytesExt, WriteBytesExt, ByteOrder, LittleEndian};
Expand Down Expand Up @@ -111,10 +111,11 @@ impl Writer {
return Ok(());
}

self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 2) as u32)?;
self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 4) as u32)?; // self.buffer.len() is the length of the buffer (in # of integers). This is variable because sometimes we dump_data() early, its not always self.buffer.capacity().
// * 4 because this value will actually tell us how much space is needed for this buffer in file, and we store each as 4 bytes

for &item in &self.buffer {
self.index_file.write_u16::<LittleEndian>(item as u16)?;
self.index_file.write_i32::<LittleEndian>(item as i32)?;
}

let suffix_array = construct_suffix_array(&self.buffer, self.vocab_size);
Expand All @@ -127,6 +128,10 @@ impl Writer {
Ok(())
}

// Some personal notes:
// (1) max_chunk_len can be whatever we want it to be, but the draftretriever Reader() works fastest when we choose something large
// (2) vocab_size should be the size of the vocabulary + 1. This is used in the suffix array construction.

fn finalize(
&mut self,
) -> PyResult<()> {
Expand Down Expand Up @@ -188,8 +193,9 @@ impl Reader {

let mut data: Vec<i32> = Vec::new();

for i in (0..data_u8.len()).step_by(2) {
let int = LittleEndian::read_u16(&data_u8[i..i+2]) as i32;
// Step by 4 to read in each 4-byte int (i32) from index file
for i in (0..data_u8.len()).step_by(4) {
let int = LittleEndian::read_i32(&data_u8[i..i+4]) as i32;
data.push(int);
}

Expand Down Expand Up @@ -259,7 +265,7 @@ impl Reader {
if start_of_indices.is_none() {
return;
}

// this binary search finds the end of the matching suffixes
let mut right_anchor = sub_index.suffixes_file_end - 4;
while left_anchor <= right_anchor {
Expand Down Expand Up @@ -300,7 +306,7 @@ impl Reader {
let data_index = LittleEndian::read_i32(suffix);
if matches_ranges.insert(data_index) {
let sub_string_plus = &sub_index.data[data_index as usize + substring_i32.len() ..std::cmp::min(data_index as usize + substring_i32.len() + long as usize, sub_index.data.len())];

local_results.push(sub_string_plus.to_vec());
cnt += 1;
if cnt >= k as usize {
Expand Down Expand Up @@ -328,7 +334,7 @@ impl Reader {
*counter += 1;
}
}

let choices = choices.unwrap_or(64);
// The items in the heap must be a Trie.
let mut heap = BinaryHeap::new();
Expand All @@ -348,7 +354,7 @@ impl Reader {
let verified: Vec<_> = verified.into_iter().collect();

// Because multiple nodes in the Trie may have same weights around the threshold, the number of draft tokens may exceed choices
// We roughly cut nodes to be less than choices in most cases.
// We roughly cut nodes to be less than choices in most cases.
let paths = cut_to_choices(verified, choices);

let (draft_choices, max_branch) = get_draft_choices(paths.clone());
Expand Down Expand Up @@ -562,4 +568,3 @@ fn draftretriever(

Ok(())
}