概要
本文在C# Linq源码分析之Take(四)的基础上继续从源码角度分析Take的优化方法,主要分析Where.Select.Take的使用案例。
Where.Select.Take的案例分析
该场景模拟我们显示中将EF中与数据库关联的对象进行过滤,然后转换成Web前端需要的对象,并分页的情况。
studentList.Where(x => x.MathResult >= 90).Select(x => new {
x.Name,
x.MathResult
}).Take(3).ToList().ForEach(x=>Console.WriteLine(x.Name + x.MathResult));
找到数学90分以上的学生,获取学生的姓名和数学成绩,每次只取前三个学生。并将学生信息打印。Student类的代码请见附录。
源码流程分析
第一步进入Where方法,返回WhereListIterator对象;
第二步进入Select方法,将Where和Select两个操作合并,返回WhereSelectListIterator对象;
第三步进入Take方法,调用takeIterator方法;由于人WhereSelectListIterator并没有实现IPartition接口和IList接口,所以无法再进行操作合并,只能返回EnumerablePartition对象。
private static IEnumerable<TSource> takeIterator<TSource>(IEnumerable<TSource> source, int count)
{
Debug.Assert(count > 0);
return
source is IPartition<TSource> partition ? partition.Take(count) :
source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, 0, count - 1) :
new EnumerablePartition<TSource>(source, 0, count - 1);
}
第四步进入ToList方法
public static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
{
if (source == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}
return source is IIListProvider<TSource> listProvider ? listProvider.ToList() : new List<TSource>(source);
}
此时的source是EnumerablePartition对象,它实现了IPartition接口,而IPartition接口继承了IIListProvider接口,所以可以调用自己的ToList方法;
public List<TSource> ToList()
{
var list = new List<TSource>();
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
int remaining = Limit - 1; // Max number of items left, not counting the current element.
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
do
{
remaining--;
list.Add(en.Current);
}
while (remaining >= comparand && en.MoveNext());
}
}
return list;
}
- 定义迭代器en,此时的_source是WhereSelectListIterator对象;
- 该ToList方法同样支持Skip,所以要判断迭代器的起始位置是不是从第一个开始;
- 每次迭代,首先从WhereSelectListIterator迭代器中返回一个符合过滤条件,并完成Selector操作的元素,存入list,直到list中包含三个元素,返回执行结果。
虽然WhereSelectListIterator没有实现IPartition接口,不能实现一次迭代,完成全部操作,但是现有的流程性能并不差,因为WhereSelectListIterator迭代器本身已经合并了过滤和投影操作,而且并不需要遍历所有元素,只要找到3个符合条件的元素即可。
我认为如果代码需要用到Take方法,尽量把它放到Linq的最后。这样做的好处是前面的Linq操作并不需要遍历全部的序列元素,只要得到Take方法中需要的元素个数即可。
本文中涉及的源码请见附录,关于WhereSelectListIterator的合并优化操作,更多详细内容,请参考C# LINQ源码分析之Select
附录
Student类
public class Student {
public string Id { get; set; }
public string Name { get; set; }
public string Classroom { get; set; }
public int MathResult { get; set; }
}
IIListProvider接口
internal interface IIListProvider<TElement> : IEnumerable<TElement>
{
TElement[] ToArray();
List<TElement> ToList();
int GetCount(bool onlyIfCheap);
}
IPartition接口
internal interface IPartition<TElement> : IIListProvider<TElement>
{
IPartition<TElement> Skip(int count);
IPartition<TElement> Take(int count);
TElement? TryGetElementAt(int index, out bool found);
TElement? TryGetFirst(out bool found);
TElement? TryGetLast(out bool found);
}
EnumerablePartition类
private sealed class EnumerablePartition<TSource> : Iterator<TSource>, IPartition<TSource>
{
private readonly IEnumerable<TSource> _source;
private readonly int _minIndexInclusive;
private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive.
// If this is -1, it's impossible to set a limit on the count.
private IEnumerator<TSource>? _enumerator;
internal EnumerablePartition(IEnumerable<TSource> source, int minIndexInclusive, int maxIndexInclusive)
{
Debug.Assert(source != null);
Debug.Assert(!(source is IList<TSource>), $"The caller needs to check for {nameof(IList<TSource>)}.");
Debug.Assert(minIndexInclusive >= 0);
Debug.Assert(maxIndexInclusive >= -1);
// Note that although maxIndexInclusive can't grow, it can still be int.MaxValue.
// We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work.
// But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow.
Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!");
Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive);
_source = source;
_minIndexInclusive = minIndexInclusive;
_maxIndexInclusive = maxIndexInclusive;
}
// If this is true (e.g. at least one Take call was made), then we have an upper bound
// on how many elements we can have.
private bool HasLimit => _maxIndexInclusive != -1;
private int Limit => unchecked((_maxIndexInclusive + 1) - _minIndexInclusive); // This is that upper bound.
public override Iterator<TSource> Clone() =>
new EnumerablePartition<TSource>(_source, _minIndexInclusive, _maxIndexInclusive);
public override void Dispose()
{
if (_enumerator != null)
{
_enumerator.Dispose();
_enumerator = null;
}
base.Dispose();
}
public int GetCount(bool onlyIfCheap)
{
if (onlyIfCheap)
{
return -1;
}
if (!HasLimit)
{
// If HasLimit is false, we contain everything past _minIndexInclusive.
// Therefore, we have to iterate the whole enumerable.
//return Math.Max(_source.Count()- _minIndexInclusive, 0);
return 0;
}
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
// We only want to iterate up to _maxIndexInclusive + 1.
// Past that, we know the enumerable will be able to fit this partition,
// so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.
// Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
// so + 1 may result in signed integer overflow. We need to handle this.
// At the same time, however, we are guaranteed that our max count can fit
// in an int because if that is true, then _minIndexInclusive must > 0.
uint count = SkipAndCount((uint)_maxIndexInclusive + 1, en);
Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
return Math.Max((int)count - _minIndexInclusive, 0);
}
}
public override bool MoveNext()
{
// Cases where GetEnumerator has not been called or Dispose has already
// been called need to be handled explicitly, due to the default: clause.
int taken = _state - 3;
if (taken < -2)
{
Dispose();
return false;
}
switch (_state)
{
case 1:
_enumerator = _source.GetEnumerator();
_state = 2;
goto case 2;
case 2:
Debug.Assert(_enumerator != null);
if (!SkipBeforeFirst(_enumerator))
{
// Reached the end before we finished skipping.
break;
}
_state = 3;
goto default;
default:
Debug.Assert(_enumerator != null);
if ((!HasLimit || taken < Limit) && _enumerator.MoveNext())
{
if (HasLimit)
{
// If we are taking an unknown number of elements, it's important not to increment _state.
// _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though
// we haven't finished enumerating.
_state++;
}
_current = _enumerator.Current;
return true;
}
break;
}
Dispose();
return false;
}
public override IEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector) =>
new SelectIPartitionIterator<TSource, TResult>(this, selector);
public IPartition<TSource> Skip(int count)
{
int minIndex = unchecked(_minIndexInclusive + count);
if (!HasLimit)
{
if (minIndex < 0)
{
// If we don't know our max count and minIndex can no longer fit in a positive int,
// then we will need to wrap ourselves in another iterator.
// This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue).
return new EnumerablePartition<TSource>(this, count, -1);
}
}
else if ((uint)minIndex > (uint)_maxIndexInclusive)
{
// If minIndex overflows and we have an upper bound, we will go down this branch.
// We know our upper bound must be smaller than minIndex, since our upper bound fits in an int.
// This branch should not be taken if we don't have a bound.
return EmptyPartition<TSource>.Instance;
}
Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows.");
return new EnumerablePartition<TSource>(_source, minIndex, _maxIndexInclusive);
}
public IPartition<TSource> Take(int count)
{
int maxIndex = unchecked(_minIndexInclusive + count - 1);
if (!HasLimit)
{
if (maxIndex < 0)
{
// If we don't know our max count and maxIndex can no longer fit in a positive int,
// then we will need to wrap ourselves in another iterator.
// Note that although maxIndex may be too large, the difference between it and
// _minIndexInclusive (which is count - 1) must fit in an int.
// Example: e.Skip(50).Take(int.MaxValue).
return new EnumerablePartition<TSource>(this, 0, count - 1);
}
}
else if (unchecked((uint)maxIndex >= (uint)_maxIndexInclusive))
{
// If we don't know our max count, we can't go down this branch.
// It's always possible for us to contain more than count items, as the rest
// of the enumerable past _minIndexInclusive can be arbitrarily long.
return this;
}
Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows.");
return new EnumerablePartition<TSource>(_source, _minIndexInclusive, maxIndex);
}
public TSource? TryGetElementAt(int index, out bool found)
{
// If the index is negative or >= our max count, return early.
if (index >= 0 && (!HasLimit || index < Limit))
{
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow.");
if (SkipBefore(_minIndexInclusive + index, en) && en.MoveNext())
{
found = true;
return en.Current;
}
}
}
found = false;
return default;
}
public TSource? TryGetFirst(out bool found)
{
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
found = true;
return en.Current;
}
}
found = false;
return default;
}
public TSource? TryGetLast(out bool found)
{
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
int remaining = Limit - 1; // Max number of items left, not counting the current element.
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
TSource result;
do
{
remaining--;
result = en.Current;
}
while (remaining >= comparand && en.MoveNext());
found = true;
return result;
}
}
found = false;
return default;
}
public List<TSource> ToList()
{
var list = new List<TSource>();
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
int remaining = Limit - 1; // Max number of items left, not counting the current element.
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
do
{
remaining--;
list.Add(en.Current);
}
while (remaining >= comparand && en.MoveNext());
}
}
return list;
}
private bool SkipBeforeFirst(IEnumerator<TSource> en) => SkipBefore(_minIndexInclusive, en);
private static bool SkipBefore(int index, IEnumerator<TSource> en) => SkipAndCount(index, en) == index;
private static int SkipAndCount(int index, IEnumerator<TSource> en)
{
Debug.Assert(index >= 0);
return (int)SkipAndCount((uint)index, en);
}
private static uint SkipAndCount(uint index, IEnumerator<TSource> en)
{
Debug.Assert(en != null);
for (uint i = 0; i < index; i++)
{
if (!en.MoveNext())
{
return i;
}
}
return index;
}
public TSource[] ToArray()
{
throw new NotImplementedException();
}
}