From ee55d2892175d391859bb5f7be0b5f25a7f2f47c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cholewi=C5=84ski?= Date: Mon, 21 Oct 2024 17:11:39 +0200 Subject: [PATCH 1/7] write request before fist response --- lib/eio/client/client.ml | 180 +++++++++--------- lib/eio/client/client.mli | 28 ++- lib/eio/core/recv_seq.ml | 13 +- .../io_server_h2_ocaml_protoc.ml | 1 - 4 files changed, 129 insertions(+), 93 deletions(-) diff --git a/lib/eio/client/client.ml b/lib/eio/client/client.ml index afd7ae2..2030aed 100644 --- a/lib/eio/client/client.ml +++ b/lib/eio/client/client.ml @@ -121,92 +121,102 @@ module Bidirectional_streaming = struct (_, headers, stream_error, conn_error, net_response) result' = match call ~sw ~io ~service ~method_name ~headers () with | Ok { writer; recv; grpc_status; write_exn } -> ( - match Eio.Promise.await recv with - | Ok { net_response; recv_seq; trailers } -> - let (module Io') = io in - if Io'.Net_response.is_ok net_response then ( - let error = ref None in - let closed = ref false in - let writer = - { - write = writer.write; - close = - (fun () -> - writer.close (); - closed := true); - } - in - let rec read recv_seq' () = - match recv_seq' () with - | Grpc_eio_core.Recv_seq.Done -> Seq.Nil - | Err e -> - let () = error := Some e in - Seq.Nil - | Next (t, next) -> Seq.Cons (t, fun () -> read next ()) - in + let closed = ref false in + let writer = + { + write = + (fun req -> + let result = writer.write req in - let res = f net_response ~writer ~read:(read recv_seq) in - if not !closed then writer.close (); - match !error with - | Some error -> - `Stream_result - { - result = res; - trailers = Eio.Promise.await trailers; - err = - Some + result); + close = + (fun () -> + writer.close (); + closed := true); + } + in + let error = ref None in + let res = + f + (fun () -> + match Eio.Promise.await recv with + | Ok { net_response; _ } -> Ok net_response + | Error e -> Error e) + ~writer + ~read:(fun () -> + match Eio.Promise.await recv with + | Ok { net_response; recv_seq; _ } -> + let (module Io') = io in + if Io'.Net_response.is_ok net_response then + let rec read recv_seq' () = + match recv_seq' () with + | Grpc_eio_core.Recv_seq.Done -> Seq.Nil + | Err e -> + let () = error := Some e in + Seq.Nil + | Next (t, next) -> Seq.Cons (t, fun () -> read next ()) + in + read recv_seq () + else Seq.Nil + | Error _ -> Seq.Nil) + in + + match Eio.Promise.await recv with + | Error _e -> Obj.magic () + | Ok { net_response = _; trailers; _ } -> ( + if not !closed then writer.close (); + match !error with + | Some error -> + `Stream_result + { + result = res; + trailers = Eio.Promise.await trailers; + err = + Some + { + stream_error = Some error; + grpc_status = Eio.Promise.await grpc_status; + write_exn = !write_exn; + }; + } + | None -> ( + let status = Eio.Promise.await grpc_status in + match Grpc.Status.code status with + | Grpc.Status.OK -> ( + match !write_exn with + | None -> + `Stream_result { - stream_error = Some error; - grpc_status = Eio.Promise.await grpc_status; - write_exn = !write_exn; - }; - } - | None -> ( - let status = Eio.Promise.await grpc_status in - match Grpc.Status.code status with - | Grpc.Status.OK -> ( - match !write_exn with - | None -> - `Stream_result - { - result = res; - err = None; - trailers = Eio.Promise.await trailers; - } - | Some _ -> - `Stream_result + result = res; + err = None; + trailers = Eio.Promise.await trailers; + } + | Some _ -> + `Stream_result + { + result = res; + trailers = Eio.Promise.await trailers; + err = + Some + { + write_exn = !write_exn; + grpc_status = Eio.Promise.await grpc_status; + stream_error = None; + }; + }) + | _ -> + `Stream_result + { + result = res; + trailers = Eio.Promise.await trailers; + err = + Some { - result = res; - trailers = Eio.Promise.await trailers; - err = - Some - { - write_exn = !write_exn; - grpc_status = Eio.Promise.await grpc_status; - stream_error = None; - }; - }) - | _ -> - `Stream_result - { - result = res; - trailers = Eio.Promise.await trailers; - err = - Some - { - grpc_status = status; - stream_error = None; - write_exn = !write_exn; - }; - })) - else - `Response_not_ok - { - net_response; - grpc_status = Eio.Promise.await grpc_status; - trailers = Eio.Promise.await trailers; - } - | Error e -> `Connection_error e) + grpc_status = status; + stream_error = None; + write_exn = !write_exn; + }; + }))) | Error e -> `Connection_error e end @@ -281,8 +291,8 @@ module Unary = struct trailers = Eio.Promise.await trailers; } | _ -> - (* Not reachable under normal circumstances - https://github.com/grpc/grpc/issues/12824 *) + (* Not reachable under normal circumstances + https://github.com/grpc/grpc/issues/12824 *) `Response_not_ok { net_response; diff --git a/lib/eio/client/client.mli b/lib/eio/client/client.mli index be6bf2c..e404749 100644 --- a/lib/eio/client/client.mli +++ b/lib/eio/client/client.mli @@ -167,7 +167,9 @@ module Server_streaming : sig method_name:string -> headers:Grpc_client.request_headers -> 'request -> - ('net_response -> read:(unit -> 'response Seq.node) -> 'a) -> + ((unit -> ('net_response, 'conn_err) result) -> + read:(unit -> 'response Seq.node) -> + 'a) -> [ `Stream_result of ('a, 'headers, 'stream_error) streaming_result | `Write_error of ('stream_error, 'headers) streaming_err option * 'headers | ('net_response, 'headers, 'conn_err) common_error ] @@ -177,7 +179,7 @@ module Bidirectional_streaming : sig type ('a, 'headers, 'stream_err, 'conn_err, 'net_response) result' = [ `Stream_result of ('a, 'headers, 'stream_err) streaming_result | ('net_response, 'headers, 'conn_err) common_error ] - + (* val call : sw:Eio.Switch.t -> io: @@ -191,7 +193,29 @@ module Bidirectional_streaming : sig service:string -> method_name:string -> headers:Grpc_client.request_headers -> + ?init_requests:'request Seq.t -> ('net_response -> + 'request Seq.t -> + writer:'request writer -> + read:(unit -> 'response Seq.node) -> + 'a) -> + ('a, 'headers, 'stream_error, 'conn_error, 'net_response) result' + *) + + val call : + sw:Eio.Switch.t -> + io: + ( 'headers, + 'net_response, + 'request, + 'response, + 'stream_error, + 'conn_error ) + Io.t -> + service:string -> + method_name:string -> + headers:Grpc_client.request_headers -> + ((unit -> ('net_response, 'conn_error) result) -> writer:'request writer -> read:(unit -> 'response Seq.node) -> 'a) -> diff --git a/lib/eio/core/recv_seq.ml b/lib/eio/core/recv_seq.ml index 7818c22..299db65 100644 --- a/lib/eio/core/recv_seq.ml +++ b/lib/eio/core/recv_seq.ml @@ -1,8 +1,7 @@ type ('a, 'err) t = unit -> ('a, 'err) recv_item and ('a, 'err) recv_item = Done | Next of 'a * ('a, 'err) t | Err of 'err -let rec map f recv = - fun () -> +let rec map f recv () = match recv () with | Done -> Done | Next (x, recv) -> Next (f x, map f recv) @@ -13,8 +12,12 @@ let to_seq ?err_to_exn recv = match recv () with | Done -> Seq.Nil | Next (x, recv) -> Seq.Cons (x, loop recv) - | Err err -> match err_to_exn with - | None -> failwith "Unexpected error on read. Implement err_to_exn for a more granular error." - | Some f -> raise (f err) + | Err err -> ( + match err_to_exn with + | None -> + failwith + "Unexpected error on read. Implement err_to_exn for a more \ + granular error." + | Some f -> raise (f err)) in loop recv diff --git a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml index f195ce4..09657b6 100644 --- a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml +++ b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml @@ -1,6 +1,5 @@ exception Unexpected_eof - module Io = struct type request = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer type response = Pbrt.Encoder.t -> unit From 7b4a6555decd311f8b77afb338b3b45844dbb19a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cholewi=C5=84ski?= Date: Tue, 12 Nov 2024 13:00:50 +0100 Subject: [PATCH 2/7] change API for better error handling --- .../Route_guide/index.html | 2 +- examples/routeguide/src/client.ml | 366 +++++++++++++---- lib/eio/arpaca/bin/codegen.ml | 274 ++++++++++++- lib/eio/arpaca/integration_tests/client.ml | 16 +- lib/eio/client/client.ml | 382 +++++++++--------- lib/eio/client/client.mli | 87 ++-- lib/eio/client/io.ml | 57 ++- lib/eio/client/rpc_error.ml | 70 ++++ lib/eio/core/body_reader.ml | 32 +- lib/eio/core/recv_seq.ml | 14 +- .../io_client_h2_ocaml_protoc.ml | 111 +++-- .../io_client_h2_ocaml_protoc.mli | 4 +- .../io_server_h2_ocaml_protoc.ml | 28 +- .../io_server_h2_ocaml_protoc.mli | 6 +- 14 files changed, 1035 insertions(+), 414 deletions(-) create mode 100644 lib/eio/client/rpc_error.ml diff --git a/docs/grpc-examples/Routeguide_proto_async/Route_guide/index.html b/docs/grpc-examples/Routeguide_proto_async/Route_guide/index.html index 3962263..637dcd1 100644 --- a/docs/grpc-examples/Routeguide_proto_async/Route_guide/index.html +++ b/docs/grpc-examples/Routeguide_proto_async/Route_guide/index.html @@ -1,2 +1,2 @@ -Route_guide (grpc-examples.Routeguide_proto_async.Route_guide)

Module Routeguide_proto_async.Route_guide

module Routeguide : sig ... end
\ No newline at end of file +Route_guide (grpc-examples.Routeguide_proto_async.Route_guide)

Module Routeguide_proto_async.Route_guide

module Routeguide : sig ... end
diff --git a/examples/routeguide/src/client.ml b/examples/routeguide/src/client.ml index d039077..a091f22 100644 --- a/examples/routeguide/src/client.ml +++ b/examples/routeguide/src/client.ml @@ -1,45 +1,78 @@ open Routeguide -module Client = Grpc_client_eio.Client +open Grpc_client_eio -let get_feature sw io request = +let random_point () = + let latitude = (Random.int 180 - 90) * 10000000 in + let longitude = (Random.int 360 - 180) * 10000000 in + Route_guide.default_point ~latitude ~longitude () + +let get_feature (type headers net_response stream_error connection_error) sw + (io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) request = let response = Client.Unary.call ~sw ~io ~service:"routeguide.RouteGuide" ~method_name:"GetFeature" ~headers:(Grpc_client.make_request_headers `Proto) (fun encoder -> Route_guide.encode_pb_point request encoder) in + let (module Io') = io in match response with | `Success ({ response = res; _ } as result) -> - `Success - { - result with - response = - res.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_feature; - } - | ( `Premature_close _ | `Write_error _ | `Connection_error _ - | `Response_not_ok _ ) as rest -> - rest + { + result with + response = + res.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_feature; + } + | #Rpc_error.Unary.error' as rest -> Io'.raise_client_error (Unary rest) + +let _ = get_feature + +let run_record_route (type headers net_response stream_error connection_error) + sw + (io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) = + let points = + Random.int 100 + |> Seq.unfold (function 0 -> None | x -> Some (random_point (), x - 1)) + in -(* $MDX part-end *) -(* $MDX part-begin=client-get-feature *) -let call_get_feature sw io point = let response = - Client.Unary.call ~sw ~io ~service:"routeguide.RouteGuide" - ~method_name:"GetFeature" + Client.Client_streaming.call ~io ~sw ~service:"routeguide.RouteGuide" ~headers:(Grpc_client.make_request_headers `Proto) - (fun encoder -> Route_guide.encode_pb_point point encoder) + ~method_name:"RecordRoute" (fun _ ~writer -> + Seq.iter + (fun point -> + writer.write (Route_guide.encode_pb_point point) |> ignore; + Printf.printf "SENT = {%s}\n%!" (Route_guide.show_point point)) + points) in match response with - | `Success { response = res; _ } -> - Printf.printf "RESPONSE = {%s}%!" - (Route_guide.show_feature - (res.Grpc_eio_core.Body_reader.consume Route_guide.decode_pb_feature)) - | _ -> Printf.printf "an error occurred" + | `Success resp -> resp + | #Rpc_error.Client_streaming.error' as rest -> + let (module Io') = io in + Io'.raise_client_error (Client_streaming rest) -(* $MDX part-end *) - -(* $MDX part-begin=client-list-features *) -let print_features sw io = +let print_features (type headers net_response stream_error connection_error) sw + (io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) = let rectangle = Route_guide.default_rectangle ~lo: @@ -67,43 +100,21 @@ let print_features sw io = read) in match stream with - | `Stream_result { err = None; _ } -> () - | _ -> failwith "an erra" + | `Stream_result_success result -> result + | #Rpc_error.Server_streaming.error' as rest -> + let (module Io') = io in + Io'.raise_client_error (Server_streaming rest) -(* $MDX part-end *) -(* $MDX part-begin=client-record-route *) -let random_point () = - let latitude = (Random.int 180 - 90) * 10000000 in - let longitude = (Random.int 360 - 180) * 10000000 in - Route_guide.default_point ~latitude ~longitude () - -let run_record_route sw io = - let points = - Random.int 100 - |> Seq.unfold (function 0 -> None | x -> Some (random_point (), x - 1)) - in - - let response = - Client.Client_streaming.call ~io ~sw ~service:"routeguide.RouteGuide" - ~headers:(Grpc_client.make_request_headers `Proto) - ~method_name:"RecordRoute" (fun _ ~writer -> - Seq.iter - (fun point -> - writer.write (Route_guide.encode_pb_point point) |> ignore; - Printf.printf "SENT = {%s}\n%!" (Route_guide.show_point point)) - points) - in - match response with - | `Success { response; _ } -> - Printf.printf "SUMMARY = {%s}\n%!" - (Route_guide.show_route_summary - (response.Grpc_eio_core.Body_reader.consume - Route_guide.decode_pb_route_summary)) - | _ -> failwith "Error occured" - -(* $MDX part-end *) -(* $MDX part-begin=client-route-chat-1 *) -let run_route_chat clock io sw = +let run_route_chat (type headers net_response stream_error connection_error) + clock + (io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) sw = (* Generate locations. *) let location_count = 5 in Printf.printf "Generating %i locations\n" location_count; @@ -146,13 +157,215 @@ let run_route_chat clock io sw = []) in match result with - | `Stream_result { err = None; _ } -> () - | _e -> failwith "Error" + | `Stream_result_success result -> result + | #Rpc_error.Bidirectional_streaming.error' as rest -> + let (module Io') = io in + Io'.raise_client_error (Bidirectional_streaming rest) -(* $MDX part-end *) -(* $MDX part-end *) +module Expert = struct + let run_route_chat clock io sw = + (* Generate locations. *) + let location_count = 5 in + Printf.printf "Generating %i locations\n" location_count; + let route_notes = + location_count + |> Seq.unfold (function + | 0 -> None + | x -> + Some + ( Route_guide.default_route_note + ~location:(Some (random_point ())) + ~message:(Printf.sprintf "Random Message %i" x) + (), + x - 1 )) + in + (* $MDX part-end *) + (* $MDX part-begin=client-route-chat-2 *) + let rec go ~send ~close reader notes = + match Seq.uncons notes with + | None -> () (* Signal no more notes from the server. *) + | Some (route_note, xs) -> ( + send (Route_guide.encode_pb_route_note route_note) |> ignore; + + Eio.Time.sleep clock 1.0; + + match reader () with + | Seq.Nil -> failwith "Expecting response" + | Seq.Cons (route_note, reader') -> + Printf.printf "NOTE = {%s}\n%!" + (Route_guide.show_route_note + (route_note.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_route_note)); + go ~send ~close reader' xs) + in + Client.Bidirectional_streaming.call ~service:"routeguide.RouteGuide" + ~method_name:"RouteChat" ~io ~sw + ~headers:(Grpc_client.make_request_headers `Proto) (fun _ ~writer ~read -> + go ~send:writer.write ~close:writer.close read route_notes; + []) + + let print_features sw io = + let rectangle = + Route_guide.default_rectangle + ~lo: + (Some + (Route_guide.default_point ~latitude:400000000 + ~longitude:(-750000000) ())) + ~hi: + (Some + (Route_guide.default_point ~latitude:420000000 + ~longitude:(-730000000) ())) + () + in + + Client.Server_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"ListFeatures" + ~headers:(Grpc_client.make_request_headers `Proto) + (Route_guide.encode_pb_rectangle rectangle) (fun _ ~read -> + Seq.iter + (fun f -> + Printf.printf "RESPONSE = {%s}%!" + (Route_guide.show_feature + (f.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_feature))) + read) + + let get_feature sw io request = + Client.Unary.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"GetFeature" + ~headers:(Grpc_client.make_request_headers `Proto) (fun encoder -> + Route_guide.encode_pb_point request encoder) + + let _ = get_feature + + let run_record_route sw io = + let points = + Random.int 100 + |> Seq.unfold (function 0 -> None | x -> Some (random_point (), x - 1)) + in + + Client.Client_streaming.call ~io ~sw ~service:"routeguide.RouteGuide" + ~headers:(Grpc_client.make_request_headers `Proto) + ~method_name:"RecordRoute" (fun _ ~writer -> + Seq.iter + (fun point -> + writer.write (Route_guide.encode_pb_point point) |> ignore; + Printf.printf "SENT = {%s}\n%!" (Route_guide.show_point point)) + points) +end + +module Result = struct + let run_route_chat clock io sw = + (* Generate locations. *) + let location_count = 5 in + Printf.printf "Generating %i locations\n" location_count; + let route_notes = + location_count + |> Seq.unfold (function + | 0 -> None + | x -> + Some + ( Route_guide.default_route_note + ~location:(Some (random_point ())) + ~message:(Printf.sprintf "Random Message %i" x) + (), + x - 1 )) + in + (* $MDX part-end *) + (* $MDX part-begin=client-route-chat-2 *) + let rec go ~send ~close reader notes = + match Seq.uncons notes with + | None -> () (* Signal no more notes from the server. *) + | Some (route_note, xs) -> ( + send (Route_guide.encode_pb_route_note route_note) |> ignore; + + Eio.Time.sleep clock 1.0; + + match reader () with + | Seq.Nil -> failwith "Expecting response" + | Seq.Cons (route_note, reader') -> + Printf.printf "NOTE = {%s}\n%!" + (Route_guide.show_route_note + (route_note.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_route_note)); + go ~send ~close reader' xs) + in + let result = + Client.Bidirectional_streaming.call ~service:"routeguide.RouteGuide" + ~method_name:"RouteChat" ~io ~sw + ~headers:(Grpc_client.make_request_headers `Proto) + (fun _ ~writer ~read -> + go ~send:writer.write ~close:writer.close read route_notes; + []) + in + match result with + | `Stream_result_success result -> Ok result + | #Rpc_error.Bidirectional_streaming.error' as rest -> Error rest + + let print_features sw io = + let rectangle = + Route_guide.default_rectangle + ~lo: + (Some + (Route_guide.default_point ~latitude:400000000 + ~longitude:(-750000000) ())) + ~hi: + (Some + (Route_guide.default_point ~latitude:420000000 + ~longitude:(-730000000) ())) + () + in + + let stream = + Client.Server_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"ListFeatures" + ~headers:(Grpc_client.make_request_headers `Proto) + (Route_guide.encode_pb_rectangle rectangle) (fun _ ~read -> + Seq.iter + (fun f -> + Printf.printf "RESPONSE = {%s}%!" + (Route_guide.show_feature + (f.Grpc_eio_core.Body_reader.consume + Route_guide.decode_pb_feature))) + read) + in + match stream with + | `Stream_result_success result -> Ok result + | #Rpc_error.Server_streaming.error' as rest -> Error rest + + let run_record_route sw io = + let points = + Random.int 100 + |> Seq.unfold (function 0 -> None | x -> Some (random_point (), x - 1)) + in + + let response = + Client.Client_streaming.call ~io ~sw ~service:"routeguide.RouteGuide" + ~headers:(Grpc_client.make_request_headers `Proto) + ~method_name:"RecordRoute" (fun _ ~writer -> + Seq.iter + (fun point -> + writer.write (Route_guide.encode_pb_point point) |> ignore; + Printf.printf "SENT = {%s}\n%!" (Route_guide.show_point point)) + points) + in + match response with + | `Success resp -> Ok resp + | #Rpc_error.Client_streaming.error' as rest -> Error rest + + let get_feature sw io request = + let response = + Client.Unary.call ~sw ~io ~service:"routeguide.RouteGuide" + ~method_name:"GetFeature" + ~headers:(Grpc_client.make_request_headers `Proto) (fun encoder -> + Route_guide.encode_pb_point request encoder) + in + match response with + | `Success resp -> Ok resp + | #Rpc_error.Unary.error' as rest -> Error rest -(* $MDX part-begin=client-main *) + let _ = get_feature +end let main env = let clock = Eio.Stdenv.clock env in @@ -170,17 +383,24 @@ let main env = let request = Route_guide.default_point ~latitude:409146138 ~longitude:(-746188906) () in - let result = call_get_feature sw io request in + get_feature sw io request |> ignore; + Expert.get_feature sw io request |> ignore; + Result.get_feature sw io request |> ignore; Printf.printf "\n*** SERVER STREAMING ***\n%!"; - print_features sw io; + print_features sw io |> ignore; + Expert.print_features sw io |> ignore; + Result.print_features sw io |> ignore; Printf.printf "\n*** CLIENT STREAMING ***\n%!"; - run_record_route sw io; + run_record_route sw io |> ignore; + Expert.run_record_route |> ignore; + Result.run_record_route |> ignore; Printf.printf "\n*** BIDIRECTIONAL STREAMING ***\n%!"; - run_route_chat clock io sw; - result + run_route_chat clock io sw |> ignore; + Expert.run_route_chat clock io sw |> ignore; + Result.run_route_chat clock io sw |> ignore in Eio.Switch.run run @@ -189,7 +409,7 @@ let () = Eio_main.run main (* $MDX part-end *) -let list_features ~sw ~io request handler = +let _list_features ~sw ~io request handler = Client.Server_streaming.call ~sw ~io ~service:"routeguide.RouteGuide" ~method_name:"ListFeatures" ~headers:(Grpc_client.make_request_headers `Proto) diff --git a/lib/eio/arpaca/bin/codegen.ml b/lib/eio/arpaca/bin/codegen.ml index 1d2a3fc..14e9bd3 100644 --- a/lib/eio/arpaca/bin/codegen.ml +++ b/lib/eio/arpaca/bin/codegen.ml @@ -59,7 +59,257 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit = let typ_mod_name = String.capitalize_ascii proto_gen_module in let service_name = service.service_name in - let gen_rpc sc i (rpc : Ot.rpc) = + let gen_exn_rpc sc i (rpc : Ot.rpc) = + if i > 0 then F.empty_line sc; + let rpc_name = rpc.rpc_name in + match rpc_kind rpc.rpc_req rpc.rpc_res with + | `Unary -> + F.linep sc + {|let %s (type headers net_response stream_error connection_error) ~sw ~(io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) request = + let response = + Grpc_client_eio.Client.Unary.call ~sw ~io ~service:"%s.%s" + ~method_name:%S + ~headers:(Grpc_client.make_request_headers `Proto) + (%s.%s request) + in + let (module Io') = io in + match response with + | `Success ({ response = res; _ } as result) -> + { + result with + response = + res.Grpc_eio_core.Body_reader.consume %s.%s; + } + | #Grpc_client_eio.Rpc_error.Unary.error' as rest -> Io'.raise_client_error (Unary rest)|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Server_streaming -> + F.linep sc + {|let %s (type headers net_response stream_error connection_error) ~sw ~(io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) request handler = + let stream = + Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (%s.%s request) (fun net_response ~read -> + let responses = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + %s.%s) + read + in + let (module Io') = io in + handler net_response responses) + in + let (module Io') = io in + match stream with + | `Stream_result_success result -> result + | #Grpc_client_eio.Rpc_error.Server_streaming.error' as rest -> Io'.raise_client_error (Server_streaming rest) +|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Client_streaming -> + F.linep sc + {|let %s (type headers net_response stream_error connection_error) ~sw ~(io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) handler = + let response = + Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer -> + let writer' req = writer.write (%s.%s req) in + handler net_response ~writer:writer') + in + let (module Io') = io in + match response with + | `Success ({ response = res; _ } as result) -> + { + result with + response = + res.Grpc_eio_core.Body_reader.consume + %s.%s; + } + | #Grpc_client_eio.Rpc_error.Client_streaming.error' as rest -> Io'.raise_client_error (Client_streaming rest)|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Bidirectional_streaming -> + F.linep sc + {|let %s (type headers net_response stream_error connection_error) ~sw ~(io : + ( headers, + net_response, + Pbrt.Encoder.t -> unit, + Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, + stream_error, + connection_error ) + Grpc_client_eio.Io.t) handler = + let stream = + Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer ~read -> + let writer' req = writer.write (%s.%s req) in + let read' = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + %s.%s) + read + in + handler net_response ~writer:writer' ~read:read') + in + let (module Io') = io in + match stream with + | `Stream_result_success result -> result + | #Grpc_client_eio.Rpc_error.Bidirectional_streaming.error' as rest -> Io'.raise_client_error (Bidirectional_streaming rest)|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + in + let gen_result_rpc sc i (rpc : Ot.rpc) = + if i > 0 then F.empty_line sc; + let rpc_name = rpc.rpc_name in + match rpc_kind rpc.rpc_req rpc.rpc_res with + | `Unary -> + F.linep sc + {|let %s ~sw ~io request = + let response = + Grpc_client_eio.Client.Unary.call ~sw ~io ~service:"%s.%s" + ~method_name:%S + ~headers:(Grpc_client.make_request_headers `Proto) + (%s.%s request) + in + match response with + | `Success ({ response = res; _ } as result) -> + Ok + { + result with + response = + res.Grpc_eio_core.Body_reader.consume %s.%s; + } + | #Grpc_client_eio.Rpc_error.Unary.error' as rest -> Error rest|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Server_streaming -> + F.linep sc + {|let %s ~sw ~io request handler = + let stream = + Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (%s.%s request) (fun net_response ~read -> + let responses = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + %s.%s) + read + in + handler net_response responses) + in + match stream with + | `Stream_result_success result -> Ok result + | #Grpc_client_eio.Rpc_error.Server_streaming.error' as rest -> Error rest|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Client_streaming -> + F.linep sc + {|let %s ~sw ~io handler = + let response = + Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer -> + let writer' req = writer.write (%s.%s req) in + handler net_response ~writer:writer') + in + match response with + | `Success ({ response = res; _ } as result) -> + Ok + { + result with + response = + res.Grpc_eio_core.Body_reader.consume + %s.%s; + } + | #Grpc_client_eio.Rpc_error.Client_streaming.error' as rest -> Error rest|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + | `Bidirectional_streaming -> + F.linep sc + {|let %s ~sw ~io handler = + let stream = + Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s.%s" + ~method_name:"%s" + ~headers:(Grpc_client.make_request_headers `Proto) + (fun net_response ~writer ~read -> + let writer' req = writer.write (%s.%s req) in + let read' = + Seq.map + (fun response -> + response.Grpc_eio_core.Body_reader.consume + %s.%s) + read + in + handler net_response ~writer:writer' ~read:read') + in + match stream with + | `Stream_result_success result -> Ok result + | #Grpc_client_eio.Rpc_error.Bidirectional_streaming.error' as rest -> Error rest|} + (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) + (service_name_of_package service.service_packages) + service.service_name rpc.rpc_name typ_mod_name + (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) + typ_mod_name + (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) + in + let gen_expert_rpc sc i (rpc : Ot.rpc) = if i > 0 then F.empty_line sc; let rpc_name = rpc.rpc_name in match rpc_kind rpc.rpc_req rpc.rpc_res with @@ -80,8 +330,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit response = res.Grpc_eio_core.Body_reader.consume %s.%s; } - | ( `Premature_close _ | `Write_error _ | `Connection_error _ - | `Response_not_ok _ ) as rest -> + | #Grpc_client_eio.Rpc_error.Unary.error' as rest -> rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) (service_name_of_package service.service_packages) @@ -130,8 +379,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit res.Grpc_eio_core.Body_reader.consume %s.%s; } - | ( `Premature_close _ | `Stream_error _ | `Connection_error _ - | `Response_not_ok _ ) as rest -> + | #Grpc_client_eio.Rpc_error.Client_streaming.error' as rest -> rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) (service_name_of_package service.service_packages) @@ -162,7 +410,17 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) in - List.iteri (gen_rpc sc) service.service_body + List.iteri (gen_exn_rpc sc) service.service_body; + F.empty_line sc; + F.linep sc "module Result = struct"; + F.empty_line sc; + List.iteri (gen_result_rpc sc) service.service_body; + F.linep sc "end"; + F.empty_line sc; + F.linep sc "module Expert = struct"; + F.empty_line sc; + List.iteri (gen_expert_rpc sc) service.service_body; + F.linep sc "end" let gen_service_server_struct ~proto_gen_module (service : Ot.service) top_scope : unit = @@ -172,7 +430,9 @@ let gen_service_server_struct ~proto_gen_module (service : Ot.service) top_scope let name = Pb_codegen_util.function_name_of_rpc rpc in F.linep sc "val %s :" (to_snake_case name); - F.linep sc " Eio.Net.Sockaddr.stream * H2.Reqd.t * H2.Request.t ->"; + F.linep sc + " Eio.Net.Sockaddr.stream * H2.Reqd.t * H2.Request.t * H2.Reqd.error \ + Eio.Promise.t ->"; let req_type = Printf.sprintf "%s.%s" typ_mod_name (ocaml_type_of_rpc_type rpc.rpc_req) in diff --git a/lib/eio/arpaca/integration_tests/client.ml b/lib/eio/arpaca/integration_tests/client.ml index 75a0362..a5b2018 100644 --- a/lib/eio/arpaca/integration_tests/client.ml +++ b/lib/eio/arpaca/integration_tests/client.ml @@ -13,15 +13,13 @@ let print_features sw io = in let stream = - RouteGuide_client.list_features ~sw ~io rectangle (fun _ read -> + RouteGuide_client.Expert.list_features ~sw ~io rectangle (fun _ read -> Seq.iter (fun f -> Printf.printf "RESPONSE = {%s}%!" (Route_guide.show_feature f)) read) in - match stream with - | `Stream_result { err = None; _ } -> () - | _ -> failwith "an erra" + match stream with `Stream_result_success _ -> () | _ -> failwith "an erra" let random_point () = let latitude = (Random.int 180 - 90) * 10000000 in @@ -35,7 +33,7 @@ let run_record_route sw io = in let response = - RouteGuide_client.record_route ~io ~sw (fun _ ~writer -> + RouteGuide_client.Expert.record_route ~io ~sw (fun _ ~writer -> Seq.iter (fun point -> writer point |> ignore; @@ -82,13 +80,11 @@ let run_route_chat clock io sw = go ~send reader' xs) in let result = - RouteGuide_client.route_chat ~io ~sw (fun _ ~writer ~read -> + RouteGuide_client.Expert.route_chat ~io ~sw (fun _ ~writer ~read -> go ~send:writer read route_notes; []) in - match result with - | `Stream_result { err = None; _ } -> () - | _e -> failwith "Error" + match result with `Stream_result_success _ -> () | _e -> failwith "Error" let main env = let clock = Eio.Stdenv.clock env in @@ -104,7 +100,7 @@ let main env = Printf.printf "*** SIMPLE RPC ***\n%!"; let result = - RouteGuide_client.get_feature ~sw ~io + RouteGuide_client.Expert.get_feature ~sw ~io (Route_guide.default_point ~latitude:409146138 ~longitude:(-746188906) ()) in diff --git a/lib/eio/client/client.ml b/lib/eio/client/client.ml index 2030aed..644b13d 100644 --- a/lib/eio/client/client.ml +++ b/lib/eio/client/client.ml @@ -10,12 +10,6 @@ type 'request writer = { close : unit -> unit; } -type ('net_response, 'headers) resp_not_ok = { - net_response : 'net_response; - grpc_status : Grpc.Status.t; - trailers : 'headers; -} - type ('net_response, 'headers, 'request, @@ -31,12 +25,9 @@ type ('net_response, Eio.Promise.t; grpc_status : Grpc.Status.t Eio.Promise.t; write_exn : exn option ref; + conn_err : 'conn_error Eio.Promise.t; } -type ('net_response, 'headers, 'conn_err) common_error = - [ `Connection_error of 'conn_err - | `Response_not_ok of ('net_response, 'headers) resp_not_ok ] - let call (type headers net_response request response stream_error conn_error) ~sw ~(io : @@ -53,60 +44,55 @@ let call (type headers net_response request response stream_error conn_error) result = let (module Io') = io in let path = Grpc_client.make_path ~service ~method_name in - match Io'.send_request ~headers path with - | Error conn_error -> Error conn_error - | Ok (writer', recv_net) -> - let write_exn = ref None in - let writer = - { - write = - (fun req -> - try - writer'.write req; - true - with exn -> - write_exn := Some exn; - false); - close = writer'.close; - } - in - let status, status_notify = Eio.Promise.create () in - let recv, recv_notify = Eio.Promise.create () in - let () = - Eio.Fiber.fork ~sw (fun () -> - Eio.Promise.resolve recv_notify - (match Eio.Promise.await recv_net with - | Error conn_error -> + let writer', recv_net, conn_err_p = Io'.send_request ~headers path in + let write_exn = ref None in + let writer = + { + write = + (fun req -> + try + writer'.write req; + true + with exn -> + write_exn := Some exn; + false); + close = writer'.close; + } + in + let status, status_notify = Eio.Promise.create () in + let recv, recv_notify = Eio.Promise.create () in + let () = + Eio.Fiber.fork_daemon ~sw (fun () -> + Eio.Promise.resolve recv_notify + (match Eio.Promise.await recv_net with + | Error conn_error -> + Eio.Promise.resolve status_notify + (Grpc.Status.make ~error_message:"Connection error" + Grpc.Status.Unknown); + Error conn_error + | Ok { response; next; trailers } -> + Eio.Fiber.fork_daemon ~sw (fun () -> Eio.Promise.resolve status_notify - (Grpc.Status.make ~error_message:"Connection error" - Grpc.Status.Unknown); - Error conn_error - | Ok { response; next; trailers } -> - Eio.Fiber.fork ~sw (fun () -> - Eio.Promise.resolve status_notify - (Grpc_client.status_of_trailers - ~get_header: - (Io'.Headers.get (Eio.Promise.await trailers)))); - Ok { net_response = response; recv_seq = next; trailers })) - in - Ok { writer; recv; grpc_status = status; write_exn } + (Grpc_client.status_of_trailers + ~get_header: + (Io'.Headers.get (Eio.Promise.await trailers))); + `Stop_daemon); + Ok { net_response = response; recv_seq = next; trailers }); + `Stop_daemon) + in + Ok { writer; recv; grpc_status = status; write_exn; conn_err = conn_err_p } -type ('stream_err, 'headers) streaming_err = { - stream_error : 'stream_err option; - write_exn : exn option; - grpc_status : Grpc.Status.t; -} - -type ('a, 'headers, 'stream_err) streaming_result = { +type ('a, 'headers) streaming_result_success = { result : 'a; trailers : 'headers; - err : ('stream_err, 'headers) streaming_err option; } module Bidirectional_streaming = struct type ('a, 'headers, 'stream_err, 'conn_err, 'net_response) result' = - [ `Stream_result of ('a, 'headers, 'stream_err) streaming_result - | ('net_response, 'headers, 'conn_err) common_error ] + [ `Stream_result_success of ('a, 'headers) streaming_result_success + | `Stream_result_error of + ('a, 'headers, 'stream_err) Rpc_error.streaming_result_err + | ('net_response, 'headers, 'conn_err) Rpc_error.common_error ] let call (type headers net_response request response stream_error conn_error) ~sw @@ -120,14 +106,13 @@ module Bidirectional_streaming = struct Io.t) ~service ~method_name ~headers f : (_, headers, stream_error, conn_error, net_response) result' = match call ~sw ~io ~service ~method_name ~headers () with - | Ok { writer; recv; grpc_status; write_exn } -> ( + | Ok { writer; recv; grpc_status; write_exn; conn_err = conn_err_p } -> ( let closed = ref false in let writer = { write = (fun req -> let result = writer.write req in - result); close = (fun () -> @@ -137,97 +122,98 @@ module Bidirectional_streaming = struct in let error = ref None in let res = - f + Eio.Fiber.first (fun () -> - match Eio.Promise.await recv with - | Ok { net_response; _ } -> Ok net_response - | Error e -> Error e) - ~writer - ~read:(fun () -> - match Eio.Promise.await recv with - | Ok { net_response; recv_seq; _ } -> - let (module Io') = io in - if Io'.Net_response.is_ok net_response then - let rec read recv_seq' () = - match recv_seq' () with - | Grpc_eio_core.Recv_seq.Done -> Seq.Nil - | Err e -> - let () = error := Some e in - Seq.Nil - | Next (t, next) -> Seq.Cons (t, fun () -> read next ()) - in - read recv_seq () - else Seq.Nil - | Error _ -> Seq.Nil) + Ok + (f + (fun () -> + match Eio.Promise.await recv with + | Ok { net_response; _ } -> Ok net_response + | Error e -> Error e) + ~writer + ~read:(fun () -> + match Eio.Promise.await recv with + | Ok { net_response; recv_seq; _ } -> + let (module Io') = io in + if Io'.Net_response.is_ok net_response then + let rec read recv_seq' () = + match recv_seq' () with + | Grpc_eio_core.Recv_seq.Done -> Seq.Nil + | Err e -> + let () = error := Some e in + Seq.Nil + | Next (t, next) -> + Seq.Cons (t, fun () -> read next ()) + in + read recv_seq () + else Seq.Nil + | Error _ -> Seq.Nil))) + (fun () -> + let erra = Eio.Promise.await conn_err_p in + Error erra) in - match Eio.Promise.await recv with - | Error _e -> Obj.magic () - | Ok { net_response = _; trailers; _ } -> ( - if not !closed then writer.close (); - match !error with - | Some error -> - `Stream_result - { - result = res; - trailers = Eio.Promise.await trailers; - err = - Some - { - stream_error = Some error; - grpc_status = Eio.Promise.await grpc_status; - write_exn = !write_exn; - }; - } - | None -> ( - let status = Eio.Promise.await grpc_status in - match Grpc.Status.code status with - | Grpc.Status.OK -> ( - match !write_exn with - | None -> - `Stream_result + match res with + | Error erra -> `Connection_error erra + | Ok res -> ( + (* TODO: change await to peek to avoid deadlocking in case we never got response *) + match Eio.Promise.peek recv with + | None -> (* TODO: return something like `Cancelled *) Obj.magic () + | Some (Error e) -> `Connection_error e + | Some (Ok { net_response = _; trailers; _ }) -> ( + if not !closed then writer.close (); + match !error with + | Some error -> + `Stream_result_error + { + result = res; + trailers = Eio.Promise.await trailers; + err = { - result = res; - err = None; - trailers = Eio.Promise.await trailers; - } - | Some _ -> - `Stream_result + stream_error = Some error; + grpc_status = Eio.Promise.await grpc_status; + write_exn = !write_exn; + }; + } + | None -> ( + let status = Eio.Promise.await grpc_status in + match Grpc.Status.code status with + | Grpc.Status.OK -> ( + match !write_exn with + | None -> + `Stream_result_success + { + result = res; + trailers = Eio.Promise.await trailers; + } + | Some _ -> + `Stream_result_error + { + result = res; + trailers = Eio.Promise.await trailers; + err = + { + write_exn = !write_exn; + grpc_status = Eio.Promise.await grpc_status; + stream_error = None; + }; + }) + | _ -> + `Stream_result_error { result = res; trailers = Eio.Promise.await trailers; err = - Some - { - write_exn = !write_exn; - grpc_status = Eio.Promise.await grpc_status; - stream_error = None; - }; - }) - | _ -> - `Stream_result - { - result = res; - trailers = Eio.Promise.await trailers; - err = - Some - { - grpc_status = status; - stream_error = None; - write_exn = !write_exn; - }; - }))) + { + grpc_status = status; + stream_error = None; + write_exn = !write_exn; + }; + })))) | Error e -> `Connection_error e end module Unary = struct - type ('net_response, 'headers, 'stream_err) premature_close = { - trailers : 'headers; - grpc_status : Grpc.Status.t; - net_response : 'net_response; - stream_error : 'stream_err option; - } - type ('net_response, 'response, 'headers) success = { net_response : 'net_response; response : 'response; @@ -236,8 +222,9 @@ module Unary = struct type ('response, 'headers, 'stream_err, 'conn_err, 'net_response) result' = [ `Success of ('net_response, 'response, 'headers) success - | `Premature_close of ('net_response, 'headers, 'stream_err) premature_close - | `Response_not_ok of ('net_response, 'headers) resp_not_ok + | `Premature_close of + ('net_response, 'headers, 'stream_err) Rpc_error.Unary.premature_close + | `Response_not_ok of ('net_response, 'headers) Rpc_error.resp_not_ok | `Connection_error of 'conn_err | `Write_error of exn ] @@ -253,7 +240,7 @@ module Unary = struct Io.t) ~service ~method_name ~headers request : (_, headers, stream_error, conn_error, net_response) result' = match call ~sw ~io ~service ~method_name ~headers () with - | Ok { writer; recv; grpc_status; write_exn } -> ( + | Ok { writer; recv; grpc_status; write_exn; _ } -> ( try if not (writer.write request) then `Write_error (Option.get !write_exn) @@ -312,14 +299,6 @@ module Unary = struct end module Client_streaming = struct - type ('a, 'headers, 'stream_err) stream_err = { - trailers : 'headers; - grpc_status : Grpc.Status.t; - result : 'a; - stream_error : 'stream_err; - write_exn : exn option; - } - type ('a, 'response, 'headers) success = { result : 'a; response : 'response; @@ -327,18 +306,13 @@ module Client_streaming = struct write_exn : exn option; } - type ('a, 'headers) premature_close = { - result : 'a; - trailers : 'headers; - grpc_status : Grpc.Status.t; - write_exn : exn option; - } - type ('a, 'headers, 'stream_err, 'conn_err, 'net_response, 'response) result' = [ `Success of ('a, 'response, 'headers) success - | `Premature_close of ('a, 'headers) premature_close - | `Stream_error of ('a, 'headers, 'stream_err) stream_err - | ('net_response, 'headers, 'conn_err) common_error ] + | `Premature_close of + ('a, 'headers) Rpc_error.Client_streaming.premature_close + | `Stream_error of + ('a, 'headers, 'stream_err) Rpc_error.Client_streaming.stream_err + | ('net_response, 'headers, 'conn_err) Rpc_error.common_error ] let call (type headers net_response request response stream_error conn_error) ~sw @@ -352,7 +326,7 @@ module Client_streaming = struct Io.t) ~service ~method_name ~headers f : (_, headers, stream_error, conn_error, net_response, response) result' = match call ~sw ~io ~service ~method_name ~headers () with - | Ok { writer; recv; grpc_status; write_exn } -> ( + | Ok { writer; recv; grpc_status; write_exn; conn_err = conn_err_p } -> ( match Eio.Promise.await recv with | Error e -> `Connection_error e | Ok { net_response; recv_seq; trailers } -> @@ -369,45 +343,54 @@ module Client_streaming = struct } in - let res = f net_response ~writer in + let res = + Eio.Fiber.first + (fun () -> Ok (f net_response ~writer)) + (fun () -> + let erra = Eio.Promise.await conn_err_p in + Error erra) + in if not !closed then writer.close (); - match recv_seq () with - | Grpc_eio_core.Recv_seq.Done -> - `Premature_close - { - result = res; - trailers = Eio.Promise.await trailers; - grpc_status = Eio.Promise.await grpc_status; - write_exn = !write_exn; - } - | Err e -> - `Stream_error - { - result = res; - stream_error = e; - trailers = Eio.Promise.await trailers; - grpc_status = Eio.Promise.await grpc_status; - write_exn = !write_exn; - } - | Next (t, _) -> ( - let status = Eio.Promise.await grpc_status in - match Grpc.Status.code status with - | OK -> - `Success + match res with + | Error erra -> `Connection_error erra + | Ok res -> ( + match recv_seq () with + | Grpc_eio_core.Recv_seq.Done -> + `Premature_close { result = res; - response = t; trailers = Eio.Promise.await trailers; + grpc_status = Eio.Promise.await grpc_status; write_exn = !write_exn; } - | _ -> - `Response_not_ok + | Err e -> + `Stream_error { - net_response; - grpc_status = status; + result = res; + stream_error = e; trailers = Eio.Promise.await trailers; - })) + grpc_status = Eio.Promise.await grpc_status; + write_exn = !write_exn; + } + | Next (t, _) -> ( + let status = Eio.Promise.await grpc_status in + match Grpc.Status.code status with + | OK -> + `Success + { + result = res; + response = t; + trailers = Eio.Promise.await trailers; + write_exn = !write_exn; + } + | _ -> + `Response_not_ok + { + net_response; + grpc_status = status; + trailers = Eio.Promise.await trailers; + }))) else `Response_not_ok { @@ -419,6 +402,14 @@ module Client_streaming = struct end module Server_streaming = struct + type ('a, 'headers, 'stream_error, 'net_response, 'conn_err) result' = + [ `Stream_result_success of ('a, 'headers) streaming_result_success + | `Stream_result_error of + ('a, 'headers, 'stream_error) Rpc_error.streaming_result_err + | `Write_error of + ('stream_error, 'headers) Rpc_error.streaming_err option * 'headers + | ('net_response, 'headers, 'conn_err) Rpc_error.common_error ] + let call ~sw ~io ~service ~method_name ~headers request f = let result = Bidirectional_streaming.call ~sw ~io ~service ~method_name ~headers @@ -430,9 +421,14 @@ module Server_streaming = struct in let module Bs = Bidirectional_streaming in match result with - | (`Connection_error _ | `Response_not_ok _) as e -> e - | `Stream_result { result; err; trailers } -> ( + | #Rpc_error.common_error as common_err -> common_err + | `Stream_result_success { result; trailers } -> ( + match result with + | `Write_error -> `Write_error (None, trailers) + | `Stream res -> `Stream_result_success { result = res; trailers }) + | `Stream_result_error { result; err; trailers } -> ( match result with - | `Write_error -> `Write_error (err, trailers) - | `Stream res -> `Stream_result { result = res; err; trailers }) + | `Write_error -> `Write_error (Some err, trailers) + | `Stream res -> + `Stream_result_error { Rpc_error.result = res; err; trailers }) end diff --git a/lib/eio/client/client.mli b/lib/eio/client/client.mli index e404749..8bdbf42 100644 --- a/lib/eio/client/client.mli +++ b/lib/eio/client/client.mli @@ -21,18 +21,9 @@ type ('net_response, Eio.Promise.t; grpc_status : Grpc.Status.t Eio.Promise.t; write_exn : exn option ref; + conn_err : 'conn_error Eio.Promise.t; } -type ('net_response, 'headers) resp_not_ok = { - net_response : 'net_response; - grpc_status : Grpc.Status.t; - trailers : 'headers; -} - -type ('net_response, 'headers, 'conn_err) common_error = - [ `Connection_error of 'conn_err - | `Response_not_ok of ('net_response, 'headers) resp_not_ok ] - val call : sw:Eio.Switch.t -> io: @@ -57,26 +48,12 @@ val call : 'conn_error ) result -type ('stream_err, 'headers) streaming_err = { - stream_error : 'stream_err option; - write_exn : exn option; - grpc_status : Grpc.Status.t; -} - -type ('a, 'headers, 'stream_err) streaming_result = { +type ('a, 'headers) streaming_result_success = { result : 'a; trailers : 'headers; - err : ('stream_err, 'headers) streaming_err option; } module Unary : sig - type ('net_response, 'headers, 'stream_err) premature_close = { - trailers : 'headers; - grpc_status : Grpc.Status.t; - net_response : 'net_response; - stream_error : 'stream_err option; - } - type ('net_response, 'response, 'headers) success = { net_response : 'net_response; response : 'response; @@ -84,10 +61,13 @@ module Unary : sig } type ('response, 'headers, 'stream_err, 'conn_err, 'net_response) result' = - [ `Premature_close of ('net_response, 'headers, 'stream_err) premature_close - | `Success of ('net_response, 'response, 'headers) success - | `Write_error of exn - | ('net_response, 'headers, 'conn_err) common_error ] + [ `Success of ('net_response, 'response, 'headers) success + | ( 'response, + 'headers, + 'stream_err, + 'conn_err, + 'net_response ) + Rpc_error.Unary.error' ] val call : sw:Eio.Switch.t -> @@ -107,14 +87,6 @@ module Unary : sig end module Client_streaming : sig - type ('a, 'headers, 'stream_err) stream_err = { - trailers : 'headers; - grpc_status : Grpc.Status.t; - result : 'a; - stream_error : 'stream_err; - write_exn : exn option; - } - type ('a, 'response, 'headers) success = { result : 'a; response : 'response; @@ -122,18 +94,15 @@ module Client_streaming : sig write_exn : exn option; } - type ('a, 'headers) premature_close = { - result : 'a; - trailers : 'headers; - grpc_status : Grpc.Status.t; - write_exn : exn option; - } - type ('a, 'headers, 'stream_err, 'conn_err, 'net_response, 'response) result' = - [ `Premature_close of ('a, 'headers) premature_close - | `Stream_error of ('a, 'headers, 'stream_err) stream_err - | `Success of ('a, 'response, 'headers) success - | ('net_response, 'headers, 'conn_err) common_error ] + [ `Success of ('a, 'response, 'headers) success + | ( 'a, + 'headers, + 'stream_err, + 'conn_err, + 'net_response, + 'response ) + Rpc_error.Client_streaming.error' ] val call : sw:Eio.Switch.t -> @@ -153,6 +122,15 @@ module Client_streaming : sig end module Server_streaming : sig + type ('a, 'headers, 'stream_error, 'net_response, 'conn_err) result' = + [ `Stream_result_success of ('a, 'headers) streaming_result_success + | ( 'a, + 'headers, + 'stream_error, + 'net_response, + 'conn_err ) + Rpc_error.Server_streaming.error' ] + val call : sw:Eio.Switch.t -> io: @@ -170,15 +148,18 @@ module Server_streaming : sig ((unit -> ('net_response, 'conn_err) result) -> read:(unit -> 'response Seq.node) -> 'a) -> - [ `Stream_result of ('a, 'headers, 'stream_error) streaming_result - | `Write_error of ('stream_error, 'headers) streaming_err option * 'headers - | ('net_response, 'headers, 'conn_err) common_error ] + ('a, 'headers, 'stream_error, 'net_response, 'conn_err) result' end module Bidirectional_streaming : sig type ('a, 'headers, 'stream_err, 'conn_err, 'net_response) result' = - [ `Stream_result of ('a, 'headers, 'stream_err) streaming_result - | ('net_response, 'headers, 'conn_err) common_error ] + [ `Stream_result_success of ('a, 'headers) streaming_result_success + | ( 'a, + 'headers, + 'stream_err, + 'conn_err, + 'net_response ) + Rpc_error.Bidirectional_streaming.error' ] (* val call : sw:Eio.Switch.t -> diff --git a/lib/eio/client/io.ml b/lib/eio/client/io.ml index 5a0272c..248d0d2 100644 --- a/lib/eio/client/io.ml +++ b/lib/eio/client/io.ml @@ -37,19 +37,56 @@ module type S = sig type connection_error type stream_error - val send_request : - headers:Grpc_client.request_headers -> - string -> - ( request writer - * ( Net_response.t, - response, + type client_error = + | Unary of + ( response, + Headers.t, + stream_error, + connection_error, + Net_response.t ) + Rpc_error.Unary.error' + | Client_streaming : + ( 'a, + Headers.t, + stream_error, + connection_error, + Net_response.t, + response ) + Rpc_error.Client_streaming.error' + -> client_error + | Server_streaming : + ( 'a, Headers.t, stream_error, + Net_response.t, connection_error ) - reader_or_error - Eio.Promise.t, - connection_error ) - result + Rpc_error.Server_streaming.error' + -> client_error + | Bidirectional_streaming : + ( 'a, + Headers.t, + stream_error, + connection_error, + Net_response.t ) + Rpc_error.Bidirectional_streaming.error' + -> client_error + + exception Grpc_client_error of client_error + + val raise_client_error : client_error -> 'exn + + val send_request : + headers:Grpc_client.request_headers -> + string -> + request writer + * ( Net_response.t, + response, + Headers.t, + stream_error, + connection_error ) + reader_or_error + Eio.Promise.t + * connection_error Eio.Promise.t end type ('headers, diff --git a/lib/eio/client/rpc_error.ml b/lib/eio/client/rpc_error.ml new file mode 100644 index 0000000..e0c46a9 --- /dev/null +++ b/lib/eio/client/rpc_error.ml @@ -0,0 +1,70 @@ +type ('net_response, 'headers) resp_not_ok = { + net_response : 'net_response; + grpc_status : Grpc.Status.t; + trailers : 'headers; +} + +type ('net_response, 'headers, 'conn_err) common_error = + [ `Connection_error of 'conn_err + | `Response_not_ok of ('net_response, 'headers) resp_not_ok ] + +module Unary = struct + type ('net_response, 'headers, 'stream_err) premature_close = { + trailers : 'headers; + grpc_status : Grpc.Status.t; + net_response : 'net_response; + stream_error : 'stream_err option; + } + + type ('response, 'headers, 'stream_err, 'conn_err, 'net_response) error' = + [ `Premature_close of ('net_response, 'headers, 'stream_err) premature_close + | `Write_error of exn + | ('net_response, 'headers, 'conn_err) common_error ] +end + +module Client_streaming = struct + type ('a, 'headers) premature_close = { + result : 'a; + trailers : 'headers; + grpc_status : Grpc.Status.t; + write_exn : exn option; + } + + type ('a, 'headers, 'stream_err) stream_err = { + trailers : 'headers; + grpc_status : Grpc.Status.t; + result : 'a; + stream_error : 'stream_err; + write_exn : exn option; + } + + type ('a, 'headers, 'stream_err, 'conn_err, 'net_response, 'response) error' = + [ `Premature_close of ('a, 'headers) premature_close + | `Stream_error of ('a, 'headers, 'stream_err) stream_err + | ('net_response, 'headers, 'conn_err) common_error ] +end + +type ('stream_err, 'headers) streaming_err = { + stream_error : 'stream_err option; + write_exn : exn option; + grpc_status : Grpc.Status.t; +} + +type ('a, 'headers, 'stream_err) streaming_result_err = { + result : 'a; + trailers : 'headers; + err : ('stream_err, 'headers) streaming_err; +} + +module Server_streaming = struct + type ('a, 'headers, 'stream_error, 'net_response, 'conn_err) error' = + [ `Stream_result_error of ('a, 'headers, 'stream_error) streaming_result_err + | `Write_error of ('stream_error, 'headers) streaming_err option * 'headers + | ('net_response, 'headers, 'conn_err) common_error ] +end + +module Bidirectional_streaming = struct + type ('a, 'headers, 'stream_err, 'conn_err, 'net_response) error' = + [ `Stream_result_error of ('a, 'headers, 'stream_err) streaming_result_err + | ('net_response, 'headers, 'conn_err) common_error ] +end diff --git a/lib/eio/core/body_reader.ml b/lib/eio/core/body_reader.ml index 10f4278..507592e 100644 --- a/lib/eio/core/body_reader.ml +++ b/lib/eio/core/body_reader.ml @@ -96,7 +96,7 @@ and unwrap_message ~msg_len ~data ~off ~len ~into:promise ~read_next ~read_more Bigstringaf.blit_to_bytes data ~src_off:off bytes ~dst_off:0 ~len; read_more (`Body (bytes, msg_len, msg_len - len)) ~into:promise -let rec read_more schedule_read buffer ~into:promise = +let rec read_more ~error schedule_read buffer ~into:promise = schedule_read ~on_eof:(fun () -> Eio.Promise.resolve promise (Err `Unexpected_eof)) ~on_read:(fun bigstring ~off ~len -> @@ -105,7 +105,7 @@ let rec read_more schedule_read buffer ~into:promise = if len < remaining then ( Bigstringaf.blit bigstring ~src_off:off buffer ~dst_off:(5 - remaining) ~len; - read_more schedule_read + read_more ~error schedule_read (`Header (buffer, remaining - len)) ~into:promise) else ( @@ -115,8 +115,8 @@ let rec read_more schedule_read buffer ~into:promise = let msg_len = extract_msg_len ~data:buffer ~off:(off + 1) in unwrap_message ~msg_len ~data:buffer ~off:remaining ~len:(len - remaining) ~into:promise - ~read_next:(fun () -> read_next schedule_read) - ~read_more:(read_more schedule_read)) + ~read_next:(fun () -> read_next ~error schedule_read) + ~read_more:(read_more ~error schedule_read)) | `Body (buffer, msg_len, remaining) -> if len >= remaining then ( Bigstringaf.blit_to_bytes bigstring ~src_off:off buffer @@ -125,8 +125,8 @@ let rec read_more schedule_read buffer ~into:promise = let next, next_u = Eio.Promise.create () in unwrap_message_with_header ~data:bigstring ~off:(off + remaining) ~len:(len - remaining) ~into:next_u - ~read_next:(fun () -> read_next schedule_read) - ~read_more:(read_more schedule_read); + ~read_next:(fun () -> read_next ~error schedule_read) + ~read_more:(read_more ~error schedule_read); Eio.Promise.resolve promise (Next ( to_consumer { bytes = buffer; len = msg_len }, @@ -135,22 +135,22 @@ let rec read_more schedule_read buffer ~into:promise = Eio.Promise.resolve promise (Next ( to_consumer { bytes = buffer; len = msg_len }, - fun () -> read_next schedule_read ))) + fun () -> read_next ~error schedule_read ))) else ( Bigstringaf.blit_to_bytes bigstring ~src_off:off buffer ~dst_off:(msg_len - remaining) ~len; - read_more schedule_read + read_more ~error schedule_read (`Body (buffer, msg_len, remaining - len)) ~into:promise)) -and read_next schedule_read = +and read_next ~error schedule_read = let promise, promise_u = Eio.Promise.create () in schedule_read ~on_eof:(fun () -> Eio.Promise.resolve promise_u Done) ~on_read:(fun bigstring ~off ~len -> unwrap_message_with_header ~data:bigstring ~off ~len ~into:promise_u - ~read_next:(fun () -> read_next schedule_read) - ~read_more:(read_more schedule_read)); + ~read_next:(fun () -> read_next ~error schedule_read) + ~read_more:(read_more ~error schedule_read)); Eio.Promise.await promise let fill_header ~pos ~length buffer = @@ -245,7 +245,8 @@ let%test_module "reading body" = 2); ] in - let result = read_next schedule_read in + let error, _ = Eio.Promise.create () in + let result = read_next ~error schedule_read in match result with Err `Unexpected_eof -> true | _ -> false let%expect_test "reading body in multiple chunks" = @@ -271,15 +272,16 @@ let%test_module "reading body" = 5); ] in - let result = read_next schedule_read in + let error, _ = Eio.Promise.create () in + let result = read_next ~error schedule_read in (match result with | Done -> print_endline "failure" - | Err `Unexpected_eof -> print_endline "failure" + | Err _ -> print_endline "failure" | Next ({ consume }, cons) -> ( print_endline (consume (fun { bytes; len } -> Bytes.sub_string bytes 0 len)); match cons () with | Done -> () - | Err `Unexpected_eof | Next _ -> failwith "expected end of sequence")); + | Err _ | Next _ -> failwith "expected end of sequence")); [%expect "5555555555"] end) diff --git a/lib/eio/core/recv_seq.ml b/lib/eio/core/recv_seq.ml index 299db65..dc73552 100644 --- a/lib/eio/core/recv_seq.ml +++ b/lib/eio/core/recv_seq.ml @@ -1,12 +1,24 @@ type ('a, 'err) t = unit -> ('a, 'err) recv_item and ('a, 'err) recv_item = Done | Next of 'a * ('a, 'err) t | Err of 'err -let rec map f recv () = +let rec map f recv = + fun () -> match recv () with | Done -> Done | Next (x, recv) -> Next (f x, map f recv) | Err err -> Err err +(* let rec map ~error f recv () = *) +(* match recv () with *) +(* | Seq.Nil -> Done *) +(* | Seq.Cons (x, recv) -> *) +(* Next *) +(* ( f x, *) +(* fun () -> *) +(* Eio.Fiber.first *) +(* (fun () -> Err (Eio.Promise.await error)) *) +(* (fun () -> map ~error f recv ()) ) *) + let to_seq ?err_to_exn recv = let rec loop recv () = match recv () with diff --git a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml index 4cd466f..28ac620 100644 --- a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml +++ b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml @@ -21,7 +21,9 @@ module Net_response = struct end type connection_error = H2.Client_connection.error -type stream_error = [ connection_error | `Unexpected_eof ] + +type stream_error = + [ `Unexpected_eof | `Connection_error of H2.Client_connection.error ] type t = ( H2.Headers.t, @@ -54,6 +56,45 @@ module Make_net (Client : Client) : type request = Pbrt.Encoder.t -> unit type response = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer + type client_error = + | Unary of + ( response, + Headers.t, + stream_error, + connection_error, + Net_response.t ) + Grpc_client_eio.Rpc_error.Unary.error' + | Client_streaming : + ( 'a, + Headers.t, + stream_error, + connection_error, + Net_response.t, + response ) + Grpc_client_eio.Rpc_error.Client_streaming.error' + -> client_error + | Server_streaming : + ( 'a, + Headers.t, + stream_error, + Net_response.t, + connection_error ) + Grpc_client_eio.Rpc_error.Server_streaming.error' + -> client_error + | Bidirectional_streaming : + ( 'a, + Headers.t, + stream_error, + connection_error, + Net_response.t ) + Grpc_client_eio.Rpc_error.Bidirectional_streaming.error' + -> client_error + + exception Grpc_client_error of client_error + + let raise_client_error (error : client_error) = + raise (Grpc_client_error error) + let send_request ~(headers : Grpc_client.request_headers) target = (* We are flushing headers immediately but potentially for the unary and server streaming cases we shouldn't do it @@ -78,14 +119,6 @@ module Make_net (Client : Client) : in (* Allocate once, use a pool of these *) let errored = ref false in - (* - let report_net_error resolver trailers_resolver err = - errored := true; - Eio.Promise.resolve resolver - (Grpc_client_eio.Io.Err (err :> stream_error)); - Eio.Promise.resolve trailers_resolver H2.Headers.empty - in - *) let response_handler response reader = let trailers, trailers_u = Eio.Promise.create () in let () = @@ -94,18 +127,8 @@ module Make_net (Client : Client) : in let next = - (* FIXME: connection error handling - - Eio.Switch.run (fun sw -> - Eio.Fiber.fork_daemon ~sw (fun () -> - report_net_error next_item_u trailers_u - (Eio.Promise.await Client.connection_error); - `Stop_daemon); - Eio.Promise.await next_item |> ignore)); - *) - let _ = Client.connection_error in (fun () -> - Grpc_eio_core.Body_reader.read_next + Grpc_eio_core.Body_reader.read_next ~error:Client.connection_error (H2.Body.Reader.schedule_read reader)) |> Grpc_eio_core.Recv_seq.map (fun { Grpc_eio_core.Body_reader.consume } -> @@ -127,25 +150,29 @@ module Make_net (Client : Client) : ~response_handler request in let encoder = Pbrt.Encoder.create ~size:65536 () in - Ok - ( { - Grpc_client_eio.Io.write = - (let header_buffer = Bytes.create 5 in - fun input -> - if !errored = true then raise Write_after_error - else ( - Pbrt.Encoder.clear encoder; - input encoder; - let msg = Pbrt.Encoder.to_bytes encoder in - Grpc.Message.fill_header ~length:(Bytes.length msg) - header_buffer; - H2.Body.Writer.write_string body_writer - (Bytes.unsafe_to_string header_buffer); - H2.Body.Writer.write_string body_writer - (Bytes.unsafe_to_string msg))); - close = (fun () -> H2.Body.Writer.close body_writer); - }, - result ) + ( { + Grpc_client_eio.Io.write = + (let header_buffer = Bytes.create 5 in + + fun input -> + if !errored = true then raise Write_after_error + else ( + Pbrt.Encoder.clear encoder; + input encoder; + + let msg = Pbrt.Encoder.to_bytes encoder in + + Grpc.Message.fill_header ~length:(Bytes.length msg) header_buffer; + + H2.Body.Writer.write_string body_writer + (Bytes.unsafe_to_string header_buffer); + + H2.Body.Writer.write_string body_writer + (Bytes.unsafe_to_string msg))); + close = (fun () -> H2.Body.Writer.close body_writer); + }, + result, + Client.connection_error ) end module Expert = struct @@ -157,9 +184,11 @@ module Expert = struct Eio.Switch.run (fun sw' -> let conn = H2_eio.Client.create_connection ~sw:sw' - ~error_handler:(Eio.Promise.resolve connection_error_resolve) + ~error_handler:(fun e -> + Eio.Promise.resolve connection_error_resolve e) socket in + Eio.Switch.on_release sw' (fun () -> Eio.Promise.await (H2_eio.Client.shutdown conn)); (* For now we're ignoring the errors, we should probably inject them into grpc handlers to let them handle it *) @@ -185,7 +214,9 @@ module Expert = struct |> List.hd in let addr = `Tcp (Eio_unix.Net.Ipaddr.of_unix inet, port) in + let socket = Eio.Net.connect ~sw net addr in + create_with_socket ~socket ~host ~scheme ~sw end diff --git a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli index 9804e41..800401d 100644 --- a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli +++ b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli @@ -1,4 +1,5 @@ -type stream_error = [ H2.Client_connection.error | `Unexpected_eof ] +type stream_error = + [ `Unexpected_eof | `Connection_error of H2.Client_connection.error ] type t = ( H2.Headers.t, @@ -26,4 +27,5 @@ module Expert : sig t end +(* TODO: add logger *) val create_client : net:Eio_unix.Net.t -> sw:Eio.Switch.t -> string -> t diff --git a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml index 09657b6..85388b1 100644 --- a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml +++ b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.ml @@ -1,4 +1,5 @@ exception Unexpected_eof +exception Connection_error of H2.Reqd.error module Io = struct type request = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer @@ -7,15 +8,20 @@ module Io = struct module Growing_buffer = Grpc.Buffer module Net_request = struct - type t = Eio.Net.Sockaddr.stream * H2.Reqd.t * H2.Request.t + type t = + Eio.Net.Sockaddr.stream + * H2.Reqd.t + * H2.Request.t + * H2.Reqd.error Eio.Promise.t - let is_post (_, _, req) = + let is_post (_, _, req, _) = match req with { H2.Request.meth = `POST; _ } -> true | _ -> false - let target (_, _, req) = req.H2.Request.target + let target (_, _, req, _) = req.H2.Request.target (* Expose a way to interrupt *) - let get_header (_, _, req) name = H2.Headers.get req.H2.Request.headers name + let get_header (_, _, req, _) name = + H2.Headers.get req.H2.Request.headers name let to_seq recv = let rec loop recv () = @@ -23,13 +29,15 @@ module Io = struct | Grpc_eio_core.Recv_seq.Done -> Seq.Nil | Next (x, recv) -> Seq.Cons (x, loop recv) | Err `Unexpected_eof -> raise Unexpected_eof + | Err (`Connection_error error) -> raise (Connection_error error) in loop recv - let body (_, reqd, _) = + let body (_, reqd, _, error) = let body = H2.Reqd.request_body reqd in (fun () -> - Grpc_eio_core.Body_reader.read_next (H2.Body.Reader.schedule_read body)) + Grpc_eio_core.Body_reader.read_next ~error + (H2.Body.Reader.schedule_read body)) |> Grpc_eio_core.Recv_seq.map (fun { Grpc_eio_core.Body_reader.consume } -> { @@ -56,7 +64,7 @@ module Io = struct -> () - let respond_streaming ~headers (_, reqd, _) = + let respond_streaming ~headers (_, reqd, _, _) = let body_writer = H2.Reqd.respond_with_streaming ~flush_headers_immediately:true reqd (H2.Response.create @@ -83,7 +91,7 @@ module Io = struct let is_closed () = H2.Body.Writer.is_closed body_writer in { Grpc_server_eio.Io.close; write; write_trailers; is_closed } - let respond_error ~status_code ~headers (_, reqd, _) = + let respond_error ~status_code ~headers (_, reqd, _, _) = H2.Reqd.respond_with_string reqd (H2.Response.create ~headers:(H2.Headers.of_list headers) @@ -102,8 +110,10 @@ let io = let connection_handler ~sw ?config ?h2_error_handler ?grpc_error_handler server : 'a Eio.Net.connection_handler = fun socket addr -> + let error, error_r = Eio.Promise.create () in let error_handler client_address ?request error respond = (* Report internal error via headers *) + Eio.Promise.resolve error_r error; let () = match h2_error_handler with | Some f -> f client_address ?request error @@ -124,5 +134,5 @@ let connection_handler ~sw ?config ?h2_error_handler ?grpc_error_handler server Eio.Fiber.fork ~sw (fun () -> Grpc_server_eio.handle_request ~io ?error_handler:grpc_error_handler server - (client_addr, reqd, H2.Reqd.request reqd))) + (client_addr, reqd, H2.Reqd.request reqd, error))) ~error_handler addr socket ~sw diff --git a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.mli b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.mli index 2184b73..23c1521 100644 --- a/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.mli +++ b/lib/eio/io-server-h2-ocaml-protoc/io_server_h2_ocaml_protoc.mli @@ -1,6 +1,10 @@ include Grpc_server_eio.Io.S - with type Net_request.t = Eio.Net.Sockaddr.stream * H2.Reqd.t * H2.Request.t + with type Net_request.t = + Eio.Net.Sockaddr.stream + * H2.Reqd.t + * H2.Request.t + * H2.Reqd.error Eio.Promise.t and type request = Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer and type response = Pbrt.Encoder.t -> unit From b4e85dbf64d112485069f50c87086841f3aceb58 Mon Sep 17 00:00:00 2001 From: Wojtek Czekalski Date: Wed, 13 Nov 2024 20:34:14 +0100 Subject: [PATCH 3/7] Fix arpaca codegen --- lib/eio/arpaca/bin/codegen.ml | 79 ++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/lib/eio/arpaca/bin/codegen.ml b/lib/eio/arpaca/bin/codegen.ml index 14e9bd3..8ec07d1 100644 --- a/lib/eio/arpaca/bin/codegen.ml +++ b/lib/eio/arpaca/bin/codegen.ml @@ -53,7 +53,8 @@ let to_snake_case = in fun str -> regex str -let service_name_of_package path = String.concat "." path +let service_name_of_package service_packages service = + String.concat "." (service_packages @ [ service ]) let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit = @@ -74,7 +75,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit connection_error ) Grpc_client_eio.Io.t) request = let response = - Grpc_client_eio.Client.Unary.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Unary.call ~sw ~io ~service:"%s" ~method_name:%S ~headers:(Grpc_client.make_request_headers `Proto) (%s.%s request) @@ -89,8 +90,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit } | #Grpc_client_eio.Rpc_error.Unary.error' as rest -> Io'.raise_client_error (Unary rest)|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -99,13 +100,13 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit {|let %s (type headers net_response stream_error connection_error) ~sw ~(io : ( headers, net_response, - Pbrt.Encoder.t -> unit, + Pbrt.Encoder.t ->unit, Pbrt.Decoder.t Grpc_eio_core.Body_reader.consumer, stream_error, connection_error ) Grpc_client_eio.Io.t) request handler = let stream = - Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (%s.%s request) (fun net_response ~read -> @@ -125,8 +126,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit | #Grpc_client_eio.Rpc_error.Server_streaming.error' as rest -> Io'.raise_client_error (Server_streaming rest) |} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -141,7 +142,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit connection_error ) Grpc_client_eio.Io.t) handler = let response = - Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (fun net_response ~writer -> @@ -159,8 +160,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit } | #Grpc_client_eio.Rpc_error.Client_streaming.error' as rest -> Io'.raise_client_error (Client_streaming rest)|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -175,7 +176,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit connection_error ) Grpc_client_eio.Io.t) handler = let stream = - Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (fun net_response ~writer ~read -> @@ -194,8 +195,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit | `Stream_result_success result -> result | #Grpc_client_eio.Rpc_error.Bidirectional_streaming.error' as rest -> Io'.raise_client_error (Bidirectional_streaming rest)|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -208,7 +209,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit F.linep sc {|let %s ~sw ~io request = let response = - Grpc_client_eio.Client.Unary.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Unary.call ~sw ~io ~service:"%s" ~method_name:%S ~headers:(Grpc_client.make_request_headers `Proto) (%s.%s request) @@ -223,8 +224,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit } | #Grpc_client_eio.Rpc_error.Unary.error' as rest -> Error rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -232,7 +233,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit F.linep sc {|let %s ~sw ~io request handler = let stream = - Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (%s.%s request) (fun net_response ~read -> @@ -249,8 +250,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit | `Stream_result_success result -> Ok result | #Grpc_client_eio.Rpc_error.Server_streaming.error' as rest -> Error rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -258,7 +259,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit F.linep sc {|let %s ~sw ~io handler = let response = - Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (fun net_response ~writer -> @@ -276,8 +277,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit } | #Grpc_client_eio.Rpc_error.Client_streaming.error' as rest -> Error rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -285,7 +286,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit F.linep sc {|let %s ~sw ~io handler = let stream = - Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (fun net_response ~writer ~read -> @@ -303,8 +304,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit | `Stream_result_success result -> Ok result | #Grpc_client_eio.Rpc_error.Bidirectional_streaming.error' as rest -> Error rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -333,7 +334,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit | #Grpc_client_eio.Rpc_error.Unary.error' as rest -> rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) + (service_name_of_package service.service_packages service.service_name) service.service_name rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name @@ -341,7 +342,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit | `Server_streaming -> F.linep sc {|let %s ~sw ~io request handler = - Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Server_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (%s.%s request) (fun net_response ~read -> @@ -354,8 +355,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit in handler net_response responses)|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -363,7 +364,7 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit F.linep sc {|let %s ~sw ~io handler = let response = - Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Client_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (fun net_response ~writer -> @@ -382,15 +383,15 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit | #Grpc_client_eio.Rpc_error.Client_streaming.error' as rest -> rest|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) | `Bidirectional_streaming -> F.linep sc {|let %s ~sw ~io handler = - Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s.%s" + Grpc_client_eio.Client.Bidirectional_streaming.call ~sw ~io ~service:"%s" ~method_name:"%s" ~headers:(Grpc_client.make_request_headers `Proto) (fun net_response ~writer ~read -> @@ -404,8 +405,8 @@ let gen_service_client_struct ~proto_gen_module (service : Ot.service) sc : unit in handler net_response ~writer:writer' ~read:read')|} (Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case) - (service_name_of_package service.service_packages) - service.service_name rpc.rpc_name typ_mod_name + (service_name_of_package service.service_packages service.service_name) + rpc.rpc_name typ_mod_name (function_name_encode_pb ~service_name ~rpc_name rpc.rpc_req) typ_mod_name (function_name_decode_pb ~service_name ~rpc_name rpc.rpc_res) @@ -471,9 +472,9 @@ let gen_service_server_struct ~proto_gen_module (service : Ot.service) top_scope let rpc_name = rpc.rpc_name in let service_name = service.service_name in - F.linep sc {|| "%s.%s", %S ->|} - (String.concat "." service.service_packages) - service.service_name rpc.rpc_name; + F.linep sc {|| "%s", %S ->|} + (String.concat "." (service.service_packages @ [ service.service_name ])) + rpc.rpc_name; let impl = Pb_codegen_util.function_name_of_rpc rpc |> to_snake_case in let decoder_func = From 23498436f168b1610bace989a42763890f4ace79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cholewi=C5=84ski?= Date: Sun, 17 Nov 2024 18:35:03 +0100 Subject: [PATCH 4/7] fix trailers on the first response --- .../io_client_h2_ocaml_protoc.ml | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml index 28ac620..7e3ae44 100644 --- a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml +++ b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml @@ -120,10 +120,19 @@ module Make_net (Client : Client) : (* Allocate once, use a pool of these *) let errored = ref false in let response_handler response reader = - let trailers, trailers_u = Eio.Promise.create () in - let () = - trailers_handler := - fun trailers -> Eio.Promise.resolve trailers_u trailers + let grpc_status_header = + H2.Headers.get response.H2.Response.headers "grpc-status" + in + let trailers = + match grpc_status_header with + | Some status -> + Eio.Promise.create_resolved + @@ H2.Headers.of_list [ ("grpc-status", status) ] + | None -> + let trailers', trailers_u = Eio.Promise.create () in + (trailers_handler := + fun trailers -> Eio.Promise.resolve trailers_u trailers); + trailers' in let next = From 92deafd245700ee76b7ccbb084a9a6f3aa5682f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cholewi=C5=84ski?= Date: Tue, 10 Dec 2024 18:13:52 +0100 Subject: [PATCH 5/7] make network argument more generic --- .../io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml | 2 +- .../io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml index 7e3ae44..2e513cb 100644 --- a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml +++ b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.ml @@ -211,7 +211,7 @@ module Expert = struct let scheme = scheme end)) - let create_with_address ~(net : Eio_unix.Net.t) ~sw ~scheme ~host ~port = + let create_with_address ~(net : _ Eio.Net.t) ~sw ~scheme ~host ~port = let inet, port = Eio_unix.run_in_systhread (fun () -> Unix.getaddrinfo host (string_of_int port) diff --git a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli index 800401d..9e10bcf 100644 --- a/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli +++ b/lib/eio/io-client-h2-ocaml-protoc/io_client_h2_ocaml_protoc.mli @@ -19,7 +19,7 @@ module Expert : sig t val create_with_address : - net:Eio_unix.Net.t -> + net:_ Eio.Net.t -> sw:Eio.Switch.t -> scheme:string -> host:string -> @@ -28,4 +28,4 @@ module Expert : sig end (* TODO: add logger *) -val create_client : net:Eio_unix.Net.t -> sw:Eio.Switch.t -> string -> t +val create_client : net:_ Eio.Net.t -> sw:Eio.Switch.t -> string -> t From 80cace103068c0966d809d35df7070480285155b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cholewi=C5=84ski?= Date: Wed, 11 Dec 2024 13:44:36 +0100 Subject: [PATCH 6/7] fix relative proto paths in command --- lib/eio/arpaca/bin/main.ml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/eio/arpaca/bin/main.ml b/lib/eio/arpaca/bin/main.ml index 9665fb9..37abf23 100644 --- a/lib/eio/arpaca/bin/main.ml +++ b/lib/eio/arpaca/bin/main.ml @@ -120,6 +120,10 @@ let prepare proto_file_name include_dirs = let { Pb_codegen_ocaml_type.proto_services; _ }, _ = compile proto_file_name include_dirs false in + let proto_file_name = + if not (String.contains proto_file_name '/') then proto_file_name + else List.hd @@ List.rev @@ String.split_on_char '/' proto_file_name + in let proto_gen_module = Pb_codegen_util.caml_file_name_of_proto_file_name ~proto_file_name in From 5dfdb4e2744136bebf49b9caec044fbf40de7778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cholewi=C5=84ski?= Date: Thu, 19 Dec 2024 17:43:27 +0100 Subject: [PATCH 7/7] handle connection error in unary --- lib/eio/client/client.ml | 97 ++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/lib/eio/client/client.ml b/lib/eio/client/client.ml index 644b13d..0b7d370 100644 --- a/lib/eio/client/client.ml +++ b/lib/eio/client/client.ml @@ -240,60 +240,69 @@ module Unary = struct Io.t) ~service ~method_name ~headers request : (_, headers, stream_error, conn_error, net_response) result' = match call ~sw ~io ~service ~method_name ~headers () with - | Ok { writer; recv; grpc_status; write_exn; _ } -> ( + | Ok { writer; recv; grpc_status; write_exn; conn_err = conn_err_p } -> ( try if not (writer.write request) then `Write_error (Option.get !write_exn) else ( writer.close (); - match Eio.Promise.await recv with - | Ok { net_response; recv_seq; trailers } -> - let (module Io') = io in - if Io'.Net_response.is_ok net_response then - match recv_seq () with - | Grpc_eio_core.Recv_seq.Done -> - `Premature_close - { - net_response; - grpc_status = Eio.Promise.await grpc_status; - trailers = Eio.Promise.await trailers; - stream_error = None; - } - | Err stream_error -> - `Premature_close - { - net_response; - grpc_status = Eio.Promise.await grpc_status; - trailers = Eio.Promise.await trailers; - stream_error = Some stream_error; - } - | Next (response, _) -> ( - let status = Eio.Promise.await grpc_status in - match Grpc.Status.code status with - | OK -> - `Success + Eio.Fiber.first + (fun () -> `Connection_error (Eio.Promise.await conn_err_p)) + (fun () -> + match Eio.Promise.await recv with + | Ok { net_response; recv_seq; trailers } -> + let (module Io') = io in + if Io'.Net_response.is_ok net_response then + match recv_seq () with + | Grpc_eio_core.Recv_seq.Done -> + (`Premature_close + { + net_response; + grpc_status = Eio.Promise.await grpc_status; + trailers = Eio.Promise.await trailers; + stream_error = None; + } + : ( response, + headers, + stream_error, + conn_error, + net_response ) + result') + | Err stream_error -> + `Premature_close { net_response; - response; + grpc_status = Eio.Promise.await grpc_status; trailers = Eio.Promise.await trailers; + stream_error = Some stream_error; } - | _ -> - (* Not reachable under normal circumstances + | Next (response, _) -> ( + let status = Eio.Promise.await grpc_status in + match Grpc.Status.code status with + | OK -> + `Success + { + net_response; + response; + trailers = Eio.Promise.await trailers; + } + | _ -> + (* Not reachable under normal circumstances https://github.com/grpc/grpc/issues/12824 *) - `Response_not_ok - { - net_response; - grpc_status = status; - trailers = Eio.Promise.await trailers; - }) - else - `Response_not_ok - { - net_response; - grpc_status = Eio.Promise.await grpc_status; - trailers = Eio.Promise.await trailers; - } - | Error e -> `Connection_error e) + `Response_not_ok + { + net_response; + grpc_status = status; + trailers = Eio.Promise.await trailers; + }) + else + `Response_not_ok + { + net_response; + grpc_status = Eio.Promise.await grpc_status; + trailers = Eio.Promise.await trailers; + } + | Error e -> `Connection_error e)) with exn -> `Write_error exn) | Error e -> `Connection_error e end