forked from lcdevelop/ChatBotCourse
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathword_vectors_loader.py
More file actions
78 lines (63 loc) · 1.99 KB
/
word_vectors_loader.py
File metadata and controls
78 lines (63 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
词向量加载器
"""
# coding:utf-8
import sys
import struct
import numpy as np
MAX_W = 50
FLOAT_SIZE = 4
def get_words_sizes(file_name):
"""
获取词向量文件的词数和维度
"""
input_file = open(file_name, "rb")
# 获取词表数目及向量维度
words_and_size = input_file.readline()
words_and_size = words_and_size.strip()
words = int(words_and_size.decode('utf-8').split(' ')[0])
size = int(words_and_size.decode('utf-8').split(' ')[1])
input_file.close()
return words, size
def load_vectors(file_name):
"""
加载向量文件
"""
print("begin load vectors")
input_file = open(file_name, "rb")
# 获取词表数目及向量维度
words_and_size = input_file.readline()
words_and_size = words_and_size.strip()
words = int(words_and_size.decode('utf-8').split(' ')[0])
size = int(words_and_size.decode('utf-8').split(' ')[1])
print("words =", words)
print("size =", size)
word_vector_dict = {}
word_id_dict = {}
for word_id in range(0, words):
word = b''
# 读取一个词
while True:
charactor = input_file.read(1)
if charactor is False or charactor == b' ':
break
word = word + charactor
word = word.strip()
# 读取词向量
vector = np.empty([size])
for index in range(0, size):
weight_str = input_file.read(FLOAT_SIZE)
(weight,) = struct.unpack('f', weight_str)
vector[index] = weight
# 将词及其对应的向量存到dict中
word_vector_dict[word] = vector
word_id_dict[word] = word_id
input_file.close()
print("load vectors finish")
return word_vector_dict, word_id_dict
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: ", sys.argv[0], "vectors.bin")
sys.exit(-1)
WORD_VECTOR_DICT, WORD_ID_DICT = load_vectors(sys.argv[1])
print(WORD_VECTOR_DICT['数学'.encode('utf-8')])