Skip to content

Commit caeb0e3

Browse files
committed
overloade session.run(), make syntax simpler.
1 parent d515c81 commit caeb0e3

17 files changed

Lines changed: 87 additions & 87 deletions

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,34 @@ public BaseSession(string target = "", Graph g = null, SessionOptions opts = nul
5454
status.Check(true);
5555
}
5656

57+
public virtual void run(Operation op, params FeedItem[] feed_dict)
58+
{
59+
_run(op, feed_dict);
60+
}
61+
62+
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
63+
{
64+
return _run(fetche, feed_dict)[0];
65+
}
66+
67+
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
68+
{
69+
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
70+
return (results[0], results[1], results[2], results[3]);
71+
}
72+
73+
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
74+
{
75+
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
76+
return (results[0], results[1], results[2]);
77+
}
78+
79+
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
80+
{
81+
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
82+
return (results[0], results[1]);
83+
}
84+
5785
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
5886
{
5987
return _run(fetches, feed_dict);

src/TensorFlowNET.Core/Sessions/FeedItem.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ public FeedItem(object key, object val)
1313
Key = key;
1414
Value = val;
1515
}
16+
17+
public static implicit operator FeedItem((object, object) feed)
18+
=> new FeedItem(feed.Item1, feed.Item2);
1619
}
1720
}

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed
377377
"`eval(session=sess)`.");
378378
}
379379

380-
return session.run(tensor, feed_dict)[0];
380+
return session.run(tensor, feed_dict);
381381
}
382382

383383
/// <summary>

test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,24 @@ public bool Run()
9191
{
9292
var c = sess.run(cost,
9393
new FeedItem(X, train_X),
94-
new FeedItem(Y, train_Y))[0];
95-
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)[0]} b={sess.run(b)[0]}");
94+
new FeedItem(Y, train_Y));
95+
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}");
9696
}
9797
}
9898

9999
Console.WriteLine("Optimization Finished!");
100100
var training_cost = sess.run(cost,
101101
new FeedItem(X, train_X),
102-
new FeedItem(Y, train_Y))[0];
103-
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)[0]} b={sess.run(b)[0]}");
102+
new FeedItem(Y, train_Y));
103+
Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}");
104104

105105
// Testing example
106106
var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f);
107107
var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f);
108108
Console.WriteLine("Testing... (Mean square loss Comparison)");
109109
var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]),
110110
new FeedItem(X, test_X),
111-
new FeedItem(Y, test_Y))[0];
111+
new FeedItem(Y, test_Y));
112112
Console.WriteLine($"Testing cost={testing_cost}");
113113
var diff = Math.Abs((float)training_cost - (float)testing_cost);
114114
Console.WriteLine($"Absolute mean square loss difference: {diff}");

test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,10 @@ public bool Run()
9090
{
9191
var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size);
9292
// Run optimization op (backprop) and cost op (to get loss value)
93-
var result = sess.run(new object[] { optimizer, cost },
94-
new FeedItem(x, batch_xs),
95-
new FeedItem(y, batch_ys));
93+
(_, float c) = sess.run((optimizer, cost),
94+
(x, batch_xs),
95+
(y, batch_ys));
9696

97-
float c = result[1];
9897
// Compute average loss
9998
avg_cost += c / total_batch;
10099
}
@@ -115,7 +114,7 @@ public bool Run()
115114
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
116115
// Calculate accuracy
117116
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
118-
float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels));
117+
float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels));
119118
print($"Accuracy: {acc.ToString("F4")}");
120119

121120
return acc > 0.9;

test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ public bool Run()
6464
foreach(int i in range(Xte.shape[0]))
6565
{
6666
// Get nearest neighbor
67-
long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]))[0];
67+
long nn_index = sess.run(pred, (xtr, Xtr), (xte, Xte[i]));
6868
// Get nearest neighbor class label and compare it to its true label
6969
int index = (int)nn_index;
7070

7171
if (i % 10 == 0 || i == 0)
7272
print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}");
7373

7474
// Calculate accuracy
75-
if ((int)np.argmax(Ytr[index]) == (int)np.argmax(Yte[i]))
75+
if (np.argmax(Ytr[index]) == np.argmax(Yte[i]))
7676
accuracy += 1f/ Xte.shape[0];
7777
}
7878

test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,8 @@ private float RunWithImportedGraph()
103103
// [train_op, gs, loss],
104104
// feed_dict={features: xy, labels: y_}
105105
// )
106-
var result = sess.run(new ITensorOrOperation[] { train_op, global_step, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
107-
loss_value = result[2];
108-
step = result[1];
109-
if (step % 1000 == 0)
106+
(_, step, loss_value) = sess.run((train_op, global_step, loss), (features, data), (labels, y_));
107+
if (step == 1 || step % 1000 == 0)
110108
Console.WriteLine($"Step {step} loss: {loss_value}");
111109
}
112110
Console.WriteLine($"Final loss: {loss_value}");
@@ -136,10 +134,8 @@ private float RunWithBuiltGraph()
136134
var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32);
137135
while (step < num_steps)
138136
{
139-
var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
140-
loss_value = result[2];
141-
step = result[1];
142-
if (step % 1000 == 0)
137+
(_, step, loss_value) = sess.run((train_op, gs, loss), (features, data), (labels, y_));
138+
if (step == 1 || step % 1000 == 0)
143139
Console.WriteLine($"Step {step} loss: {loss_value}");
144140
}
145141
Console.WriteLine($"Final loss: {loss_value}");

test/TensorFlowNET.Examples/BasicOperations.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ public bool Run()
5353
new FeedItem(b, (short)3)
5454
};
5555
// Run every operation with variable input
56-
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)[0]}");
57-
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)[0]}");
56+
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}");
57+
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}");
5858
}
5959

6060
// ----------------
@@ -91,7 +91,7 @@ public bool Run()
9191
// The output of the op is returned in 'result' as a numpy `ndarray` object.
9292
using (sess = tf.Session())
9393
{
94-
var result = sess.run(product)[0];
94+
var result = sess.run(product);
9595
Console.WriteLine(result.ToString()); // ==> [[ 12.]]
9696
};
9797

@@ -136,7 +136,7 @@ public bool Run()
136136
var checkTensor = np.array<float>(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0);
137137
using (var sess = tf.Session())
138138
{
139-
var result = sess.run(batchMul)[0];
139+
var result = sess.run(batchMul);
140140
Console.WriteLine(result.ToString());
141141
//
142142
// ==> array([[[0, 6],

test/TensorFlowNET.Examples/HelloWorld.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ of the Constant op. */
2828
using (var sess = tf.Session())
2929
{
3030
// Run the op
31-
var result = sess.run(hello)[0];
31+
var result = sess.run(hello);
3232
Console.WriteLine(result.ToString());
3333
return result.ToString().Equals(str);
3434
}

test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ public void Train(Session sess)
160160
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
161161

162162
// Run optimization op (backprop)
163-
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
163+
sess.run(optimizer, (x, x_batch), (y, y_batch));
164164

165165
if (iteration % display_freq == 0)
166166
{
@@ -174,9 +174,7 @@ public void Train(Session sess)
174174
}
175175

176176
// Run validation after every epoch
177-
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid));
178-
loss_val = results1[0];
179-
accuracy_val = results1[1];
177+
(loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid));
180178
print("---------------------------------------------------------");
181179
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
182180
print("---------------------------------------------------------");
@@ -185,9 +183,7 @@ public void Train(Session sess)
185183

186184
public void Test(Session sess)
187185
{
188-
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test));
189-
loss_test = result[0];
190-
accuracy_test = result[1];
186+
(loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test));
191187
print("---------------------------------------------------------");
192188
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
193189
print("---------------------------------------------------------");

0 commit comments

Comments
 (0)