目录
1、源码及框架分析
2、模拟实现map和set
2.1 复用的红黑树框架及Insert
2.2 iterator的实现
2.2.1 iterator的核心源码
2.2.2 iterator的实现思路
2.3 map支持[ ]
2.4 map和set的代码实现
2.4.1 MyMap.h
2.4.2 MySet.h
2.4.3 RBTree.h
2.4.4 Test.cpp
1、源码及框架分析
SGI-STL30版本源代码,map和set的源代码在map/set/stl_map.h/stl_set.h/stl_tree.h等几个头文件中。 map和set的实现结构框架核心部分截取出来如下:
// set
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_set.h>
#include <stl_multiset.h>
// map
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_map.h>
#include <stl_multimap.h>
// stl_set.h
template <class Key, class Compare = less<Key>, class Alloc = alloc>
class set {
public:
// typedefs:
typedef Key key_type;
typedef Key value_type;
private:
typedef rb_tree<key_type, value_type,
identity<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing set
};
// stl_map.h
template <class Key, class T, class Compare = less<Key>,
class Alloc = alloc>
class map {
public:
// typedefs:
typedef Key key_type;
typedef T mapped_type;
typedef pair<const Key, T> value_type;
private:
typedef rb_tree<key_type, value_type,
select1st<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing map
};
// stl_tree.h
struct __rb_tree_node_base {
typedef __rb_tree_color_type color_type;
typedef __rb_tree_node_base* base_ptr;
color_type color;
base_ptr parent;
base_ptr left;
base_ptr right;
};
// stl_tree.h
template <class Key, class Value, class KeyOfValue, class Compare,
class Alloc = alloc>
class rb_tree {
protected:
typedef void* void_pointer;
typedef __rb_tree_node_base* base_ptr;
typedef __rb_tree_node<Value> rb_tree_node;
typedef rb_tree_node* link_type;
typedef Key key_type;
typedef Value value_type;
public:
// insert
pair<iterator, bool> insert_unique(const value_type& x);
// erase and find
size_type erase(const key_type& x);
iterator find(const key_type& x);
protected:
size_type node_count; // keeps track of size of tree
link_type header;
};
template <class Value>
struct __rb_tree_node : public __rb_tree_node_base {
typedef __rb_tree_node<Value>* link_type;
Value value_field;
};
template <class Key, class Value, class KeyOfValue, class Compare,class Alloc = alloc>
删除查找用Key,插入用Value,KeyOfValue如果是一个仿函数,取Value中的Key。
2、模拟实现map和set
2.1 复用的红黑树框架及Insert
1. 这里相比源码调整一下,key参数就用K,value参数就用V,红黑树中的数据类型,我们使用T。
2. 源码中的pair的<比较,比较了key和value,但是红黑树只需要比较key,所以MyMap和MySet各自实现了一个只比较key的仿函数。MySet是为了兼容MyMap,所以也要实现。
3. const保证了不能修改key。
RBTree<K, pair<const K, V>, MapKfromT> _t;
RBTree<K, const K, SetKfromT> _t;
template<class K, class T, class KfromT>
class RBTree{};
// 源码中 pair 支持的 < 重载
//template <class T1, class T2>
//bool operator<(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
// return lhs.first < rhs.first || (!(rhs.first < lhs.first) && lhs.second < rhs.second);
//}
// Mymap.h
namespace Lzc
{
template<class K, class V>
class MyMap
{
struct MapKfromT
{
const K& operator()(const pair<const K, V>& kv)
{
return kv.first;
}
};
public:
bool insert(const pair<const K, V>& kv)
{
return _t.Insert(kv);
}
private:
RBTree<K, pair<const K, V>, MapKfromT> _t;
};
}
// Myset.h
namespace Lzc
{
template<class K>
class MySet
{
struct SetKfromT
{
const K& operator()(const K& k)
{
return k;
}
};
public:
bool insert(const K& k)
{
return _t.Insert(k);
}
private:
RBTree<K, const K, SetKfromT> _t;
};
}
// RBTree.h
namespace Lzc
{
enum Color
{
RED,
BLACK
};
template<class T>
struct RBTreeNode
{
T _data;
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
RBTreeNode<T>* _parent;
Color _col;
RBTreeNode(const T& data)
:_data(data)
, _left(nullptr)
, _right(nullptr)
, _parent(nullptr)
, _col(RED)
{ }
};
template<class K, class T, class KfromT>
class RBTree
{
typedef RBTreeNode<T> Node;
public:
KfromT KfT;
bool Insert(const T& data)
{
if (_root == nullptr)
{
_root = new Node(data);
_root->_col = BLACK;
return true;
}
Node* parent = nullptr;
Node* cur = _root;
while (cur)
{
if (KfT(data) > KfT(cur->_data))
{
parent = cur;
cur = cur->_right;
}
else if (KfT(data) < KfT(cur->_data))
{
parent = cur;
cur = cur->_left;
}
else
{
return false;
}
}
cur = new Node(data);
if (KfT(data) > KfT(parent->_data))
parent->_right = cur;
else
parent->_left = cur;
cur->_parent = parent;
while (parent && parent->_col == RED)
{
Node* grandfather = parent->_parent;
Node* uncle;
if (parent == grandfather->_left)
{
// g
// p u
uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left)
{
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else
{
// g
// u p
uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right)
{
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
if (parent == nullptr)
_root->_col = BLACK;
return true;
}
void RotateR(Node* parent)
{
Node* pParent = parent->_parent;
Node* subL = parent->_left;
Node* subLR = subL->_right;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
subL->_right = parent;
parent->_parent = subL;
subL->_parent = pParent;
if (pParent == nullptr) // 当pParent == nullptr时,_root == parent
{
_root = subL;
}
else
{
if (pParent->_left == parent)
pParent->_left = subL;
else
pParent->_right = subL;
}
}
void RotateL(Node* parent)
{
Node* pParent = parent->_parent;
Node* subR = parent->_right;
Node* subRL = subR->_left;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
subR->_left = parent;
parent->_parent = subR;
subR->_parent = pParent;
if (pParent == nullptr)
_root = subR;
else
{
if (pParent->_left == parent)
pParent->_left = subR;
else
pParent->_right = subR;
}
}
Node* Find(const K& key)
{
Node* cur = _root;
while (cur)
{
if (key > KfT(cur->_data))
cur = cur->_right;
else if (key < KfT(cur->_data))
cur = cur->_left;
else
return cur;
}
return nullptr;
}
~RBTree()
{
Destroy(_root);
_root = nullptr;
}
void Destroy(Node* root)
{
if (root == nullptr)
return;
Destroy(root->_left);
Destroy(root->_right);
delete root;
}
private:
Node* _root = nullptr;
};
}
2.2 iterator的实现
2.2.1 iterator的核心源码
typedef bool __rb_tree_color_type;
const __rb_tree_color_type __rb_tree_red = false;
const __rb_tree_color_type __rb_tree_black = true;
struct __rb_tree_base_iterator {
typedef __rb_tree_node_base::base_ptr base_ptr;
base_ptr node;
void increment() {
if (node->right != 0) {
node = node->right;
while (node->left != 0)
node = node->left;
} else {
base_ptr y = node->parent;
while (node == y->right) {
node = y;
y = y->parent;
}
if (node->right != y)
node = y;
}
}
void decrement() {
if (node->color == __rb_tree_red && node->parent->parent == node) {
node = node->right;
} else if (node->left != 0) {
base_ptr y = node->left;
while (y->right != 0)
y = y->right;
node = y;
} else {
base_ptr y = node->parent;
while (node == y->left) {
node = y;
y = y->parent;
}
node = y;
}
}
};
template <class Value, class Ref, class Ptr>
struct __rb_tree_iterator : public __rb_tree_base_iterator {
typedef Value value_type;
typedef Ref reference;
typedef Ptr pointer;
typedef __rb_tree_iterator<Value, Value&, Value*> iterator;
__rb_tree_iterator() {}
__rb_tree_iterator(link_type x) { node = x; }
__rb_tree_iterator(const iterator& it) { node = it.node; }
reference operator*() const { return link_type(node)->value_field; }
#ifndef __SGI_STL_NO_ARROW_OPERATOR
pointer operator->() const { return &(operator*()); }
#endif /* __SGI_STL_NO_ARROW_OPERATOR */
self& operator++() {
increment();
return *this;
}
self& operator--() {
decrement();
return *this;
}
inline bool operator==(const __rb_tree_base_iterator& x, const __rb_tree_base_iterator& y) {
return x.node == y.node;
}
inline bool operator!=(const __rb_tree_base_iterator& x, const __rb_tree_base_iterator& y) {
return x.node != y.node;
}
};
2.2.2 iterator的实现思路
1. 整体思路与list的iterator一致,封装节点的指针,迭代器类模板多传Ref和Ptr两个参数,一份模板实现iterator和const_iterator。
2. 重点是operator++和operator--的实现。operator++走中序遍历,左中右,
当左为空,表示左访问完了,访问中(其实只能访问中,给的节点就是访问完的中节点),
如果右不为空,在右子树中进行,左中右,访问右子树的最左节点,
如果右为空(整个子树已经访问完了,如果这个子树是外面的右子树,那么外面一层的子树也访问完了,直到子树是外面子树的左子树,左子树访问完了,访问中),就访问,当孩子是父亲左的那个父亲(祖先),相当于外层左边的子树访问完了,然后访问中。
然后更新迭代器中的节点指针,返回*this。
operator--就是走右中左,基本相同。
3. begin和end。begin就给最左节点,end给nullptr,但是,--end()呢?
所以给迭代器类模板的增加一个成员变量_root(红黑树的根节点),--end()就可以是最右节点。
2.3 map支持[ ]
map要支持[ ]主要需要修改insert返回值,
修改RBtree中的insert返回值为pair<Iterator,bool> Insert(const T& data),
插入失败,就返回相同的key的value的引用。
插入成功,就返回key的value(默认值)的引用。
2.4 map和set的代码实现
2.4.1 MyMap.h
#pragma once
#include "RBTree.h"
namespace Lzc
{
template<class K, class V>
class MyMap
{
struct MapKfromT
{
const K& operator()(const pair<const K, V>& kv)
{
return kv.first;
}
};
public:
typedef typename RBTree<K, pair<const K, V>, MapKfromT>::Iterator iterator;
typedef typename RBTree<K, pair<const K, V>, MapKfromT>::ConstIterator const_iterator;
pair<iterator, bool> insert(const pair<const K, V>& kv)
{
return _t.Insert(kv);
}
V& operator[](const K& k)
{
iterator ret = _t.Insert({ k, V() }).first;
return ret->second;
}
iterator begin()
{
return _t.Begin();
}
iterator end()
{
return _t.End();
}
const_iterator begin() const
{
return _t.Begin();
}
const_iterator end() const
{
return _t.End();
}
private:
RBTree<K, pair<const K, V>, MapKfromT> _t;
};
}
2.4.2 MySet.h
#pragma once
#include "RBTree.h"
namespace Lzc
{
template<class K>
class MySet
{
struct SetKfromT
{
const K& operator()(const K& k)
{
return k;
}
};
public:
typedef typename RBTree<K, const K, SetKfromT>::Iterator iterator;
typedef typename RBTree<K, const K, SetKfromT>::ConstIterator const_iterator;
pair<iterator, bool> insert(const K& k)
{
return _t.Insert(k);
}
iterator begin()
{
return _t.Begin();
}
iterator end()
{
return _t.End();
}
const_iterator begin() const
{
return _t.Begin();
}
const_iterator end() const
{
return _t.End();
}
private:
RBTree<K, const K, SetKfromT> _t;
};
}
2.4.3 RBTree.h
#pragma once
#include <iostream>
#include <assert.h>
using namespace std;
namespace Lzc
{
enum Color
{
RED,
BLACK
};
template<class T>
struct RBTreeNode
{
T _data;
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
RBTreeNode<T>* _parent;
Color _col;
RBTreeNode(const T& data)
:_data(data)
, _left(nullptr)
, _right(nullptr)
, _parent(nullptr)
, _col(RED)
{ }
};
template<class T, class Ref, class Ptr>
struct RBTreeIterator
{
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, Ref, Ptr> Self;
Node* _node;
Node* _root;
RBTreeIterator(Node* node, Node* root)
:_node(node)
, _root(root)
{}
Self& operator++()
{
if (_node->_right)
{
Node* cur = _node->_right;
while (cur->_left)
{
cur = cur->_left;
}
_node = cur;
}
else
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_right)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Self& operator--()
{
// --end,因为end == nullptr,所以最右节点需要_root
if (_node == nullptr)
{
Node* MostRight = _root;
while (MostRight->_right)
{
MostRight = MostRight->_right;
}
_node = MostRight;
}
else if (_node->_left)
{
Node* cur = _node->_left;
while (cur->_right)
{
cur = cur->_right;
}
_node = cur;
}
else
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_left)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Ref operator*()
{
return _node->_data;
}
Ptr operator->()
{
return &(_node->_data);
}
bool operator!=(const Self& s) const
{
return _node != s._node;
}
bool operator==(const Self& s) const
{
return _node == s._node;
}
};
template<class K, class T, class KfromT>
class RBTree
{
typedef RBTreeNode<T> Node;
public:
typedef RBTreeIterator<T, T&, T*> Iterator;
typedef RBTreeIterator<T, const T&, const T*> ConstIterator;
Iterator Begin()
{
Node* cur = _root;
while (cur && cur->_left)
{
cur = cur->_left;
}
return { cur,_root };
}
Iterator End()
{
return { nullptr,_root };
}
ConstIterator Begin() const
{
Node* cur = _root;
while (cur && cur->_left)
{
cur = cur->_left;
}
return { cur,_root };
}
ConstIterator End() const
{
return { nullptr,_root };
}
KfromT KfT;
pair<Iterator, bool> Insert(const T& data)
{
if (_root == nullptr)
{
_root = new Node(data);
_root->_col = BLACK;
return { Iterator(_root,_root),true };
}
Node* parent = nullptr;
Node* cur = _root;
while (cur)
{
if (KfT(data) > KfT(cur->_data))
{
parent = cur;
cur = cur->_right;
}
else if (KfT(data) < KfT(cur->_data))
{
parent = cur;
cur = cur->_left;
}
else
{
return { Iterator(cur,_root),false };
}
}
cur = new Node(data);
Node* newnode = cur; // cur可能后面会更新
if (KfT(data) > KfT(parent->_data))
parent->_right = cur;
else
parent->_left = cur;
cur->_parent = parent;
while (parent && parent->_col == RED)
{
Node* grandfather = parent->_parent;
Node* uncle;
if (parent == grandfather->_left)
{
// g
// p u
uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left)
{
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else
{
// g
// u p
uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right)
{
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
if (parent == nullptr)
_root->_col = BLACK;
return { Iterator(newnode,_root),true };
}
void RotateR(Node* parent)
{
Node* pParent = parent->_parent;
Node* subL = parent->_left;
Node* subLR = subL->_right;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
subL->_right = parent;
parent->_parent = subL;
subL->_parent = pParent;
if (pParent == nullptr) // 当pParent == nullptr时,_root == parent
{
_root = subL;
}
else
{
if (pParent->_left == parent)
pParent->_left = subL;
else
pParent->_right = subL;
}
}
void RotateL(Node* parent)
{
Node* pParent = parent->_parent;
Node* subR = parent->_right;
Node* subRL = subR->_left;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
subR->_left = parent;
parent->_parent = subR;
subR->_parent = pParent;
if (pParent == nullptr)
_root = subR;
else
{
if (pParent->_left == parent)
pParent->_left = subR;
else
pParent->_right = subR;
}
}
Node* Find(const K& key)
{
Node* cur = _root;
while (cur)
{
if (key > KfT(cur->_data))
cur = cur->_right;
else if (key < KfT(cur->_data))
cur = cur->_left;
else
return cur;
}
return nullptr;
}
~RBTree()
{
Destroy(_root);
_root = nullptr;
}
void Destroy(Node* root)
{
if (root == nullptr)
return;
Destroy(root->_left);
Destroy(root->_right);
delete root;
}
private:
Node* _root = nullptr;
};
}
2.4.4 Test.cpp
#include "MySet.h"
#include "MyMap.h"
// 遍历 MyMap
void TestMapIterator()
{
Lzc::MyMap<int, string> map;
map.insert({ 1, "one" });
map.insert({ 2, "two" });
map.insert({ 3, "three" });
cout << "Testing MyMap iterator:" << endl;
for (auto it = map.begin(); it != map.end(); ++it)
{
cout << "Key: " << it->first << ", Value: " << it->second << endl;
}
cout << "-----------------------------" << endl;
}
// 反向遍历 MyMap
void TestMapReverseIterator()
{
Lzc::MyMap<int, string> map;
map.insert({ 5, "five" });
map.insert({ 3, "three" });
map.insert({ 7, "seven" });
auto it = map.end();
--it; // 移动到最后一个元素
cout << "Testing MyMap reverse iterator:" << endl;
while (it != map.begin())
{
cout << "Key: " << it->first << ", Value: " << it->second << endl;
--it;
}
cout << "Key: " << it->first << ", Value: " << it->second << endl; // 打印第一个元素
cout << "-----------------------------" << endl;
}
// 测试 operator[] 和迭代器
void TestMapOperatorBracket()
{
Lzc::MyMap<int, string> map;
map[1] = "one";
map[2] = "two";
map[3] = "three";
cout << "Testing MyMap operator[] and iterator:" << endl;
for (auto it = map.begin(); it != map.end(); ++it)
{
cout << "Key: " << it->first << ", Value: " << it->second << endl;
}
cout << "-----------------------------" << endl;
}
// 遍历 MySet
void TestSetIterator()
{
Lzc::MySet<int> set;
set.insert(10);
set.insert(20);
set.insert(30);
cout << "Testing MySet iterator:" << endl;
for (auto it = set.begin(); it != set.end(); ++it)
{
cout << "Key: " << *it << endl;
}
cout << "-----------------------------" << endl;
}
// 反向遍历 MySet
void TestSetReverseIterator()
{
Lzc::MySet<int> set;
set.insert(50);
set.insert(30);
set.insert(70);
auto it = set.end();
--it; // 移动到最后一个元素
cout << "Testing MySet reverse iterator:" << endl;
while (it != set.begin())
{
cout << "Key: " << *it << endl;
--it;
}
cout << "Key: " << *it << endl; // 打印第一个元素
cout << "-----------------------------" << endl;
}
// 测试空 MySet 的迭代器
void TestEmptySetIterator()
{
Lzc::MySet<int> set;
auto it = set.begin();
auto end = set.end();
cout << "Testing empty MySet iterator:" << endl;
if (it == end)
{
cout << "Set is empty, begin() == end()" << endl;
}
else
{
cout << "Set is not empty" << endl;
}
cout << "-----------------------------" << endl;
}
void RunIteratorTests()
{
TestMapIterator();
TestMapReverseIterator();
TestMapOperatorBracket();
TestSetIterator();
TestSetReverseIterator();
TestEmptySetIterator();
}
int main()
{
RunIteratorTests();
return 0;
}