Skip to content

Commit 2bf72e9

Browse files
committed
nest.map_structure: fixed bug when merging multiple structures
1 parent ade4ef7 commit 2bf72e9

3 files changed

Lines changed: 58 additions & 64 deletions

File tree

src/TensorFlowNET.Core/Python.cs

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -131,44 +131,6 @@ public static float time()
131131
}
132132
}
133133

134-
/// <summary>
135-
/// Untyped implementation of zip for arbitrary data
136-
///
137-
/// Converts an list of lists or arrays [[1,2,3], [4,5,6], [7,8,9]] into a list of arrays
138-
/// representing tuples of the same index of all source arrays [[1,4,7], [2,5,9], [3,6,9]]
139-
/// </summary>
140-
/// <param name="lists">one or multiple sequences to be zipped</param>
141-
/// <returns></returns>
142-
public static IEnumerable<object[]> zip(params object[] lists)
143-
{
144-
if (lists.Length == 0)
145-
yield break;
146-
var first = lists[0];
147-
if (first == null)
148-
yield break;
149-
var arity = (first as IEnumerable).OfType<object>().Count();
150-
for (int i = 0; i < arity; i++)
151-
{
152-
var array= new object[lists.Length];
153-
for (int j = 0; j < lists.Length; j++)
154-
array[j] = GetSequenceElementAt(lists[j], i);
155-
yield return array;
156-
}
157-
}
158-
159-
private static object GetSequenceElementAt(object sequence, int i)
160-
{
161-
switch (sequence)
162-
{
163-
case Array array:
164-
return array.GetValue(i);
165-
case IList list:
166-
return list[i];
167-
default:
168-
return (sequence as IEnumerable).OfType<object>().Skip(Math.Max(0, i)).FirstOrDefault();
169-
}
170-
}
171-
172134
public static IEnumerable<(int, T)> enumerate<T>(IList<T> values)
173135
{
174136
for (int i = 0; i < values.Count; i++)

src/TensorFlowNET.Core/Util/nest.py.cs

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,44 @@ namespace Tensorflow.Util
2323
public static class nest
2424
{
2525

26-
public static IEnumerable<object[]> zip(params object[] structures)
27-
=> Python.zip(structures);
26+
27+
/// <summary>
28+
/// Untyped implementation of zip for arbitrary data
29+
///
30+
/// Converts an list of lists or arrays [[1,2,3], [4,5,6], [7,8,9]] into a list of arrays
31+
/// representing tuples of the same index of all source arrays [[1,4,7], [2,5,9], [3,6,9]]
32+
/// </summary>
33+
/// <param name="lists">one or multiple sequences to be zipped</param>
34+
/// <returns></returns>
35+
public static IEnumerable<object[]> zip_many(params IEnumerable<object>[] lists)
36+
{
37+
if (lists.Length == 0)
38+
yield break;
39+
var first = lists[0];
40+
if (first == null)
41+
yield break;
42+
var arity = first.Count();
43+
for (int i = 0; i < arity; i++)
44+
{
45+
var array = new object[lists.Length];
46+
for (int j = 0; j < lists.Length; j++)
47+
array[j] = GetSequenceElementAt(lists[j], i);
48+
yield return array;
49+
}
50+
}
51+
52+
private static object GetSequenceElementAt(object sequence, int i)
53+
{
54+
switch (sequence)
55+
{
56+
case Array array:
57+
return array.GetValue(i);
58+
case IList list:
59+
return list[i];
60+
default:
61+
return _yield_value(sequence).Skip(Math.Max(0, i)).FirstOrDefault();
62+
}
63+
}
2864

2965
public static IEnumerable<(T1, T2)> zip<T1, T2>(IEnumerable<T1> e1, IEnumerable<T2> e2)
3066
=> Python.zip(e1, e2);
@@ -40,9 +76,9 @@ public static Dictionary<string, object> ConvertToDict(object dyn)
4076
/// <summary>
4177
/// Returns a sorted list of the dict keys, with error if keys not sortable.
4278
/// </summary>
43-
private static IEnumerable<string> _sorted(IDictionary dict_)
79+
private static IEnumerable<object> _sorted(IDictionary dict_)
4480
{
45-
return dict_.Keys.OfType<string>().OrderBy(x => x);
81+
return dict_.Keys.OfType<object>().OrderBy(x => x);
4682
}
4783

4884

@@ -86,7 +122,7 @@ private static object _sequence_like(object instance, IEnumerable<object> args)
86122
{
87123
case Hashtable hash:
88124
var result = new Hashtable();
89-
foreach ((object key, object value) in zip(_sorted(hash).OfType<object>(), args))
125+
foreach ((object key, object value) in zip<object, object>(_sorted(hash), args))
90126
result[key] = value;
91127
return result;
92128
}
@@ -370,13 +406,13 @@ private static (int new_index, List<object> child) _packed_nest_with_indices(obj
370406
/// <returns> `flat_sequence` converted to have the same recursive structure as
371407
/// `structure`.
372408
/// </returns>
373-
public static object pack_sequence_as<T>(object structure, IEnumerable<T> flat_sequence)
409+
public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence)
374410
{
375411
List<object> flat = null;
376412
if (flat_sequence is List<object>)
377413
flat = flat_sequence as List<object>;
378414
else
379-
flat=new List<object>(flat_sequence.OfType<object>());
415+
flat=new List<object>(flat_sequence);
380416
if (flat_sequence==null)
381417
throw new ArgumentException("flat_sequence must not be null");
382418
// if not is_sequence(flat_sequence):
@@ -403,7 +439,7 @@ public static object pack_sequence_as<T>(object structure, IEnumerable<T> flat_s
403439
var flat_structure = flatten(structure);
404440
if (len(flat_structure) != len(flat))
405441
{
406-
throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
442+
throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " +
407443
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
408444
}
409445
return _sequence_like(structure, packed);
@@ -413,7 +449,7 @@ public static object pack_sequence_as<T>(object structure, IEnumerable<T> flat_s
413449
var flat_structure = flatten(structure);
414450
if (len(flat_structure) != len(flat))
415451
{
416-
throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
452+
throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " +
417453
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
418454
}
419455
return _sequence_like(structure, packed);
@@ -427,10 +463,8 @@ public static object pack_sequence_as<T>(object structure, IEnumerable<T> flat_s
427463
/// `structure[i]`. All structures in `structure` must have the same arity,
428464
/// and the return value will contain the results in the same structure.
429465
/// </summary>
430-
/// <typeparam name="T">the type of the elements of the output structure (object if diverse)</typeparam>
431466
/// <param name="func"> A callable that accepts as many arguments as there are structures.</param>
432-
/// <param name="structures">scalar, or tuple or list of constructed scalars and/or other
433-
/// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param>
467+
/// <param name="structures">one or many IEnumerable of object</param>
434468
/// <param name="check_types">If set to
435469
/// `True` (default) the types of iterables within the structures have to be
436470
/// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
@@ -444,23 +478,22 @@ public static object pack_sequence_as<T>(object structure, IEnumerable<T> flat_s
444478
/// `check_types` is `False` the sequence types of the first structure will be
445479
/// used.
446480
/// </returns>
447-
public static IEnumerable<object> map_structure(Func<object[], object> func, object structure, params object[] more_structures)
481+
public static IEnumerable<object> map_structure(Func<object[], object> func, params IEnumerable<object>[] structure)
448482
{
449483
// TODO: check structure and types
450484
// for other in structure[1:]:
451485
// assert_same_structure(structure[0], other, check_types=check_types)
452486

453-
if (more_structures.Length==0)
487+
if (structure.Length==1)
454488
{
455489
// we don't need to zip if we have only one structure
456-
return map_structure(a => func(new object[]{a}), structure);
490+
return map_structure(a => func(new object[]{a}), structure[0]);
457491
}
458-
var flat_structures = new List<object>() { flatten(structure) };
459-
flat_structures.AddRange(more_structures.Select(flatten));
460-
var entries = zip(flat_structures);
492+
var flat_structures = structure.Select(flatten).ToArray(); // ToArray is important here!
493+
var entries = zip_many(flat_structures);
461494
var mapped_flat_structure = entries.Select(func);
462495

463-
return (pack_sequence_as(structure, mapped_flat_structure) as IEnumerable).OfType<object>();
496+
return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList();
464497
}
465498

466499
/// <summary>
@@ -469,7 +502,7 @@ public static IEnumerable<object> map_structure(Func<object[], object> func, obj
469502
/// <param name="func"></param>
470503
/// <param name="structure"></param>
471504
/// <returns></returns>
472-
public static IEnumerable<object> map_structure(Func<object, object> func, object structure)
505+
public static IEnumerable<object> map_structure(Func<object, object> func, IEnumerable<object> structure)
473506
{
474507
// TODO: check structure and types
475508
// for other in structure[1:]:
@@ -478,7 +511,7 @@ public static IEnumerable<object> map_structure(Func<object, object> func, objec
478511
var flat_structure = flatten(structure);
479512
var mapped_flat_structure = flat_structure.Select(func).ToList();
480513

481-
return (pack_sequence_as(structure, mapped_flat_structure) as IEnumerable).OfType<object>();
514+
return _yield_value(pack_sequence_as(structure, mapped_flat_structure)).ToList();
482515
}
483516

484517
//def map_structure_with_paths(func, *structure, **kwargs):

test/TensorFlowNET.UnitTest/nest_test/NestTest.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,10 @@ public void testMapStructure()
387387
// nest.assert_same_structure(structure1, structure1_plus1)
388388
self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 });
389389
self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" });
390-
// structure1_plus_structure2 = nest.map_structure(
391-
// lambda x, y: x + y, structure1, structure2)
392-
// self.assertEqual(
393-
// (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
394-
// structure1_plus_structure2)
390+
var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2);
391+
self.assertEqual(
392+
new object[] { new object[] { new object[] { 1 + 7, 2 + 8}, 3 + 9}, 4 + 10, new object[] { 5 + 11, 6 + 12}},
393+
structure1_plus_structure2);
395394

396395
// self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
397396

0 commit comments

Comments
 (0)