import cPickle import gzip import os import sys import time import numpy import theano import theano.tensor as T def prepare_data(seqs, labels, maxlen=None): # x: a list of sentences lengths = [len(s) for s in seqs] if maxlen is not None: new_seqs = [] new_labels = [] new_lengths = [] for l, s, y in zip(lengths, seqs, labels): if l < maxlen: new_seqs.append(s) new_labels.append(y) new_lengths.append(l) lengths = new_lengths labels = new_labels seqs = new_seqs if len(lengths) < 1: return None, None, None n_samples = len(seqs) maxlen = numpy.max(lengths) x = numpy.zeros((maxlen, n_samples)).astype('int64') x_mask = numpy.zeros((maxlen, n_samples)).astype('float32') for idx, s in enumerate(seqs): x[:lengths[idx], idx] = s x_mask[:lengths[idx], idx] = 1. return x, x_mask, labels def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1): ''' Loads the dataset :type dataset: string :param dataset: the path to the dataset (here IMDB) ''' ############# # LOAD DATA # ############# print '... loading data' # Load the dataset f = open(path, 'rb') train_set = cPickle.load(f) test_set = cPickle.load(f) f.close() # split training set into validation set train_set_x, train_set_y = train_set n_samples = len(train_set_x) sidx = numpy.random.permutation(n_samples) n_train = int(numpy.round(n_samples * (1. - valid_portion))) valid_set_x = [train_set_x[s] for s in sidx[n_train:]] valid_set_y = [train_set_y[s] for s in sidx[n_train:]] train_set_x = [train_set_x[s] for s in sidx[:n_train]] train_set_y = [train_set_y[s] for s in sidx[:n_train]] train_set = (train_set_x, train_set_y) valid_set = (valid_set_x, valid_set_y) def remove_unk(x): return [[1 if w >= n_words else w for w in sen] for sen in x] test_set_x, test_set_y = test_set valid_set_x, valid_set_y = valid_set train_set_x, train_set_y = train_set train_set_x = remove_unk(train_set_x) valid_set_x = remove_unk(valid_set_x) test_set_x = remove_unk(test_set_x) train = (train_set_x, train_set_y) valid = (valid_set_x, valid_set_y) test = (test_set_x, test_set_y) return train, valid, test