CyclicBarrier - 同步屏障实现分析

958 阅读4分钟

CyclicBarrier

CyclicBarrier 是可循环使用的屏障,主要功能是让一组线程到达一个屏障时被阻塞,直到最后一个线程到达屏障时,屏障才会打开;所有被屏障拦截的线程才会继续执行。

使用示例

public class CyclicBarrierTest {

    // 线程个数
    private int parties = 3;

    private AtomicInteger atomicInteger = new AtomicInteger(parties);

    private CyclicBarrier cyclicBarrier;

    class Protector implements Runnable {

        @Override
        public void run() {
            try {
                System.out.println(Thread.currentThread().getName() + " - 到达屏障前");

                TimeUnit.SECONDS.sleep(2);

                cyclicBarrier.await();

                atomicInteger.decrementAndGet();

                System.out.println(Thread.currentThread().getName() + " - 到达屏障后");

            } catch (InterruptedException e) {
                System.out.println(Thread.currentThread().getName() + " - 等待中断");
            } catch (BrokenBarrierException e) {
                System.out.println(Thread.currentThread().getName() + " - 屏障被破坏");
            }
        }
    }


    @Before
    public void init() {
        cyclicBarrier = new CyclicBarrier(parties);
    }

    @Test
    public void allAwait() {
        for (int i = 0; i < parties; i++) {
            new Thread(new Protector(), "Thread-" + i).start();
        }

        while (true) {
            if (atomicInteger.get() == 0) {
                // 所有线程到达屏障后退出结束
                System.out.println("test over");
                break;
            }
        }
    }

    @Test
    public void oneAwaitInterrupted() throws InterruptedException {
        Thread threadA = new Thread(new Protector(), "Thread-A");
        Thread threadB = new Thread(new Protector(), "Thread-B");

        threadA.start();
        threadB.start();
        // 等待 3 秒,避免是 time sleep 触发中断异常
        TimeUnit.SECONDS.sleep(3);

        threadA.interrupt();

        while (true) {
            if (atomicInteger.get() == 0) {
                System.out.println("test over");
                break;
            }
            if (cyclicBarrier.isBroken()) {
                System.out.println("屏障中断退出");
                break;
            }
        }
    }
}
Thread-A - 到达屏障前
Thread-B - 到达屏障前
屏障中断退出
Thread-A - 等待中断
Thread-B - 屏障被破坏

Thread-0 - 到达屏障前
Thread-1 - 到达屏障前
Thread-2 - 到达屏障前
Thread-2 - 到达屏障后
Thread-0 - 到达屏障后
Thread-1 - 到达屏障后
test over

从 oneAwaitInterrupted 方法执行结果可以看出,当一个线程 A 执行中断时,另外一个线程 B 会抛出 BrokenBarrierException

构造

// 可以指定拦截线程个数
public CyclicBarrier(int parties) {
    this(parties, null);
}

// 指定拦截线程个数和所有线程到达屏障处后执行的动作
public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    this.parties = parties;
    this.count = parties;
    this.barrierCommand = barrierAction;
}

实现

概念
  • barrier : 屏障
  • parties : 为屏障拦截的线程数
  • tripped : 跳闸,可以理解为打开屏障
  • generation.broken : 屏障是否破损,当屏障被打开或被重置的时候会改变值

简单的理解就是,当线程都到达屏障的时候,会打开屏障。

await()

await 说明线程到达屏障

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}
private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
    final ReentrantLock lock = this.lock;
    // 获取排他锁
    lock.lock();
    try {
        final Generation g = generation;
        // 屏障被破坏则抛异常
        if (g.broken)
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
        	// 线程中断 则退出屏障
            breakBarrier();
            throw new InterruptedException();
        }

        // 到达屏障的计数减一
        int index = --count;      
        if (index == 0) {  // tripped
        	// index == 0, 说明指定 count 的线程均到达屏障
        	// 此时可以打开屏障
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                	// 若指定了 barrierCommand 则执行
                    command.run();
                ranAction = true;
                // 唤醒阻塞在屏障的线程并重置 generation
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        for (;;) {
            try {
                if (!timed)
                	// 若未指定阻塞在屏障处的等待时间,则一直等待;直至最后一个线程到达屏障处的时候被唤醒
                    trip.await();
                else if (nanos > 0L)
                	// 若指定了阻塞在屏障处的等待时间,则在指定时间到达时会返回
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                	// 若等待过程中,线程发生了中断,则退出屏障
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            // 屏障被破坏 则抛出异常
            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
            	// g != generation 说明所有线程均到达屏障处 可直接返回
            	// 因为所有线程到达屏障处的时候,会重置 generation
            	// 参考 nextGeneration
                return index;

            if (timed && nanos <= 0L) {
            	// 说明指定时间内,还有线程未到达屏障处,也就是等待超时
            	// 退出屏障
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}
private void nextGeneration() {
    // signal completion of last generation
    // 唤醒阻塞在等待队列的线程
    trip.signalAll();
    // set up next generation
    // 重置 count
    count = parties;
    // 重置 generation
    generation = new Generation();
}
private void breakBarrier() {
	// broken 设置为 true
    generation.broken = true;
    // 重置 count
    count = parties;
    // 唤醒等待队列的线程
    trip.signalAll();
}

如下图为 CyclicBarrier 实现效果图:

isBroken()

返回屏障是否被破坏,也是是否被中断

public boolean isBroken() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        return generation.broken;
    } finally {
        lock.unlock();
    }
}
reset()
public void reset() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 唤醒阻塞的线程
        breakBarrier();   // break the current generation
        // 重新设置 generation
        nextGeneration(); // start a new generation
    } finally {
        lock.unlock();
    }
}
getNumberWaiting

获取阻塞在屏障处的线程数

public int getNumberWaiting() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 拦截线程数 - 未到达屏障数
        return parties - count;
    } finally {
        lock.unlock();
    }
}

小结

CyclicBarrier 和 CountDownLatch 功能类似,不同之处在于 CyclicBarrier 支持重复利用,而 CountDownLatch 计数只能使用一次。