Skip to content

Commit 5a377da

Browse files
author
李闯
committed
chatbotv4
1 parent 3f74881 commit 5a377da

File tree

9 files changed

+2212
-0
lines changed

9 files changed

+2212
-0
lines changed

chatbotv4/data/test.input

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
你 好
2+
对 不 起
3+
谢 谢 你
4+
再 见

chatbotv4/data/test.output

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
你 也 好
2+
没 关 系
3+
不 客 气
4+
再 见

chatbotv4/data/train.input

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
你 好
2+
对 不 起
3+
谢 谢 你
4+
再 见
5+
明 天 见
6+
我 爱 你
7+
你 好
8+
对 不 起
9+
谢 谢 你
10+
再 见
11+
明 天 见
12+
我 爱 你
13+
你 好
14+
对 不 起
15+
谢 谢 你
16+
再 见
17+
明 天 见
18+
我 爱 你
19+
你 好
20+
对 不 起
21+
谢 谢 你
22+
再 见
23+
明 天 见
24+
我 爱 你

chatbotv4/data/train.output

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
你 也 好
2+
没 关 系
3+
不 客 气
4+
再 见
5+
明 天 见
6+
我 也 爱 你
7+
你 也 好
8+
没 关 系
9+
不 客 气
10+
再 见
11+
明 天 见
12+
我 也 爱 你
13+
你 也 好
14+
没 关 系
15+
不 客 气
16+
再 见
17+
明 天 见
18+
我 也 爱 你
19+
你 也 好
20+
没 关 系
21+
不 客 气
22+
再 见
23+
明 天 见
24+
我 也 爱 你

chatbotv4/data_utils.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import gzip
22+
import os
23+
import re
24+
import tarfile
25+
26+
from six.moves import urllib
27+
28+
from tensorflow.python.platform import gfile
29+
import tensorflow as tf
30+
31+
# Special vocabulary symbols - we always put them at the start.
32+
_PAD = b"_PAD"
33+
_GO = b"_GO"
34+
_EOS = b"_EOS"
35+
_UNK = b"_UNK"
36+
_START_VOCAB = [_PAD, _GO, _EOS, _UNK]
37+
38+
PAD_ID = 0
39+
GO_ID = 1
40+
EOS_ID = 2
41+
UNK_ID = 3
42+
43+
# Regular expressions used to tokenize.
44+
_WORD_SPLIT = re.compile(b"([.,!?\"':;)(])")
45+
_DIGIT_RE = re.compile(br"\d")
46+
47+
48+
def gunzip_file(gz_path, new_path):
49+
"""Unzips from gz_path into new_path."""
50+
print("Unpacking %s to %s" % (gz_path, new_path))
51+
with gzip.open(gz_path, "rb") as gz_file:
52+
with open(new_path, "wb") as new_file:
53+
for line in gz_file:
54+
new_file.write(line)
55+
56+
57+
def get_wmt_enfr_train_set(directory):
58+
"""Download the WMT en-fr training corpus to directory unless it's there."""
59+
train_path = os.path.join(directory, "train")
60+
return train_path
61+
62+
63+
def get_wmt_enfr_dev_set(directory):
64+
"""Download the WMT en-fr training corpus to directory unless it's there."""
65+
dev_name = "test"
66+
dev_path = os.path.join(directory, dev_name)
67+
return dev_path
68+
69+
70+
def basic_tokenizer(sentence):
71+
"""Very basic tokenizer: split the sentence into a list of tokens."""
72+
words = []
73+
for space_separated_fragment in sentence.strip().split():
74+
words.extend(_WORD_SPLIT.split(space_separated_fragment))
75+
return [w for w in words if w]
76+
77+
78+
def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
79+
tokenizer=None, normalize_digits=True):
80+
"""Create vocabulary file (if it does not exist yet) from data file.
81+
82+
Data file is assumed to contain one sentence per line. Each sentence is
83+
tokenized and digits are normalized (if normalize_digits is set).
84+
Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
85+
We write it to vocabulary_path in a one-token-per-line format, so that later
86+
token in the first line gets id=0, second line gets id=1, and so on.
87+
88+
Args:
89+
vocabulary_path: path where the vocabulary will be created.
90+
data_path: data file that will be used to create vocabulary.
91+
max_vocabulary_size: limit on the size of the created vocabulary.
92+
tokenizer: a function to use to tokenize each data sentence;
93+
if None, basic_tokenizer will be used.
94+
normalize_digits: Boolean; if true, all digits are replaced by 0s.
95+
"""
96+
if not gfile.Exists(vocabulary_path):
97+
print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
98+
vocab = {}
99+
with gfile.GFile(data_path, mode="rb") as f:
100+
counter = 0
101+
for line in f:
102+
counter += 1
103+
if counter % 100000 == 0:
104+
print(" processing line %d" % counter)
105+
line = tf.compat.as_bytes(line)
106+
tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
107+
for w in tokens:
108+
word = _DIGIT_RE.sub(b"0", w) if normalize_digits else w
109+
if word in vocab:
110+
vocab[word] += 1
111+
else:
112+
vocab[word] = 1
113+
vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
114+
if len(vocab_list) > max_vocabulary_size:
115+
vocab_list = vocab_list[:max_vocabulary_size]
116+
with gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
117+
for w in vocab_list:
118+
vocab_file.write(w + b"\n")
119+
120+
121+
def initialize_vocabulary(vocabulary_path):
122+
"""Initialize vocabulary from file.
123+
124+
We assume the vocabulary is stored one-item-per-line, so a file:
125+
dog
126+
cat
127+
will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
128+
also return the reversed-vocabulary ["dog", "cat"].
129+
130+
Args:
131+
vocabulary_path: path to the file containing the vocabulary.
132+
133+
Returns:
134+
a pair: the vocabulary (a dictionary mapping string to integers), and
135+
the reversed vocabulary (a list, which reverses the vocabulary mapping).
136+
137+
Raises:
138+
ValueError: if the provided vocabulary_path does not exist.
139+
"""
140+
if gfile.Exists(vocabulary_path):
141+
rev_vocab = []
142+
with gfile.GFile(vocabulary_path, mode="rb") as f:
143+
rev_vocab.extend(f.readlines())
144+
rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab]
145+
vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
146+
return vocab, rev_vocab
147+
else:
148+
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
149+
150+
151+
def sentence_to_token_ids(sentence, vocabulary,
152+
tokenizer=None, normalize_digits=True):
153+
"""Convert a string to list of integers representing token-ids.
154+
155+
For example, a sentence "I have a dog" may become tokenized into
156+
["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
157+
"a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
158+
159+
Args:
160+
sentence: the sentence in bytes format to convert to token-ids.
161+
vocabulary: a dictionary mapping tokens to integers.
162+
tokenizer: a function to use to tokenize each sentence;
163+
if None, basic_tokenizer will be used.
164+
normalize_digits: Boolean; if true, all digits are replaced by 0s.
165+
166+
Returns:
167+
a list of integers, the token-ids for the sentence.
168+
"""
169+
170+
if tokenizer:
171+
words = tokenizer(sentence)
172+
else:
173+
words = basic_tokenizer(sentence)
174+
if not normalize_digits:
175+
return [vocabulary.get(w, UNK_ID) for w in words]
176+
# Normalize digits by 0 before looking words up in the vocabulary.
177+
return [vocabulary.get(_DIGIT_RE.sub(b"0", w), UNK_ID) for w in words]
178+
179+
180+
def data_to_token_ids(data_path, target_path, vocabulary_path,
181+
tokenizer=None, normalize_digits=True):
182+
"""Tokenize data file and turn into token-ids using given vocabulary file.
183+
184+
This function loads data line-by-line from data_path, calls the above
185+
sentence_to_token_ids, and saves the result to target_path. See comment
186+
for sentence_to_token_ids on the details of token-ids format.
187+
188+
Args:
189+
data_path: path to the data file in one-sentence-per-line format.
190+
target_path: path where the file with token-ids will be created.
191+
vocabulary_path: path to the vocabulary file.
192+
tokenizer: a function to use to tokenize each sentence;
193+
if None, basic_tokenizer will be used.
194+
normalize_digits: Boolean; if true, all digits are replaced by 0s.
195+
"""
196+
if not gfile.Exists(target_path):
197+
print("Tokenizing data in %s" % data_path)
198+
vocab, _ = initialize_vocabulary(vocabulary_path)
199+
with gfile.GFile(data_path, mode="rb") as data_file:
200+
with gfile.GFile(target_path, mode="w") as tokens_file:
201+
counter = 0
202+
for line in data_file:
203+
counter += 1
204+
if counter % 100000 == 0:
205+
print(" tokenizing line %d" % counter)
206+
token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab,
207+
tokenizer, normalize_digits)
208+
tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
209+
210+
211+
def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer=None):
212+
"""Get WMT data into data_dir, create vocabularies and tokenize data.
213+
214+
Args:
215+
data_dir: directory in which the data sets will be stored.
216+
en_vocabulary_size: size of the English vocabulary to create and use.
217+
fr_vocabulary_size: size of the French vocabulary to create and use.
218+
tokenizer: a function to use to tokenize each data sentence;
219+
if None, basic_tokenizer will be used.
220+
221+
Returns:
222+
A tuple of 6 elements:
223+
(1) path to the token-ids for English training data-set,
224+
(2) path to the token-ids for French training data-set,
225+
(3) path to the token-ids for English development data-set,
226+
(4) path to the token-ids for French development data-set,
227+
(5) path to the English vocabulary file,
228+
(6) path to the French vocabulary file.
229+
"""
230+
# Get wmt data to the specified directory.
231+
train_path = get_wmt_enfr_train_set(data_dir)
232+
dev_path = get_wmt_enfr_dev_set(data_dir)
233+
234+
from_train_path = train_path + ".input"
235+
to_train_path = train_path + ".output"
236+
from_dev_path = dev_path + ".input"
237+
to_dev_path = dev_path + ".output"
238+
return prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, en_vocabulary_size,
239+
fr_vocabulary_size, tokenizer)
240+
241+
242+
def prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, from_vocabulary_size,
243+
to_vocabulary_size, tokenizer=None):
244+
"""Preapre all necessary files that are required for the training.
245+
246+
Args:
247+
data_dir: directory in which the data sets will be stored.
248+
from_train_path: path to the file that includes "from" training samples.
249+
to_train_path: path to the file that includes "to" training samples.
250+
from_dev_path: path to the file that includes "from" dev samples.
251+
to_dev_path: path to the file that includes "to" dev samples.
252+
from_vocabulary_size: size of the "from language" vocabulary to create and use.
253+
to_vocabulary_size: size of the "to language" vocabulary to create and use.
254+
tokenizer: a function to use to tokenize each data sentence;
255+
if None, basic_tokenizer will be used.
256+
257+
Returns:
258+
A tuple of 6 elements:
259+
(1) path to the token-ids for "from language" training data-set,
260+
(2) path to the token-ids for "to language" training data-set,
261+
(3) path to the token-ids for "from language" development data-set,
262+
(4) path to the token-ids for "to language" development data-set,
263+
(5) path to the "from language" vocabulary file,
264+
(6) path to the "to language" vocabulary file.
265+
"""
266+
# Create vocabularies of the appropriate sizes.
267+
to_vocab_path = os.path.join(data_dir, "vocab%d.output" % to_vocabulary_size)
268+
from_vocab_path = os.path.join(data_dir, "vocab%d.input" % from_vocabulary_size)
269+
create_vocabulary(to_vocab_path, to_train_path , to_vocabulary_size, tokenizer)
270+
create_vocabulary(from_vocab_path, from_train_path , from_vocabulary_size, tokenizer)
271+
272+
# Create token ids for the training data.
273+
to_train_ids_path = to_train_path + (".ids%d" % to_vocabulary_size)
274+
from_train_ids_path = from_train_path + (".ids%d" % from_vocabulary_size)
275+
data_to_token_ids(to_train_path, to_train_ids_path, to_vocab_path, tokenizer)
276+
data_to_token_ids(from_train_path, from_train_ids_path, from_vocab_path, tokenizer)
277+
278+
# Create token ids for the development data.
279+
to_dev_ids_path = to_dev_path + (".ids%d" % to_vocabulary_size)
280+
from_dev_ids_path = from_dev_path + (".ids%d" % from_vocabulary_size)
281+
data_to_token_ids(to_dev_path, to_dev_ids_path, to_vocab_path, tokenizer)
282+
data_to_token_ids(from_dev_path, from_dev_ids_path, from_vocab_path, tokenizer)
283+
284+
return (from_train_ids_path, to_train_ids_path,
285+
from_dev_ids_path, to_dev_ids_path,
286+
from_vocab_path, to_vocab_path)

chatbotv4/readme.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
训练:
2+
python ./translate.py
3+
4+
预测:
5+
python translate.py --decode True

0 commit comments

Comments
 (0)