Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reuse CallTag for streaming calls #414

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions grpc-sys/bindings/x86_64-unknown-linux-gnu-bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5948,6 +5948,11 @@ extern "C" {
extern "C" {
pub fn grpcwrap_slice_length(slice: *const grpc_slice) -> usize;
}
extern "C" {
pub fn grpcwrap_batch_context_take_send_message(
ctx: *mut grpcwrap_batch_context,
) -> *mut grpc_byte_buffer;
}
extern "C" {
pub fn grpcwrap_batch_context_take_recv_message(
ctx: *mut grpcwrap_batch_context,
Expand Down
10 changes: 10 additions & 0 deletions grpc-sys/grpc_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,16 @@ GPR_EXPORT size_t GPR_CALLTYPE grpcwrap_slice_length(const grpc_slice* slice) {
return GRPC_SLICE_LENGTH(*slice);
}

GPR_EXPORT grpc_byte_buffer* GPR_CALLTYPE
grpcwrap_batch_context_take_send_message(grpcwrap_batch_context* ctx) {
grpc_byte_buffer* buf = nullptr;
if (ctx->send_message) {
buf = ctx->send_message;
ctx->send_message = nullptr;
}
return buf;
}

GPR_EXPORT grpc_byte_buffer* GPR_CALLTYPE
grpcwrap_batch_context_take_recv_message(grpcwrap_batch_context* ctx) {
grpc_byte_buffer* buf = nullptr;
Expand Down
23 changes: 17 additions & 6 deletions src/call/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::channel::Channel;
use crate::codec::{DeserializeFn, SerializeFn};
use crate::error::{Error, Result};
use crate::metadata::Metadata;
use crate::task::{BatchFuture, BatchType, SpinLock};
use crate::task::{unref_raw_tag, BatchFuture, BatchType, CallTag, SpinLock};

/// Update the flag bit in res.
#[inline]
Expand Down Expand Up @@ -104,7 +104,7 @@ impl Call {
let call = channel.create_call(method, &opt)?;
let mut payload = vec![];
(method.req_ser())(req, &mut payload);
let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
let (cq_f, _) = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_unary(
call.call,
ctx,
Expand All @@ -127,7 +127,7 @@ impl Call {
mut opt: CallOption,
) -> Result<(ClientCStreamSender<Req>, ClientCStreamReceiver<Resp>)> {
let call = channel.create_call(method, &opt)?;
let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
let (cq_f, _) = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_client_streaming(
call.call,
ctx,
Expand Down Expand Up @@ -158,7 +158,7 @@ impl Call {
let call = channel.create_call(method, &opt)?;
let mut payload = vec![];
(method.req_ser())(req, &mut payload);
let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
let (cq_f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_server_streaming(
call.call,
ctx,
Expand Down Expand Up @@ -187,7 +187,7 @@ impl Call {
mut opt: CallOption,
) -> Result<(ClientDuplexSender<Req>, ClientDuplexReceiver<Resp>)> {
let call = channel.create_call(method, &opt)?;
let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
let (cq_f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_duplex_streaming(
call.call,
ctx,
Expand Down Expand Up @@ -410,8 +410,11 @@ struct ResponseStreamImpl<H, T> {
read_done: bool,
finished: bool,
resp_de: DeserializeFn<T>,
tag: *mut CallTag,
}

unsafe impl<H, T> Send for ResponseStreamImpl<H, T> {}

impl<H: ShareCallHolder, T> ResponseStreamImpl<H, T> {
fn new(call: H, resp_de: DeserializeFn<T>) -> ResponseStreamImpl<H, T> {
ResponseStreamImpl {
Expand All @@ -420,6 +423,7 @@ impl<H: ShareCallHolder, T> ResponseStreamImpl<H, T> {
read_done: false,
finished: false,
resp_de,
tag: ptr::null_mut(),
}
}

Expand Down Expand Up @@ -457,7 +461,8 @@ impl<H: ShareCallHolder, T> ResponseStreamImpl<H, T> {

// so msg_f must be either stale or not initialised yet.
self.msg_f.take();
let msg_f = self.call.call(|c| c.call.start_recv_message())?;
let tag = &mut self.tag;
let msg_f = self.call.call(|c| c.call.start_recv_message(tag))?;
self.msg_f = Some(msg_f);
if let Some(data) = bytes {
let msg = (self.resp_de)(data)?;
Expand All @@ -475,6 +480,12 @@ impl<H: ShareCallHolder, T> ResponseStreamImpl<H, T> {
}
}

impl<H, T> Drop for ResponseStreamImpl<H, T> {
fn drop(&mut self) {
unsafe { unref_raw_tag(self.tag) }
}
}

/// A receiver for server streaming call.
#[must_use = "if unused the ClientSStreamReceiver may immediately cancel the RPC"]
pub struct ClientSStreamReceiver<Resp> {
Expand Down
117 changes: 92 additions & 25 deletions src/call/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::buf::{GrpcByteBuffer, GrpcByteBufferReader};
use crate::codec::{DeserializeFn, Marshaller, SerializeFn};
use crate::error::{Error, Result};
use crate::grpc_sys::grpc_status_code::*;
use crate::task::{self, BatchFuture, BatchType, CallTag, SpinLock};
use crate::task::{self, unref_raw_tag, BatchFuture, BatchType, CallTag, SpinLock};

// By default buffers in `SinkBase` will be shrink to 4K size.
const BUF_SHRINK_SIZE: usize = 4 * 1024;
Expand Down Expand Up @@ -182,6 +182,15 @@ impl BatchContext {
}
}

pub fn take_send_message(&self) -> Option<GrpcByteBuffer> {
let ptr = unsafe { grpc_sys::grpcwrap_batch_context_take_send_message(self.ctx) };
if ptr.is_null() {
None
} else {
Some(unsafe { GrpcByteBuffer::from_raw(ptr) })
}
}

/// Get the status of the rpc call.
pub fn rpc_status(&self) -> RpcStatus {
let status = RpcStatusCode(unsafe {
Expand Down Expand Up @@ -228,20 +237,36 @@ fn box_batch_tag(tag: CallTag) -> (*mut grpcwrap_batch_context, *mut c_void) {
}

/// A helper function that runs the batch call and checks the result.
fn check_run<F>(bt: BatchType, f: F) -> BatchFuture
fn check_run<F>(bt: BatchType, f: F) -> (BatchFuture, *mut CallTag)
where
F: FnOnce(*mut grpcwrap_batch_context, *mut c_void) -> grpc_call_error,
{
let (cq_f, tag) = CallTag::batch_pair(bt);
let (batch_ptr, tag_ptr) = box_batch_tag(tag);
let code = f(batch_ptr, tag_ptr);
if code != grpc_call_error::GRPC_CALL_OK {
unsafe {
Box::from_raw(tag_ptr);
}
drop(unsafe { Box::from_raw(tag_ptr) });
panic!("create call fail: {:?}", code);
}
cq_f
(cq_f, tag_ptr as *mut CallTag)
}

fn check_run_with_tag<F>(tag: *mut CallTag, f: F) -> (BatchFuture, *mut CallTag)
where
F: FnOnce(*mut grpcwrap_batch_context, *mut c_void) -> grpc_call_error,
{
unsafe {
let cq_f = match &*tag {
CallTag::Batch(promise) => promise.cq_future(),
_ => unreachable!(),
};
let ctx = (*tag).batch_ctx().unwrap().as_ptr();
let code = f(ctx, tag as *mut c_void);
if code != grpc_call_error::GRPC_CALL_OK {
panic!("create call fail: {:?}", code);
}
(cq_f, tag)
}
}

/// A Call represents an RPC.
Expand All @@ -268,38 +293,55 @@ impl Call {
msg: &[u8],
write_flags: u32,
initial_meta: bool,
batch: &mut *mut CallTag,
) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let ptr = msg.as_ptr() as _;
let len = msg.len();
let i = if initial_meta { 1 } else { 0 };
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_send_message(
self.call,
ctx,
msg.as_ptr() as _,
msg.len(),
write_flags,
i,
tag,
)
});
let send_message = |ctx, tag| unsafe {
match *(tag as *mut CallTag) {
CallTag::Batch(ref prom) => prom.ref_batch(),
_ => unreachable!(),
}
grpc_sys::grpcwrap_call_send_message(self.call, ctx, ptr, len, write_flags, i, tag)
};

let (f, tag) = if !batch.is_null() {
check_run_with_tag(*batch, send_message)
} else {
check_run(BatchType::Finish, send_message)
};
*batch = tag;
Ok(f)
}

/// Finish the rpc call from client.
pub fn start_send_close_client(&mut self) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let f = check_run(BatchType::Finish, |_, tag| unsafe {
let (f, _) = check_run(BatchType::Finish, |_, tag| unsafe {
grpc_sys::grpcwrap_call_send_close_from_client(self.call, tag)
});
Ok(f)
}

/// Receive a message asynchronously.
pub fn start_recv_message(&mut self) -> Result<BatchFuture> {
pub fn start_recv_message(&mut self, batch: &mut *mut CallTag) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let f = check_run(BatchType::Read, |ctx, tag| unsafe {
let recv_message = |ctx, tag| unsafe {
match *(tag as *mut CallTag) {
CallTag::Batch(ref prom) => prom.ref_batch(),
_ => unreachable!(),
}
grpc_sys::grpcwrap_call_recv_message(self.call, ctx, tag)
});
};

let (f, tag) = if !batch.is_null() {
check_run_with_tag(*batch, recv_message)
} else {
check_run(BatchType::Read, recv_message)
};
*batch = tag;
Ok(f)
}

Expand All @@ -308,7 +350,7 @@ impl Call {
/// Future will finish once close is received by the server.
pub fn start_server_side(&mut self) -> Result<BatchFuture> {
let _cq_ref = self.cq.borrow()?;
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
let (f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_serverside(self.call, ctx, tag)
});
Ok(f)
Expand All @@ -327,7 +369,7 @@ impl Call {
let (payload_ptr, payload_len) = payload
.as_ref()
.map_or((ptr::null(), 0), |b| (b.as_ptr(), b.len()));
let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
let (f, _) = check_run(BatchType::Finish, |ctx, tag| unsafe {
let details_ptr = status
.details
.as_ref()
Expand Down Expand Up @@ -487,14 +529,21 @@ struct StreamingBase {
close_f: Option<BatchFuture>,
msg_f: Option<BatchFuture>,
read_done: bool,

// `tag` can be reused during the stream's lifetime.
tag: *mut CallTag,
}

// Because it carrys a `CallTag`.
unsafe impl Send for StreamingBase {}

impl StreamingBase {
fn new(close_f: Option<BatchFuture>) -> StreamingBase {
StreamingBase {
close_f,
msg_f: None,
read_done: false,
tag: ptr::null_mut(),
}
}

Expand Down Expand Up @@ -539,7 +588,7 @@ impl StreamingBase {

// so msg_f must be either stale or not initialised yet.
self.msg_f.take();
let msg_f = call.call(|c| c.call.start_recv_message())?;
let msg_f = call.call(|c| c.call.start_recv_message(&mut self.tag))?;
self.msg_f = Some(msg_f);
if bytes.is_none() {
self.poll(call, true)
Expand All @@ -557,6 +606,12 @@ impl StreamingBase {
}
}

impl Drop for StreamingBase {
fn drop(&mut self) {
unsafe { unref_raw_tag(self.tag) }
}
}

/// Flags for write operations.
#[derive(Default, Clone, Copy)]
pub struct WriteFlags {
Expand Down Expand Up @@ -603,14 +658,20 @@ struct SinkBase {
batch_f: Option<BatchFuture>,
buf: Vec<u8>,
send_metadata: bool,

tag: *mut CallTag,
}

// Because it carrys a `CallTag`.
unsafe impl Send for SinkBase {}

impl SinkBase {
fn new(send_metadata: bool) -> SinkBase {
SinkBase {
batch_f: None,
buf: Vec::new(),
send_metadata,
tag: ptr::null_mut(),
}
}

Expand All @@ -637,7 +698,7 @@ impl SinkBase {
}
let write_f = call.call(|c| {
c.call
.start_send_message(&self.buf, flags.flags, self.send_metadata)
.start_send_message(&self.buf, flags.flags, self.send_metadata, &mut self.tag)
})?;
// NOTE: Content of `self.buf` is copied into grpc internal.
if self.buf.capacity() > BUF_SHRINK_SIZE {
Expand All @@ -658,3 +719,9 @@ impl SinkBase {
Ok(Async::Ready(()))
}
}

impl Drop for SinkBase {
fn drop(&mut self) {
unsafe { unref_raw_tag(self.tag) }
}
}
5 changes: 2 additions & 3 deletions src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::thread::{Builder as ThreadBuilder, JoinHandle};
use crate::grpc_sys;

use crate::cq::{CompletionQueue, CompletionQueueHandle, EventType, WorkQueue};
use crate::task::CallTag;
use crate::task::{self, CallTag};

// event loop
fn poll_queue(tx: mpsc::Sender<CompletionQueue>) {
Expand All @@ -26,8 +26,7 @@ fn poll_queue(tx: mpsc::Sender<CompletionQueue>) {
}

let tag: Box<CallTag> = unsafe { Box::from_raw(e.tag as _) };

tag.resolve(&cq, e.success != 0);
task::resolve(tag, &cq, e.success != 0);
while let Some(work) = unsafe { cq.worker.pop_work() } {
work.finish(&cq);
}
Expand Down
Loading