概要
本文主要对Take的优化方法进行源码分析,分析Take在配合Select,Where等常用的Linq扩展方法使用时候,如何实现优化处理。
本文涉及到Select, Where和Take和三个方法的源码分析,其中Select, Where, Take更详尽的源码分析,请参考我之前写的文章。
源码分析
我们之前对Take的源码分析,主要是真对数据序列对象直接调用的Take方法的情况。本文介绍的Take优化方法,主要是针对多个Linq方法配合使用的情况,例如xx.Where.Select.Take或者xx.Select.Take的情况。
Take的优化方法,是定义在Take.SpeedOpt.cs文件中,体现在下面的TakeIterator方法中,Take方法对TakeIterator方法的具体调用方式,请参考C# Linq源码分析之Take方法
private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> source, int count)
{
Debug.Assert(source != null);
Debug.Assert(count > 0);
return
source is IPartition<TSource> partition ? partition.Take(count) :
source is IList<TSource> sourceList ? new ListPartition<TSource>(sourceList, 0, count - 1) :
new EnumerablePartition<TSource>(source, 0, count - 1);
}
该优化方法的基本逻辑如下:
- 如果source序列实现了IPartition接口,则调用IPartition中的Take方法;
- 如果source序列实现了IList接口,则返回ListPartition对象;
- 返回EnumerablePartition对象。
案例分析
Select.Take
该场景模拟我们显示中将EF中与数据库关联的对象,转换成Web前端需要的对象,并分页的情况。
将Student对象中的学生姓名和考试成绩取出后,并分页,Student类的定义和初始化请见附录。
studentList.Select(x => new {
x.Name, x.MathResult
}).take(3).ToList();
执行流程如上图所示:
- 进入Select方法后返回一个SelectListIterator对象,该对象存储了source序列数据和selector;
- 进入Take方法后, 因为SelectListIterator实现了IPartition接口,因此可以调用自己的Take方法,使用source,selector和count(Take方法的参数)实例化SelectListPartitionIterator对象;
public IPartition<TResult> Take(int count)
{
Debug.Assert(count > 0);
int maxIndex = _minIndexInclusive + count - 1;
return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, maxIndex);
}
- 进入ToList方法后,根据下面的代码,因为SelectListPartitionIterator对象实现了IPartition接口,而IPartition又继承了IIListProvider接口,所以listProvider.ToList()被执行。
public static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
{
if (source == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}
return source is IIListProvider<TSource> listProvider ? listProvider.ToList() : new List<TSource>(source);
}
listProvider.ToList() 被执行,即SelectListPartitionIterator对象内定义的ToList()方法被执行,该ToList方法可以将
Selecxt的投影操作,Take的取值操作同时执行,代码如下:
public List<TResult> ToList()
{
int count = Count;
if (count == 0)
{
return new List<TResult>();
}
List<TResult> list = new List<TResult>(count);
int end = _minIndexInclusive + count;
for (int i = _minIndexInclusive; i != end; ++i)
{
list.Add(_selector(_source[i]));
}
return list;
}
- 如果Take的取值为0,直接返回空List,不会进行Select操作;
- 按照Take中指定的Count,取到相应的元素,再进行投影操作;
- 返回操作结果。
这样,源List序列source,Linq只遍历和一次,就同时完成了投影和过滤两个操作,实现了优化。
附录
Student类
public class Student {
public string Id { get; set; }
public string Name { get; set; }
public string Classroom { get; set; }
public int MathResult { get; set; }
}
IIListProvider接口
internal interface IIListProvider<TElement> : IEnumerable<TElement>
{
TElement[] ToArray();
List<TElement> ToList();
int GetCount(bool onlyIfCheap);
}
IPartition接口
internal interface IPartition<TElement> : IIListProvider<TElement>
{
IPartition<TElement> Skip(int count);
IPartition<TElement> Take(int count);
TElement? TryGetElementAt(int index, out bool found);
TElement? TryGetFirst(out bool found);
TElement? TryGetLast(out bool found);
}
SelectListPartitionIterator类
private sealed class SelectListPartitionIterator<TSource, TResult> : Iterator<TResult>, IPartition<TResult>
{
private readonly IList<TSource> _source;
private readonly Func<TSource, TResult> _selector;
private readonly int _minIndexInclusive;
private readonly int _maxIndexInclusive;
public SelectListPartitionIterator(IList<TSource> source, Func<TSource, TResult> selector, int minIndexInclusive, int maxIndexInclusive)
{
Debug.Assert(source != null);
Debug.Assert(selector != null);
Debug.Assert(minIndexInclusive >= 0);
Debug.Assert(minIndexInclusive <= maxIndexInclusive);
_source = source;
_selector = selector;
_minIndexInclusive = minIndexInclusive;
_maxIndexInclusive = maxIndexInclusive;
}
public override Iterator<TResult> Clone() =>
new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, _maxIndexInclusive);
public override bool MoveNext()
{
// _state - 1 represents the zero-based index into the list.
// Having a separate field for the index would be more readable. However, we save it
// into _state with a bias to minimize field size of the iterator.
int index = _state - 1;
if (unchecked((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive))
{
_current = _selector(_source[_minIndexInclusive + index]);
++_state;
return true;
}
Dispose();
return false;
}
public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector) =>
new SelectListPartitionIterator<TSource, TResult2>(_source, CombineSelectors(_selector, selector), _minIndexInclusive, _maxIndexInclusive);
public IPartition<TResult> Skip(int count)
{
Debug.Assert(count > 0);
int minIndex = _minIndexInclusive + count;
return (uint)minIndex > (uint)_maxIndexInclusive ? EmptyPartition<TResult>.Instance : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, minIndex, _maxIndexInclusive);
}
public IPartition<TResult> Take(int count)
{
Debug.Assert(count > 0);
int maxIndex = _minIndexInclusive + count - 1;
return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator<TSource, TResult>(_source, _selector, _minIndexInclusive, maxIndex);
}
public TResult? TryGetElementAt(int index, out bool found)
{
if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive)
{
found = true;
return _selector(_source[_minIndexInclusive + index]);
}
found = false;
return default;
}
public TResult? TryGetFirst(out bool found)
{
if (_source.Count > _minIndexInclusive)
{
found = true;
return _selector(_source[_minIndexInclusive]);
}
found = false;
return default;
}
public TResult? TryGetLast(out bool found)
{
int lastIndex = _source.Count - 1;
if (lastIndex >= _minIndexInclusive)
{
found = true;
return _selector(_source[Math.Min(lastIndex, _maxIndexInclusive)]);
}
found = false;
return default;
}
private int Count
{
get
{
int count = _source.Count;
if (count <= _minIndexInclusive)
{
return 0;
}
return Math.Min(count - 1, _maxIndexInclusive) - _minIndexInclusive + 1;
}
}
public TResult[] ToArray()
{
int count = Count;
if (count == 0)
{
return Array.Empty<TResult>();
}
TResult[] array = new TResult[count];
for (int i = 0, curIdx = _minIndexInclusive; i != array.Length; ++i, ++curIdx)
{
array[i] = _selector(_source[curIdx]);
}
return array;
}
public List<TResult> ToList()
{
int count = Count;
if (count == 0)
{
return new List<TResult>();
}
List<TResult> list = new List<TResult>(count);
int end = _minIndexInclusive + count;
for (int i = _minIndexInclusive; i != end; ++i)
{
list.Add(_selector(_source[i]));
}
return list;
}
public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.
int count = Count;
if (!onlyIfCheap)
{
int end = _minIndexInclusive + count;
for (int i = _minIndexInclusive; i != end; ++i)
{
_selector(_source[i]);
}
}
return count;
}
}
SelectListIterator类 Select.cs
private sealed partial class SelectListIterator<TSource, TResult> : Iterator<TResult>
{
private readonly List<TSource> _source;
private readonly Func<TSource, TResult> _selector;
private List<TSource>.Enumerator _enumerator;
public SelectListIterator(List<TSource> source, Func<TSource, TResult> selector)
{
Debug.Assert(source != null);
Debug.Assert(selector != null);
_source = source;
_selector = selector;
}
private int CountForDebugger => _source.Count;
public override Iterator<TResult> Clone() => new SelectListIterator<TSource, TResult>(_source, _selector);
public override bool MoveNext()
{
switch (_state)
{
case 1:
_enumerator = _source.GetEnumerator();
_state = 2;
goto case 2;
case 2:
if (_enumerator.MoveNext())
{
_current = _selector(_enumerator.Current);
return true;
}
Dispose();
break;
}
return false;
}
public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector) =>
new SelectListIterator<TSource, TResult2>(_source, CombineSelectors(_selector, selector));
}
Select.SpeedOpt.cs
private sealed partial class SelectListIterator<TSource, TResult> : IPartition<TResult>
{
public TResult[] ToArray()
{
int count = _source.Count;
if (count == 0)
{
return Array.Empty<TResult>();
}
var results = new TResult[count];
for (int i = 0; i < results.Length; i++)
{
results[i] = _selector(_source[i]);
}
return results;
}
public List<TResult> ToList()
{
int count = _source.Count;
var results = new List<TResult>(count);
for (int i = 0; i < count; i++)
{
results.Add(_selector(_source[i]));
}
return results;
}
public int GetCount(bool onlyIfCheap)
{
// In case someone uses Count() to force evaluation of
// the selector, run it provided `onlyIfCheap` is false.
int count = _source.Count;
if (!onlyIfCheap)
{
for (int i = 0; i < count; i++)
{
_selector(_source[i]);
}
}
return count;
}
public IPartition<TResult> Skip(int count)
{
Debug.Assert(count > 0);
return new SelectListPartitionIterator<TSource, TResult>(_source, _selector, count, int.MaxValue);
}
public IPartition<TResult> Take(int count)
{
Debug.Assert(count > 0);
return new SelectListPartitionIterator<TSource, TResult>(_source, _selector, 0, count - 1);
}
public TResult? TryGetElementAt(int index, out bool found)
{
if (unchecked((uint)index < (uint)_source.Count))
{
found = true;
return _selector(_source[index]);
}
found = false;
return default;
}
public TResult? TryGetFirst(out bool found)
{
if (_source.Count != 0)
{
found = true;
return _selector(_source[0]);
}
found = false;
return default;
}
public TResult? TryGetLast(out bool found)
{
int len = _source.Count;
if (len != 0)
{
found = true;
return _selector(_source[len - 1]);
}
found = false;
return default;
}
}