From e84fbb32f432d5a190148716665c7ffa3c630f1f Mon Sep 17 00:00:00 2001 From: mihacooper Date: Thu, 19 Oct 2017 22:49:51 +0300 Subject: [PATCH] Function upvalues implementation (#86) --- src/cpp/function.cpp | 104 +++++++++++++++++++++++++++++ src/cpp/function.h | 35 ++++++++++ src/cpp/garbage-collector.cpp | 2 +- src/cpp/lua-helpers.cpp | 5 +- src/cpp/lua-helpers.h | 5 +- src/cpp/lua-module.cpp | 5 +- src/cpp/stored-object.cpp | 52 +++++++-------- src/cpp/threading.cpp | 56 +++++++++------- src/cpp/threading.h | 9 ++- src/lua/effil.lua | 3 +- tests/lua/bootstrap-tests.lua | 7 ++ tests/lua/gc-stress.lua | 2 - tests/lua/metatable.lua | 2 +- tests/lua/run_tests | 1 + tests/lua/thread.lua | 4 +- tests/lua/type.lua | 11 +++- tests/lua/type_mismatch.lua | 14 ++-- tests/lua/upvalues.lua | 119 ++++++++++++++++++++++++++++++++++ 18 files changed, 356 insertions(+), 80 deletions(-) create mode 100644 src/cpp/function.cpp create mode 100644 src/cpp/function.h create mode 100644 tests/lua/upvalues.lua diff --git a/src/cpp/function.cpp b/src/cpp/function.cpp new file mode 100644 index 0000000..a583f3c --- /dev/null +++ b/src/cpp/function.cpp @@ -0,0 +1,104 @@ +#include "function.h" + +namespace effil { + +namespace { + +bool allowTableUpvalues(const sol::optional& newValue = sol::nullopt) { + static std::atomic_bool value(true); + + if (newValue) + return value.exchange(newValue.value()); + return value; +} + +} // anonymous + +sol::object luaAllowTableUpvalues(sol::this_state state, const sol::stack_object& value) { + if (value.valid()) { + REQUIRE(value.get_type() == sol::type::boolean) << "bad argument #1 to 'effil.allow_table_upvalues' (boolean expected, got " << luaTypename(value) << ")"; + return sol::make_object(state, allowTableUpvalues(value.template as())); + } + else { + return sol::make_object(state, allowTableUpvalues()); + } +} + +void FunctionObject::initialize(const sol::function& luaObject) { + assert(luaObject.valid()); + assert(luaObject.get_type() == sol::type::function); + + lua_State* state = luaObject.lua_state(); + sol::stack::push(state, luaObject); + + lua_Debug dbgInfo; + lua_getinfo(state, ">u", &dbgInfo); // function is popped from stack here + sol::stack::push(state, luaObject); + + data_->function = dumpFunction(luaObject); + data_->upvalues.resize(dbgInfo.nups); +#if LUA_VERSION_NUM > 501 + data_->envUpvaluePos = 0; // means no _ENV upvalue +#endif // LUA_VERSION_NUM > 501 + + for (unsigned char i = 1; i <= dbgInfo.nups; ++i) { + const char* valueName = lua_getupvalue(state, -1, i); // push value on stack + (void)valueName; // get rid of 'unused' warning for Lua5.1 + assert(valueName != nullptr); + +#if LUA_VERSION_NUM > 501 + if (strcmp(valueName, "_ENV") == 0) { // do not serialize _ENV + sol::stack::pop(state); + data_->envUpvaluePos = i; + continue; + } +#endif // LUA_VERSION_NUM > 501 + + const auto& upvalue = sol::stack::pop(state); // pop from stack + if (!allowTableUpvalues() && upvalue.get_type() == sol::type::table) { + sol::stack::pop(state); + throw effil::Exception() << "bad function upvalue #" << (int)i << " (table is disabled by effil.allow_table_upvalues)"; + } + + StoredObject storedObject; + try { + storedObject = createStoredObject(upvalue); + assert(storedObject.get() != nullptr); + } + catch(const std::exception& err) { + sol::stack::pop(state); + throw effil::Exception() << "bad function upvalue #" << (int)i << " (" << err.what() << ")"; + } + + if (storedObject->gcHandle() != nullptr) { + addReference(storedObject->gcHandle()); + storedObject->releaseStrongReference(); + } + data_->upvalues[i - 1] = std::move(storedObject); + } + sol::stack::pop(state); +} + +sol::object FunctionObject::loadFunction(lua_State* state) { + sol::function result = loadString(state, data_->function); + assert(result.valid()); + + sol::stack::push(state, result); + for(size_t i = 0; i < data_->upvalues.size(); ++i) { +#if LUA_VERSION_NUM > 501 + if (data_->envUpvaluePos == i + 1) { + lua_rawgeti(state, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS); // push _ENV to stack + lua_setupvalue(state, -2, i + 1); // pop _ENV and set as upvalue + continue; + } +#endif // LUA_VERSION_NUM > 501 + assert(data_->upvalues[i].get() != nullptr); + + const auto& obj = data_->upvalues[i]->unpack(sol::this_state{state}); + sol::stack::push(state, obj); + lua_setupvalue(state, -2, i + 1); + } + return sol::stack::pop(state); +} + +} // namespace effil diff --git a/src/cpp/function.h b/src/cpp/function.h new file mode 100644 index 0000000..4f20083 --- /dev/null +++ b/src/cpp/function.h @@ -0,0 +1,35 @@ +#pragma once + +#include "gc-object.h" +#include "utils.h" +#include "lua-helpers.h" + +namespace effil { + +sol::object luaAllowTableUpvalues(sol::this_state state, const sol::stack_object&); + +class FunctionObject: public GCObject { +public: + template + FunctionObject(const SolType& luaObject) + : data_(std::make_shared()) { + initialize(luaObject); + } + + sol::object loadFunction(lua_State* state); + +private: + void initialize(const sol::function& luaObject); + + struct SharedData { + std::string function; +#if LUA_VERSION_NUM > 501 + unsigned char envUpvaluePos; +#endif // LUA_VERSION_NUM > 501 + std::vector upvalues; + }; + + std::shared_ptr data_; +}; + +} // namespace effil diff --git a/src/cpp/garbage-collector.cpp b/src/cpp/garbage-collector.cpp index adb20b2..6783544 100644 --- a/src/cpp/garbage-collector.cpp +++ b/src/cpp/garbage-collector.cpp @@ -37,8 +37,8 @@ void GC::collect() { } } - DEBUG << "Removing " << (objects_.size() - black.size()) << " out of " << objects_.size() << std::endl; // Sweep phase + DEBUG << "Removing " << (objects_.size() - black.size()) << " out of " << objects_.size() << std::endl; objects_ = std::move(black); lastCleanup_.store(0); diff --git a/src/cpp/lua-helpers.cpp b/src/cpp/lua-helpers.cpp index a9c5cfd..21a37a9 100644 --- a/src/cpp/lua-helpers.cpp +++ b/src/cpp/lua-helpers.cpp @@ -42,8 +42,9 @@ std::string dumpFunction(const sol::function& f) { return result; } -sol::function loadString(const sol::state_view& lua, const std::string& str) { - int ret = luaL_loadbuffer(lua, str.c_str(), str.size(), nullptr); +sol::function loadString(const sol::state_view& lua, const std::string& str, + const sol::optional& source /* = sol::nullopt*/) { + int ret = luaL_loadbuffer(lua, str.c_str(), str.size(), source ? source.value().c_str() : nullptr); REQUIRE(ret == LUA_OK) << "Unable to load function from string: " << luaError(ret); return sol::stack::pop(lua); } diff --git a/src/cpp/lua-helpers.h b/src/cpp/lua-helpers.h index 25990bd..dd4a2fa 100644 --- a/src/cpp/lua-helpers.h +++ b/src/cpp/lua-helpers.h @@ -11,7 +11,8 @@ class Channel; class Thread; std::string dumpFunction(const sol::function& f); -sol::function loadString(const sol::state_view& lua, const std::string& str); +sol::function loadString(const sol::state_view& lua, const std::string& str, + const sol::optional& source = sol::nullopt); std::chrono::milliseconds fromLuaTime(int duration, const sol::optional& period); template @@ -21,7 +22,7 @@ std::string luaTypename(const SolObject& obj) { return "effil.table"; else if (obj.template is()) return "effil.channel"; - else if (obj.template is>()) + else if (obj.template is()) return "effil.thread"; else return "userdata"; diff --git a/src/cpp/lua-module.cpp b/src/cpp/lua-module.cpp index d11bbab..c66df4c 100644 --- a/src/cpp/lua-module.cpp +++ b/src/cpp/lua-module.cpp @@ -15,7 +15,7 @@ sol::object createThread(const sol::this_state& lua, int step, const sol::function& function, const sol::variadic_args& args) { - return sol::make_object(lua, std::make_shared(path, cpath, step, function, args)); + return sol::make_object(lua, GC::instance().create(path, cpath, step, function, args)); } sol::object createTable(sol::this_state lua, const sol::optional& tbl) { @@ -66,7 +66,8 @@ int luaopen_libeffil(lua_State* L) { "channel", createChannel, "type", getLuaTypename, "pairs", SharedTable::globalLuaPairs, - "ipairs", SharedTable::globalLuaIPairs + "ipairs", SharedTable::globalLuaIPairs, + "allow_table_upvalues", luaAllowTableUpvalues ); sol::stack::push(lua, publicApi); return 1; diff --git a/src/cpp/stored-object.cpp b/src/cpp/stored-object.cpp index 50cf564..6fa790c 100644 --- a/src/cpp/stored-object.cpp +++ b/src/cpp/stored-object.cpp @@ -1,8 +1,8 @@ #include "stored-object.h" #include "channel.h" - #include "threading.h" #include "shared-table.h" +#include "function.h" #include "utils.h" #include @@ -45,29 +45,6 @@ private: StoredType data_; }; -class FunctionHolder : public BaseHolder { -public: - template - FunctionHolder(SolObject luaObject) noexcept { - sol::state_view lua(luaObject.lua_state()); - function_ = dumpFunction(luaObject); - } - - bool rawCompare(const BaseHolder* other) const noexcept final { - return function_ < static_cast(other)->function_; - } - - sol::object unpack(sol::this_state state) const final { - sol::function result = loadString(state, function_); - // The result of restaring always is valid function. - assert(result.valid()); - return sol::make_object(state, result); - } - -private: - std::string function_; -}; - template class GCObjectHolder : public BaseHolder { public: @@ -87,7 +64,7 @@ public: return handle_ < static_cast*>(other)->handle_; } - sol::object unpack(sol::this_state state) const final { + sol::object unpack(sol::this_state state) const override { return sol::make_object(state, GC::instance().get(handle_)); } @@ -103,11 +80,22 @@ public: } } -private: +protected: GCObjectHandle handle_; sol::optional strongRef_; }; +class FunctionHolder : public GCObjectHolder { +public: + template + FunctionHolder(const SolType& luaObject) : GCObjectHolder(luaObject) {} + FunctionHolder(GCObjectHandle handle) : GCObjectHolder(handle) {} + + sol::object unpack(sol::this_state state) const final { + return GC::instance().get(handle_).loadFunction(state); + } +}; + // This class is used as a storage for visited sol::tables // TODO: try to use map or unordered map instead of linear search in vector // TODO: Trick is - sol::object has only operator==:/ @@ -171,12 +159,16 @@ StoredObject fromSolObject(const SolObject& luaObject) { return std::make_unique>(luaObject); else if (luaObject.template is()) return std::make_unique>(luaObject); - else if (luaObject.template is>()) - return std::make_unique>>(luaObject); + else if (luaObject.template is()) + return std::make_unique(luaObject); + else if (luaObject.template is()) + return std::make_unique>(luaObject); else throw Exception() << "Unable to store userdata object\n"; - case sol::type::function: - return std::make_unique(luaObject); + case sol::type::function: { + FunctionObject func = GC::instance().create(luaObject); + return std::make_unique(func.handle()); + } case sol::type::table: { sol::table luaTable = luaObject; // Tables pool is used to store tables. diff --git a/src/cpp/threading.cpp b/src/cpp/threading.cpp index eae3e75..202c092 100644 --- a/src/cpp/threading.cpp +++ b/src/cpp/threading.cpp @@ -163,43 +163,46 @@ void luaHook(lua_State*, lua_Debug*) { } } -void runThread(std::shared_ptr handle, - std::string strFunction, +} // namespace + +void Thread::runThread(Thread thread, + FunctionObject function, effil::StoredArray arguments) { - assert(handle); - thisThreadHandle = handle.get(); + thisThreadHandle = thread.handle_.get(); + assert(thisThreadHandle != nullptr); try { { - ScopeGuard reportComplete([handle, &arguments](){ - DEBUG << "Finished " << std::endl; + ScopeGuard reportComplete([thread, &arguments](){ // Let's destroy accociated state // to release all resources as soon as possible arguments.clear(); - handle->destroyLua(); + thread.handle_->destroyLua(); }); - sol::function userFuncObj = loadString(handle->lua(), strFunction); + sol::function userFuncObj = function.loadFunction(thread.handle_->lua()); sol::function_result results = userFuncObj(std::move(arguments)); (void)results; // just leave all returns on the stack - sol::variadic_args args(handle->lua(), -lua_gettop(handle->lua())); + sol::variadic_args args(thread.handle_->lua(), -lua_gettop(thread.handle_->lua())); for (const auto& iter : args) { StoredObject store = createStoredObject(iter.get()); - handle->result().emplace_back(std::move(store)); + if (store->gcHandle() != nullptr) + { + thread.addReference(store->gcHandle()); + store->releaseStrongReference(); + } + thread.handle_->result().emplace_back(std::move(store)); } } - handle->changeStatus(Status::Completed); + thread.handle_->changeStatus(Status::Completed); } catch (const LuaHookStopException&) { - handle->changeStatus(Status::Canceled); + thread.handle_->changeStatus(Status::Canceled); } catch (const sol::error& err) { DEBUG << "Failed with msg: " << err.what() << std::endl; - handle->result().emplace_back(createStoredObject(err.what())); - handle->changeStatus(Status::Failed); + thread.handle_->result().emplace_back(createStoredObject(err.what())); + thread.handle_->changeStatus(Status::Failed); } } -} // namespace - - std::string threadId() { std::stringstream ss; ss << std::this_thread::get_id(); @@ -233,6 +236,11 @@ Thread::Thread(const std::string& path, const sol::variadic_args& variadicArgs) : handle_(std::make_shared()) { + sol::optional functionObj; + try { + functionObj = FunctionObject(function); + } RETHROW_WITH_PREFIX("effil.thread"); + handle_->lua()["package"]["path"] = path; handle_->lua()["package"]["cpath"] = cpath; handle_->lua().script("require 'effil'"); @@ -240,20 +248,20 @@ Thread::Thread(const std::string& path, if (step != 0) lua_sethook(handle_->lua(), luaHook, LUA_MASKCOUNT, step); - std::string strFunction = dumpFunction(function); - effil::StoredArray arguments; try { for (const auto& arg : variadicArgs) { - arguments.emplace_back(createStoredObject(arg.get())); + const auto& storedObj = createStoredObject(arg.get()); + addReference(storedObj->gcHandle()); + storedObj->releaseStrongReference(); + arguments.emplace_back(storedObj); } } RETHROW_WITH_PREFIX("effil.thread"); - std::thread thr(&runThread, - handle_, - std::move(strFunction), + std::thread thr(&Thread::runThread, + *this, + functionObj.value(), std::move(arguments)); - DEBUG << "Created " << thr.get_id() << std::endl; thr.detach(); } diff --git a/src/cpp/threading.h b/src/cpp/threading.h index 881f0da..a55cc19 100644 --- a/src/cpp/threading.h +++ b/src/cpp/threading.h @@ -2,6 +2,7 @@ #include #include "lua-helpers.h" +#include "function.h" namespace effil { @@ -12,7 +13,7 @@ void sleep(const sol::stack_object& duration, const sol::stack_object& metric); class ThreadHandle; -class Thread { +class Thread : public GCObject { public: Thread(const std::string& path, const std::string& cpath, @@ -37,11 +38,9 @@ public: void resume(); private: - std::shared_ptr handle_; + static void runThread(Thread, FunctionObject, effil::StoredArray); -private: - Thread(const Thread&) = delete; - Thread& operator=(const Thread&) = delete; + std::shared_ptr handle_; }; } // effil diff --git a/src/lua/effil.lua b/src/lua/effil.lua index 73c9816..38cd591 100644 --- a/src/lua/effil.lua +++ b/src/lua/effil.lua @@ -15,7 +15,8 @@ local api = { channel = capi.channel, type = capi.type, pairs = capi.pairs, - ipairs = capi.ipairs + ipairs = capi.ipairs, + allow_table_upvalues = capi.allow_table_upvalues } api.size = function (something) diff --git a/tests/lua/bootstrap-tests.lua b/tests/lua/bootstrap-tests.lua index e533f32..189765d 100644 --- a/tests/lua/bootstrap-tests.lua +++ b/tests/lua/bootstrap-tests.lua @@ -9,6 +9,13 @@ function default_tear_down() effil.gc.collect() -- effil.G is always present -- thus, gc has one object + if effil.gc.count() ~= 1 then + print "Not all bojects were removed, gonna sleep for 2 seconds" + effil.sleep(2) + + collectgarbage() + effil.gc.collect() + end test.equal(effil.gc.count(), 1) end diff --git a/tests/lua/gc-stress.lua b/tests/lua/gc-stress.lua index 6a43cc4..64d39b6 100644 --- a/tests/lua/gc-stress.lua +++ b/tests/lua/gc-stress.lua @@ -34,6 +34,4 @@ test.gc_stress.create_and_collect_in_parallel = function () for i = 1, thread_num do test.equal(threads[i]:wait(), "completed") end - - test.equal(effil.gc.count(), 1) end diff --git a/tests/lua/metatable.lua b/tests/lua/metatable.lua index 2794f65..a372701 100644 --- a/tests/lua/metatable.lua +++ b/tests/lua/metatable.lua @@ -9,7 +9,7 @@ test.metatable.tear_down = function (metatable) if type(metatable) == "table" then test.equal(effil.gc.count(), 1) else - test.equal(effil.gc.count(), 2) + test.equal(effil.gc.count(), 3) end end diff --git a/tests/lua/run_tests b/tests/lua/run_tests index 711a46e..13d9b6d 100755 --- a/tests/lua/run_tests +++ b/tests/lua/run_tests @@ -14,6 +14,7 @@ require "thread" require "shared-table" require "metatable" require "type_mismatch" +require "upvalues" if os.getenv("STRESS") then require "channel-stress" diff --git a/tests/lua/thread.lua b/tests/lua/thread.lua index de916c4..5e33ddd 100644 --- a/tests/lua/thread.lua +++ b/tests/lua/thread.lua @@ -98,6 +98,7 @@ test.thread.cancel = function () test.is_true(thread:cancel()) test.equal(thread:status(), "canceled") end + test.thread.async_cancel = function () local thread_runner = effil.thread( function() @@ -213,9 +214,6 @@ test.thread.returns = function () test.is_function(returns[5]) test.equal(returns[5](11, 89), 100) - - -- Workaround to get child thread free all return values - effil.sleep(2) end test.thread.timed_cancel = function () diff --git a/tests/lua/type.lua b/tests/lua/type.lua index ef1a98a..bd84a03 100644 --- a/tests/lua/type.lua +++ b/tests/lua/type.lua @@ -1,9 +1,16 @@ require "bootstrap-tests" -test.type = function() +test.type.tear_down = default_tear_down + +test.type.check_all_types = function() test.equal(effil.type(1), "number") test.equal(effil.type("string"), "string") + test.equal(effil.type(true), "boolean") + test.equal(effil.type(nil), "nil") + test.equal(effil.type(function()end), "function") test.equal(effil.type(effil.table()), "effil.table") test.equal(effil.type(effil.channel()), "effil.channel") - test.equal(effil.type(effil.thread(function() end)()), "effil.thread") + local thr = effil.thread(function() end)() + test.equal(effil.type(thr), "effil.thread") + thr:wait() end diff --git a/tests/lua/type_mismatch.lua b/tests/lua/type_mismatch.lua index 254aec8..3d5d4c1 100644 --- a/tests/lua/type_mismatch.lua +++ b/tests/lua/type_mismatch.lua @@ -122,6 +122,11 @@ local function generate_tests() -- effil.gc.step test.type_mismatch.input_types_mismatch_p(1, "number", "gc.step", type_instance) end + + if typename ~= "boolean" then + -- effil.allow_table_upvalue + test.type_mismatch.input_types_mismatch_p(1, "boolean", "allow_table_upvalues", type_instance) + end end -- Below presented tests which support everything except coroutines @@ -149,8 +154,7 @@ end -- Put it to function to limit the lifetime of objects generate_tests() -test.type_mismatch.gc_checks_after_tests = function () - collectgarbage() - effil.gc.collect() - test.equal(effil.gc.count(), 1) -end +test.type_mismatch.gc_checks_after_tests = function() + effil.allow_table_upvalues(true) + default_tear_down() +end \ No newline at end of file diff --git a/tests/lua/upvalues.lua b/tests/lua/upvalues.lua new file mode 100644 index 0000000..704c119 --- /dev/null +++ b/tests/lua/upvalues.lua @@ -0,0 +1,119 @@ +require "bootstrap-tests" + +test.upvalues.tear_down = default_tear_down + +test.upvalues.check_single_upvalue_p = function(type_creator, type_checker) + local obj = type_creator() + local thread_worker = function(checker) return require("effil").type(obj) .. ": " .. checker(obj) end + local ret = effil.thread(thread_worker)(type_checker):get() + + print("Returned: " .. ret) + test.equal(ret, effil.type(obj) .. ": " .. type_checker(obj)) +end + +local foo = function() return 22 end + +test.upvalues.check_single_upvalue_p(function() return 1488 end, + function() return "1488" end) + +test.upvalues.check_single_upvalue_p(function() return "awesome" end, + function() return "awesome" end) + +test.upvalues.check_single_upvalue_p(function() return true end, + function() return "true" end) + +test.upvalues.check_single_upvalue_p(function() return nil end, + function() return "nil" end) + +test.upvalues.check_single_upvalue_p(function() return foo end, + function(f) return f() end) + +test.upvalues.check_single_upvalue_p(function() return effil.table({key = 44}) end, + function(t) return t.key end) + +test.upvalues.check_single_upvalue_p(function() local c = effil.channel() c:push(33) c:push(33) return c end, + function(c) return c:pop() end) + +test.upvalues.check_single_upvalue_p(function() return effil.thread(foo)() end, + function(t) return t:get() end) + +test.upvalues.check_invalid_coroutine = function() + local obj = coroutine.create(foo) + local thread_worker = function() return tostring(obj) end + local ret, err = pcall(effil.thread(thread_worker)) + if ret then + ret:wait() + end + test.is_false(ret) + print("Returned: " .. err) + upvalue_num = LUA_VERSION > 51 and 2 or 1 + test.equal(err, "effil.thread: bad function upvalue #" .. upvalue_num .. + " (unable to store object of thread type)") +end + +test.upvalues.check_table = function() + local obj = { key = "value" } + local thread_worker = function() return require("effil").type(obj) .. ": " .. obj.key end + local ret = effil.thread(thread_worker)():get() + + print("Returned: " .. ret) + test.equal(ret, "effil.table: value") +end + +test.upvalues.check_env = function() + local obj1 = 13 -- local + obj2 = { key = "origin" } -- global + local obj3 = 79 -- local + + local function foo() -- _ENV is 2nd upvalue + return obj1, obj2.key, obj3 + end + + local function thread_worker(func) + obj1 = 31 -- global + obj2 = { key = "local" } -- global + obj3 = 97 -- global + return table.concat({func()}, ", ") + end + + local ret = effil.thread(thread_worker)(foo):get() + print("Returned: " .. ret) + test.equal(ret, "13, local, 79") +end + +local function check_works(should_work) + local obj = { key = "value"} + local function worker() + return obj.key + end + + local ret, err = pcall(effil.thread(worker)) + if ret then + err:wait() + end + test.equal(ret, should_work) + if not should_work then + test.equal(err, "effil.thread: bad function upvalue #1 (table is disabled by effil.allow_table_upvalues)") + end +end + +test.upvalues_table.tear_down = function() + effil.allow_table_upvalues(true) + default_tear_down() +end + +test.upvalues_table.disabling_table_upvalues = function() + test.equal(effil.allow_table_upvalues(), true) + -- works by default + check_works(true) + + -- disable + test.equal(effil.allow_table_upvalues(false), true) + check_works(false) + test.equal(effil.allow_table_upvalues(), false) + + -- enable back + test.equal(effil.allow_table_upvalues(true), false) + check_works(true) + test.equal(effil.allow_table_upvalues(), true) +end