diff --git a/openvpn/client/cliconnect.hpp b/openvpn/client/cliconnect.hpp index 6981e52bf..5ca834963 100644 --- a/openvpn/client/cliconnect.hpp +++ b/openvpn/client/cliconnect.hpp @@ -252,7 +252,7 @@ class ClientConnect : ClientProto::NotifyCallback, void post_cc_msg(const std::string &msg) { if (!halt && client) - client->post_cc_msg(msg); + client->validate_and_post_cc_msg(msg); } void thread_safe_post_cc_msg(std::string msg) diff --git a/openvpn/client/cliproto.hpp b/openvpn/client/cliproto.hpp index a50c1272a..c94fb153e 100644 --- a/openvpn/client/cliproto.hpp +++ b/openvpn/client/cliproto.hpp @@ -59,6 +59,7 @@ using namespace std::chrono_literals; #include #include +#include #ifdef OPENVPN_DEBUG_CLIPROTO #define OPENVPN_LOG_CLIPROTO(x) OPENVPN_LOG(x) @@ -221,6 +222,29 @@ class Session : ProtoContextCallbackInterface, tun->set_disconnect(); } + /** + * Posts a control message from the client API. To ensure the client that will always send + * valid message (e.g. no extra newlines or invalid) character this method will first check the + * message for validity before sending it to the control channel + * @param msg control channel message + */ + void validate_and_post_cc_msg(const std::string &msg) + { + if (!Unicode::is_valid_utf8(msg, Unicode::UTF8_NO_CTRL)) + { + ClientEvent::Base::Ptr ev = new ClientEvent::UnsupportedFeature{"Invalid chars in control message", "Control channel message with invalid characters not allowed to be send with post_cc_msg", false}; + cli_events->add_event(std::move(ev)); + return; + } + post_cc_msg(msg); + } + + /** + * Post a control message to the control channel. This only intended to be used by consumers that + * either validated the message itself beforehand or construct a message in a way that it is + * always valid. + * @param msg The message to send on the control channel. + */ void post_cc_msg(const std::string &msg) { proto_context.update_now(); diff --git a/test/unittests/test_proto.cpp b/test/unittests/test_proto.cpp index d75b15ee3..ad2638e87 100644 --- a/test/unittests/test_proto.cpp +++ b/test/unittests/test_proto.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #define OPENVPN_DEBUG @@ -42,7 +43,6 @@ #ifndef BF #define BF 0 #endif -#define OPENVPN_BS64_DATA_LIMIT 50000 #if BF == 1 #define PROTO_CIPHER "BF-CBC" #define TLS_VER_MIN TLSVersion::UNDEF @@ -855,136 +855,155 @@ class MySessionStats : public SessionStats count_t errors[Error::N_ERRORS]; }; -// execute the unit test in one thread -int test(const int thread_num, bool use_tls_ekm) +/** + * Create a client ssl config for testing. + * @return + */ +static auto create_client_ssl_config(Frame::Ptr frame, ClientRandomAPI::Ptr rng) { - try - { - // frame - Frame::Ptr frame(new Frame(Frame::Context(128, 378, 128, 0, 16, 0))); - - // RNG - ClientRandomAPI::Ptr rng_cli(new ClientRandomAPI()); - ClientRandomAPI::Ptr prng_cli(new ClientRandomAPI()); - ServerRandomAPI::Ptr rng_serv(new ServerRandomAPI()); - ServerRandomAPI::Ptr prng_serv(new ServerRandomAPI()); - MTRand rng_noncrypto; - - // init simulated time - Time time; - const Time::Duration time_step = Time::Duration::binary_ms(100); - - // client config files - const std::string ca_crt = read_text(TEST_KEYCERT_DIR "ca.crt"); - const std::string client_crt = read_text(TEST_KEYCERT_DIR "client.crt"); - const std::string client_key = read_text(TEST_KEYCERT_DIR "client.key"); - const std::string server_crt = read_text(TEST_KEYCERT_DIR "server.crt"); - const std::string server_key = read_text(TEST_KEYCERT_DIR "server.key"); - const std::string dh_pem = read_text(TEST_KEYCERT_DIR "dh.pem"); - const std::string tls_auth_key = read_text(TEST_KEYCERT_DIR "tls-auth.key"); - const std::string tls_crypt_v2_server_key = read_text(TEST_KEYCERT_DIR "tls-crypt-v2-server.key"); - const std::string tls_crypt_v2_client_key = read_text(TEST_KEYCERT_DIR "tls-crypt-v2-client.key"); - - // client config - ClientSSLAPI::Config::Ptr cc(new ClientSSLAPI::Config()); - cc->set_mode(Mode(Mode::CLIENT)); - cc->set_frame(frame); - cc->set_rng(rng_cli); + const std::string client_crt = read_text(TEST_KEYCERT_DIR "client.crt"); + const std::string client_key = read_text(TEST_KEYCERT_DIR "client.key"); + const std::string ca_crt = read_text(TEST_KEYCERT_DIR "ca.crt"); + + // client config + ClientSSLAPI::Config::Ptr cc(new ClientSSLAPI::Config()); + cc->set_mode(Mode(Mode::CLIENT)); + cc->set_frame(frame); + cc->set_rng(rng); #ifdef USE_APPLE_SSL - cc->load_identity("etest"); + cc->load_identity("etest"); #else - cc->load_ca(ca_crt, true); - cc->load_cert(client_crt); - cc->load_private_key(client_key); + cc->load_ca(ca_crt, true); + cc->load_cert(client_crt); + cc->load_private_key(client_key); #endif - cc->set_tls_version_min(TLS_VER_MIN); + cc->set_tls_version_min(TLS_VER_MIN); #ifdef VERBOSE - cc->set_debug_level(1); + cc->set_debug_level(1); #endif + return cc; +} - // stats - MySessionStats::Ptr cli_stats(new MySessionStats); - MySessionStats::Ptr serv_stats(new MySessionStats); +static auto create_client_proto_context(ClientSSLAPI::Config::Ptr cc, Frame::Ptr frame, ClientRandomAPI::Ptr rng, MySessionStats::Ptr cli_stats, Time &time) - // client ProtoContext config - typedef ProtoContext ClientProtoContext; - ClientProtoContext::ProtoConfig::Ptr cp(new ClientProtoContext::ProtoConfig); - cp->ssl_factory = cc->new_factory(); - CryptoAlgs::allow_default_dc_algs(cp->ssl_factory->libctx(), false, false); - cp->dc.set_factory(new CryptoDCSelect(cp->ssl_factory->libctx(), frame, cli_stats, prng_cli)); - cp->tlsprf_factory.reset(new CryptoTLSPRFFactory()); - cp->frame = frame; - cp->now = &time; - cp->rng = rng_cli; - cp->prng = prng_cli; - cp->protocol = Protocol(Protocol::UDPv4); - cp->layer = Layer(Layer::OSI_LAYER_3); +{ + const std::string tls_crypt_v2_client_key = read_text(TEST_KEYCERT_DIR "tls-crypt-v2-client.key"); + + // client ProtoContext config + typedef ProtoContext ClientProtoContext; + ClientProtoContext::ProtoConfig::Ptr cp(new ClientProtoContext::ProtoConfig); + cp->ssl_factory = cc->new_factory(); + CryptoAlgs::allow_default_dc_algs(cp->ssl_factory->libctx(), false, false); + cp->dc.set_factory(new CryptoDCSelect(cp->ssl_factory->libctx(), frame, cli_stats, rng)); + cp->tlsprf_factory.reset(new CryptoTLSPRFFactory()); + cp->frame = std::move(frame); + cp->now = &time; + cp->rng = rng; + cp->prng = rng; + cp->protocol = Protocol(Protocol::UDPv4); + cp->layer = Layer(Layer::OSI_LAYER_3); #ifdef PROTOv2 - cp->enable_op32 = true; - cp->remote_peer_id = 100; + cp->enable_op32 = true; + cp->remote_peer_id = 100; #endif - cp->comp_ctx = CompressContext(COMP_METH, false); - cp->dc.set_cipher(CryptoAlgs::lookup(PROTO_CIPHER)); - cp->dc.set_digest(CryptoAlgs::lookup(PROTO_DIGEST)); - if (use_tls_ekm) - cp->dc.set_key_derivation(CryptoAlgs::KeyDerivation::TLS_EKM); + cp->comp_ctx = CompressContext(COMP_METH, false); + cp->dc.set_cipher(CryptoAlgs::lookup(PROTO_CIPHER)); + cp->dc.set_digest(CryptoAlgs::lookup(PROTO_DIGEST)); + #ifdef USE_TLS_AUTH - cp->tls_auth_factory.reset(new CryptoOvpnHMACFactory()); - cp->tls_key.parse(tls_auth_key); - cp->set_tls_auth_digest(CryptoAlgs::lookup(PROTO_DIGEST)); - cp->key_direction = 0; + cp->tls_auth_factory.reset(new CryptoOvpnHMACFactory()); + cp->tls_key.parse(tls_auth_key); + cp->set_tls_auth_digest(CryptoAlgs::lookup(PROTO_DIGEST)); + cp->key_direction = 0; #endif #ifdef USE_TLS_CRYPT - cp->tls_crypt_factory.reset(new CryptoTLSCryptFactory()); - cp->tls_key.parse(tls_auth_key); - cp->set_tls_crypt_algs(); - cp->tls_crypt_ = ClientProtoContext::Config::TLSCrypt::V1; + cp->tls_crypt_factory.reset(new CryptoTLSCryptFactory()); + cp->tls_key.parse(tls_auth_key); + cp->set_tls_crypt_algs(); + cp->tls_crypt_ = ProtoContext::Config::TLSCrypt::V1; #endif #ifdef USE_TLS_CRYPT_V2 - cp->tls_crypt_factory.reset(new CryptoTLSCryptFactory()); - cp->set_tls_crypt_algs(); - { - TLSCryptV2ClientKey tls_crypt_v2_key(cp->tls_crypt_context); - tls_crypt_v2_key.parse(tls_crypt_v2_client_key); - tls_crypt_v2_key.extract_key(cp->tls_key); - tls_crypt_v2_key.extract_wkc(cp->wkc); - } - cp->tls_crypt_ = ClientProtoContext::ProtoConfig::TLSCrypt::V2; + cp->tls_crypt_factory.reset(new CryptoTLSCryptFactory()); + cp->set_tls_crypt_algs(); + { + TLSCryptV2ClientKey tls_crypt_v2_key(cp->tls_crypt_context); + tls_crypt_v2_key.parse(tls_crypt_v2_client_key); + tls_crypt_v2_key.extract_key(cp->tls_key); + tls_crypt_v2_key.extract_wkc(cp->wkc); + } + cp->tls_crypt_ = ProtoContext::ProtoConfig::TLSCrypt::V2; #endif #if defined(HANDSHAKE_WINDOW) - cp->handshake_window = Time::Duration::seconds(HANDSHAKE_WINDOW); + cp->handshake_window = Time::Duration::seconds(HANDSHAKE_WINDOW); #elif SITER > 1 - cp->handshake_window = Time::Duration::seconds(30); + cp->handshake_window = Time::Duration::seconds(30); #else - cp->handshake_window = Time::Duration::seconds(18); // will cause a small number of handshake failures + cp->handshake_window = Time::Duration::seconds(18); // will cause a small number of handshake failures #endif #ifdef BECOME_PRIMARY_CLIENT - cp->become_primary = Time::Duration::seconds(BECOME_PRIMARY_CLIENT); + cp->become_primary = Time::Duration::seconds(BECOME_PRIMARY_CLIENT); #else - cp->become_primary = cp->handshake_window; + cp->become_primary = cp->handshake_window; #endif - cp->tls_timeout = Time::Duration::milliseconds(TLS_TIMEOUT_CLIENT); + cp->tls_timeout = Time::Duration::milliseconds(TLS_TIMEOUT_CLIENT); #if defined(CLIENT_NO_RENEG) - cp->renegotiate = Time::Duration::infinite(); + cp->renegotiate = Time::Duration::infinite(); #else - cp->renegotiate = Time::Duration::seconds(RENEG); + cp->renegotiate = Time::Duration::seconds(RENEG); #endif - cp->expire = cp->renegotiate + cp->renegotiate; - cp->keepalive_ping = Time::Duration::seconds(5); - cp->keepalive_timeout = Time::Duration::seconds(60); - cp->keepalive_timeout_early = cp->keepalive_timeout; + cp->expire = cp->renegotiate + cp->renegotiate; + cp->keepalive_ping = Time::Duration::seconds(5); + cp->keepalive_timeout = Time::Duration::seconds(60); + cp->keepalive_timeout_early = cp->keepalive_timeout; #ifdef VERBOSE - std::cout << "CLIENT OPTIONS: " << cp->options_string() << std::endl; - std::cout << "CLIENT PEER INFO:" << std::endl; - std::cout << cp->peer_info_string(); + std::cout << "CLIENT OPTIONS: " << cp->options_string() << std::endl; + std::cout << "CLIENT PEER INFO:" << std::endl; + std::cout << cp->peer_info_string(); #endif + return cp; +} + +// execute the unit test in one thread +int test(const int thread_num, bool use_tls_ekm) +{ + try + { + // frame + Frame::Ptr frame(new Frame(Frame::Context(128, 378, 128, 0, 16, 0))); + + // RNG + ClientRandomAPI::Ptr prng_cli(new ClientRandomAPI()); + ServerRandomAPI::Ptr prng_serv(new ServerRandomAPI()); + MTRand rng_noncrypto; + + // init simulated time + Time time; + const Time::Duration time_step = Time::Duration::binary_ms(100); + + // config files + const std::string ca_crt = read_text(TEST_KEYCERT_DIR "ca.crt"); + const std::string server_crt = read_text(TEST_KEYCERT_DIR "server.crt"); + const std::string server_key = read_text(TEST_KEYCERT_DIR "server.key"); + const std::string dh_pem = read_text(TEST_KEYCERT_DIR "dh.pem"); + const std::string tls_auth_key = read_text(TEST_KEYCERT_DIR "tls-auth.key"); + const std::string tls_crypt_v2_server_key = read_text(TEST_KEYCERT_DIR "tls-crypt-v2-server.key"); + + // client config + ClientSSLAPI::Config::Ptr cc = create_client_ssl_config(frame, prng_cli); + MySessionStats::Ptr cli_stats(new MySessionStats); + + auto cp = create_client_proto_context(std::move(cc), frame, prng_cli, cli_stats, time); + if (use_tls_ekm) + cp->dc.set_key_derivation(CryptoAlgs::KeyDerivation::TLS_EKM); // server config - ClientSSLAPI::Config::Ptr sc(new ClientSSLAPI::Config()); + MySessionStats::Ptr serv_stats(new MySessionStats); + + ServerSSLAPI::Config::Ptr sc(new ClientSSLAPI::Config()); sc->set_mode(Mode(Mode::SERVER)); sc->set_frame(frame); - sc->set_rng(rng_serv); + sc->set_rng(prng_serv); sc->load_ca(ca_crt, true); sc->load_cert(server_crt); sc->load_private_key(server_key); @@ -1002,7 +1021,7 @@ int test(const int thread_num, bool use_tls_ekm) sp->tlsprf_factory.reset(new CryptoTLSPRFFactory()); sp->frame = frame; sp->now = &time; - sp->rng = rng_serv; + sp->rng = prng_serv; sp->prng = prng_serv; sp->protocol = Protocol(Protocol::UDPv4); sp->layer = Layer(Layer::OSI_LAYER_3); @@ -1036,7 +1055,7 @@ int test(const int thread_num, bool use_tls_ekm) } sp->set_tls_crypt_algs(); sp->tls_crypt_metadata_factory.reset(new CryptoTLSCryptMetadataFactory()); - sp->tls_crypt_ = ClientProtoContext::ProtoConfig::TLSCrypt::V2; + sp->tls_crypt_ = ProtoContext::ProtoConfig::TLSCrypt::V2; #endif #if defined(HANDSHAKE_WINDOW) sp->handshake_window = Time::Duration::seconds(HANDSHAKE_WINDOW); @@ -1371,3 +1390,60 @@ TEST(proto, controlmessage_invalidchar) EXPECT_EQ(msg6, ""); EXPECT_TRUE(Unicode::is_valid_utf8(msg5, Unicode::UTF8_NO_CTRL)); } + +class MockCallback : public openvpn::ClientProto::NotifyCallback +{ + void client_proto_terminate() + { + } +}; + +class EventQueueVector : public openvpn::ClientEvent::Queue +{ + public: + void add_event(openvpn::ClientEvent::Base::Ptr event) override + { + events.push_back(event); + } + + std::vector events; +}; + +TEST(proto, client_proto_check_cc_msg) +{ + asio::io_context io_context; + ClientRandomAPI::Ptr rng_cli(new ClientRandomAPI()); + Frame::Ptr frame(new Frame(Frame::Context(128, 378, 128, 0, 16, 0))); + MySessionStats::Ptr cli_stats(new MySessionStats); + Time time; + + openvpn::ClientEvent::Queue::Ptr eqv_ptr = new EventQueueVector{}; + /* keep a reference to the right class to avoid repeated casted */ + EventQueueVector *eqv = dynamic_cast(eqv_ptr.get()); + /* check that the cast worked */ + ASSERT_TRUE(eqv); + + MockCallback mockCB; + openvpn::ClientProto::Session::Config clisessconf{}; + clisessconf.proto_context_config = create_client_proto_context(create_client_ssl_config(frame, rng_cli), + frame, + rng_cli, + std::move(cli_stats), + time); + clisessconf.cli_events = std::move(eqv_ptr); + openvpn::ClientProto::Session::Ptr clisession = new ClientProto::Session{io_context, clisessconf, &mockCB}; + + clisession->validate_and_post_cc_msg("valid message"); + + + EXPECT_TRUE(eqv->events.empty()); + + clisession->validate_and_post_cc_msg("invalid\nmessage"); + EXPECT_EQ(eqv->events.size(), 1); + auto ev = eqv->events.back(); + auto uf = dynamic_cast(ev.get()); + /* check that the cast worked */ + ASSERT_TRUE(uf); + EXPECT_EQ(uf->name, "Invalid chars in control message"); + EXPECT_EQ(uf->reason, "Control channel message with invalid characters not allowed to be send with post_cc_msg"); +}