Skip to content

Commit 0132f30

Browse files
committed
CNN Text classification.
1 parent c976b81 commit 0132f30

4 files changed

Lines changed: 124 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow
7373
* [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs)
7474
* [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs)
7575
* [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs)
76+
* [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs)
7677
* [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs)
7778
* [Named Entity Recognition](test/TensorFlowNET.Examples/NamedEntityRecognition.cs)
7879

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using NumSharp.Core;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
7+
namespace TensorFlowNET.Examples.CnnTextClassification
8+
{
9+
public class CnnTextTrain : Python, IExample
10+
{
11+
// Percentage of the training data to use for validation
12+
private float dev_sample_percentage = 0.1f;
13+
// Data source for the positive data.
14+
private string positive_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.pos";
15+
// Data source for the negative data.
16+
private string negative_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.neg";
17+
// Dimensionality of character embedding (default: 128)
18+
private int embedding_dim = 128;
19+
// Comma-separated filter sizes (default: '3,4,5')
20+
private string filter_sizes = "3,4,5";
21+
// Number of filters per filter size (default: 128)
22+
private int num_filters = 128;
23+
// Dropout keep probability (default: 0.5)
24+
private float dropout_keep_prob = 0.5f;
25+
// L2 regularization lambda (default: 0.0)
26+
private float l2_reg_lambda = 0.0f;
27+
// Batch Size (default: 64)
28+
private int batch_size = 64;
29+
// Number of training epochs (default: 200)
30+
private int num_epochs = 200;
31+
// Evaluate model on dev set after this many steps (default: 100)
32+
private int evaluate_every = 100;
33+
// Save model after this many steps (default: 100)
34+
private int checkpoint_every = 100;
35+
// Number of checkpoints to store (default: 5)
36+
private int num_checkpoints = 5;
37+
// Allow device soft device placement
38+
private bool allow_soft_placement = true;
39+
// Log placement of ops on devices
40+
private bool log_device_placement = false;
41+
42+
public void Run()
43+
{
44+
var (x_train, y_train, vocab_processor, x_dev, y_dev) = preprocess();
45+
}
46+
47+
public (NDArray, NDArray, NDArray, NDArray, NDArray) preprocess()
48+
{
49+
DataHelpers.load_data_and_labels(positive_data_file, negative_data_file);
50+
throw new NotImplementedException("");
51+
}
52+
}
53+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using NumSharp.Core;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Text;
7+
using System.Text.RegularExpressions;
8+
9+
namespace TensorFlowNET.Examples.CnnTextClassification
10+
{
11+
public class DataHelpers
12+
{
13+
/// <summary>
14+
/// Loads MR polarity data from files, splits the data into words and generates labels.
15+
/// Returns split sentences and labels.
16+
/// </summary>
17+
/// <param name="positive_data_file"></param>
18+
/// <param name="negative_data_file"></param>
19+
/// <returns></returns>
20+
public static (NDArray, NDArray) load_data_and_labels(string positive_data_file, string negative_data_file)
21+
{
22+
Directory.CreateDirectory("CnnTextClassification");
23+
Utility.Web.Download(positive_data_file, "CnnTextClassification/rt-polarity.pos");
24+
Utility.Web.Download(negative_data_file, "CnnTextClassification/rt-polarity.neg");
25+
26+
// Load data from files
27+
var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos")
28+
.Select(x => x.Trim())
29+
.ToArray();
30+
31+
var negative_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.neg")
32+
.Select(x => x.Trim())
33+
.ToArray();
34+
35+
var x_text = new List<string>();
36+
x_text.AddRange(positive_examples);
37+
x_text.AddRange(negative_examples);
38+
x_text = x_text.Select(x => clean_str(x)).ToList();
39+
40+
var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray();
41+
var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray();
42+
// var y = np.
43+
// return (x_text, y);
44+
throw new NotImplementedException("load_data_and_labels");
45+
}
46+
47+
private static string clean_str(string str)
48+
{
49+
str = Regex.Replace(str, @"[^A-Za-z0-9(),!?\'\`]", " ");
50+
str = Regex.Replace(str, @"\'s", " \'s");
51+
return str;
52+
}
53+
}
54+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow;
5+
6+
namespace TensorFlowNET.Examples.CnnTextClassification
7+
{
8+
/// <summary>
9+
/// Convolutional Neural Network for Text Classification
10+
/// https://github.com/dennybritz/cnn-text-classification-tf
11+
/// </summary>
12+
public class TextCNN : Python
13+
{
14+
15+
}
16+
}

0 commit comments

Comments
 (0)