• Rust源码分析:channel内部mpsc队列


    https://zhuanlan.zhihu.com/p/50176724

    接着前面的channel的升级继续讲。

    首先,之前的upgrade过程中内存的回收要稍微注意下。因为Receiver现在指向shared::Packet之后,那个new_port需要被析构,也就是调用drop函数,我们看下drop的实现:

    impl<T> Drop for Receiver<T> {
        fn drop(&mut self) {
            match *unsafe { self.inner() } {
                Flavor::Oneshot(ref p) => p.drop_port(),
                Flavor::Stream(ref p) => p.drop_port(),
                Flavor::Shared(ref p) => p.drop_port(),
                Flavor::Sync(ref p) => p.drop_port(),
            }
        }
    }

    由于之前的swap操作,走Flavor::Oneshot路径:

        pub fn drop_port(&self) {
            match self.state.swap(DISCONNECTED, Ordering::SeqCst) {
                // An empty channel has nothing to do, and a remotely disconnected
                // channel also has nothing to do b/c we're about to run the drop
                // glue
                DISCONNECTED | EMPTY => {}
    
                // There's data on the channel, so make sure we destroy it promptly.
                // This is why not using an arc is a little difficult (need the box
                // to stay valid while we take the data).
                DATA => unsafe { (&mut *self.data.get()).take().unwrap(); },
    
                // We're the only ones that can block on this port
                _ => unreachable!()
            }
        }

    同样是DISCONNECTED替换DISCONNECTED而已,没有过多操作。

    同时不再需要的oneshot::Packet也要被析构:

    impl<T> Drop for Packet<T> {
        fn drop(&mut self) {
            assert_eq!(self.state.load(Ordering::SeqCst), DISCONNECTED);
        }
    }

    只是个DISCONNECTED的检验操作。

    所以现在Sender/Receiver都存放了Flavor::Shared(Arc<shared::Packet<T>>),之前的Flavor::Oneshot(Arc<oneshot::Packet<T>>>和临时产生的Sender/Receiver都不存在了。

    并发队列

    所以我们接着关注内在的数据结构,通过跟踪以下函数来分析:

    • Sender::send(&self, t: T)
    • Receiver::recv(&self)
    • Receiver::recv_timeout(&self, timeout: Duration)

    Sender::send(&self, t: T):

        pub fn send(&self, t: T) -> Result<(), SendError<T>> {
            let (new_inner, ret) = match *unsafe { self.inner() } {
                Flavor::Oneshot(ref p) => {
                    if !p.sent() {
                        return p.send(t).map_err(SendError);
                    } else {
                        let a = Arc::new(stream::Packet::new());
                        let rx = Receiver::new(Flavor::Stream(a.clone()));
                        match p.upgrade(rx) {
                            oneshot::UpSuccess => {
                                let ret = a.send(t);
                                (a, ret)
                            }
                            oneshot::UpDisconnected => (a, Err(t)),
                            oneshot::UpWoke(token) => {
                                // This send cannot panic because the thread is
                                // asleep (we're looking at it), so the receiver
                                // can't go away.
                                a.send(t).ok().unwrap();
                                token.signal();
                                (a, Ok(()))
                            }
                        }
                    }
                }
                Flavor::Stream(ref p) => return p.send(t).map_err(SendError),
                Flavor::Shared(ref p) => return p.send(t).map_err(SendError),
                Flavor::Sync(..) => unreachable!(),
            };
    
            unsafe {
                let tmp = Sender::new(Flavor::Stream(new_inner));
                mem::swap(self.inner_mut(), tmp.inner_mut());
            }
            ret.map_err(SendError)
        }

    事实上,对于我们的case,只有需要关注一句代码即可:

     Flavor::Shared(ref p) => return p.send(t).map_err(SendError),

    这里的p是Arc<shared::Packet<T>>的一个引用。我们继续看p.send(t):

        pub fn send(&self, t: T) -> Result<(), T> {
            // See Port::drop for what's going on
            if self.port_dropped.load(Ordering::SeqCst) { return Err(t) }
    
            if self.cnt.load(Ordering::SeqCst) < DISCONNECTED + FUDGE {
                return Err(t)
            }
    
            self.queue.push(t);
            match self.cnt.fetch_add(1, Ordering::SeqCst) {
                -1 => {
                    self.take_to_wake().signal();
                }
    
                n if n < DISCONNECTED + FUDGE => {
                    // see the comment in 'try' for a shared channel for why this
                    // window of "not disconnected" is ok.
                    self.cnt.store(DISCONNECTED, Ordering::SeqCst);
    
                    if self.sender_drain.fetch_add(1, Ordering::SeqCst) == 0 {
                        loop {
                            // drain the queue, for info on the thread yield see the
                            // discussion in try_recv
                            loop {
                                match self.queue.pop() {
                                    mpsc::Data(..) => {}
                                    mpsc::Empty => break,
                                    mpsc::Inconsistent => thread::yield_now(),
                                }
                            }
    
                            if self.sender_drain.fetch_sub(1, Ordering::SeqCst) == 1 {
                                break
                            }
                        }
                    }
                }
    
                // Can't make any assumptions about this case like in the SPSC case.
                _ => {}
            }
    
            Ok(())
        }

    同时,我们再看下shared::Packet的数据结构跟初始化信息:

    const DISCONNECTED: isize = isize::MIN;
    const FUDGE: isize = 1024;
    
    pub struct Packet<T> {
        queue: mpsc::Queue<T>,
        cnt: AtomicIsize, // How many items are on this channel
        steals: UnsafeCell<isize>, // How many times has a port received without blocking?
        to_wake: AtomicUsize, // SignalToken for wake up
    
        channels: AtomicUsize,
    
        port_dropped: AtomicBool,
        sender_drain: AtomicIsize,
    
        select_lock: Mutex<()>,
    }
        pub fn new() -> Packet<T> {
            Packet {
                queue: mpsc::Queue::new(),
                cnt: AtomicIsize::new(0),
                steals: UnsafeCell::new(0),
                to_wake: AtomicUsize::new(0),
                channels: AtomicUsize::new(2),
                port_dropped: AtomicBool::new(false),
                sender_drain: AtomicIsize::new(0),
                select_lock: Mutex::new(()),
            }
        }

    我们发现:

    • port_dropped用于标记接收端是否已经drop。
    • cnt会计数当前存入多少个数据。同时cnt通过跟DISCONNECTED的比较来判断消费者是否已断开。
    • 如果send中发现消费的一方已经断开,则会自己尝试pop所有的数据,将他们清理掉。
    • 主要的操作是通过self.queue.push(t)来完成。

    那这个self.queue是怎么实现的呢?看下它的代码,位于文件sync/mpsc/mpsc_queue.rs:

    pub struct Queue<T> {
        head: AtomicPtr<Node<T>>,
        tail: UnsafeCell<*mut Node<T>>,
    }
    unsafe impl<T: Send> Send for Queue<T> { }
    unsafe impl<T: Send> Sync for Queue<T> { }
    impl<T> Queue<T> {
        pub fn new() -> Queue<T> {
            let stub = unsafe { Node::new(None) };
            Queue {
                head: AtomicPtr::new(stub),
                tail: UnsafeCell::new(stub),
            }
        }
    
        pub fn push(&self, t: T) {
            unsafe {
                let n = Node::new(Some(t));
                let prev = self.head.swap(n, Ordering::AcqRel);
                (*prev).next.store(n, Ordering::Release);
            }
        }
    
        pub fn pop(&self) -> PopResult<T> {
            unsafe {
                let tail = *self.tail.get();
                let next = (*tail).next.load(Ordering::Acquire);
    
                if !next.is_null() {
                    *self.tail.get() = next;
                    assert!((*tail).value.is_none());
                    assert!((*next).value.is_some());
                    let ret = (*next).value.take().unwrap();
                    let _: Box<Node<T>> = Box::from_raw(tail);
                    return Data(ret);
                }
    
                if self.head.load(Ordering::Acquire) == tail {Empty} else {Inconsistent}
            }
        }
        ............
    }

    事实上,它采用了Non-intrusive MPSC node-based queue的算法,构造了一个mpsc的单向链表,感兴趣的可以通过这个链接详细了解。

    这个算法的优点是:

    • push:并发特别快,无等待并且几乎仅仅一个swap(XCHG指令)操作,通过不断地先swap成为head,然后再链接prev_head.next = head来构造链表。

    缺点是:

    • non-Linearability:不具备线性一致性,push操作会阻塞pop操作,pop操作中如果发现head != tail 同时 tail.next还没来得变为非null,那么就观察到整个队列处于不一致的状态,这种情况下这里的实现返回Inconsistent。

    同时我们看一下Node的代码:

    struct Node<T> {
        next: AtomicPtr<Node<T>>,
        value: Option<T>,
    }
    impl<T> Node<T> {
        unsafe fn new(v: Option<T>) -> *mut Node<T> {
            Box::into_raw(box Node {
                next: AtomicPtr::new(ptr::null_mut()),
                value: v,
            })
        }
    }

    相对以往不同的是new操作返回的是*mut Node<T>,这里通过Box::into_raw让使用者自己负责Node的内存释放。

    另一方面,当我们Receiver.recv()时假如channel中没有数据,那么就需要等待,所以我们再看下相关的代码:

        pub fn recv(&self) -> Result<T, RecvError> {
            loop {
                let new_port = match *unsafe { self.inner() } {
                    Flavor::Oneshot(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(oneshot::Disconnected) => return Err(RecvError),
                            Err(oneshot::Upgraded(rx)) => rx,
                            Err(oneshot::Empty) => unreachable!(),
                        }
                    }
                    Flavor::Stream(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(stream::Disconnected) => return Err(RecvError),
                            Err(stream::Upgraded(rx)) => rx,
                            Err(stream::Empty) => unreachable!(),
                        }
                    }
                    Flavor::Shared(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(shared::Disconnected) => return Err(RecvError),
                            Err(shared::Empty) => unreachable!(),
                        }
                    }
                    Flavor::Sync(ref p) => return p.recv(None).map_err(|_| RecvError),
                };
                unsafe {
                    mem::swap(self.inner_mut(), new_port.inner_mut());
                }
            }
        }

    只要看:

        pub fn recv(&self) -> Result<T, RecvError> {
            loop {
                let new_port = match *unsafe { self.inner() } {
                    .........
                    Flavor::Shared(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(shared::Disconnected) => return Err(RecvError),
                            Err(shared::Empty) => unreachable!(),
                        }
                    }
                };
                ...........
            }
        }

    接着看p.recv(),它的返回值决定了调用结果:

        pub fn recv(&self, deadline: Option<Instant>) -> Result<T, Failure> {
            // This code is essentially the exact same as that found in the stream
            // case (see stream.rs)
            match self.try_recv() {
                Err(Empty) => {}
                data => return data,
            }
    
            let (wait_token, signal_token) = blocking::tokens();
            if self.decrement(signal_token) == Installed {
                if let Some(deadline) = deadline {
                    let timed_out = !wait_token.wait_max_until(deadline);
                    if timed_out {
                        self.abort_selection(false);
                    }
                } else {
                    wait_token.wait();
                }
            }
    
            match self.try_recv() {
                data @ Ok(..) => unsafe { *self.steals.get() -= 1; data },
                data => data,
            }
        }

    这里的逻辑是,前面的self.try_recv假如返回了数据,那么直接返回数据即可。否则很可能channel为空,所以通过blocking::tokens()为Receiver准备阻塞相关的数据,然后通过decrement方法再次判断是否有数据,从而进入阻塞状态,decrement代码:

        fn decrement(&self, token: SignalToken) -> StartResult {
            unsafe {
                assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
                let ptr = token.cast_to_usize();
                self.to_wake.store(ptr, Ordering::SeqCst);
    
                let steals = ptr::replace(self.steals.get(), 0);
    
                match self.cnt.fetch_sub(1 + steals, Ordering::SeqCst) {
                    DISCONNECTED => { self.cnt.store(DISCONNECTED, Ordering::SeqCst); }
                    n => {
                        assert!(n >= 0);
                        if n - steals <= 0 { return Installed }
                    }
                }
                self.to_wake.store(0, Ordering::SeqCst);
                drop(SignalToken::cast_from_usize(ptr));
                Abort
            }
        }

    如上所示,将token: SignalToken的指针放入to_wake中,等待将来被唤醒。

    所以这里通过self.cnt字段减除1+ steals来判断队列是否为空,原因在于这里的计数方式并不是每次pop一个数据就将cnt-1,也许是为了性能考虑,我们将pop的数据个数汇总在了steals字段中,然后等到steals足够大或者发现channel为空了才去修改cnt的值。所以这里通过self.cnt - (1+ steals) 与 0 比较来判断是否已有数据,如果没有则返回Installed,否则清理数据再返回Abort。

    我们先看下Installed之后的逻辑:

            if self.decrement(signal_token) == Installed {
                if let Some(deadline) = deadline {
                    let timed_out = !wait_token.wait_max_until(deadline);
                    if timed_out {
                        self.abort_selection(false);
                    }
                } else {
                    wait_token.wait();
                }
            }

    对于我们的情况它只是调用 wait_token.wait(),代码为:

    impl WaitToken {
        pub fn wait(self) {
            while !self.inner.woken.load(Ordering::SeqCst) {
                thread::park()
            }
        }
        ...........

    先检查woken再调用park(),注意这里是与之前Send的send操作相匹配的:

        pub fn send(&self, t: T) -> Result<(), T> {
            .............
            self.queue.push(t);
            match self.cnt.fetch_add(1, Ordering::SeqCst) {
                -1 => {
                    self.take_to_wake().signal();
                }
          ..........

    我们看下相关的代码:

        fn take_to_wake(&self) -> SignalToken {
            let ptr = self.to_wake.load(Ordering::SeqCst);
            self.to_wake.store(0, Ordering::SeqCst);
            assert!(ptr != 0);
            unsafe { SignalToken::cast_from_usize(ptr) }
    impl SignalToken {
        pub fn signal(&self) -> bool {
            let wake = !self.inner.woken.compare_and_swap(false, true, Ordering::SeqCst);
            if wake {
                self.inner.thread.unpark();
            }
            wake
        }
      ....
    }

    先设置woken再调用unpark()。如此一来确保等待的Receiver不会永远睡眠。

    我们再看下decrement返回Abort的情况:

        pub fn recv(&self, deadline: Option<Instant>) -> Result<T, Failure> {
            match self.try_recv() {
                Err(Empty) => {}
                data => return data,
            }
    
            let (wait_token, signal_token) = blocking::tokens();
            if self.decrement(signal_token) == Installed {
                  .............
            }
    
            match self.try_recv() {
                data @ Ok(..) => unsafe { *self.steals.get() -= 1; data },
                data => data,
            }
        }

    只是再次调用self.try_recv()而已,至于这里为什么会有*self.steals.get()-=1的操作,那是要看try_recv操作本身了,它有一个默认steals+1的操作,但是这里的第二个self.try_recv()的计数已经cnt汇总了,所以这个不需要steals+1,我们通过-1来平衡:

        pub fn try_recv(&self) -> Result<T, Failure> {
            let ret = match self.queue.pop() {
                mpsc::Data(t) => Some(t),
                mpsc::Empty => None,
                mpsc::Inconsistent => {
                    let data;
                    loop {
                        thread::yield_now();
                        match self.queue.pop() {
                            mpsc::Data(t) => { data = t; break }
                            mpsc::Empty => panic!("inconsistent => empty"),
                            mpsc::Inconsistent => {}
                        }
                    }
                    Some(data)
                }
            };
            match ret {
                Some(data) => unsafe {
                    if *self.steals.get() > MAX_STEALS {
                        match self.cnt.swap(0, Ordering::SeqCst) {
                            DISCONNECTED => {
                                self.cnt.store(DISCONNECTED, Ordering::SeqCst);
                            }
                            n => {
                                let m = cmp::min(n, *self.steals.get());
                                *self.steals.get() -= m;
                                self.bump(n - m);
                            }
                        }
                        assert!(*self.steals.get() >= 0);
                    }
                    *self.steals.get() += 1;
                    Ok(data)
                },
                None => {
                    match self.cnt.load(Ordering::SeqCst) {
                        n if n != DISCONNECTED => Err(Empty),
                        _ => {
                            match self.queue.pop() {
                                mpsc::Data(t) => Ok(t),
                                mpsc::Empty => Err(Disconnected),
                                // with no senders, an inconsistency is impossible.
                                mpsc::Inconsistent => unreachable!(),
                            }
                        }
                    }
                }
            }
        }

    从代码中可以看到,如果pop()取得数据则直接返回;如果Empty则返回None,从而让Receiver可以陷入等待;如果Inconsistent 则说明队列处于push操作稍慢的不一致状态,我们的办法就是通过thread::yield_now(),一直调用pop()直到返回数据或者None。

    另外,的确是通过MAX_STEALS 这个字段先汇总steals的值:

            match ret {
                Some(data) => unsafe {
                    if *self.steals.get() > MAX_STEALS {
                        match self.cnt.swap(0, Ordering::SeqCst) {
                            DISCONNECTED => {
                                self.cnt.store(DISCONNECTED, Ordering::SeqCst);
                            }
                            n => {
                                let m = cmp::min(n, *self.steals.get());
                                *self.steals.get() -= m;
                                self.bump(n - m);
                            }
                        }
                        assert!(*self.steals.get() >= 0);
                    }
                    *self.steals.get() += 1;
                    Ok(data)
                },
                ...............
            }

    假如steals足够大,大于MAX_STEALS 我们才通过与cnt比较,然后从cnt中减除它。

  • 相关阅读:
    final和abstract能否共同修饰一个类
    Java三大变量分别是类变量、实例变量和局部变量
    变量的就近原则
    成员变量和局部变量
    初始化集合对象,通过contains判断是否为null
    三目表达式运算符优先级分析
    京东物流POP入仓商品关联笔记
    京东POP入仓操作笔记
    闪购活动报名笔记
    excel常用的快捷键大全
  • 原文地址:https://www.cnblogs.com/dhcn/p/12957397.html
Copyright © 2020-2023  润新知