-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathtext.py
More file actions
64 lines (54 loc) · 2.29 KB
/
text.py
File metadata and controls
64 lines (54 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright 2019 The Vearch Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# ==============================================================================
import os
import subprocess
import numpy as np
from keras_bert import extract_embeddings,get_checkpoint_paths,load_trained_model_from_checkpoint, load_vocabulary
MODEL_URL = 'https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip'
class Text(object):
def __init__(self, config):
model_path = config["model_path"]
if not os.path.exists(model_path):
model_dir = os.path.dirname(model_path)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
subprocess.run(
f"wget -P {model_dir} {MODEL_URL} && cd {model_dir} && unzip chinese_L-12_H-768_A-12.zip",
shell=True)
paths = get_checkpoint_paths(model_path)
self.model = load_trained_model_from_checkpoint(
config_file=paths.config,
checkpoint_file=paths.checkpoint,
output_layer_num=1)
self.vocabs = load_vocabulary(paths.vocab)
def encode(self, texts):
embeddings = extract_embeddings(self.model, texts, vocabs=self.vocabs)
# result = [np.max(x, axis=0) for x in embeddings]
result = [np.max(x, axis=0).tolist() for x in embeddings]
return result
def load_model(config):
return Text(config)
def test():
import time
model = load_model({"model_path": "../../model/chinese_L-12_H-768_A-12"})
while True:
try:
print("input your question!\n")
question = input()
test(model, question.split("--"))
except Exception as err:
print(err)
result = model.encode(texts)
print(np.linalg.norm(np.array(result[0]) - np.array(result[1])))
if __name__ == "__main__":
pass