From 1c09cd84fb2d5e0293fad6febcd58ecbaace1755 Mon Sep 17 00:00:00 2001 From: Jackarain Date: Fri, 24 Nov 2023 15:21:25 +0800 Subject: [PATCH] Improve proxy_server, move proto detect to session class. --- proxy/include/proxy/proxy_server.hpp | 598 ++++++++++++++------------- 1 file changed, 316 insertions(+), 282 deletions(-) diff --git a/proxy/include/proxy/proxy_server.hpp b/proxy/include/proxy/proxy_server.hpp index 27f7475cd4..96608bc51e 100644 --- a/proxy/include/proxy/proxy_server.hpp +++ b/proxy/include/proxy/proxy_server.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -424,13 +425,17 @@ R"x*x*x( } public: - proxy_session(proxy_stream_type&& socket, - size_t id, std::weak_ptr server) - : m_local_socket(std::move(socket)) - , m_remote_socket(instantiate_proxy_stream( - m_local_socket.get_executor())) - , m_udp_socket(m_local_socket.get_executor()) - , m_timer(m_local_socket.get_executor()) + proxy_session( + net::any_io_executor executor, + proxy_stream_type&& socket, + size_t id, + std::weak_ptr server + ) + : m_executor(executor) + , m_local_socket(std::move(socket)) + , m_remote_socket(instantiate_proxy_stream(executor)) + , m_udp_socket(executor) + , m_timer(executor) , m_connection_id(id) , m_proxy_server(server) { @@ -442,7 +447,10 @@ R"x*x*x( if (!server) return; + // 从 server 中移除当前 session. server->remove_session(m_connection_id); + + // 打印当前 session 数量. auto num = server->num_session(); XLOG_DBG << "connection id: " @@ -458,14 +466,20 @@ R"x*x*x( if (!server) return; + // 保存 server 的参数选项. m_option = server->option(); + // 如果指定了 proxy_pass_ 参数, 则解析它, 这说明它是一个 + // 多层代理, 本服务器将会连接到下一个代理服务器. + // 所有数据将会通过本服务器转发到由 proxy_pass_ 指定的下一 + // 个代理服务器. if (!m_option.proxy_pass_.empty()) { try { m_next_proxy = - std::make_unique(m_option.proxy_pass_); + std::make_unique( + m_option.proxy_pass_); } catch (const std::exception& e) { @@ -475,29 +489,299 @@ R"x*x*x( << m_option.proxy_pass_ << ", exception: " << e.what(); + return; } } + // 保持 self 对象指针, 以防止在协程完成后 this 被销毁. auto self = shared_from_this(); - net::co_spawn(m_local_socket.get_executor(), + // 启动协议侦测协程. + net::co_spawn(m_executor, [this, self, server]() -> net::awaitable { - co_await start_proxy(); + co_await proto_detect(); + co_return; }, net::detached); } virtual void close() override { + if (m_abort) + return; + m_abort = true; boost::system::error_code ignore_ec; + + // 关闭所有 socket. m_local_socket.close(ignore_ec); m_remote_socket.close(ignore_ec); + + m_udp_socket.close(ignore_ec); + + // 取消所有定时器. + m_timer.cancel(ignore_ec); } private: + inline net::awaitable + noise_handshake(tcp_socket& socket) + { + boost::system::error_code error; + + std::vector noise = + generate_noise(nosie_injection_max_len, global_known_proto); + + XLOG_DBG << "connection id: " + << m_connection_id + << ", send noise, length: " + << noise.size(); + + // 发送 noise 消息. + co_await net::async_write( + socket, + net::buffer(noise), + net_awaitable[error]); + if (error) + { + XLOG_WARN << "connection id: " + << m_connection_id + << ", noise write error: " + << error.message(); + co_return false; + } + + // 接收客户端发过来的 noise 回应消息. + size_t len = 0; + int noise_length = -1; + int recv_length = 2; + uint8_t bufs[2]; + uint16_t fvalue = 0; + uint16_t cvalue = 0; + + while (true) + { + if (m_abort) + co_return false; + + fvalue = cvalue; + + co_await net::async_read( + socket, + net::buffer(bufs, recv_length), + net_awaitable[error]); + + if (error) + { + XLOG_WARN << "connection id: " + << m_connection_id + << ", noise read error: " + << error.message(); + + co_return false; + } + + cvalue = + static_cast(bufs[1]) | + (static_cast(bufs[0]) << 8); + + len += recv_length; + if (len == 1) + continue; + + if (len >= nosie_injection_max_len) + { + XLOG_WARN << "connection id: " + << m_connection_id + << ", noise max length reached"; + + co_return false; + } + + if (noise_length != -1) + { + recv_length = noise_length - len; + recv_length = std::min(recv_length, 2); + + if (recv_length != 0) + continue; + + XLOG_DBG << "connection id: " + << m_connection_id + << ", noise length: " + << noise_length + << ", receive completed"; + + break; + } + + noise_length = fvalue & cvalue; + if (noise_length >= nosie_injection_max_len || + noise_length < 4) + { + noise_length = -1; + continue; + } + + XLOG_DBG << "connection id: " + << m_connection_id + << ", noise length: " + << noise_length + << ", receive"; + } + + // 在完成 noise 握手后, 重新检测协议. + co_await proto_detect(false); + + co_return true; + } + + // 协议侦测协程. + inline net::awaitable proto_detect(bool noise = true) + { + auto self = shared_from_this(); + auto error = boost::system::error_code{}; + + // 从 m_local_socket 中获取 tcp::socket 对象的引用. + auto& socket = boost::variant2::get(m_local_socket); + + // 等待 read 事件以确保下面 recv 偷看数据时能有数据. + co_await socket.async_wait( + tcp_socket::wait_read, net_awaitable[error]); + if (error) + { + XLOG_WARN << "connection id: " + << m_connection_id + << ", socket.async_wait error: " + << error.message(); + co_return; + } + + // 检查协议. + auto fd = socket.native_handle(); + uint8_t detect[5] = { 0 }; + +#if defined(WIN32) || defined(__APPLE__) + auto ret = recv(fd, (char*)detect, sizeof(detect), + MSG_PEEK); +#else + auto ret = recv(fd, (void*)detect, sizeof(detect), + MSG_PEEK | MSG_NOSIGNAL | MSG_DONTWAIT); +#endif + if (ret <= 0) + { + XLOG_WARN << "connection id: " + << m_connection_id + << ", peek message return: " + << ret; + co_return; + } + + // 非安全连接检查. + if (m_option.disable_insecure_) + { + if (detect[0] != 0x16) + { + XLOG_DBG << "connection id: " + << m_connection_id + << ", insecure protocol disabled"; + co_return; + } + } + + // plain socks4/5 protocol. + if (detect[0] == 0x05 || detect[0] == 0x04) + { + if (m_option.disable_socks_) + { + XLOG_DBG << "connection id: " + << m_connection_id + << ", socks protocol disabled"; + co_return; + } + + XLOG_DBG << "connection id: " + << m_connection_id + << ", plain socks4/5 protocol"; + + // 开始启动代理协议. + co_await start_proxy(); + } + else if (detect[0] == 0x16) // http/socks proxy with ssl crypto protocol. + { + XLOG_DBG << "connection id: " + << m_connection_id + << ", ssl protocol"; + + // instantiate socks stream with ssl context. + auto ssl_socks_stream = instantiate_proxy_stream( + std::move(socket), m_ssl_context); + + // get origin ssl stream type. + ssl_stream& ssl_socket = + boost::variant2::get(ssl_socks_stream); + + // do async ssl handshake. + co_await ssl_socket.async_handshake( + net::ssl::stream_base::server, + net_awaitable[error]); + + if (error) + { + XLOG_DBG << "connection id: " + << m_connection_id + << ", ssl protocol handshake error: " + << error.message(); + co_return; + } + + // 使用 ssl_socks_stream 替换 m_local_socket. + m_local_socket = std::move(ssl_socks_stream); + + // 开始启动代理协议. + co_await start_proxy(); + } // plain http protocol. + else if (detect[0] == 0x47 || // 'G' + detect[0] == 0x50 || // 'P' + detect[0] == 0x43) // 'C' + { + if (m_option.disable_http_) + { + XLOG_DBG << "connection id: " + << m_connection_id + << ", http protocol disabled"; + co_return; + } + + XLOG_DBG << "connection id: " + << m_connection_id + << ", plain http protocol"; + + // 开始启动代理协议. + co_await start_proxy(); + } + else if (noise && m_option.scramble_) + { + // 进入噪声过滤协议, 同时返回一段噪声给客户端. + XLOG_DBG << "connection id: " + << m_connection_id + << ", noise protocol"; + + if (!co_await noise_handshake(socket)) + co_return; + } + else + { + XLOG_DBG << "connection id: " + << m_connection_id + << ", unknown protocol"; + } + + co_return; + } + inline net::awaitable start_proxy() { // read @@ -3268,6 +3552,7 @@ R"x*x*x( } private: + net::any_io_executor m_executor; proxy_stream_type m_local_socket; proxy_stream_type m_remote_socket; udp::socket m_udp_socket; @@ -3293,7 +3578,7 @@ R"x*x*x( proxy_server(const proxy_server&) = delete; proxy_server& operator=(const proxy_server&) = delete; - proxy_server(net::io_context::executor_type executor, + proxy_server(net::any_io_executor executor, const tcp::endpoint& endp, proxy_server_option opt) : m_executor(executor) , m_acceptor(executor) @@ -3351,10 +3636,10 @@ R"x*x*x( } public: - inline static std::shared_ptr make( - net::io_context::executor_type executor, + inline static std::shared_ptr + make(net::any_io_executor executor, const tcp::endpoint& endp, - proxy_server_option opt) + proxy_server_option opt) { return std::shared_ptr(new proxy_server(executor, std::cref(endp), opt)); @@ -3481,270 +3766,9 @@ R"x*x*x( } private: - inline net::awaitable - noise_process(tcp_socket socket, size_t connection_id) - { - boost::system::error_code error; - - std::vector noise = - generate_noise(nosie_injection_max_len, global_known_proto); - - XLOG_DBG << "connection id: " - << connection_id - << ", send noise, length: " - << noise.size(); - - // 发送 noise 消息. - co_await net::async_write( - socket, - net::buffer(noise), - net_awaitable[error]); - if (error) - { - XLOG_WARN << "connection id: " - << connection_id - << ", noise write error: " - << error.message(); - co_return false; - } - - // 接收客户端发过来的 noise 回应消息. - size_t len = 0; - int noise_length = -1; - int recv_length = 2; - uint8_t bufs[2]; - uint16_t fvalue = 0; - uint16_t cvalue = 0; - - while (true) - { - if (m_abort) - co_return false; - - fvalue = cvalue; - - co_await net::async_read( - socket, - net::buffer(bufs, recv_length), - net_awaitable[error]); - - if (error) { - XLOG_WARN << "connection id: " - << connection_id - << ", noise read error: " - << error.message(); - - co_return false; - } - - cvalue = - static_cast(bufs[1]) | - (static_cast(bufs[0]) << 8); - - len += recv_length; - if (len == 1) - continue; - - if (len >= nosie_injection_max_len) - { - XLOG_WARN << "connection id: " - << connection_id - << ", noise max length reached"; - - co_return false; - } - - if (noise_length != -1) - { - recv_length = noise_length - len; - recv_length = std::min(recv_length, 2); - - if (recv_length != 0) - continue; - - XLOG_DBG << "connection id: " - << connection_id - << ", noise length: " - << noise_length - << ", receive completed"; - - break; - } - - noise_length = fvalue & cvalue; - if (noise_length >= nosie_injection_max_len || - noise_length < 4) - { - noise_length = -1; - continue; - } - - XLOG_DBG << "connection id: " - << connection_id - << ", noise length: " - << noise_length - << ", receive"; - } - - // 在完成 noise 握手后, 重新检测协议. - co_await socket_detect(std::move(socket), connection_id, false); - - co_return true; - } - - inline net::awaitable - socket_detect(tcp_socket socket, size_t connection_id, bool noise = true) - { - auto self = shared_from_this(); - auto error = boost::system::error_code{}; - - // 等待 read 事件以确保下面 recv 偷看数据时能有数据. - co_await socket.async_wait( - tcp_socket::wait_read, net_awaitable[error]); - if (error) - { - XLOG_WARN << "connection id: " - << connection_id - << ", socket.async_wait error: " - << error.message(); - co_return; - } - - // 检查协议. - auto fd = socket.native_handle(); - uint8_t detect[5] = { 0 }; - -#if defined(WIN32) || defined(__APPLE__) - auto ret = recv(fd, (char*)detect, sizeof(detect), - MSG_PEEK); -#else - auto ret = recv(fd, (void*)detect, sizeof(detect), - MSG_PEEK | MSG_NOSIGNAL | MSG_DONTWAIT); -#endif - if (ret <= 0) - { - XLOG_WARN << "connection id: " - << connection_id - << ", peek message return: " - << ret; - co_return; - } - - // 非安全连接检查. - if (m_option.disable_insecure_) - { - if (detect[0] != 0x16) - { - XLOG_DBG << "connection id: " - << connection_id - << ", insecure protocol disabled"; - co_return; - } - } - - // plain socks4/5 protocol. - if (detect[0] == 0x05 || detect[0] == 0x04) - { - if (m_option.disable_socks_) - { - XLOG_DBG << "connection id: " - << connection_id - << ", socks protocol disabled"; - co_return; - } - - XLOG_DBG << "connection id: " - << connection_id - << ", socks4/5 protocol"; - - auto new_session = - std::make_shared( - instantiate_proxy_stream(std::move(socket)), - connection_id, self); - - m_clients[connection_id] = new_session; - - new_session->start(); - } - else if (detect[0] == 0x16) // http/socks proxy with ssl crypto protocol. - { - XLOG_DBG << "connection id: " - << connection_id - << ", socks/https protocol"; - - // instantiate socks stream with ssl context. - auto ssl_socks_stream = instantiate_proxy_stream( - std::move(socket), m_ssl_context); - - // get origin ssl stream type. - ssl_stream& ssl_socket = - boost::variant2::get(ssl_socks_stream); - - // do async ssl handshake. - co_await ssl_socket.async_handshake( - net::ssl::stream_base::server, - net_awaitable[error]); - if (error) - { - XLOG_DBG << "connection id: " - << connection_id - << ", ssl protocol handshake error: " - << error.message(); - co_return; - } - - // make socks session shared ptr. - auto new_session = - std::make_shared( - std::move(ssl_socks_stream), connection_id, self); - m_clients[connection_id] = new_session; - - new_session->start(); - } // plain http protocol. - else if (detect[0] == 0x47 || // 'G' - detect[0] == 0x50 || // 'P' - detect[0] == 0x43) // 'C' - { - if (m_option.disable_http_) - { - XLOG_DBG << "connection id: " - << connection_id - << ", http protocol disabled"; - co_return; - } - - XLOG_DBG << "connection id: " - << connection_id - << ", http protocol"; - - auto new_session = - std::make_shared( - instantiate_proxy_stream(std::move(socket)), - connection_id, self); - m_clients[connection_id] = new_session; - - new_session->start(); - } - else if (noise && m_option.scramble_) - { - // 进入噪声过滤协议, 同时返回一段噪声给客户端. - XLOG_DBG << "connection id: " - << connection_id - << ", noise protocol"; - - if (!co_await noise_process(std::move(socket), connection_id)) - co_return; - } - else - { - XLOG_DBG << "connection id: " - << connection_id - << ", unknown protocol"; - } - - co_return; - } - + // start_proxy_listen 启动一个协程, 用于监听 proxy client 的连接. + // 当有新的连接到来时, 会创建一个 proxy_session 对象, 并启动 proxy_session + // 的对象. inline net::awaitable start_proxy_listen(tcp_acceptor& a) { boost::system::error_code error; @@ -3788,9 +3812,19 @@ R"x*x*x( << ", start client incoming: " << client; - net::co_spawn(m_executor, - socket_detect(std::move(socket), connection_id), - net::detached); + // 创建 proxy_session 对象. + auto new_session = + std::make_shared( + m_executor, + instantiate_proxy_stream(std::move(socket)), + connection_id, + self); + + // 保存 proxy_session 对象到 m_clients 中. + m_clients[connection_id] = new_session; + + // 启动 proxy_session 对象. + new_session->start(); } XLOG_WARN << "start_proxy_listen exit ..."; @@ -3798,7 +3832,7 @@ R"x*x*x( } private: - net::io_context::executor_type m_executor; + net::any_io_executor m_executor; tcp_acceptor m_acceptor; proxy_server_option m_option; using proxy_session_weak_ptr =