一、Michael & Scoot 原版伪代码实现
structure pointer_t {ptr: pointer to node_t, count: unsigned integer}
structure node_t {value: data type, next: pointer_t}
structure queue_t {Head: pointer_t, Tail: pointer_t}
initialize(Q: pointer to queue_t)
node = new_node() // Allocate a free node
node->next.ptr = NULL // Make it the only node in the linked list
Q->Head.ptr = Q->Tail.ptr = node // Both Head and Tail point to it
enqueue(Q: pointer to queue_t, value: data type)
E1: node = new_node() // Allocate a new node from the free list
E2: node->value = value // Copy enqueued value into node
E3: node->next.ptr = NULL // Set next pointer of node to NULL
E4: loop // Keep trying until Enqueue is done
E5: tail = Q->Tail // Read Tail.ptr and Tail.count together
E6: next = tail.ptr->next // Read next ptr and count fields together
E7: if tail == Q->Tail // Are tail and next consistent?
// Was Tail pointing to the last node?
E8: if next.ptr == NULL
// Try to link node at the end of the linked list
E9: if CAS(&tail.ptr->next, next, <node, next.count+1>)
E10: break // Enqueue is done. Exit loop
E11: endif
E12: else // Tail was not pointing to the last node
// Try to swing Tail to the next node
E13: CAS(&Q->Tail, tail, <next.ptr, tail.count+1>)
E14: endif
E15: endif
E16: endloop
// Enqueue is done. Try to swing Tail to the inserted node
E17: CAS(&Q->Tail, tail, <node, tail.count+1>)
dequeue(Q: pointer to queue_t, pvalue: pointer to data type): boolean
D1: loop // Keep trying until Dequeue is done
D2: head = Q->Head // Read Head
D3: tail = Q->Tail // Read Tail
D4: next = head.ptr->next // Read Head.ptr->next
D5: if head == Q->Head // Are head, tail, and next consistent?
D6: if head.ptr == tail.ptr // Is queue empty or Tail falling behind?
D7: if next.ptr == NULL // Is queue empty?
D8: return FALSE // Queue is empty, couldn't dequeue
D9: endif
// Tail is falling behind. Try to advance it
D10: CAS(&Q->Tail, tail, <next.ptr, tail.count+1>)
D11: else // No need to deal with Tail
// Read value before CAS
// Otherwise, another dequeue might free the next node
D12: *pvalue = next.ptr->value
// Try to swing Head to the next node
D13: if CAS(&Q->Head, head, <next.ptr, head.count+1>)
D14: break // Dequeue is done. Exit loop
D15: endif
D16: endif
D17: endif
D18: endloop
D19: free(head.ptr) // It is safe now to free the old node
D20: return TRUE // Queue was not empty, dequeue succeeded
二、C++实现
c++的实现就直接看上述的伪代码跟着实现即可,这里的一些atomic操作也可以看我之前写的博客
template<typename T>
class LockFreeQueue {
private:
// 队列结构
struct Node {
std::shared_ptr<T> data;
std::atomic<Node*> next;
Node() : next(nullptr) {};
};
std::atomic<Node*> head; // 头节点
std::atomic<Node*> tail; // 尾节点
Node* dummy; // 用于回收节点的哑节点
public:
LockFreeQueue() {
dummy = new Node();
head.store(dummy);
tail.store(dummy);
}
~LockFreeQueue() {
T output;
while (dequeue(output)) {}
delete dummy;
}
// 禁止拷贝构造和赋值
LockFreeQueue(const LockFreeQueue&) = delete;
LockFreeQueue& operator=(const LockFreeQueue&) = delete;
void enqueue(const T& value) {
std::shared_ptr<T> new_data(std::make_shared<T>(value)); // 创建数据
Node* new_node = new Node(); // 创建新节点
Node* old_tail;
while (true) {
old_tail = tail.load();
Node* next = old_tail->next.load();
if (old_tail == tail.load()) { // 此时保证tail还没有被其他线程改变
if (next == nullptr) { // 此时保证tail是队列的最后一个节点
if (old_tail->next.compare_exchange_weak(next, new_node)) break; // 插入成功,这是一个原语,可以一次操作
}
else { // 说明tail落后了,推进tail指针
tail.compare_exchange_weak(old_tail, next);
}
}
}
old_tail->data = new_data;
tail.compare_exchange_weak(old_tail, new_node);
}
bool dequeue(T& value) {
Node* old_head;
while (true) {
old_head = head.load(); // 获取当前头部节点
Node* old_tail = tail.load(); // 获取当前尾部节点
Node* next = old_head->next.load();
if (old_head == head.load()) { // 确保head没有被改变
if (old_head == old_tail) { // 队列为空,或者tail落后
if (next == nullptr) { // 队列为空
return false;
}
tail.compare_exchange_weak(old_tail, next);
} else {
// 从队列中移除head,并且读取其值
if (next->data) {
value = *(next->data);
if (head.compare_exchange_weak(old_head, next)) break; // 清除head节点,退出循环
}
}
}
}
delete old_head;
return true;
}
};
三、测试
测试代码如下:
int main() {
LockFreeQueue<int> queue;
std::thread t1([&queue](){
for(int i=0; i<100; i++) {
queue.enqueue(i);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
});
std::thread t2([&queue](){
for(int i=100; i<200; i++) {
queue.enqueue(i);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
});
std::thread t3([&queue]() {
while (true) {
int i = 0;
if (queue.dequeue(i)) {
std::cout << "t3:" << i << std::endl;
}
}
});
std::thread t4([&queue]() {
while (true) {
int i = 0;
if (queue.dequeue(i)) {
std::cout << "t4:" << i << std::endl;
}
}
});
t1.join();
t2.join();
t3.join();
t4.join();
}
可以看到,多线程下可以顺利的插入与找到数据