Commit 07595c9e authored by Grégoire Henry's avatar Grégoire Henry

P2p: improve cancelation

Pending connections were not easily interuptible.
parent 4adf696c
......@@ -390,10 +390,11 @@ let register st conn =
conn
end
let write { write_queue } msg =
Lwt.catch
(fun () -> Lwt_pipe.push write_queue msg >>= return)
(fun _ -> fail P2p_errors.Connection_closed)
let write ?canceler { write_queue } msg =
trace P2p_errors.Connection_closed @@
protect ?canceler begin fun () ->
Lwt_pipe.push write_queue msg >>= return
end
let write_now { write_queue } msg = Lwt_pipe.push_now write_queue msg
let read_from conn ?pos ?len buf msg =
......@@ -426,7 +427,7 @@ let read_now conn ?pos ?len buf =
(Lwt_pipe.pop_now conn.read_queue)
with Lwt_pipe.Closed -> Some (Error [P2p_errors.Connection_closed])
let read conn ?pos ?len buf =
let read ?canceler conn ?pos ?len buf =
match conn.partial_read with
| Some msg ->
conn.partial_read <- None ;
......@@ -434,11 +435,13 @@ let read conn ?pos ?len buf =
| None ->
Lwt.catch
(fun () ->
Lwt_pipe.pop conn.read_queue >|= fun msg ->
protect ?canceler begin fun () ->
Lwt_pipe.pop conn.read_queue
end >|= fun msg ->
read_from conn ?pos ?len buf msg)
(fun _ -> fail P2p_errors.Connection_closed)
let read_full conn ?pos ?len buf =
let read_full ?canceler conn ?pos ?len buf =
let maxlen = MBytes.length buf in
let pos = Option.unopt ~default:0 pos in
let len = Option.unopt ~default:(maxlen - pos) len in
......@@ -448,7 +451,7 @@ let read_full conn ?pos ?len buf =
if len = 0 then
return_unit
else
read conn ~pos ~len buf >>=? fun read_len ->
read ?canceler conn ~pos ~len buf >>=? fun read_len ->
loop (pos + read_len) (len - read_len) in
loop pos len
......
......@@ -61,7 +61,9 @@ val create:
val register: t -> P2p_fd.t -> connection
(** [register sched fd] is a [connection] managed by [sched]. *)
val write: connection -> MBytes.t -> unit tzresult Lwt.t
val write:
?canceler:Lwt_canceler.t ->
connection -> MBytes.t -> unit tzresult Lwt.t
(** [write conn msg] returns [Ok ()] when [msg] has been added to
[conn]'s write queue, or fail with an error. *)
......@@ -76,11 +78,13 @@ val read_now:
[buf] starting at [pos]. *)
val read:
?canceler:Lwt_canceler.t ->
connection -> ?pos:int -> ?len:int -> MBytes.t -> int tzresult Lwt.t
(** Like [read_now], but waits till [conn] read queue has at least one
element instead of failing. *)
val read_full:
?canceler:Lwt_canceler.t ->
connection -> ?pos:int -> ?len:int -> MBytes.t -> unit tzresult Lwt.t
(** Like [read], but blits exactly [len] bytes in [buf]. *)
......
......@@ -783,6 +783,7 @@ and raw_authenticate pool ?point_info canceler fd point =
(if incoming then " incoming" else "") >>= fun () ->
protect ~canceler begin fun () ->
P2p_socket.authenticate
~canceler
~proof_of_work_target:pool.config.proof_of_work_target
~incoming fd point
?listening_port:pool.config.listening_port
......@@ -885,6 +886,7 @@ and raw_authenticate pool ?point_info canceler fd point =
?incoming_message_queue_size:pool.config.incoming_message_queue_size
?outgoing_message_queue_size:pool.config.outgoing_message_queue_size
?binary_chunks_size:pool.config.binary_chunks_size
~canceler
auth_fd pool.encoding >>=? fun conn ->
lwt_debug "authenticate: %a -> Connected %a"
P2p_point.Id.pp point
......
......@@ -56,7 +56,7 @@ module Crypto = struct
input and output. *)
let () = assert (Crypto_box.boxzerobytes >= header_length)
let write_chunk fd cryptobox_data msg =
let write_chunk ?canceler fd cryptobox_data msg =
let msglen = MBytes.length msg in
fail_unless
(msglen <= max_content_length) P2p_errors.Invalid_message_size >>=? fun () ->
......@@ -71,15 +71,15 @@ module Crypto = struct
let header_pos = Crypto_box.boxzerobytes - header_length in
MBytes.set_int16 buf header_pos encrypted_length ;
let payload = MBytes.sub buf header_pos (buf_length - header_pos) in
P2p_io_scheduler.write fd payload
P2p_io_scheduler.write ?canceler fd payload
let read_chunk fd cryptobox_data =
let read_chunk ?canceler fd cryptobox_data =
let header_buf = MBytes.create header_length in
P2p_io_scheduler.read_full ~len:header_length fd header_buf >>=? fun () ->
P2p_io_scheduler.read_full ?canceler ~len:header_length fd header_buf >>=? fun () ->
let encrypted_length = MBytes.get_uint16 header_buf 0 in
let buf_length = encrypted_length + Crypto_box.boxzerobytes in
let buf = MBytes.make buf_length '\x00' in
P2p_io_scheduler.read_full
P2p_io_scheduler.read_full ?canceler
~pos:Crypto_box.boxzerobytes ~len:encrypted_length fd buf >>=? fun () ->
let remote_nonce = cryptobox_data.remote_nonce in
cryptobox_data.remote_nonce <- Crypto_box.increment_nonce remote_nonce ;
......@@ -140,7 +140,7 @@ module Connection_message = struct
(req "message_nonce" Crypto_box.nonce_encoding)
(req "versions" (Variable.list P2p_version.encoding)))
let write fd message =
let write ~canceler fd message =
let encoded_message_len =
Data_encoding.Binary.length encoding message in
fail_unless
......@@ -155,20 +155,20 @@ module Connection_message = struct
| Some last ->
fail_unless (last = len) P2p_errors.Encoding_error >>=? fun () ->
MBytes.set_int16 buf 0 encoded_message_len ;
P2p_io_scheduler.write fd buf >>=? fun () ->
P2p_io_scheduler.write ~canceler fd buf >>=? fun () ->
(* We return the raw message as it is used later to compute
the nonces *)
return buf
let read fd =
let read ~canceler fd =
let header_buf = MBytes.create Crypto.header_length in
P2p_io_scheduler.read_full
P2p_io_scheduler.read_full ~canceler
~len:Crypto.header_length fd header_buf >>=? fun () ->
let len = MBytes.get_uint16 header_buf 0 in
let pos = Crypto.header_length in
let buf = MBytes.create (pos + len) in
MBytes.set_int16 buf 0 len ;
P2p_io_scheduler.read_full ~len ~pos fd buf >>=? fun () ->
P2p_io_scheduler.read_full ~canceler ~len ~pos fd buf >>=? fun () ->
match Data_encoding.Binary.read encoding buf pos len with
| None ->
fail P2p_errors.Decoding_error
......@@ -188,7 +188,7 @@ type 'meta metadata_config = {
module Metadata = struct
let write metadata_config cryptobox_data fd message =
let write ~canceler metadata_config cryptobox_data fd message =
let encoded_message_len =
Data_encoding.Binary.length metadata_config.conn_meta_encoding message in
let buf = MBytes.create encoded_message_len in
......@@ -201,10 +201,10 @@ module Metadata = struct
| Some last ->
fail_unless (last = encoded_message_len)
P2p_errors.Encoding_error >>=? fun () ->
Crypto.write_chunk cryptobox_data fd buf
Crypto.write_chunk ~canceler cryptobox_data fd buf
let read metadata_config fd cryptobox_data =
Crypto.read_chunk fd cryptobox_data >>=? fun buf ->
let read ~canceler metadata_config fd cryptobox_data =
Crypto.read_chunk ~canceler fd cryptobox_data >>=? fun buf ->
let length = MBytes.length buf in
let encoding = metadata_config.conn_meta_encoding in
match
......@@ -248,7 +248,7 @@ module Ack = struct
nack_case (Tag 255) ;
]
let write fd cryptobox_data message =
let write ?canceler fd cryptobox_data message =
let encoded_message_len =
Data_encoding.Binary.length encoding message in
let buf = MBytes.create encoded_message_len in
......@@ -258,10 +258,10 @@ module Ack = struct
| Some last ->
fail_unless (last = encoded_message_len)
P2p_errors.Encoding_error >>=? fun () ->
Crypto.write_chunk fd cryptobox_data buf
Crypto.write_chunk ?canceler fd cryptobox_data buf
let read fd cryptobox_data =
Crypto.read_chunk fd cryptobox_data >>=? fun buf ->
let read ?canceler fd cryptobox_data =
Crypto.read_chunk ?canceler fd cryptobox_data >>=? fun buf ->
let length = MBytes.length buf in
match Data_encoding.Binary.read encoding buf 0 length with
| None ->
......@@ -289,18 +289,19 @@ let kick { fd ; cryptobox_data ; _ } =
whether we're trying to connect to a peer or checking an incoming
connection, both parties must first introduce themselves. *)
let authenticate
~canceler
~proof_of_work_target
~incoming fd (remote_addr, remote_socket_port as point)
?listening_port identity supported_versions metadata_config =
let local_nonce_seed = Crypto_box.random_nonce () in
lwt_debug "Sending authenfication to %a" P2p_point.Id.pp point >>= fun () ->
Connection_message.write fd
Connection_message.write ~canceler fd
{ public_key = identity.P2p_identity.public_key ;
proof_of_work_stamp = identity.proof_of_work_stamp ;
message_nonce = local_nonce_seed ;
port = listening_port ;
versions = supported_versions } >>=? fun sent_msg ->
Connection_message.read fd >>=? fun (msg, recv_msg) ->
Connection_message.read ~canceler fd >>=? fun (msg, recv_msg) ->
let remote_listening_port =
if incoming then msg.port else Some remote_socket_port in
let id_point = remote_addr, remote_listening_port in
......@@ -318,8 +319,8 @@ let authenticate
Crypto_box.generate_nonces ~incoming ~sent_msg ~recv_msg in
let cryptobox_data = { Crypto.channel_key ; local_nonce ; remote_nonce } in
let local_metadata = metadata_config.conn_meta_value remote_peer_id in
Metadata.write metadata_config fd cryptobox_data local_metadata >>=? fun () ->
Metadata.read metadata_config fd cryptobox_data >>=? fun remote_metadata ->
Metadata.write ~canceler metadata_config fd cryptobox_data local_metadata >>=? fun () ->
Metadata.read ~canceler metadata_config fd cryptobox_data >>=? fun remote_metadata ->
let info =
{ P2p_connection.Info.peer_id = remote_peer_id ;
versions = msg.versions ; incoming ;
......@@ -351,9 +352,8 @@ module Reader = struct
lwt_debug "[read_message] incremental decoding error" >>= fun () ->
return_none
| Await decode_next_buf ->
protect ~canceler:st.canceler begin fun () ->
Crypto.read_chunk st.conn.fd st.conn.cryptobox_data
end >>=? fun buf ->
Crypto.read_chunk ~canceler:st.canceler
st.conn.fd st.conn.cryptobox_data >>=? fun buf ->
lwt_debug
"reading %d bytes from %a"
(MBytes.length buf) P2p_peer.Id.pp st.conn.info.peer_id >>= fun () ->
......@@ -432,9 +432,8 @@ module Writer = struct
let rec loop = function
| [] -> return_unit
| buf :: l ->
protect ~canceler:st.canceler begin fun () ->
Crypto.write_chunk st.conn.fd st.conn.cryptobox_data buf
end >>=? fun () ->
Crypto.write_chunk ~canceler:st.canceler
st.conn.fd st.conn.cryptobox_data buf >>=? fun () ->
lwt_debug "writing %d bytes to %a"
(MBytes.length buf) P2p_peer.Id.pp st.conn.info.peer_id >>= fun () ->
loop l in
......@@ -561,11 +560,12 @@ let private_node { conn } = conn.info.private_node
let accept
?incoming_message_queue_size ?outgoing_message_queue_size
?binary_chunks_size
~canceler
conn
encoding =
protect begin fun () ->
Ack.write conn.fd conn.cryptobox_data Ack >>=? fun () ->
Ack.read conn.fd conn.cryptobox_data
Ack.write ~canceler conn.fd conn.cryptobox_data Ack >>=? fun () ->
Ack.read ~canceler conn.fd conn.cryptobox_data
end ~on_error:begin fun err ->
P2p_io_scheduler.close conn.fd >>= fun _ ->
match err with
......
......@@ -62,6 +62,7 @@ val private_node: ('msg, 'meta) t -> bool
(** {1 Low-level functions (do not use directly)} *)
val authenticate:
canceler:Lwt_canceler.t ->
proof_of_work_target:Crypto_box.target ->
incoming:bool ->
P2p_io_scheduler.connection -> P2p_point.Id.t ->
......@@ -84,6 +85,7 @@ val accept:
?incoming_message_queue_size:int ->
?outgoing_message_queue_size:int ->
?binary_chunks_size: int ->
canceler:Lwt_canceler.t ->
'meta authenticated_connection ->
'msg Data_encoding.t -> ('msg, 'meta) t tzresult Lwt.t
(** (Low-level) (Cancelable) Accepts a remote peer given an
......
......@@ -27,6 +27,8 @@ include Logging.Make (struct let name = "test.p2p.connection" end)
let addr = ref Ipaddr.V6.localhost
let canceler = Lwt_canceler.create () (* unused *)
let proof_of_work_target = Crypto_box.make_target 16.
let id1 = P2p_identity.generate proof_of_work_target
let id2 = P2p_identity.generate proof_of_work_target
......@@ -117,6 +119,7 @@ let raw_accept sched main_socket =
let accept sched main_socket =
raw_accept sched main_socket >>= fun (fd, point) ->
P2p_socket.authenticate
~canceler
~proof_of_work_target
~incoming:true fd point id1 versions
conn_meta_config
......@@ -132,6 +135,7 @@ let raw_connect sched addr port =
let connect sched addr port id =
raw_connect sched addr port >>= fun fd ->
P2p_socket.authenticate
~canceler
~proof_of_work_target
~incoming:false fd
(addr, port) id versions conn_meta_config >>=? fun (info, auth_fd) ->
......@@ -197,7 +201,7 @@ module Kick = struct
let client _ch sched addr port =
connect sched addr port id2 >>=? fun auth_fd ->
P2p_socket.accept auth_fd encoding >>= fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>= fun conn ->
_assert (is_rejected conn) __LOC__ "" >>=? fun () ->
return_unit
......@@ -211,7 +215,7 @@ module Kicked = struct
let server _ch sched socket =
accept sched socket >>=? fun (_info, auth_fd) ->
P2p_socket.accept auth_fd encoding >>= fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>= fun conn ->
_assert (Kick.is_rejected conn) __LOC__ "" >>=? fun () ->
return_unit
......@@ -233,7 +237,7 @@ module Simple_message = struct
let server ch sched socket =
accept sched socket >>=? fun (_info, auth_fd) ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
P2p_socket.write_sync conn simple_msg >>=? fun () ->
P2p_socket.read conn >>=? fun (_msg_size, msg) ->
_assert (MBytes.compare simple_msg2 msg = 0) __LOC__ "" >>=? fun () ->
......@@ -243,7 +247,7 @@ module Simple_message = struct
let client ch sched addr port =
connect sched addr port id2 >>=? fun auth_fd ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
P2p_socket.write_sync conn simple_msg2 >>=? fun () ->
P2p_socket.read conn >>=? fun (_msg_size, msg) ->
_assert (MBytes.compare simple_msg msg = 0) __LOC__ "" >>=? fun () ->
......@@ -265,6 +269,7 @@ module Chunked_message = struct
let server ch sched socket =
accept sched socket >>=? fun (_info, auth_fd) ->
P2p_socket.accept
~canceler
~binary_chunks_size:21 auth_fd encoding >>=? fun conn ->
P2p_socket.write_sync conn simple_msg >>=? fun () ->
P2p_socket.read conn >>=? fun (_msg_size, msg) ->
......@@ -276,6 +281,7 @@ module Chunked_message = struct
let client ch sched addr port =
connect sched addr port id2 >>=? fun auth_fd ->
P2p_socket.accept
~canceler
~binary_chunks_size:21 auth_fd encoding >>=? fun conn ->
P2p_socket.write_sync conn simple_msg2 >>=? fun () ->
P2p_socket.read conn >>=? fun (_msg_size, msg) ->
......@@ -297,7 +303,7 @@ module Oversized_message = struct
let server ch sched socket =
accept sched socket >>=? fun (_info, auth_fd) ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
P2p_socket.write_sync conn simple_msg >>=? fun () ->
P2p_socket.read conn >>=? fun (_msg_size, msg) ->
_assert (MBytes.compare simple_msg2 msg = 0) __LOC__ "" >>=? fun () ->
......@@ -307,7 +313,7 @@ module Oversized_message = struct
let client ch sched addr port =
connect sched addr port id2 >>=? fun auth_fd ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
P2p_socket.write_sync conn simple_msg2 >>=? fun () ->
P2p_socket.read conn >>=? fun (_msg_size, msg) ->
_assert (MBytes.compare simple_msg msg = 0) __LOC__ "" >>=? fun () ->
......@@ -327,14 +333,14 @@ module Close_on_read = struct
let server ch sched socket =
accept sched socket >>=? fun (_info, auth_fd) ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
sync ch >>=? fun () ->
P2p_socket.close conn >>= fun _stat ->
return_unit
let client ch sched addr port =
connect sched addr port id2 >>=? fun auth_fd ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
sync ch >>=? fun () ->
P2p_socket.read conn >>= fun err ->
_assert (is_connection_closed err) __LOC__ "" >>=? fun () ->
......@@ -353,14 +359,14 @@ module Close_on_write = struct
let server ch sched socket =
accept sched socket >>=? fun (_info, auth_fd) ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
P2p_socket.close conn >>= fun _stat ->
sync ch >>=? fun ()->
return_unit
let client ch sched addr port =
connect sched addr port id2 >>=? fun auth_fd ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
sync ch >>=? fun ()->
Lwt_unix.sleep 0.1 >>= fun () ->
P2p_socket.write_sync conn simple_msg >>= fun err ->
......@@ -390,7 +396,7 @@ module Garbled_data = struct
let server _ch sched socket =
accept sched socket >>=? fun (_info, auth_fd) ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
P2p_socket.raw_write_sync conn garbled_msg >>=? fun () ->
P2p_socket.read conn >>= fun err ->
_assert (is_connection_closed err) __LOC__ "" >>=? fun () ->
......@@ -399,7 +405,7 @@ module Garbled_data = struct
let client _ch sched addr port =
connect sched addr port id2 >>=? fun auth_fd ->
P2p_socket.accept auth_fd encoding >>=? fun conn ->
P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn ->
P2p_socket.read conn >>= fun err ->
_assert (is_decoding_error err) __LOC__ "" >>=? fun () ->
P2p_socket.close conn >>= fun _stat ->
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment