-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_embedding.py
More file actions
91 lines (76 loc) · 3.01 KB
/
Copy pathcreate_embedding.py
File metadata and controls
91 lines (76 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
# Configuration
MODEL_NAME = "msmarco-MiniLM-L6-cos-v5"
DOCS_INDEX_PATH = "index/docs.txt"
MSMARCO_PATH = "./data/msmarco-docs.tsv"
OUTPUT_EMBEDDINGS = f"corpus_embeddings_{MODEL_NAME}.pt"
OUTPUT_DOC_IDS = "embedding_doc_ids.pt"
BATCH_SIZE = 4096
NUM_GPUS = 7
MAX_SEQ_LENGTH = 256
def create_embeddings():
print(f"Loading {DOCS_INDEX_PATH} to get document offsets...")
# Read docs.txt to get (doc_id, offset, len)
# Format: doc_id \t url \t term_count \t offset \t len
doc_metadata = []
with open(DOCS_INDEX_PATH, "r", encoding="utf-8") as f:
for line in tqdm(f, desc="Parsing docs.txt"):
parts = line.strip().split("\t")
if len(parts) >= 5:
doc_id = int(parts[0])
offset = int(parts[3])
length = int(parts[4])
doc_metadata.append((doc_id, offset, length))
print(f"Found {len(doc_metadata)} documents in index.")
# Sort by offset to optimize disk reads (though they should be sorted already)
doc_metadata.sort(key=lambda x: x[1])
print(f"Reading text from {MSMARCO_PATH}...")
texts = []
valid_doc_ids = []
with open(MSMARCO_PATH, "rb") as f:
for doc_id, offset, length in tqdm(doc_metadata, desc="Reading text"):
f.seek(offset)
line_bytes = f.read(length)
try:
line = line_bytes.decode("utf-8")
# Parse the line to extract body
# Format: id \t url \t title \t body
parts = line.strip().split("\t")
body = ""
if len(parts) == 2:
body = parts[1]
elif len(parts) == 3:
body = parts[2]
elif len(parts) >= 4:
body = " ".join(parts[2:])
if body:
texts.append(body)
valid_doc_ids.append(doc_id)
except Exception:
continue
print(f"Extracted {len(texts)} valid texts.")
# Save DocIDs immediately
doc_ids_tensor = torch.tensor(valid_doc_ids, dtype=torch.int32)
torch.save(doc_ids_tensor, OUTPUT_DOC_IDS)
print(f"Saved {len(doc_ids_tensor)} DocIDs to {OUTPUT_DOC_IDS}")
# Multi-GPU Encoding
print(f"Encoding with {NUM_GPUS} GPUs...")
model = SentenceTransformer(MODEL_NAME)
model.max_seq_length = MAX_SEQ_LENGTH
devices = [f"cuda:{i}" for i in range(NUM_GPUS)]
# Compute the embeddings using the list of devices
# This automatically handles multi-process encoding
corpus_embeddings = model.encode(
texts,
device=devices,
batch_size=BATCH_SIZE,
show_progress_bar=True,
convert_to_tensor=True
)
print(f"Saving embeddings to {OUTPUT_EMBEDDINGS}...")
torch.save(corpus_embeddings, OUTPUT_EMBEDDINGS)
print("Done!")
if __name__ == "__main__":
create_embeddings()