Skip to content

Commit 11de767

Browse files
jamesqostephentoub
authored andcommitted
Recognize Skip/Take chains on lazy sequences. (dotnet/corefx#13628)
Commit migrated from dotnet/corefx@3554ed2
1 parent 8fc6df4 commit 11de767

5 files changed

Lines changed: 432 additions & 33 deletions

File tree

src/libraries/System.Linq/src/System/Linq/Partition.cs

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,5 +389,315 @@ public int GetCount(bool onlyIfCheap)
389389
return Count;
390390
}
391391
}
392+
393+
private sealed class EnumerablePartition<TSource> : Iterator<TSource>, IPartition<TSource>
394+
{
395+
private readonly IEnumerable<TSource> _source;
396+
private readonly int _minIndexInclusive;
397+
private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive.
398+
// If this is -1, it's impossible to set a limit on the count.
399+
private IEnumerator<TSource> _enumerator;
400+
401+
internal EnumerablePartition(IEnumerable<TSource> source, int minIndexInclusive, int maxIndexInclusive)
402+
{
403+
Debug.Assert(source != null);
404+
Debug.Assert(!(source is IList<TSource>), $"The caller needs to check for {nameof(IList<TSource>)}.");
405+
Debug.Assert(minIndexInclusive >= 0);
406+
Debug.Assert(maxIndexInclusive >= -1);
407+
// Note that although maxIndexInclusive can't grow, it can still be int.MaxValue.
408+
// We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work.
409+
// But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow.
410+
Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!");
411+
Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive);
412+
413+
_source = source;
414+
_minIndexInclusive = minIndexInclusive;
415+
_maxIndexInclusive = maxIndexInclusive;
416+
}
417+
418+
// If this is true (e.g. at least one Take call was made), then we have an upper bound
419+
// on how many elements we can have.
420+
private bool HasLimit => _maxIndexInclusive != -1;
421+
422+
private int Limit => (_maxIndexInclusive + 1) - _minIndexInclusive; // This is that upper bound.
423+
424+
public override Iterator<TSource> Clone()
425+
{
426+
return new EnumerablePartition<TSource>(_source, _minIndexInclusive, _maxIndexInclusive);
427+
}
428+
429+
public int GetCount(bool onlyIfCheap)
430+
{
431+
if (onlyIfCheap)
432+
{
433+
return -1;
434+
}
435+
436+
if (!HasLimit)
437+
{
438+
// If HasLimit is false, we contain everything past _minIndexInclusive.
439+
// Therefore, we have to iterate the whole enumerable.
440+
return Math.Max(_source.Count() - _minIndexInclusive, 0);
441+
}
442+
443+
using (IEnumerator<TSource> en = _source.GetEnumerator())
444+
{
445+
// We only want to iterate up to _maxIndexInclusive + 1.
446+
// Past that, we know the enumerable will be able to fit this partition,
447+
// so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.
448+
449+
// Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
450+
// so + 1 may result in signed integer overflow. We need to handle this.
451+
// At the same time, however, we are guaranteed that our max count can fit
452+
// in an int because if that is true, then _minIndexInclusive must > 0.
453+
454+
uint count = SkipAndCount((uint)_maxIndexInclusive + 1, en);
455+
Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
456+
return Math.Max((int)count - _minIndexInclusive, 0);
457+
}
458+
459+
}
460+
461+
public override bool MoveNext()
462+
{
463+
// Cases where GetEnumerator has not been called or Dispose has already
464+
// been called need to be handled explicitly, due to the default: clause.
465+
int taken = _state - 3;
466+
if (taken < -2)
467+
{
468+
Dispose();
469+
return false;
470+
}
471+
472+
switch (_state)
473+
{
474+
case 1:
475+
_enumerator = _source.GetEnumerator();
476+
_state = 2;
477+
goto case 2;
478+
case 2:
479+
if (!SkipBeforeFirst(_enumerator))
480+
{
481+
// Reached the end before we finished skipping.
482+
break;
483+
}
484+
485+
_state = 3;
486+
goto default;
487+
default:
488+
if ((!HasLimit || taken < Limit) && _enumerator.MoveNext())
489+
{
490+
if (HasLimit)
491+
{
492+
// If we are taking an unknown number of elements, it's important not to increment _state.
493+
// _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though
494+
// we haven't finished enumerating.
495+
_state++;
496+
}
497+
_current = _enumerator.Current;
498+
return true;
499+
}
500+
501+
break;
502+
}
503+
504+
Dispose();
505+
return false;
506+
}
507+
508+
public override IEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
509+
{
510+
return new SelectIPartitionIterator<TSource, TResult>(this, selector);
511+
}
512+
513+
public IPartition<TSource> Skip(int count)
514+
{
515+
int minIndex = _minIndexInclusive + count;
516+
if (!HasLimit)
517+
{
518+
if (minIndex < 0)
519+
{
520+
// If we don't know our max count and minIndex can no longer fit in a positive int,
521+
// then we will need to wrap ourselves in another iterator.
522+
// This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue).
523+
return new EnumerablePartition<TSource>(this, count, -1);
524+
}
525+
}
526+
else if ((uint)minIndex > (uint)_maxIndexInclusive)
527+
{
528+
// If minIndex overflows and we have an upper bound, we will go down this branch.
529+
// We know our upper bound must be smaller than minIndex, since our upper bound fits in an int.
530+
// This branch should not be taken if we don't have a bound.
531+
return EmptyPartition<TSource>.Instance;
532+
}
533+
534+
Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows.");
535+
return new EnumerablePartition<TSource>(_source, minIndex, _maxIndexInclusive);
536+
}
537+
538+
public IPartition<TSource> Take(int count)
539+
{
540+
int maxIndex = _minIndexInclusive + count - 1;
541+
if (!HasLimit)
542+
{
543+
if (maxIndex < 0)
544+
{
545+
// If we don't know our max count and maxIndex can no longer fit in a positive int,
546+
// then we will need to wrap ourselves in another iterator.
547+
// Note that although maxIndex may be too large, the difference between it and
548+
// _minIndexInclusive (which is count - 1) must fit in an int.
549+
// Example: e.Skip(50).Take(int.MaxValue).
550+
551+
return new EnumerablePartition<TSource>(this, 0, count - 1);
552+
}
553+
}
554+
else if ((uint)maxIndex >= (uint)_maxIndexInclusive)
555+
{
556+
// If we don't know our max count, we can't go down this branch.
557+
// It's always possible for us to contain more than count items, as the rest
558+
// of the enumerable past _minIndexInclusive can be arbitrarily long.
559+
return this;
560+
}
561+
562+
Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows.");
563+
return new EnumerablePartition<TSource>(_source, _minIndexInclusive, maxIndex);
564+
}
565+
566+
public TSource TryGetElementAt(int index, out bool found)
567+
{
568+
// If the index is negative or >= our max count, return early.
569+
if (index >= 0 && (!HasLimit || index < Limit))
570+
{
571+
using (IEnumerator<TSource> en = _source.GetEnumerator())
572+
{
573+
Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow.");
574+
575+
if (SkipBefore(_minIndexInclusive + index, en) && en.MoveNext())
576+
{
577+
found = true;
578+
return en.Current;
579+
}
580+
}
581+
}
582+
583+
found = false;
584+
return default(TSource);
585+
}
586+
587+
public TSource TryGetFirst(out bool found)
588+
{
589+
using (IEnumerator<TSource> en = _source.GetEnumerator())
590+
{
591+
if (SkipBeforeFirst(en) && en.MoveNext())
592+
{
593+
found = true;
594+
return en.Current;
595+
}
596+
}
597+
598+
found = false;
599+
return default(TSource);
600+
}
601+
602+
public TSource TryGetLast(out bool found)
603+
{
604+
using (IEnumerator<TSource> en = _source.GetEnumerator())
605+
{
606+
if (SkipBeforeFirst(en) && en.MoveNext())
607+
{
608+
int remaining = Limit - 1; // Max number of items left, not counting the current element.
609+
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
610+
TSource result;
611+
612+
do
613+
{
614+
remaining--;
615+
result = en.Current;
616+
}
617+
while (remaining >= comparand && en.MoveNext());
618+
619+
found = true;
620+
return result;
621+
}
622+
}
623+
624+
found = false;
625+
return default(TSource);
626+
}
627+
628+
public TSource[] ToArray()
629+
{
630+
using (IEnumerator<TSource> en = _source.GetEnumerator())
631+
{
632+
if (SkipBeforeFirst(en) && en.MoveNext())
633+
{
634+
int remaining = Limit - 1; // Max number of items left, not counting the current element.
635+
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
636+
637+
int maxCapacity = HasLimit ? Limit : int.MaxValue;
638+
var builder = new LargeArrayBuilder<TSource>(maxCapacity);
639+
640+
do
641+
{
642+
remaining--;
643+
builder.Add(en.Current);
644+
}
645+
while (remaining >= comparand && en.MoveNext());
646+
647+
return builder.ToArray();
648+
}
649+
}
650+
651+
return Array.Empty<TSource>();
652+
}
653+
654+
public List<TSource> ToList()
655+
{
656+
var list = new List<TSource>();
657+
658+
using (IEnumerator<TSource> en = _source.GetEnumerator())
659+
{
660+
if (SkipBeforeFirst(en) && en.MoveNext())
661+
{
662+
int remaining = Limit - 1; // Max number of items left, not counting the current element.
663+
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
664+
665+
do
666+
{
667+
remaining--;
668+
list.Add(en.Current);
669+
}
670+
while (remaining >= comparand && en.MoveNext());
671+
}
672+
}
673+
674+
return list;
675+
}
676+
677+
private bool SkipBeforeFirst(IEnumerator<TSource> en) => SkipBefore(_minIndexInclusive, en);
678+
679+
private static bool SkipBefore(int index, IEnumerator<TSource> en) => SkipAndCount(index, en) == index;
680+
681+
private static int SkipAndCount(int index, IEnumerator<TSource> en)
682+
{
683+
Debug.Assert(index >= 0);
684+
return (int)SkipAndCount((uint)index, en);
685+
}
686+
687+
private static uint SkipAndCount(uint index, IEnumerator<TSource> en)
688+
{
689+
Debug.Assert(en != null);
690+
691+
for (uint i = 0; i < index; i++)
692+
{
693+
if (!en.MoveNext())
694+
{
695+
return i;
696+
}
697+
}
698+
699+
return index;
700+
}
701+
}
392702
}
393703
}

src/libraries/System.Linq/src/System/Linq/Skip.cs

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,7 @@ public static IEnumerable<TSource> Skip<TSource>(this IEnumerable<TSource> sourc
4141
return new ListPartition<TSource>(sourceList, count, int.MaxValue);
4242
}
4343

44-
return SkipIterator(source, count);
45-
}
46-
47-
private static IEnumerable<TSource> SkipIterator<TSource>(IEnumerable<TSource> source, int count)
48-
{
49-
using (IEnumerator<TSource> e = source.GetEnumerator())
50-
{
51-
while (count > 0 && e.MoveNext())
52-
{
53-
count--;
54-
}
55-
56-
if (count <= 0)
57-
{
58-
while (e.MoveNext())
59-
{
60-
yield return e.Current;
61-
}
62-
}
63-
}
44+
return new EnumerablePartition<TSource>(source, count, -1);
6445
}
6546

6647
public static IEnumerable<TSource> SkipWhile<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)

src/libraries/System.Linq/src/System/Linq/Take.cs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,7 @@ public static IEnumerable<TSource> Take<TSource>(this IEnumerable<TSource> sourc
3232
return new ListPartition<TSource>(sourceList, 0, count - 1);
3333
}
3434

35-
return TakeIterator(source, count);
36-
}
37-
38-
private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> source, int count)
39-
{
40-
foreach (TSource element in source)
41-
{
42-
yield return element;
43-
if (--count == 0)
44-
{
45-
break;
46-
}
47-
}
35+
return new EnumerablePartition<TSource>(source, 0, count - 1);
4836
}
4937

5038
public static IEnumerable<TSource> TakeWhile<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)

0 commit comments

Comments
 (0)