java AbstractQueuedSynchronizer 解析

308 阅读9分钟

AbstractQueuedSynchronizer(AQS)是Java中的一个重要的同步框架,位于java.util.concurrent包下,AQS提供了一种基本的框架,可以用来构建自定义的同步器,如 锁、信号量、计数器等。它是实现多线程并发控制的关键组件,广泛用于Java并发编程中。

  • 独占锁常常用于实现互斥访问,例如 ReentrantLock 就是一个独占锁。
  • 共享锁常常用于实现读写锁,例如 ReentrantReadWriteLock,它允许多个线程同时读取共享资源,但只有一个线程能够写入资源。

AQS 的核心思想是使用一个等待队列来管理线程的竞争和等待状态。它通过内部的状态变量来表示对象的状态,并使用FIFO(先进先出)等待队列来管理等待线程。

如果需要使用 AbstractQueuedSynchronizer 需要覆写以下方法

public class Sync extends AbstractQueuedSynchronizer {
    // 尝试获取独占同步状态
    @Override
    protected boolean tryAcquire(int arg) {
        return super.tryAcquire(arg);
    }
   // 尝试释放独占同步状态
    @Override
    protected boolean tryRelease(int arg) {
        return super.tryRelease(arg);
    }
    // 尝试获取共享同步状态
    @Override
    protected int tryAcquireShared(int arg) {
        return super.tryAcquireShared(arg);
    }
    // 尝试释放共享同步状态
    @Override
    protected boolean tryReleaseShared(int arg) {
        return super.tryReleaseShared(arg);
    }
   // 获取是否有线程持有独占同步状态
    @Override
    protected boolean isHeldExclusively() {
        return super.isHeldExclusively();
    }
}

使用示例

public class Mutex implements Lock {

    private static class Sync extends AbstractQueuedSynchronizer {
        // 判断是否有线程持有独占同步状态
        @Override
        protected boolean isHeldExclusively() {
            return getState() == 1;
        }

        // 尝试获取独占同步状态,如果当前状态为 0,说明没有被其他线程持有, 
        // CAS 把 state 设置为 1 成功 把当前线程设置为持有线程独占同步状态的线程,返回 true 。
        // 如果当前状态为 1 ,说明其他线程持有同步状态, 则 CAS 失败 返回 false 
        @Override
        public boolean tryAcquire(int acquires) {
            if (compareAndSetState(0, 1)) {
                setExclusiveOwnerThread(Thread.currentThread());
                return true;
            }
            return false;
        }

        //尝试释放独占同步状态
        @Override
        protected boolean tryRelease(int releases) {
            // 如果没有线程持有独占同步状态进行释放会抛出移除
            if (getState() == 0) throw new IllegalMonitorStateException();
            setExclusiveOwnerThread(null); 
            setState(0);
            return true;
        }

        // 提供一个 Condition 对象 
        Condition newCondition() {
            return new ConditionObject();
        }
    }

    // The sync object does all the hard work. We just forward to it.
    private final Sync sync = new Sync();

    @Override
    public void lock() {
        sync.acquire(1); // 获取独占同步状态
    }

    @Override
    public boolean tryLock() {
        return sync.tryAcquire(1); // 尝试获取独占同步状态
    }

    @Override
    public void unlock() {
        sync.release(1); //释放独占同步状态
    }

    @Override
    public Condition newCondition() {
        return sync.newCondition();
    }

    @Override
    public void lockInterruptibly() throws InterruptedException {
        sync.acquireInterruptibly(1);
    }

    @Override
    public boolean tryLock(long timeout, TimeUnit unit)
            throws InterruptedException {
        return sync.tryAcquireNanos(1, unit.toNanos(timeout));
    }

    private static int VALUE = 0;

    public static void main(String[] args) {
        Mutex mutex = new Mutex();
        // 使用 Mutex 来创建一个简单的互斥区域
        Runnable task = () -> {
            mutex.lock(); // 获取锁
            try {
                // 在这里执行互斥区域的代码
                for (int i = 0; i < 100000; i++) {
                    VALUE++;
                }
                System.out.println(Thread.currentThread().getName() + " finish");
            } finally {
                mutex.unlock(); // 释放锁
            }
        };
        // 创建多个线程并启动
        Thread thread1 = new Thread(task, "Thread_1");
        Thread thread2 = new Thread(task, "Thread_2");

        thread1.start();
        thread2.start();
        
        // 等待两个两个线程执行完成
        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("VALUE = " + VALUE);
    }
}

运行结果

Thread_1 finish
Thread_2 finish
VALUE = 200000

源码解析

AbstractQueuedSynchronizer 内部类 Node


static final class Node {
 
    // 表示节点在共享模式下等待
    static final Node SHARED = new Node();
    // 表示节点在独占模式下等待
    static final Node EXCLUSIVE = null;

    // waitStatus 的值表示线程已被取消
    static final int CANCELLED =  1;
    // waitStatus 的值表示线程被阻塞,需要被唤醒
    static final int SIGNAL    = -1;
   // waitStatus 的值表示线程正在等待条件
    static final int CONDITION = -2;
    // waitStatus 的值表示下一个 acquireShared 应该无条件传播
    static final int PROPAGATE = -3;
    // 表示当前节点的等待状态
    volatile int waitStatus;
    // 表示当前节点的上一个节点
    volatile Node prev;
    // 表示当前节点的下一个节点
    volatile Node next;
    // 当前节点的线程
    volatile Thread thread;
   
    // 链接到等待条件的下一个节点,或特殊值 SHARED。 
    Node nextWaiter;
}

AbstractQueuedSynchronizer 成员变量


// 队列的头节点
private transient volatile Node head;
// 队列的尾节点
private transient volatile Node tail;
// 队列的当前状态
private volatile int state;

独占同步状态

获取独占同步状态

// 获取独占同步状态
public final void acquire(int arg) {
    // 若子类实现 tryAcquire 返回为 true ,表明当前线程已获取独占同步状态, 
    // if 直接为 false,此方法什么都不做,就可以进入临界区
    // 返回 false ,则表明其他线程已获取独占同步状态,会调用 acquireQueued()方法
    if (!tryAcquire(arg) &&
        acquireQueued(addWaiter(Node.EXCLUSIVE), arg))
        selfInterrupt();
}

// 传入 Node.EXCLUSIVE 表明为独占模式
private Node addWaiter(Node mode) {
    // 创建一个 nextWaiter 为 Node.EXCLUSIVE ,thread 设置为当前线程 的节点
    Node node = new Node(mode);
	// 无限循环
    for (;;) {
        // 当队列未初始化时,tail 为 null ,进行队列初始化
        // 初始化之后 tail 不为 null
        Node oldTail = tail;
        if (oldTail != null) {
            // 把 node 的 prev 设置为 oldTail
            node.setPrevRelaxed(oldTail);
            // CAS node 设置队列的 tail 节点
            if (compareAndSetTail(oldTail, node)) {
                oldTail.next = node; // oldTail 的 next 设置为 node
                return node; // 返回 node
            }
        } else {
            // 初始化队列
            initializeSyncQueue();
        }
    }
}

private final void initializeSyncQueue() {
    Node h;
    // 新建一个节点,并赋值给 h ,CAS 判断头节点为空,并且把 h 赋值给头节点,如果赋值成功
    // 并把 h 赋值给 tail 尾节点 
    if (HEAD.compareAndSet(this, null, (h = new Node())))
        tail = h;
}


final boolean acquireQueued(final Node node, int arg) {
    boolean interrupted = false; 
    try {
        for (;;) {
            // 获取 node 的前一个节点
            final Node p = node.predecessor();
            // 如果前一个节点是 头节点,再次调用 tryAcquire 
            if (p == head && tryAcquire(arg)) {
                setHead(node); // 把当前节点设置为头节点,相当于删除了之前的头节点
                p.next = null; // help GC
                return interrupted;
            }
            // 第一轮循环时 , 返回 false ,第二次循环返回 true 
            if (shouldParkAfterFailedAcquire(p, node))
                interrupted |= parkAndCheckInterrupt(); // 阻塞线程
        }
    } catch (Throwable t) {
        cancelAcquire(node);
        if (interrupted)
            selfInterrupt();
        throw t;
    }
}

private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
    // 获取前一个节点的等待状态,若 node 是第一个插入的节点,pred 为 head的节点,ws 为 0
    // 再次调用时,ws 已设置为 Node.SIGNAL ,返回 true
    int ws = pred.waitStatus;
    if (ws == Node.SIGNAL)
        return true;
    if (ws > 0) { // 大于 0 只有 Node.CANCELLED
        // 删除队列中 waitStatus 状态为 Node.CANCELLED 的节点
        do {
            node.prev = pred = pred.prev; 
        } while (pred.waitStatus > 0); 
        pred.next = node;
    } else {
        // pred 节点的 waitStatus值设置为 Node.SIGNAL ,最后返回 false
        pred.compareAndSetWaitStatus(ws, Node.SIGNAL);
    }
    return false;
}

// 如果其他线程释放了获取独占同步状态 ,阻塞的线程被唤醒 ,返回线程是否被中断。
private final boolean parkAndCheckInterrupt() {
    LockSupport.park(this); //阻塞当前线程
    return Thread.interrupted();
}

释放独占同步状态

// 释放同步状态
public final boolean release(int arg) {
    if (tryRelease(arg)) { // 尝试释放同步状态,子类实现,如果能释放
        Node h = head;
        if (h != null && h.waitStatus != 0) // h.waitStatus == Node.SIGNAL
            unparkSuccessor(h); // 解锁当前线程
        return true;
    }
    return false;
}

private void unparkSuccessor(Node node) {

    int ws = node.waitStatus;
    if (ws < 0)
        node.compareAndSetWaitStatus(ws, 0); // 把 h.waitStatus 设置为 0 

    Node s = node.next; // 找到 h 的 next 节点
    if (s == null || s.waitStatus > 0) {// next 节点为 null, h.waitStatus == Node.CANCELLED 
        s = null; // s 赋值为 null
        // 从尾部开始,一直往前找一个 waitStatus <= 0 的节点 ,把当前节点赋值为 s 
        for (Node p = tail; p != node && p != null; p = p.prev) 
            if (p.waitStatus <= 0)
                s = p;
    }
    if (s != null)
        LockSupport.unpark(s.thread); // s 不为null,解锁 s 节点存储的 thread 线程
}

使用示例中AbstractQueuedSynchronizer内部状态流转图

flowchart TB
    subgraph Seven["head=node_thread_2,tail=node_thread_2(thread_2获取到独占同步状态,node_thread_1节点被移除"]
    direction LR
     H["节点node_thread_2 waitStatus=0
	    nextWaiter = Node.EXCLUSIVE "]	
    end
    
    subgraph Six["head=node_thread_1,tail=node_thread_2,(thread_1 调用unparkSuccessor解锁 thread_2)"]
    direction LR
     G1["节点 node_thread_1,waitStatus=0,
     	nextWaiter=Node.EXCLUSIVE"]	
     G2["节点node_thread_2,waitStatus=0,
     	nextWaiter=Node.EXCLUSIVE"]
     G1-->|next|G2
     G2-->|prev|G1
    end
    
    subgraph Five["head=node_thread_1,tail=node_thread_2,(thread_2 调用parkAndCheckInterrupt被阻塞)"]
    direction LR
     F1["节点 node_thread_1,waitStatus=Node.SIGNAL,
     	nextWaiter=Node.EXCLUSIVE"]	
     F2["节点node_thread_2,waitStatus=0,
     	nextWaiter=Node.EXCLUSIVE"]
     F1-->|next|F2
     F2-->|prev|F1
    end
    
    subgraph Four["head=node_thread_1,tail=node_thread_2"]
    direction LR
     D1["节点 node_thread_1,waitStatus=0,
     	nextWaiter=Node.EXCLUSIVE"]	
     D2["节点node_thread_2,waitStatus =0,
     	nextWaiter=Node.EXCLUSIVE"]
     D1-->|next|D2
     D2-->|prev|D1
    end
    
    subgraph Three["head=node_thread_1,tail=node_thread_1(thread_1获取到独占同步状态节点h被移除)"]
    direction LR
     C["节点node_thread_1 waitStatus=0
	    nextWaiter = Node.EXCLUSIVE "]	
    end
    
    subgraph Two["head = h,tail = node_thread_1"]
    direction LR
    B1["初始化队列的节点 h waitStatus=0"]
	B2["节点 node_thread_1 waitStatus=0
	   nextWaiter=Node.EXCLUSIVE "]			
     B1-->|next|B2
     B2-->|prev|B1
    end
    
    subgraph One["head = h,tail = h"]
    direction LR
    A["initializeSyncQueue(); 
      初始化队列的节点h waitStatus =0"]	
    end
    
One-->|"addWaiter(Thread_1)"|Two
-->|"acquireQueued(Thread_1)"|Three 
-->|"addWaiter(Thread_2)"|Four
-->|"acquireQueued(Thread_2)"|Five 
-->|"release(Thread_1)"|Six
-->|"acquireQueued(Thread_2)恢复运行"|Seven       

共享同步状态

ReentrantReadWriteLock 中的读锁中使用了共享同步状态

获取共享同步状态

// 读锁 lock 方法调用是 AQS acquireShared(1)
public void lock() {
    sync.acquireShared(1);
}
 // 调用 tryAcquireShared() 如果大于等于 0 说明可以进行读操作 
 // 如果小于 0 会调用 doAcquireShared() 阻塞 
public final void acquireShared(int arg) {
    if (tryAcquireShared(arg) < 0)
        doAcquireShared(arg);
}

// 关于读写锁的计算,用一个 int 类型 ,前 16 位为写锁 ,后 16 位为读锁 。
static final int SHARED_SHIFT   = 16;
static final int SHARED_UNIT    = (1 << SHARED_SHIFT);
static final int MAX_COUNT      = (1 << SHARED_SHIFT) - 1;
static final int EXCLUSIVE_MASK = (1 << SHARED_SHIFT) - 1;

// 获取读锁的次数(读锁是共享锁,同一个线程可以获取多次)
static int sharedCount(int c)    { return c >>> SHARED_SHIFT; }
// 获取写锁的个数 (写锁是独占锁)
static int exclusiveCount(int c) { return c & EXCLUSIVE_MASK; }

static final class HoldCounter {
    int count;          // 初始化为 0
    // 线程 id 
    final long tid = LockSupport.getThreadId(Thread.currentThread());
}
// 线程本地变量,当当前线程 readHolds没有存储 HoldCounter时 get 会新建一个 HoldCounter 对象  
static final class ThreadLocalHoldCounter
    extends ThreadLocal<HoldCounter> {
    public HoldCounter initialValue() { 
        return new HoldCounter();
    }
}
private transient ThreadLocalHoldCounter readHolds;

private transient HoldCounter cachedHoldCounter;

protected final int tryAcquireShared(int unused) {
    Thread current = Thread.currentThread();
    int c = getState();
    // 如果有其他线程获取了写锁,直接返回 -1 
    if (exclusiveCount(c) != 0 &&
        getExclusiveOwnerThread() != current)
        return -1;
    // 获取等待读锁的次数
    int r = sharedCount(c);
    // 不需要阻塞读锁,并且读锁次数小于读锁最大值,
    // 并且读锁数量加 1(读锁个数计算是 后 16 位,所以加的是 SHARED_UNIT) 。
    if (!readerShouldBlock() &&
        r < MAX_COUNT &&
        compareAndSetState(c, c + SHARED_UNIT)) {
        // 当之前的等待读锁的线程数为 0 ,当前线程是第一个获取读锁的  
        if (r == 0) {
            firstReader = current;  // 记录 firstReader 为当前线程
            firstReaderHoldCount = 1;  // 第一个线程获取读锁次数等于 1 
            // 当前线程是记录的第一个线程,第一个线程获取读锁次数 +1
        } else if (firstReader == current) {
            firstReaderHoldCount++;  
        } else {
            HoldCounter rh = cachedHoldCounter;
            // 当 rh 为 null 或者当前线程 id 和 rh 存储的 id 不相等 
            if (rh == null ||
                rh.tid != LockSupport.getThreadId(current))
                //  获取当前线程的  HoldCounter 
                cachedHoldCounter = rh = readHolds.get();
            // rh 的 count 为 0 时 ,把当前 cachedHoldCounter 设置为 readHolds
            else if (rh.count == 0)
                readHolds.set(rh); 
            // rh 计数器加 1    
            rh.count++;
        }
        return 1;
    }
    return fullTryAcquireShared(current);
}
// 读锁默认非公平锁 
static final class NonfairSync extends Sync {
    private static final long serialVersionUID = -8159625535654395037L;
    final boolean writerShouldBlock() {
        return false;
    }
    final boolean readerShouldBlock() {
        // 为避免写线程饥饿,当队列等待的第一个线程是写线程的时候,
        // 把读线程先阻塞让写线程先获取同步状态 
        return apparentlyFirstQueuedIsExclusive();
    }
}


// 如果第一个排队线程(如果存在)正在以独占模式等待,则返回 true。
// 如果此方法返回 true,并且当前线程正在尝试以共享模式获取
//(即,从 tryAcquireShared 调用此方法),则可以保证当前线程不是第一个排队的线程。
final boolean apparentlyFirstQueuedIsExclusive() {
    Node h, s;
    return (h = head) != null &&
        (s = h.next)  != null &&
        !s.isShared()         &&
        s.thread != null;
}

final int fullTryAcquireShared(Thread current) {
    HoldCounter rh = null;
    for (;;) {
        int c = getState();
        // 如果有其他线程获取了写锁,直接返回 -1 
        if (exclusiveCount(c) != 0) {
            if (getExclusiveOwnerThread() != current)
                return -1;
            // else 不能被阻塞在这里会导致死锁     
        } else if (readerShouldBlock()) {
            if (firstReader == current) {
                // assert firstReaderHoldCount > 0;
            } else {
                if (rh == null) {
                    rh = cachedHoldCounter;
                    if (rh == null ||
                        rh.tid != LockSupport.getThreadId(current)) {
                        rh = readHolds.get();
                        if (rh.count == 0)
                            readHolds.remove();
                    }
                }
                // rh.count == 0 返回 -1 
                if (rh.count == 0)
                    return -1;
            }
        }
        // 如果获取读锁的次数到达最大值抛出异常
        if (sharedCount(c) == MAX_COUNT)
            throw new Error("Maximum lock count exceeded");
        if (compareAndSetState(c, c + SHARED_UNIT)) {
            if (sharedCount(c) == 0) {
                firstReader = current;
                firstReaderHoldCount = 1;
            } else if (firstReader == current) {
                firstReaderHoldCount++;
            } else {
                if (rh == null)
                    rh = cachedHoldCounter;
                if (rh == null ||
                    rh.tid != LockSupport.getThreadId(current))
                    rh = readHolds.get();
                else if (rh.count == 0)
                    readHolds.set(rh);
                rh.count++;
                cachedHoldCounter = rh; // cache for release
            }
            return 1;
        }
    }
}
// 读锁阻塞 
private void doAcquireShared(int arg) {
    // 新建一个 nextWaiter 为 Node.SHARED 的节点 
    final Node node = addWaiter(Node.SHARED);
    boolean interrupted = false;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    // 大于等于 0 说明可以进行读操作  
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    return;
                }
            }
            // 阻塞
            if (shouldParkAfterFailedAcquire(p, node))
                interrupted |= parkAndCheckInterrupt();
        }
    } catch (Throwable t) {
        cancelAcquire(node);
        throw t;
    } finally {
        if (interrupted)
            selfInterrupt();
    }
}

private void setHeadAndPropagate(Node node, int propagate) {
    Node h = head; // Record old head for check below
    setHead(node);
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        Node s = node.next;
        if (s == null || s.isShared())
            doReleaseShared(); // 解锁读操作线程 
    }
}

释放共享同步状态

public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

protected final boolean tryReleaseShared(int unused) {
    Thread current = Thread.currentThread();
    if (firstReader == current) {
        // assert firstReaderHoldCount > 0;
        if (firstReaderHoldCount == 1)
            firstReader = null;
        else
            firstReaderHoldCount--;
    } else {
        HoldCounter rh = cachedHoldCounter;
        if (rh == null ||
            rh.tid != LockSupport.getThreadId(current))
            rh = readHolds.get();
        int count = rh.count;
        if (count <= 1) {
            readHolds.remove();
            if (count <= 0)
                throw unmatchedUnlockException();
        }
        --rh.count;
    }
    for (;;) {
        int c = getState();
        int nextc = c - SHARED_UNIT;
        if (compareAndSetState(c, nextc))
            // Releasing the read lock has no effect on readers,
            // but it may allow waiting writers to proceed if
            // both read and write locks are now free.
            return nextc == 0;
    }
}

private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                    continue;           
                unparkSuccessor(h); // 解锁
            }
            else if (ws == 0 &&
                     !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                continue;                
        }
        if (h == head)                   
            break;
    }
}