...
 
Commits (5)
......@@ -7,7 +7,7 @@ module String = Base.String
(* Define a function to handle message_create *)
let check_command (message:Message.t) =
(* Simple example of command parsing. *)
let cmd, rest = match String.split ~on:' ' message.content with
let cmd, rest = match String.split ~on:' ' (String.lowercase message.content) with
| hd::tl -> hd, tl
| [] -> "", []
in match cmd with
......@@ -19,6 +19,7 @@ let check_command (message:Message.t) =
| "!echo" -> Commands.echo message rest
(* | "!cache" -> Commands.cache message rest *)
| "!shutdown" -> Commands.shutdown message rest
| "!restart" -> Commands.restart message rest
| "!rgm" -> Commands.request_members message rest
(* | "!new" -> Commands.new_guild message rest *)
(* | "!delall" -> Commands.delete_guilds message rest *)
......
......@@ -13,9 +13,11 @@ let client, r_client = Lwt.wait ()
(* Example ping command with REST round trip time edited into the response. *)
let ping message _args =
Message.reply message "Pong!" >>= function
| Ok _message -> Lwt.return_unit
(* let diff = Time.diff (Time.now ()) (Time.of_string message.timestamp) in
Message.set_content message (Printf.sprintf "Pong! `%d ms`" (Time.Span.to_ms diff |> Float.abs |> Float.to_int)) *)
| Ok message' ->
let `Message_id t, `Message_id t' = message.id, message'.id in
let time = Snowflake.(timestamp t' - timestamp t) |> abs in
Message.set_content message' (Printf.sprintf "Pong! `%d ms`" time)
>|= ignore
| Error e -> Error.(of_string e |> raise)
(* Send a list of consecutive integers of N size with 1 message per list item. *)
......@@ -108,6 +110,11 @@ let shutdown (message:Message.t) _args =
exit 0
else Lwt.return_unit
let restart (message:Message.t) _args =
if message.author.id = `User_id 242675474927583232 then
client >>= Client.shutdown_all
else Lwt.return_unit
(* Request guild members to be sent over the gateway for the guild the command is run in. This will cause multiple GUILD_MEMBERS_CHUNK events. *)
let request_members (message:Message.t) _args =
client >>= fun client ->
......
......@@ -19,7 +19,7 @@ let decompress src =
Zlib_inflate.bytes in_buf out_buf
(fun dst ->
let len = min 0xFFFF (src_len - !pos) in
Caml.Bytes.blit_string src !pos dst 0 len;
Bytes.blit_string src !pos dst 0 len;
pos := !pos + len;
len)
(fun obuf len ->
......@@ -34,6 +34,7 @@ module Shard = struct
{ compress: bool
; hb_interval: int Lwt.t * int Lwt.u
; hb_stopper: Lwt_engine.event option
; hb_acked: bool
; id: int
; large_threshold: int
; ready: unit Lwt.t * unit Lwt.u
......@@ -60,7 +61,7 @@ module Shard = struct
| Binary ->
if compress then `Ok (decompress frame.content |> Yojson.Safe.from_string)
else `Error "Failed to decompress"
| Close -> `Close frame
| Close -> `Close (String.sub frame.content 0 2)
| op ->
let op = Frame.Opcode.to_string op in
`Error ("Unexpected opcode " ^ op)
......@@ -82,8 +83,12 @@ module Shard = struct
match shard.seq with
| 0 -> Lwt.return shard
| i ->
if not shard.hb_acked then
shard.send (Frame.close 1001) >|= fun () -> shard
else
Logs_lwt.info (fun m -> m "Heartbeating - Shard: [%d, %d] - Seq: %d" shard.id shard.shard_count shard.seq) >>= fun () ->
push_frame ~payload:(`Int i) ~ev:HEARTBEAT shard
push_frame ~payload:(`Int i) ~ev:HEARTBEAT shard >|= fun shard ->
{ shard with hb_acked = false }
let dispatch ~payload shard =
let module J = Yojson.Safe.Util in
......@@ -179,7 +184,7 @@ module Shard = struct
end
| RECONNECT -> initialize shard
| HELLO -> initialize ~data:(J.member "d" f) shard
| HEARTBEAT_ACK -> Lwt.return shard
| HEARTBEAT_ACK -> Lwt.return { shard with hb_acked = true }
| opcode ->
Logs_lwt.warn (fun m -> m "Invalid Opcode: %s" (Opcode.to_string opcode)) >|= fun () ->
shard
......@@ -198,6 +203,7 @@ module Shard = struct
{ compress
; hb_interval = Lwt.wait ()
; hb_stopper = None
; hb_acked = true
; id = fst shards
; large_threshold
; ready = Lwt.wait ()
......@@ -210,13 +216,18 @@ module Shard = struct
}
let shutdown ?(clean=false) ?(restart=true) t =
let _ = clean in
t.can_resume <- restart;
Lwt.wakeup_later (snd t.stop) ();
Logs_lwt.info (fun m -> m "Performing shutdown. Shard [%d, %d]" t.state.id t.state.shard_count) >>= fun () ->
t.state.send (Frame.close 1001) >|= fun () ->
Option.map t.state.hb_stopper ~f:(fun ev -> Lwt_engine.stop_event ev)
|> ignore
if clean then
let re = if restart then "restart" else "shutdown" in
Lwt.wakeup_later (snd t.stop) ();
Logs_lwt.info (fun m -> m "Performing clean %s. Shard [%d, %d]" re t.state.id t.state.shard_count) >>= fun () ->
t.state.send (Frame.close 1000) >|= fun () ->
Option.map t.state.hb_stopper ~f:(fun ev -> Lwt_engine.stop_event ev) |> ignore
else
let re = if restart then "restarting..." else "shutting down." in
Logs_lwt.info (fun m -> m "Shard closed unexpectedly, %s Shard [%d, %d]" re t.state.id t.state.shard_count) >>= fun () ->
t.state.send (Frame.close 1001) >|= fun () ->
Option.map t.state.hb_stopper ~f:(fun ev -> Lwt_engine.stop_event ev) |> ignore
end
type t = { shards: (Shard.shard Shard.t) list }
......@@ -243,7 +254,7 @@ let start ?count ?compress ?large_threshold () =
Shard.handle_frame ~f t.state >|= fun s ->
t.state <- s
| `Close c ->
Logs_lwt.warn (fun m -> m "Close frame received. %s" (Frame.show c)) >>= fun () ->
Logs_lwt.warn (fun m -> m "Close frame received. %s" c) >>= fun () ->
Shard.shutdown t
| `Error e ->
Logs_lwt.warn (fun m -> m "Websocket soft error: %s" e) >>= fun () ->
......
......@@ -12,6 +12,7 @@ module Shard : sig
{ compress: bool (** Whether to compress payloads. *)
; hb_interval: int Lwt.t * int Lwt.u (** Time between heartbeats. Not known until HELLO is received. *)
; hb_stopper: Lwt_engine.event option (** Used to cancel heartbeat sequencer *)
; hb_acked: bool (** Whether the last heartbeat was acked. Missing an ack will reconnect the shard. *)
; id: int (** ID of the current shard. Must be less than shard_count. *)
; large_threshold: int (** Minimum number of members needed for a guild to be considered large. *)
; ready: unit Lwt.t * unit Lwt.u (** A simple promise indicating if the shard has received READY. *)
......
......@@ -41,7 +41,6 @@ module Base = struct
Lwt_result.fail @@ Printf.sprintf "Unsuccessful response received: %d - %s" code body
let request ?(body=`Null) ?(query=[]) m path =
Logs_lwt.info (fun m -> m "Making HTTP request. Path: %s" path) >>= fun () ->
let limit, rlm = Rl.get_rl m path !rl in
rl := rlm;
Lwt_mvar.take limit >>= fun limit ->
......@@ -58,10 +57,11 @@ module Base = struct
>>= process_response path
in if limit.remaining > 0 then process ()
else
(* let time = Time.(Span.of_int_sec limit.reset |> of_span_since_epoch) in
Logs.debug (fun m -> m "Rate-limiting [Route: %s] [Duration: %d ms]" path Time.(diff time (Time.now ()) |> Span.to_ms |> Float.to_int) );
Clock.at time >>= process *)
process ()
let time = float_of_int limit.reset -. Unix.time () in
Logs_lwt.info (fun m -> m
"Rate-limiting [Route: %s] [Duration: %f s]"
path time) >>= fun () ->
Lwt_unix.sleep time >>= process
end
let r_map f = function
......