Rust async FFI

3,181 阅读7分钟

前言

rust FFI(Foreign Function Interface),即允许rust同其他语言“交互”。近期在项目开发中,由于某些原因,同一个程序的部分模块是c++写的,部分模块是rust写的,rust需要调用c++接口,并且还是异步调用。看了一圈资料,都是同步调用,于是自行摸索了一下,总结了这篇文档给有需要的人。

同步调用

网上同步调用的例子很多,这里就简单提一下,不做过多的阐述,一些详细内容可以看文末参考资料。

简单的rust call c

代码位置:github.com/kkk-imm/Rus…

// 目录结构:
/*
├── README.md
└── example01
    ├── Cargo.lock
    ├── Cargo.toml
    ├── libcallee.so
    ├── callee.c
    ├── src
    │   └── main.rs
    ├── build.rs
    ├── target
    │   ├── CACHEDIR.TAG
    │   └── debug
    └── callee.o
*/ 
// example01/src/main.rs
use std::os::raw::c_int;

#[link(name="callee")]
extern "C" {
    fn sum(a: c_int, b: c_int) -> c_int;
}

fn main() {
    let k = unsafe { sum(2, 4) };
    println!("sum result is {}",k);
}
// example01/callee.c
// gcc -c -Wall -Werror -fpic callee.c
int sum(int a, int b) { return a + b; }
// build.rs
fn main() { 
    println!("cargo:rustc-link-search=native=."); 
    println!("cargo:rustc-link-lib=dylib=callee");
}

执行:(注意要带上 LD_LIBRARY_PATH)

LD_LIBRARY_PATH=. cargo run

异步FFI

这里所说的异步FFI是指,rust使用await语意,等待c接口返回结果。在我们项目开发过程中,rust模块位于上层,调用c接口,c接口内部会做一些耗时的IO操作。因此我们不能同步阻塞的去等待c接口返回结果。下面介绍两种实现方式。

自己封装future

假设我们的场景是,需要获取某个学生的信息,c代码对应的是存储引擎,rust侧则是查询模块。这里我们选用tokio作为rust async runtime。 代码位置:github.com/kkk-imm/Rus…

// src/main.rs
use std::future::Future;
use std::os::raw::{c_int, c_void};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use std::thread::sleep;
use std::time;

// Record 学生信息,id、height
#[repr(C)]
#[derive(Debug, Default, Clone, PartialEq)]
pub struct Record {
    id: c_int,
    height: c_int,
}

// 函数参数,第一个参数实际上对应的是Record 第二个参数是个closure。第一个参数实际上是closure执行时的参数。
pub type GetRecord<T> = unsafe extern "C" fn(*mut T, *mut c_void);

extern "C" {
    pub fn query(id: c_int, fnptr: GetRecord<Record>, closure: *mut c_void);
}

// hook 与GetRecord函数签名一致。我们期望c能够传递一个record 到我们的closure中,我们处理完之后,c再free record
unsafe extern "C" fn hook<F, T>(record: *mut T, closure: *mut c_void)
where
    F: FnOnce(*mut T),
{
    let closure = Box::from_raw(closure as *mut F); // from_raw,使得closure 可以释放内存
    closure(record);
}

// get_callback 返回一个函数指针
pub fn get_callback<F, T>(_closure: &F) -> GetRecord<T>
where
    F: FnOnce(*mut T),
{
    hook::<F, T>
}

struct QueryFuture {
    query_id: c_int, // 用于传给c接口,模拟作为查询参数
    state: AtomicBool, // 用于通知Future是否结束
    result: Option<Record>, // 用于返回查询结果
}

impl Future for QueryFuture {
    type Output = Record;
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Record> {
        if self.state.load(Ordering::Relaxed) {
            Poll::Ready(self.result.take().unwrap())
        } else {
            let waker = cx.waker().clone();
            let q_id = self.query_id;
            unsafe {
                let closure = Box::new(move |r: *mut Record| {
                    self.result = Some(Record {
                        id: (*r).id,
                        height: (*r).height,
                    });
                    self.state.store(true, Ordering::Relaxed);
                    println!("next to wake...");
                    waker.wake(); // 继续poll task
                    sleep(time::Duration::from_secs(2));
                    println!("wake success");
                });
                let closure_ptr = Box::leak(closure); // 这里leak的box,会在hock() 里free
                let callback = get_callback(closure_ptr);
                query(q_id, callback, closure_ptr as *mut _ as *mut c_void);
            }
            Poll::Pending
        }
    }
}

#[tokio::main]
async fn main() {
    let future = QueryFuture {
        query_id: 3,
        state: AtomicBool::new(false),
        result: None,
    };

    let out = future.await;
    println!("get record is :{:#?} {:#?}", out.id, out.height);
    sleep(time::Duration::from_secs(3));
}
/* 输出:
queryfn start
query fn end
thfn start
next to wake...
get record is :5123 34
wake success
thfn done
*/
// src/dbmanager.c
// gcc -fPIC -shared -o libdbmanager.so dbmanager.c
#include <stdio.h>
#include <pthread.h>

typedef struct Record
{
  int id;
  int height;
} Record;

typedef void (*GetRecord)(Record *rec, void *closure);

// TArgs thread args
typedef struct TArgs 
{
  void *closure;
  GetRecord fnptr;
} Targs;

void thfn(void *arg)
{
  sleep(3);
  printf("thfn start\n");
  Targs *targs = (Targs *)arg;
  Record *r = (Record *)malloc(sizeof(Record));
  r->height = 34;
  r->id = 5123;
  targs->fnptr(r, targs->closure);
  // rust侧使用拷贝的数据。record使用完之后,c侧free自己的内存,而不是交给rust去free。
  // 因为二者用的内存分配器不是同一个,会有坑。
  free(r);
  free(targs);
  printf("thfn done");
}

void query(int id, GetRecord fnptr, void *closure)
{
  printf("queryfn start\n");
  pthread_t thread1;
  Targs *targs = (Targs *)(malloc(sizeof(Targs)));
  targs->closure = closure;
  targs->fnptr = fnptr;

  // c侧就直接用pthread去创建多线程操作
  pthread_create(&thread1, NULL, (void *)&thfn, (void *)targs);
  pthread_detach(thread1);
  printf("query fn end\n");
}

这种实现方法比较麻烦,每个ffi接口都要自己封装一个future,下面提供一种简单点的设计。

基于notify的回调

代码位置: github.com/kkk-imm/Rus…

// example03/src/main.rs
use std::os::raw::{c_int, c_void};
use std::sync::Arc;
use std::thread::sleep;
use std::time;

use tokio::sync::Notify;
pub type WakeCallbackExecutor = unsafe extern "C" fn(*mut c_void);

// Record 学生信息,id、height
#[repr(C)]
#[derive(Debug, Default, Clone, PartialEq)]
pub struct Record {
    id: c_int,
    height: c_int,
}

extern "C" {
    pub fn query2(
        id: c_int,
        res_int: *mut c_int,
        res: *mut Record,
        callback_executor: WakeCallbackExecutor,
        callback_ptr: *mut c_void,
    );
}

unsafe extern "C" fn hook2<F>(closure: *mut c_void)
where
    F: FnOnce(),
{
    let closure = Box::from_raw(closure as *mut F);
    closure();
}

pub fn get_callback2<F>(_closure: &F) -> WakeCallbackExecutor
where
    F: FnOnce(),
{
    hook2::<F>
}

pub async fn async_query2(id: c_int, res_int: SendPtr<*mut c_int>, res: SendPtr<*mut Record>) {
    let notify = Arc::new(Notify::new());
    let notify2 = notify.clone();
    let closure = Box::new(move || {
        notify2.notify_one(); // c侧的callback,唤醒async_query2
    });
    let closure_ptr = Box::leak(closure);
    let callback = get_callback2(closure_ptr);
    unsafe {
        query2(
            id,
            res_int.0,
            res.0,
            callback,
            closure_ptr as *mut _ as *mut c_void,
        );
    }
    notify.notified().await;
    println!("notify success");
}

#[derive(Debug)]
#[repr(transparent)]
pub struct SendPtr<T>(T);
unsafe impl<T> Send for SendPtr<T> {}

#[tokio::main]
async fn main() {
    tokio::spawn(async {
        let mut record = Record::default();
        let mut res_int: c_int = 0;
        let send_ptr_res_int = SendPtr(&mut res_int as *mut c_int);
        let send_ptr_record = SendPtr(&mut record as *mut Record);
        async_query2(4, send_ptr_res_int, send_ptr_record).await;
        println!("res_int {:#?}", res_int); // 要保证 res_id 和 record 跨await,不然会有segment fault 的风险
        println!("record {:#?}", record);
    });
    sleep(time::Duration::from_secs(5));
}

/*
    输出:
    query2 start
    query2 fn end
    thfn2 start
    thfn2 done
    notify success
    res_int 4
    record Record {
        id: 11233,
        height: 99999,
    }
*/

// example03/src/dbmanager.c
#include <stdio.h>
#include <pthread.h>

typedef struct Record
{
  int id;
  int height;
} Record;

typedef void (*WakeCallbackExecutor)(void *closure);

typedef struct TArgs2
{
  int *res_int;
  Record *record;
  void *callback_ptr;
  WakeCallbackExecutor callback_executor;
} TArgs2;

void thfn2(void *arg)
{
  sleep(2);
  printf("thfn2 start\n");
  TArgs2 *targs = (TArgs2 *)arg;
  *targs->res_int = 4;
  targs->record->id = 11233;
  targs->record->height = 99999;
  targs->callback_executor(targs->callback_ptr);
  printf("thfn2 done\n");
}

void query2(int id, int *res_int, Record *res, WakeCallbackExecutor callback_executor, void *callback_ptr)
{
  printf("query2 start\n");
  pthread_t thread1;
  TArgs2 *targs = (TArgs2 *)(malloc(sizeof(TArgs2)));
  targs->callback_ptr = callback_ptr;
  targs->callback_executor = callback_executor;
  targs->res_int = res_int;
  targs->record = res;
  pthread_create(&thread1, NULL, (void *)&thfn2, (void *)targs);
  pthread_detach(thread1);
  printf("query2 fn end\n");
}

其他方案

channel通知

参考一些其他不同语言的异步FFI设计方案。这里简单描述一下模型。

  1. rust 生成一个future,如这个future是获取某一行数据。同时将future 自身的waker.wake()封装成一个回调函数,存在一个全局的 map中,key为future_id。
  2. 调用c提供的接口,将查询参数、future_id、结果指针等传递个接口(下图假设这个获取某一行的数据接口名称为 submit_task),future自身pending
  3. c在处理完任务之后,将future_id 发往一个结果队列。
  4. 同时rust runtime同时存在一个常驻的worker,这个worker的作用是同步阻塞调用c提供的 recv()接口,这个接口会返回完成的任务的future id。
  5. 根据拿到的future id 从全局map中获取future的回调函数。
  6. 利用回调函数唤醒future。

这么设计的一个优点是保持rust、c侧尽可能解耦,wake up操作都放在rust侧来做。

image.png

cxx-async

github.com/pcwalton/cx…

cxx是ffi 的一个比较常用的开源库,不过他并没有支持async ffi。有一个开源库实现了一套,不过是基于c++20 coroutine,我们用的c++17,最终也没采用。

C实现一套Future+Waker的abi,C里直接wake

github.com/oxalica/asy…

这是一个开源库里的实现。开源库实现了一个带泛型的future,需要在c里面手动构造Future 与rust Future abi对应。与方案一的区别本质上是把poll,wake等方法实现由rust转移到c里面。

好处是rust侧代码简洁,但是相应的坏处就是c处的代码复杂,且需要理解poll这一套rust的逻辑。下面从开源库里摘了个例子。

use anyhow::{Context as _, Result};
use async_ffi::FfiFuture;
use libloading::{Library, Symbol};
use std::time::{Duration, Instant};

// Some plugin fn.
type PluginRunFn = unsafe extern "C" fn(a: u32, b: u32) -> FfiFuture<u32>;

#[tokio::main]
async fn main() -> Result<()> {
    let path = std::env::args().nth(1).context("Missing argument")?;
    unsafe {
        let lib = Library::new(&path).context("Cannot load library")?;
        let plugin_run: Symbol<PluginRunFn> = lib.get(b"plugin_run")?;

        let t = Instant::now();
        let ret = plugin_run(42, 1).await;
        // We did some async sleep in plugin_run.
        assert!(Duration::from_millis(500) < t.elapsed());
        assert_eq!(ret, 43);
        println!("42 + 1 = {}", ret);
    }

    Ok(())
}
#include <assert.h>
#include <pthread.h>
#include <signal.h>
#include <stdatomic.h>
#include <stdint.h>
#include <stdlib.h>
#include <unistd.h>

struct FfiWakerVTable {
    struct FfiWaker const *(*clone)(struct FfiWaker const *);
    void (*wake)(struct FfiWaker const *);
    void (*wake_by_ref)(struct FfiWaker const *);
    void (*drop)(struct FfiWaker const *);
};

struct FfiWaker {
    struct FfiWakerVTable const *vtable;
};

struct FfiContext {
    struct FfiWaker const *waker_ref;
};

struct PollU32 {
    uint8_t is_pending;
    union { uint32_t value; };
};

struct FfiFutureU32 {
    void *fut;
    struct PollU32 (*poll)(void *fut, struct FfiContext *context);
    void (*drop)(void *fut);
};

struct my_data {
    uint32_t state;
    uint32_t a, b, ret;
    pthread_t handle;
    struct FfiWaker const *waker;
};

static void *handler (void *data_raw) {
    struct my_data *data = (struct my_data *)data_raw;
    usleep(500000);
    data->ret = data->a + data->b;
    atomic_store(&data->state, 2);
    (data->waker->vtable->wake)(data->waker);
}

static struct PollU32 fut_poll (void *fut, struct FfiContext *context) {
    struct my_data *data = (struct my_data *)fut;
    pthread_t handle;
    switch (atomic_load(&data->state)) {
        case 0:
            data->waker = (context->waker_ref->vtable->clone)(context->waker_ref);
            data->state = 1;
            pthread_create(&data->handle, NULL, handler, (void *)data);
        case 1:
            return (struct PollU32) { .is_pending = 1 };
        case 2:
            pthread_join(data->handle, NULL);
            data->handle = 0;
            return (struct PollU32) { .is_pending = 0, .value = data->ret };
    }
}

static void fut_drop(void *fut) {
    struct my_data *data = (struct my_data *)fut;
    if (data->handle != 0) {
        pthread_kill(data->handle, SIGKILL);
        pthread_join(data->handle, NULL);
    }
    free(data);
}

struct FfiFutureU32 plugin_run (uint32_t a, uint32_t b) {
    struct my_data *data = malloc(sizeof(struct my_data));
    data->handle = 0;
    data->state = 0;
    data->a = a;
    data->b = b;
    struct FfiFutureU32 fut = {
        .fut = (void *)data,
        .poll = fut_poll,
        .drop = fut_drop,
    };
    return fut;
}

参考资料

  1. doc.rust-lang.org/nomicon/ffi…
  2. cloud.tencent.com/developer/a…
  3. github.com/yujinliang/…