Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cable::Server#active_connections_for and Cable::Server#subscribed_channels_for public checkup methods #86

20 changes: 15 additions & 5 deletions spec/cable/connection_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@ require "../spec_helper"
include RequestHelpers

describe Cable::Connection do
it "removes the connection channel on close" do
connect do |connection, _socket|
connection.receive({"command" => "subscribe", "identifier" => {channel: "ChatChannel", room: "1"}.to_json}.to_json)
ConnectionTest::CHANNELS.keys.size.should eq(1)
describe "#close" do
it "closes the connection socket even without channel subscriptions" do
connect do |connection, _socket|
connection.closed?.should eq(false)
connection.close
connection.closed?.should eq(true)
end
end
it "removes the connection channel on close" do
connect do |connection, _socket|
connection.receive({"command" => "subscribe", "identifier" => {channel: "ChatChannel", room: "1"}.to_json}.to_json)
ConnectionTest::CHANNELS.keys.size.should eq(1)
connection.close
ConnectionTest::CHANNELS.keys.size.should eq(0)
end
end
ConnectionTest::CHANNELS.keys.size.should eq(0)
end

describe "#receive" do
Expand Down
4 changes: 4 additions & 0 deletions spec/cable/handler_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ describe Cable::Handler do
ws2 = HTTP::WebSocket.new("ws://#{listen_address}/updates?test_token=1")

Cable.server.connections.size.should eq(1)
Cable.server.active_connections_for("1").size.should eq(1)
Cable.server.subscribed_channels_for("1").size.should eq(0)

messages = [
{type: "welcome"}.to_json,
Expand All @@ -69,6 +71,8 @@ describe Cable::Handler do
ws2.run

Cable.server.connections.size.should eq(1)
Cable.server.active_connections_for("1").size.should eq(1)
Cable.server.subscribed_channels_for("1").size.should eq(1)
end

it "malformed data from client" do
Expand Down
83 changes: 80 additions & 3 deletions spec/cable/server_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ describe Cable::Server do
it "finds the connection and disconnects it" do
Cable.reset_server
Cable.temp_config(backend_class: Cable::DevBackend) do
socket = DummySocket.new(IO::Memory.new)
request = builds_request("abc123")
connection = ApplicationCable::Connection.new(request, socket)
connection = creates_new_connection("abc123")
Cable.server.add_connection(connection)
connection.connection_identifier.should contain("abc123")

Expand All @@ -19,4 +17,83 @@ describe Cable::Server do
end
end
end

describe "#active_connections_for" do
it "accurately returns active connections for a specificic token" do
Cable.reset_server
Cable.temp_config(backend_class: Cable::DevBackend) do
Cable.server.active_connections_for("abc123").size.should eq(0)
Cable.server.active_connections_for("def456").size.should eq(0)

connection = creates_new_connection("abc123")
Cable.server.add_connection(connection)

Cable.server.active_connections_for("abc123").size.should eq(1)

other_connection = creates_new_connection("def456")
Cable.server.add_connection(other_connection)

Cable.server.active_connections_for("def456").size.should eq(1)

connection.close

Cable.server.active_connections_for("abc123").size.should eq(0)
Cable.server.active_connections_for("def456").size.should eq(1)

other_connection.close

Cable.server.active_connections_for("def456").size.should eq(0)
end
end
end

describe "#subscribed_channels_for" do
it "accurately returns active channel subscriptions for a specificic token" do
Cable.reset_server
Cable.temp_config(backend_class: Cable::DevBackend) do
connection_1 = creates_new_connection("aa")
connection_2 = creates_new_connection("bb")

Cable.server.add_connection(connection_1)
Cable.server.add_connection(connection_2)

Cable.server.subscribed_channels_for("aa").size.should eq(0)
Cable.server.subscribed_channels_for("bb").size.should eq(0)

connection_1.subscribe(subscribe_payload("room_a"))

Cable.server.subscribed_channels_for("aa").size.should eq(1)
Cable.server.subscribed_channels_for("bb").size.should eq(0)

connection_1.subscribe(subscribe_payload("room_b"))

Cable.server.subscribed_channels_for("aa").size.should eq(2)
Cable.server.subscribed_channels_for("bb").size.should eq(0)

connection_2.subscribe(subscribe_payload("room_a"))

Cable.server.subscribed_channels_for("aa").size.should eq(2)
Cable.server.subscribed_channels_for("bb").size.should eq(1)

connection_1.close
connection_2.close
end
Cable.reset_server
end
end
end

def creates_new_connection(token : String | Nil) : ApplicationCable::Connection
ApplicationCable::Connection.new(builds_request(token: token), DummySocket.new(IO::Memory.new))
end

def subscribe_payload(room : String) : Cable::Payload
payload_json = {
command: "subscribe",
identifier: {
channel: "ChatChannel",
room: room,
}.to_json,
}.to_json
Cable::Payload.from_json(payload_json)
end
51 changes: 5 additions & 46 deletions src/cable/channel.cr
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ module Cable
@stream_identifier = stream_identifier.to_s
end

def self.broadcast_to(channel : String, message : JSON::Any)
def self.broadcast_to(channel : String, message : JSON::Any | Hash(String, String))
Cable::Logger.info { "[ActionCable] Broadcasting to #{channel}: #{message}" }
Cable.server.publish(channel, message.to_json)
end
Expand All @@ -70,60 +70,19 @@ module Cable
Cable.server.publish(channel, message)
end

def self.broadcast_to(channel : String, message : Hash(String, String))
Cable::Logger.info { "[ActionCable] Broadcasting to #{channel}: #{message}" }
Cable.server.publish(channel, message.to_json)
end

def broadcast(message : String)
if stream_id = stream_identifier.presence
Cable::Logger.info { "[ActionCable] Broadcasting to #{self.class}: #{message}" }
Cable.server.send_to_channels(stream_id, message)
else
Cable::Logger.error { "#{self.class}.transmit(message : String) with #{message} without already using stream_from(stream_identifier)" }
end
end

def broadcast(message : JSON::Any)
def broadcast(message : String | JSON::Any | Hash(String, String))
if stream_id = stream_identifier.presence
Cable::Logger.info { "[ActionCable] Broadcasting to #{self.class}: #{message}" }
Cable.server.send_to_channels(stream_id, message)
else
Cable::Logger.error { "#{self.class}.transmit(message : JSON::Any) with #{message} without already using stream_from(stream_identifier)" }
Cable::Logger.error { "#{self.class}.transmit(message : #{message.class}) with #{message} without already using stream_from(stream_identifier)" }
end
end

def broadcast(message : Hash(String, String))
if stream_id = stream_identifier.presence
Cable::Logger.info { "[ActionCable] Broadcasting to #{self.class}: #{message}" }
Cable.server.send_to_channels(stream_id, message.to_json)
else
Cable::Logger.error { "#{self.class}.transmit(message : Hash(String, String)) with #{message} without already using stream_from(stream_identifier)" }
end
end

# broadcast single message to single connection for this channel
def transmit(message : String)
Cable::Logger.info { "[ActionCable] transmitting to #{self.class}: #{message}" }
connection.socket.send({
identifier: identifier,
message: Cable.server.safe_decode_message(message),
}.to_json)
end

# broadcast single message to single connection for this channel
def transmit(message : JSON::Any)
Cable::Logger.info { "[ActionCable] transmitting to #{self.class}: #{message}" }
connection.socket.send({
identifier: identifier,
message: Cable.server.safe_decode_message(message),
}.to_json)
end

# broadcast single message to single connection for this channel
def transmit(message : Hash(String, String))
def transmit(message : String | JSON::Any | Hash(String, String))
Cable::Logger.info { "[ActionCable] transmitting to #{self.class}: #{message}" }
connection.socket.send({
connection.send_message({
identifier: identifier,
message: Cable.server.safe_decode_message(message),
}.to_json)
Expand Down
35 changes: 21 additions & 14 deletions src/cable/connection.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ require "uuid"

module Cable
abstract class Connection
class UnathorizedConnectionException < Exception; end
class UnauthorizedConnectionException < Exception; end
rmarronnier marked this conversation as resolved.
Show resolved Hide resolved

property internal_identifier : String = "0"
property connection_identifier : String = ""
Expand Down Expand Up @@ -38,7 +38,7 @@ module Cable
# gather connection_identifier after the connection has gathered the id from identified_by(field)
self.connection_identifier = "#{internal_identifier}-#{UUID.random}"
subscribe_to_internal_channel
rescue e : UnathorizedConnectionException
rescue e : UnauthorizedConnectionException
reject_connection!
unsubscribe_from_internal_channel
socket.close(HTTP::WebSocket::CloseCode::NormalClosure, "Farewell")
Expand All @@ -52,36 +52,43 @@ module Cable
@connection_rejected = true
end

def channels : Array(Cable::Channel)
return Array(Cable::Channel).new unless Connection::CHANNELS.has_key?(connection_identifier)
Connection::CHANNELS.[connection_identifier].values
end

def closed? : Bool
socket.closed?
end

def close
return true unless Connection::CHANNELS.has_key?(connection_identifier)
if Connection::CHANNELS.has_key?(connection_identifier)
Comment on lines -60 to +65
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't return before closing the connection socket in the case of a connection without channel subscriptions.

Connection::CHANNELS[connection_identifier].each do |identifier, channel|
# the ordering here is important
Connection::CHANNELS[connection_identifier].delete(identifier)
channel.close
rescue e : IO::Error
Cable.settings.on_error.call(e, "IO::Error: #{e.message} -> #{self.class.name}#close")
end

Connection::CHANNELS[connection_identifier].each do |identifier, channel|
# the ordering here is important
Connection::CHANNELS[connection_identifier].delete(identifier)
channel.close
rescue e : IO::Error
Cable.settings.on_error.call(e, "IO::Error: #{e.message} -> #{self.class.name}#close")
Connection::CHANNELS.delete(connection_identifier)
unsubscribe_from_internal_channel
end

Connection::CHANNELS.delete(connection_identifier)
unsubscribe_from_internal_channel
Cable::Logger.info { "Terminating connection #{connection_identifier}" }
return true if closed?

Cable::Logger.info { "Terminating connection #{connection_identifier}" }
socket.close
end

def send_message(message : String)
return if socket.closed?
return if closed?

socket.send(message)
end

def reject_unauthorized_connection
raise UnathorizedConnectionException.new
raise UnauthorizedConnectionException.new
end

# Convert the `message` to a proper `Payload`.
Expand Down
2 changes: 1 addition & 1 deletion src/cable/handler.cr
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ module Cable
socket.close(HTTP::WebSocket::CloseCode::InvalidFramePayloadData, "Invalid message")
Cable.server.remove_connection(connection_id)
Cable.settings.on_error.call(e, "Cable::Handler#socket.on_message")
rescue e : Cable::Connection::UnathorizedConnectionException
rescue e : Cable::Connection::UnauthorizedConnectionException
# handle unauthorized connections
# no need to log them
ws_pinger.stop
Expand Down
16 changes: 15 additions & 1 deletion src/cable/server.cr
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ module Cable
connections.delete(connection_id).try(&.close)
end

# You shouldn't rely on these following two methods
# for an exhaustive array of connections and channels
# if your application can spawn more than one Cable.server instance.

# Only returns connections opened on this instance.
def active_connections_for(token : String) : Array(Connection)
connections.values.select { |connection| connection.token == token && !connection.closed? }
end

# Only returns channel subscriptions opened on this instance.
def subscribed_channels_for(token : String) : Array(Channel)
active_connections_for(token).sum(&.channels)
end

def subscribe_channel(channel : Channel, identifier : String)
@channel_mutex.synchronize do
if [email protected]_key?(identifier)
Expand Down Expand Up @@ -112,7 +126,7 @@ module Cable
@channels[channel_identifier].each do |channel|
# TODO: would be nice to have a test where we open two connections
# close one, and make sure the other one receives the message
if channel.connection.socket.closed?
if channel.connection.closed?
channel.close
else
Cable::Logger.info { "#{channel.class} transmitting #{parsed_message} (via streamed from #{channel.stream_identifier})" }
Expand Down