forked from allenai/olmocr
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfinigram_count.py
More file actions
173 lines (146 loc) · 6.9 KB
/
infinigram_count.py
File metadata and controls
173 lines (146 loc) · 6.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python3
import argparse
import json
import random
import re
import time
import boto3
import requests
from tqdm import tqdm
from transformers import AutoTokenizer
# Allowed characters: alphanumeric, space, and basic punctuation ".,!?()"
ALLOWED_RE = re.compile(r"^[A-Za-z0-9\.,!?() ]+$")
def get_random_line_from_s3(bucket, key):
"""
Reads an S3 object line-by-line and returns a random line using reservoir sampling.
"""
s3 = boto3.client("s3")
response = s3.get_object(Bucket=bucket, Key=key)
random_line = None
count = 0
for line in response["Body"].iter_lines():
if not line:
continue
line_str = line.decode("utf-8")
count += 1
if random.randint(1, count) == 1:
random_line = line_str
return random_line
def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
"""
Sends a count query to the infini-gram API for the given n-gram.
Retries a few times in case of network issues.
"""
url = "https://api.infini-gram.io/"
payload = {
"index": index,
"query_type": "count",
"query": ngram,
}
for i in range(retries):
try:
response = requests.post(url, json=payload, timeout=10)
if response.status_code == 200:
result = response.json()
if "count" in result:
return result["count"]
except Exception as e: # type: ignore
time.sleep(1)
return 0
def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"):
"""
Tokenizes the document using the Llama2 tokenizer and samples random n-grams.
Each n-gram is chosen such that:
1. It starts on a word-split boundary (using the offset mapping and a check on the preceding character).
2. Its decoded string contains only alphanumeric characters, spaces, and the punctuation marks ".,!?()".
Each valid n-gram is then queried using the infini-gram API.
The function returns the document id, the number of matching n-grams (i.e. API count > 0),
the total number of valid n-grams sampled, and a list of tuples (flag, ngram_string).
"""
text = doc.get("text", "")
doc_id = doc.get("id", "Unknown")
# Get tokenized representation with offset mapping to determine word boundaries.
tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
token_ids = tokenized["input_ids"]
# offsets = tokenized["offset_mapping"]
if len(token_ids) < ngram_size:
return doc_id, 0, 0, []
# Determine valid starting indices based on word-split boundaries.
valid_positions = []
# for i in range(len(token_ids) - ngram_size + 1):
# start_offset = offsets[i][0]
# if start_offset == 0 or (start_offset > 0 and text[start_offset - 1] == " "):
# valid_positions.append(i)
if not valid_positions:
# Fallback: if no valid positions are found, use all possible positions.
valid_positions = list(range(len(token_ids) - ngram_size + 1))
valid_ngram_details = []
attempts = 0
max_attempts = num_samples * 10 # Limit to prevent infinite loops.
while len(valid_ngram_details) < num_samples and attempts < max_attempts:
idx = random.choice(valid_positions)
ngram_token_ids = token_ids[idx : idx + ngram_size]
ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True)
# Only accept n-grams that contain only allowed characters.
if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3:
count = query_infinigram(ngram_str, index=index)
flag = "YES" if count > 0 else "NO"
valid_ngram_details.append((flag, ngram_str))
attempts += 1
match_count = sum(1 for flag, _ in valid_ngram_details if flag == "YES")
sample_count = len(valid_ngram_details)
return doc_id, match_count, sample_count, valid_ngram_details
def main():
parser = argparse.ArgumentParser(description="Infini-gram n-gram matching script with Llama2 tokenization.")
parser.add_argument("N", type=int, help="Number of random .jsonl files to process")
parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)")
parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)")
parser.add_argument("--ngram_size", type=int, default=10, help="Size of the n-gram to sample (default: 10)")
parser.add_argument("--num_ngrams", type=int, default=100, help="Number of random n-grams to sample from each document (default: 100)")
args = parser.parse_args()
if not args.s3_path.startswith("s3://"):
print("Error: s3_path must start with 's3://'")
return
path_without_scheme = args.s3_path[5:]
parts = path_without_scheme.split("/", 1)
bucket = parts[0]
prefix = parts[1] if len(parts) > 1 else ""
print("Listing .jsonl files from S3...")
s3 = boto3.client("s3")
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
files = [obj["Key"] for obj in response.get("Contents", []) if obj["Key"].endswith(".jsonl")]
if not files:
print("No .jsonl files found in the given prefix.")
return
if args.N > len(files):
print(f"Requested {args.N} files, but only found {len(files)}. Processing all available files.")
args.N = len(files)
random_files = random.sample(files, args.N)
print("Loading Llama2 tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
total_matches = 0
total_ngrams_sampled = 0
for key in tqdm(random_files, desc="Processing files"):
line = get_random_line_from_s3(bucket, key)
if not line:
print(f"Skipping {key}: No valid lines found.")
continue
try:
doc = json.loads(line)
except Exception as e:
print(f"Error parsing JSON in {key}: {e}")
continue
doc_id, match_count, sample_count, details = process_document(doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index)
# Print per-document n-gram summary
print(f"\nDocument ID: {doc_id}")
for flag, ngram in details:
# Print the flag in a fixed-width field (4 characters) followed by the n-gram representation.
print(f"{flag:4} {repr(ngram)}")
percentage = (match_count / sample_count * 100) if sample_count else 0
print(f"Matched n-grams: {match_count}/{sample_count} ({percentage:.2f}%)")
total_matches += match_count
total_ngrams_sampled += sample_count
overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0
print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)")
if __name__ == "__main__":
main()