C# Linq源码分析之Take(五)

news2025/1/20 16:32:19


本文在C# Linq源码分析之Take(四)的基础上继续从源码角度分析Take的优化方法,主要分析Where.Select.Take的使用案例。



 studentList.Where(x => x.MathResult >= 90).Select(x => new {
            }).Take(3).ToList().ForEach(x=>Console.WriteLine(x.Name + x.MathResult));




private static IEnumerable<TSource> takeIterator<TSource>(IEnumerable<TSource> source, int count)
     Debug.Assert(count > 0);
        source is IPartition<TSource> partition ? partition.Take(count) :
        source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, 0, count - 1) :
        new EnumerablePartition<TSource>(source, 0, count - 1);


  public static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
     if (source == null)
     return source is IIListProvider<TSource> listProvider ? listProvider.ToList() : new List<TSource>(source);


public List<TSource> ToList()
    var list = new List<TSource>();

    using (IEnumerator<TSource> en = _source.GetEnumerator())
        if (SkipBeforeFirst(en) && en.MoveNext())
            int remaining = Limit - 1; // Max number of items left, not counting the current element.
            int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.

            while (remaining >= comparand && en.MoveNext());

    return list;
  1. 定义迭代器en,此时的_source是WhereSelectListIterator对象;
  2. 该ToList方法同样支持Skip,所以要判断迭代器的起始位置是不是从第一个开始;
  3. 每次迭代,首先从WhereSelectListIterator迭代器中返回一个符合过滤条件,并完成Selector操作的元素,存入list,直到list中包含三个元素,返回执行结果。



本文中涉及的源码请见附录,关于WhereSelectListIterator的合并优化操作,更多详细内容,请参考C# LINQ源码分析之Select



public class Student {
    public string Id { get; set; }
    public string Name { get; set; }
    public string Classroom { get; set; }
    public int MathResult { get; set; }


internal interface IIListProvider<TElement> : IEnumerable<TElement>
    TElement[] ToArray();
    List<TElement> ToList();
    int GetCount(bool onlyIfCheap);


internal interface IPartition<TElement> : IIListProvider<TElement>

    IPartition<TElement> Skip(int count);

    IPartition<TElement> Take(int count);

    TElement? TryGetElementAt(int index, out bool found);

    TElement? TryGetFirst(out bool found);

    TElement? TryGetLast(out bool found);


private sealed class EnumerablePartition<TSource> : Iterator<TSource>, IPartition<TSource>
            private readonly IEnumerable<TSource> _source;
            private readonly int _minIndexInclusive;
            private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive.
                                                     // If this is -1, it's impossible to set a limit on the count.
            private IEnumerator<TSource>? _enumerator;

            internal EnumerablePartition(IEnumerable<TSource> source, int minIndexInclusive, int maxIndexInclusive)
                Debug.Assert(source != null);
                Debug.Assert(!(source is IList<TSource>), $"The caller needs to check for {nameof(IList<TSource>)}.");
                Debug.Assert(minIndexInclusive >= 0);
                Debug.Assert(maxIndexInclusive >= -1);
                // Note that although maxIndexInclusive can't grow, it can still be int.MaxValue.
                // We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work.
                // But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow.
                Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!");
                Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive);

                _source = source;
                _minIndexInclusive = minIndexInclusive;
                _maxIndexInclusive = maxIndexInclusive;

            // If this is true (e.g. at least one Take call was made), then we have an upper bound
            // on how many elements we can have.
            private bool HasLimit => _maxIndexInclusive != -1;

            private int Limit => unchecked((_maxIndexInclusive + 1) - _minIndexInclusive); // This is that upper bound.

            public override Iterator<TSource> Clone() =>
                new EnumerablePartition<TSource>(_source, _minIndexInclusive, _maxIndexInclusive);

            public override void Dispose()
                if (_enumerator != null)
                    _enumerator = null;


            public int GetCount(bool onlyIfCheap)
                if (onlyIfCheap)
                    return -1;

                if (!HasLimit)
                    // If HasLimit is false, we contain everything past _minIndexInclusive.
                    // Therefore, we have to iterate the whole enumerable.
                    //return Math.Max(_source.Count()- _minIndexInclusive, 0);
                    return 0;

                using (IEnumerator<TSource> en = _source.GetEnumerator())
                    // We only want to iterate up to _maxIndexInclusive + 1.
                    // Past that, we know the enumerable will be able to fit this partition,
                    // so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.

                    // Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
                    // so + 1 may result in signed integer overflow. We need to handle this.
                    // At the same time, however, we are guaranteed that our max count can fit
                    // in an int because if that is true, then _minIndexInclusive must > 0.

                    uint count = SkipAndCount((uint)_maxIndexInclusive + 1, en);
                    Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
                    return Math.Max((int)count - _minIndexInclusive, 0);


            public override bool MoveNext()
                // Cases where GetEnumerator has not been called or Dispose has already
                // been called need to be handled explicitly, due to the default: clause.
                int taken = _state - 3;
                if (taken < -2)
                    return false;

                switch (_state)
                    case 1:
                        _enumerator = _source.GetEnumerator();
                        _state = 2;
                        goto case 2;
                    case 2:
                        Debug.Assert(_enumerator != null);
                        if (!SkipBeforeFirst(_enumerator))
                            // Reached the end before we finished skipping.

                        _state = 3;
                        goto default;
                        Debug.Assert(_enumerator != null);
                        if ((!HasLimit || taken < Limit) && _enumerator.MoveNext())
                            if (HasLimit)
                                // If we are taking an unknown number of elements, it's important not to increment _state.
                                // _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though
                                // we haven't finished enumerating.
                            _current = _enumerator.Current;
                            return true;


                return false;

            public override IEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector) =>
                new SelectIPartitionIterator<TSource, TResult>(this, selector);

            public IPartition<TSource> Skip(int count)
                int minIndex = unchecked(_minIndexInclusive + count);

                if (!HasLimit)
                    if (minIndex < 0)
                        // If we don't know our max count and minIndex can no longer fit in a positive int,
                        // then we will need to wrap ourselves in another iterator.
                        // This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue).
                        return new EnumerablePartition<TSource>(this, count, -1);
                else if ((uint)minIndex > (uint)_maxIndexInclusive)
                    // If minIndex overflows and we have an upper bound, we will go down this branch.
                    // We know our upper bound must be smaller than minIndex, since our upper bound fits in an int.
                    // This branch should not be taken if we don't have a bound.
                    return EmptyPartition<TSource>.Instance;

                Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows.");
                return new EnumerablePartition<TSource>(_source, minIndex, _maxIndexInclusive);

            public IPartition<TSource> Take(int count)
                int maxIndex = unchecked(_minIndexInclusive + count - 1);
                if (!HasLimit)
                    if (maxIndex < 0)
                        // If we don't know our max count and maxIndex can no longer fit in a positive int,
                        // then we will need to wrap ourselves in another iterator.
                        // Note that although maxIndex may be too large, the difference between it and
                        // _minIndexInclusive (which is count - 1) must fit in an int.
                        // Example: e.Skip(50).Take(int.MaxValue).

                        return new EnumerablePartition<TSource>(this, 0, count - 1);
                else if (unchecked((uint)maxIndex >= (uint)_maxIndexInclusive))
                    // If we don't know our max count, we can't go down this branch.
                    // It's always possible for us to contain more than count items, as the rest
                    // of the enumerable past _minIndexInclusive can be arbitrarily long.
                    return this;

                Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows.");
                return new EnumerablePartition<TSource>(_source, _minIndexInclusive, maxIndex);

            public TSource? TryGetElementAt(int index, out bool found)
                // If the index is negative or >= our max count, return early.
                if (index >= 0 && (!HasLimit || index < Limit))
                    using (IEnumerator<TSource> en = _source.GetEnumerator())
                        Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow.");

                        if (SkipBefore(_minIndexInclusive + index, en) && en.MoveNext())
                            found = true;
                            return en.Current;

                found = false;
                return default;

            public TSource? TryGetFirst(out bool found)
                using (IEnumerator<TSource> en = _source.GetEnumerator())
                    if (SkipBeforeFirst(en) && en.MoveNext())
                        found = true;
                        return en.Current;

                found = false;
                return default;

            public TSource? TryGetLast(out bool found)
                using (IEnumerator<TSource> en = _source.GetEnumerator())
                    if (SkipBeforeFirst(en) && en.MoveNext())
                        int remaining = Limit - 1; // Max number of items left, not counting the current element.
                        int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
                        TSource result;

                            result = en.Current;
                        while (remaining >= comparand && en.MoveNext());

                        found = true;
                        return result;

                found = false;
                return default;


            public List<TSource> ToList()
                var list = new List<TSource>();

                using (IEnumerator<TSource> en = _source.GetEnumerator())
                    if (SkipBeforeFirst(en) && en.MoveNext())
                        int remaining = Limit - 1; // Max number of items left, not counting the current element.
                        int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.

                        while (remaining >= comparand && en.MoveNext());

                return list;

            private bool SkipBeforeFirst(IEnumerator<TSource> en) => SkipBefore(_minIndexInclusive, en);

            private static bool SkipBefore(int index, IEnumerator<TSource> en) => SkipAndCount(index, en) == index;

            private static int SkipAndCount(int index, IEnumerator<TSource> en)
                Debug.Assert(index >= 0);
                return (int)SkipAndCount((uint)index, en);

            private static uint SkipAndCount(uint index, IEnumerator<TSource> en)
                Debug.Assert(en != null);

                for (uint i = 0; i < index; i++)
                    if (!en.MoveNext())
                        return i;

                return index;

            public TSource[] ToArray()
                throw new NotImplementedException();




