-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathTextFeatureExtractor.py
More file actions
44 lines (35 loc) · 1.47 KB
/
TextFeatureExtractor.py
File metadata and controls
44 lines (35 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
from transformers import BertTokenizer, BertModel
import torch
import re
class TextFeatureExtractor:
"""Text feature extractor using mBERT (multilingual BERT)"""
def __init__(self):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
self.model = BertModel.from_pretrained('bert-base-multilingual-cased')
self.model.eval() # Set to evaluation mode
def extract_text_features(self, text_content):
"""Extract mBERT-based features from text"""
if not text_content or not text_content.strip():
return np.zeros(768) # mBERT 输出维度是 768
try:
# Clean and preprocess text
text_content = re.sub(r'\s+', ' ', text_content).strip()
# Tokenize input text
inputs = self.tokenizer(
text_content,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512,
add_special_tokens=True
)
# Get embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Use the [CLS] token's embedding as sentence representation
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
return cls_embedding # Shape: (768,)
except Exception as e:
print(f"Error extracting mBERT features: {e}")
return np.zeros(768)