From 55e99f22f8a28cf64390963ef044e456f17fafb0 Mon Sep 17 00:00:00 2001 From: Twice Date: Sun, 3 Nov 2024 19:38:09 +0800 Subject: [PATCH] fix(script): avoid SetCurrentConnection on read-only scripting (#2640) --- src/server/redis_connection.cc | 12 ----------- src/server/server.cc | 4 ++-- src/server/server.h | 5 ----- src/server/worker.cc | 2 +- src/storage/scripting.cc | 38 +++++++++++++++++++--------------- src/storage/scripting.h | 6 +++--- 6 files changed, 27 insertions(+), 40 deletions(-) diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index 584b506839a..70abfe70b54 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -417,22 +417,10 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { // No lock guard, because 'exec' command has acquired 'WorkExclusivityGuard' } else if (cmd_flags & kCmdExclusive) { exclusivity = srv_->WorkExclusivityGuard(); - - // When executing lua script commands that have "exclusive" attribute, we need to know current connection, - // but we should set current connection after acquiring the WorkExclusivityGuard to make it thread-safe - srv_->SetCurrentConnection(this); } else { concurrency = srv_->WorkConcurrencyGuard(); } - auto category = attributes->category; - if ((category == CommandCategory::Function || category == CommandCategory::Script) && (cmd_flags & kCmdReadOnly)) { - // FIXME: since read-only script commands are not exclusive, - // SetCurrentConnection here is weird and can cause many issues, - // we should pass the Connection directly to the lua context instead - srv_->SetCurrentConnection(this); - } - if (srv_->IsLoading() && !(cmd_flags & kCmdLoading)) { Reply(redis::Error({Status::RedisLoading, errRestoringBackup})); if (is_multi_exec) multi_error_ = true; diff --git a/src/server/server.cc b/src/server/server.cc index e569d12e5f9..5b52eb333fb 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -102,7 +102,7 @@ Server::Server(engine::Storage *storage, Config *config) AdjustOpenFilesLimit(); slow_log_.SetMaxEntries(config->slowlog_max_len); perf_log_.SetMaxEntries(config->profiling_sample_record_max_len); - lua_ = lua::CreateState(this); + lua_ = lua::CreateState(); } Server::~Server() { @@ -1764,7 +1764,7 @@ Status Server::FunctionSetLib(const std::string &func, const std::string &lib) c } void Server::ScriptReset() { - auto lua = lua_.exchange(lua::CreateState(this)); + auto lua = lua_.exchange(lua::CreateState()); lua::DestroyState(lua); } diff --git a/src/server/server.h b/src/server/server.h index 7d8c8327f05..3c2ab16e997 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -285,9 +285,6 @@ class Server { Status ExecPropagatedCommand(const std::vector &tokens); Status ExecPropagateScriptCommand(const std::vector &tokens); - void SetCurrentConnection(redis::Connection *conn) { curr_connection_ = conn; } - redis::Connection *GetCurrentConnection() { return curr_connection_; } - LogCollector *GetPerfLog() { return &perf_log_; } LogCollector *GetSlowLog() { return &slow_log_; } void SlowlogPushEntryIfNeeded(const std::vector *args, uint64_t duration, const redis::Connection *conn); @@ -343,8 +340,6 @@ class Server { std::atomic lua_; - redis::Connection *curr_connection_ = nullptr; - // client counters std::atomic client_id_{1}; std::atomic connected_clients_{0}; diff --git a/src/server/worker.cc b/src/server/worker.cc index 4ddf31add0b..6420d76c361 100644 --- a/src/server/worker.cc +++ b/src/server/worker.cc @@ -76,7 +76,7 @@ Worker::Worker(Server *srv, Config *config) : srv(srv), base_(event_base_new()) } } } - lua_ = lua::CreateState(srv); + lua_ = lua::CreateState(); } Worker::~Worker() { diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index f38d942237f..5768aee8169 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -57,15 +57,12 @@ enum { namespace lua { -lua_State *CreateState(Server *srv) { +lua_State *CreateState() { lua_State *lua = lua_open(); LoadLibraries(lua); RemoveUnsupportedFunctions(lua); LoadFuncs(lua); - lua_pushlightuserdata(lua, srv); - lua_setglobal(lua, REDIS_LUA_SERVER_PTR); - EnableGlobalsProtection(lua); return lua; } @@ -273,7 +270,10 @@ int RedisRegisterFunction(lua_State *lua) { } // store the map from function name to library name - auto s = GetServer(lua)->FunctionSetLib(name, libname); + auto *script_run_ctx = GetFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); + CHECK_NOTNULL(script_run_ctx); + + auto s = script_run_ctx->conn->GetServer()->FunctionSetLib(name, libname); if (!s) { lua_pushstring(lua, "redis.register_function() failed to store informantion."); return lua_error(lua); @@ -305,6 +305,12 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee if (!s) return s; } + ScriptRunCtx script_run_ctx; + script_run_ctx.conn = conn; + script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; + + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx); + lua_pushstring(lua, libname.c_str()); lua_setglobal(lua, REDIS_FUNCTION_LIBNAME); auto libname_exit = MakeScopeExit([lua] { @@ -331,6 +337,8 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee return {Status::NotOK, "Error while running new function lib: " + err_msg}; } + RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); + if (!FunctionIsLibExist(conn, libname, false, read_only)) { return {Status::NotOK, "Please register some function in FUNCTION LOAD"}; } @@ -396,6 +404,7 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: } ScriptRunCtx script_run_ctx; + script_run_ctx.conn = conn; script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); if (!lua_isnil(lua, -1)) { @@ -642,6 +651,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh } ScriptRunCtx current_script_run_ctx; + current_script_run_ctx.conn = conn; current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str()); if (!lua_isnil(lua, -1)) { @@ -709,14 +719,6 @@ int RedisCallCommand(lua_State *lua) { return RedisGenericCommand(lua, 1); } int RedisPCallCommand(lua_State *lua) { return RedisGenericCommand(lua, 0); } -Server *GetServer(lua_State *lua) { - lua_getglobal(lua, REDIS_LUA_SERVER_PTR); - auto srv = reinterpret_cast(lua_touserdata(lua, -1)); - lua_pop(lua, 1); - - return srv; -} - // TODO: we do not want to repeat same logic as Connection::ExecuteCommands, // so the function need to be refactored int RedisGenericCommand(lua_State *lua, int raise_error) { @@ -772,10 +774,10 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { std::string cmd_name = attributes->name; - auto srv = GetServer(lua); + auto *conn = script_run_ctx->conn; + auto *srv = conn->GetServer(); Config *config = srv->GetConfig(); - redis::Connection *conn = srv->GetCurrentConnection(); if (config->cluster_enabled) { if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) { PushError(lua, "Can not run script on cluster, 'no-cluster' flag is set"); @@ -901,8 +903,10 @@ int RedisReturnSingleFieldTable(lua_State *lua, const char *field) { } int RedisSetResp(lua_State *lua) { - auto srv = GetServer(lua); - auto conn = srv->GetCurrentConnection(); + auto *script_run_ctx = GetFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); + CHECK_NOTNULL(script_run_ctx); + auto *conn = script_run_ctx->conn; + auto *srv = conn->GetServer(); if (lua_gettop(lua) != 1) { PushError(lua, "redis.setresp() requires one argument."); diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 9aa4044ba3c..188f855c9ba 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -35,7 +35,6 @@ inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] = "f_"; inline constexpr const char REDIS_LUA_FUNC_SHA_FLAGS[] = "f_{}_flags_"; inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] = "__redis_registered_"; inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_registered_flags_"; -inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr"; inline constexpr const char REDIS_FUNCTION_LIBNAME[] = "REDIS_FUNCTION_LIBNAME"; inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] = "REDIS_FUNCTION_NEEDSTORE"; inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LIBRARIES"; @@ -43,9 +42,8 @@ inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX"; namespace lua { -lua_State *CreateState(Server *srv); +lua_State *CreateState(); void DestroyState(lua_State *lua); -Server *GetServer(lua_State *lua); void LoadFuncs(lua_State *lua); void LoadLibraries(lua_State *lua); @@ -150,6 +148,8 @@ struct ScriptRunCtx { // and is used to detect whether there is cross-slot access // between multiple commands in a script or function. int current_slot = -1; + // the current connection + redis::Connection *conn = nullptr; }; /// SaveOnRegistry saves user-defined data to lua REGISTRY