Function upvalues implementation (#86)

This commit is contained in:
mihacooper 2017-10-19 22:49:51 +03:00 committed by Ilia
parent f77391baa5
commit e84fbb32f4
18 changed files with 356 additions and 80 deletions

104
src/cpp/function.cpp Normal file
View File

@ -0,0 +1,104 @@
#include "function.h"
namespace effil {
namespace {
bool allowTableUpvalues(const sol::optional<bool>& newValue = sol::nullopt) {
static std::atomic_bool value(true);
if (newValue)
return value.exchange(newValue.value());
return value;
}
} // anonymous
sol::object luaAllowTableUpvalues(sol::this_state state, const sol::stack_object& value) {
if (value.valid()) {
REQUIRE(value.get_type() == sol::type::boolean) << "bad argument #1 to 'effil.allow_table_upvalues' (boolean expected, got " << luaTypename(value) << ")";
return sol::make_object(state, allowTableUpvalues(value.template as<bool>()));
}
else {
return sol::make_object(state, allowTableUpvalues());
}
}
void FunctionObject::initialize(const sol::function& luaObject) {
assert(luaObject.valid());
assert(luaObject.get_type() == sol::type::function);
lua_State* state = luaObject.lua_state();
sol::stack::push(state, luaObject);
lua_Debug dbgInfo;
lua_getinfo(state, ">u", &dbgInfo); // function is popped from stack here
sol::stack::push(state, luaObject);
data_->function = dumpFunction(luaObject);
data_->upvalues.resize(dbgInfo.nups);
#if LUA_VERSION_NUM > 501
data_->envUpvaluePos = 0; // means no _ENV upvalue
#endif // LUA_VERSION_NUM > 501
for (unsigned char i = 1; i <= dbgInfo.nups; ++i) {
const char* valueName = lua_getupvalue(state, -1, i); // push value on stack
(void)valueName; // get rid of 'unused' warning for Lua5.1
assert(valueName != nullptr);
#if LUA_VERSION_NUM > 501
if (strcmp(valueName, "_ENV") == 0) { // do not serialize _ENV
sol::stack::pop<sol::object>(state);
data_->envUpvaluePos = i;
continue;
}
#endif // LUA_VERSION_NUM > 501
const auto& upvalue = sol::stack::pop<sol::object>(state); // pop from stack
if (!allowTableUpvalues() && upvalue.get_type() == sol::type::table) {
sol::stack::pop<sol::object>(state);
throw effil::Exception() << "bad function upvalue #" << (int)i << " (table is disabled by effil.allow_table_upvalues)";
}
StoredObject storedObject;
try {
storedObject = createStoredObject(upvalue);
assert(storedObject.get() != nullptr);
}
catch(const std::exception& err) {
sol::stack::pop<sol::object>(state);
throw effil::Exception() << "bad function upvalue #" << (int)i << " (" << err.what() << ")";
}
if (storedObject->gcHandle() != nullptr) {
addReference(storedObject->gcHandle());
storedObject->releaseStrongReference();
}
data_->upvalues[i - 1] = std::move(storedObject);
}
sol::stack::pop<sol::object>(state);
}
sol::object FunctionObject::loadFunction(lua_State* state) {
sol::function result = loadString(state, data_->function);
assert(result.valid());
sol::stack::push(state, result);
for(size_t i = 0; i < data_->upvalues.size(); ++i) {
#if LUA_VERSION_NUM > 501
if (data_->envUpvaluePos == i + 1) {
lua_rawgeti(state, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS); // push _ENV to stack
lua_setupvalue(state, -2, i + 1); // pop _ENV and set as upvalue
continue;
}
#endif // LUA_VERSION_NUM > 501
assert(data_->upvalues[i].get() != nullptr);
const auto& obj = data_->upvalues[i]->unpack(sol::this_state{state});
sol::stack::push(state, obj);
lua_setupvalue(state, -2, i + 1);
}
return sol::stack::pop<sol::function>(state);
}
} // namespace effil

35
src/cpp/function.h Normal file
View File

@ -0,0 +1,35 @@
#pragma once
#include "gc-object.h"
#include "utils.h"
#include "lua-helpers.h"
namespace effil {
sol::object luaAllowTableUpvalues(sol::this_state state, const sol::stack_object&);
class FunctionObject: public GCObject {
public:
template <typename SolType>
FunctionObject(const SolType& luaObject)
: data_(std::make_shared<SharedData>()) {
initialize(luaObject);
}
sol::object loadFunction(lua_State* state);
private:
void initialize(const sol::function& luaObject);
struct SharedData {
std::string function;
#if LUA_VERSION_NUM > 501
unsigned char envUpvaluePos;
#endif // LUA_VERSION_NUM > 501
std::vector<StoredObject> upvalues;
};
std::shared_ptr<SharedData> data_;
};
} // namespace effil

View File

@ -37,8 +37,8 @@ void GC::collect() {
}
}
DEBUG << "Removing " << (objects_.size() - black.size()) << " out of " << objects_.size() << std::endl;
// Sweep phase
DEBUG << "Removing " << (objects_.size() - black.size()) << " out of " << objects_.size() << std::endl;
objects_ = std::move(black);
lastCleanup_.store(0);

View File

@ -42,8 +42,9 @@ std::string dumpFunction(const sol::function& f) {
return result;
}
sol::function loadString(const sol::state_view& lua, const std::string& str) {
int ret = luaL_loadbuffer(lua, str.c_str(), str.size(), nullptr);
sol::function loadString(const sol::state_view& lua, const std::string& str,
const sol::optional<std::string>& source /* = sol::nullopt*/) {
int ret = luaL_loadbuffer(lua, str.c_str(), str.size(), source ? source.value().c_str() : nullptr);
REQUIRE(ret == LUA_OK) << "Unable to load function from string: " << luaError(ret);
return sol::stack::pop<sol::function>(lua);
}

View File

@ -11,7 +11,8 @@ class Channel;
class Thread;
std::string dumpFunction(const sol::function& f);
sol::function loadString(const sol::state_view& lua, const std::string& str);
sol::function loadString(const sol::state_view& lua, const std::string& str,
const sol::optional<std::string>& source = sol::nullopt);
std::chrono::milliseconds fromLuaTime(int duration, const sol::optional<std::string>& period);
template <typename SolObject>
@ -21,7 +22,7 @@ std::string luaTypename(const SolObject& obj) {
return "effil.table";
else if (obj.template is<Channel>())
return "effil.channel";
else if (obj.template is<std::shared_ptr<Thread>>())
else if (obj.template is<Thread>())
return "effil.thread";
else
return "userdata";

View File

@ -15,7 +15,7 @@ sol::object createThread(const sol::this_state& lua,
int step,
const sol::function& function,
const sol::variadic_args& args) {
return sol::make_object(lua, std::make_shared<Thread>(path, cpath, step, function, args));
return sol::make_object(lua, GC::instance().create<Thread>(path, cpath, step, function, args));
}
sol::object createTable(sol::this_state lua, const sol::optional<sol::object>& tbl) {
@ -66,7 +66,8 @@ int luaopen_libeffil(lua_State* L) {
"channel", createChannel,
"type", getLuaTypename,
"pairs", SharedTable::globalLuaPairs,
"ipairs", SharedTable::globalLuaIPairs
"ipairs", SharedTable::globalLuaIPairs,
"allow_table_upvalues", luaAllowTableUpvalues
);
sol::stack::push(lua, publicApi);
return 1;

View File

@ -1,8 +1,8 @@
#include "stored-object.h"
#include "channel.h"
#include "threading.h"
#include "shared-table.h"
#include "function.h"
#include "utils.h"
#include <map>
@ -45,29 +45,6 @@ private:
StoredType data_;
};
class FunctionHolder : public BaseHolder {
public:
template <typename SolObject>
FunctionHolder(SolObject luaObject) noexcept {
sol::state_view lua(luaObject.lua_state());
function_ = dumpFunction(luaObject);
}
bool rawCompare(const BaseHolder* other) const noexcept final {
return function_ < static_cast<const FunctionHolder*>(other)->function_;
}
sol::object unpack(sol::this_state state) const final {
sol::function result = loadString(state, function_);
// The result of restaring always is valid function.
assert(result.valid());
return sol::make_object(state, result);
}
private:
std::string function_;
};
template<typename T>
class GCObjectHolder : public BaseHolder {
public:
@ -87,7 +64,7 @@ public:
return handle_ < static_cast<const GCObjectHolder<T>*>(other)->handle_;
}
sol::object unpack(sol::this_state state) const final {
sol::object unpack(sol::this_state state) const override {
return sol::make_object(state, GC::instance().get<T>(handle_));
}
@ -103,11 +80,22 @@ public:
}
}
private:
protected:
GCObjectHandle handle_;
sol::optional<T> strongRef_;
};
class FunctionHolder : public GCObjectHolder<FunctionObject> {
public:
template <typename SolType>
FunctionHolder(const SolType& luaObject) : GCObjectHolder<FunctionObject>(luaObject) {}
FunctionHolder(GCObjectHandle handle) : GCObjectHolder(handle) {}
sol::object unpack(sol::this_state state) const final {
return GC::instance().get<FunctionObject>(handle_).loadFunction(state);
}
};
// This class is used as a storage for visited sol::tables
// TODO: try to use map or unordered map instead of linear search in vector
// TODO: Trick is - sol::object has only operator==:/
@ -171,12 +159,16 @@ StoredObject fromSolObject(const SolObject& luaObject) {
return std::make_unique<GCObjectHolder<SharedTable>>(luaObject);
else if (luaObject.template is<Channel>())
return std::make_unique<GCObjectHolder<Channel>>(luaObject);
else if (luaObject.template is<std::shared_ptr<Thread>>())
return std::make_unique<PrimitiveHolder<std::shared_ptr<Thread>>>(luaObject);
else if (luaObject.template is<FunctionObject>())
return std::make_unique<FunctionHolder>(luaObject);
else if (luaObject.template is<Thread>())
return std::make_unique<GCObjectHolder<Thread>>(luaObject);
else
throw Exception() << "Unable to store userdata object\n";
case sol::type::function:
return std::make_unique<FunctionHolder>(luaObject);
case sol::type::function: {
FunctionObject func = GC::instance().create<FunctionObject>(luaObject);
return std::make_unique<FunctionHolder>(func.handle());
}
case sol::type::table: {
sol::table luaTable = luaObject;
// Tables pool is used to store tables.

View File

@ -163,43 +163,46 @@ void luaHook(lua_State*, lua_Debug*) {
}
}
void runThread(std::shared_ptr<ThreadHandle> handle,
std::string strFunction,
} // namespace
void Thread::runThread(Thread thread,
FunctionObject function,
effil::StoredArray arguments) {
assert(handle);
thisThreadHandle = handle.get();
thisThreadHandle = thread.handle_.get();
assert(thisThreadHandle != nullptr);
try {
{
ScopeGuard reportComplete([handle, &arguments](){
DEBUG << "Finished " << std::endl;
ScopeGuard reportComplete([thread, &arguments](){
// Let's destroy accociated state
// to release all resources as soon as possible
arguments.clear();
handle->destroyLua();
thread.handle_->destroyLua();
});
sol::function userFuncObj = loadString(handle->lua(), strFunction);
sol::function userFuncObj = function.loadFunction(thread.handle_->lua());
sol::function_result results = userFuncObj(std::move(arguments));
(void)results; // just leave all returns on the stack
sol::variadic_args args(handle->lua(), -lua_gettop(handle->lua()));
sol::variadic_args args(thread.handle_->lua(), -lua_gettop(thread.handle_->lua()));
for (const auto& iter : args) {
StoredObject store = createStoredObject(iter.get<sol::object>());
handle->result().emplace_back(std::move(store));
if (store->gcHandle() != nullptr)
{
thread.addReference(store->gcHandle());
store->releaseStrongReference();
}
thread.handle_->result().emplace_back(std::move(store));
}
}
handle->changeStatus(Status::Completed);
thread.handle_->changeStatus(Status::Completed);
} catch (const LuaHookStopException&) {
handle->changeStatus(Status::Canceled);
thread.handle_->changeStatus(Status::Canceled);
} catch (const sol::error& err) {
DEBUG << "Failed with msg: " << err.what() << std::endl;
handle->result().emplace_back(createStoredObject(err.what()));
handle->changeStatus(Status::Failed);
thread.handle_->result().emplace_back(createStoredObject(err.what()));
thread.handle_->changeStatus(Status::Failed);
}
}
} // namespace
std::string threadId() {
std::stringstream ss;
ss << std::this_thread::get_id();
@ -233,6 +236,11 @@ Thread::Thread(const std::string& path,
const sol::variadic_args& variadicArgs)
: handle_(std::make_shared<ThreadHandle>()) {
sol::optional<FunctionObject> functionObj;
try {
functionObj = FunctionObject(function);
} RETHROW_WITH_PREFIX("effil.thread");
handle_->lua()["package"]["path"] = path;
handle_->lua()["package"]["cpath"] = cpath;
handle_->lua().script("require 'effil'");
@ -240,20 +248,20 @@ Thread::Thread(const std::string& path,
if (step != 0)
lua_sethook(handle_->lua(), luaHook, LUA_MASKCOUNT, step);
std::string strFunction = dumpFunction(function);
effil::StoredArray arguments;
try {
for (const auto& arg : variadicArgs) {
arguments.emplace_back(createStoredObject(arg.get<sol::object>()));
const auto& storedObj = createStoredObject(arg.get<sol::object>());
addReference(storedObj->gcHandle());
storedObj->releaseStrongReference();
arguments.emplace_back(storedObj);
}
} RETHROW_WITH_PREFIX("effil.thread");
std::thread thr(&runThread,
handle_,
std::move(strFunction),
std::thread thr(&Thread::runThread,
*this,
functionObj.value(),
std::move(arguments));
DEBUG << "Created " << thr.get_id() << std::endl;
thr.detach();
}

View File

@ -2,6 +2,7 @@
#include <sol.hpp>
#include "lua-helpers.h"
#include "function.h"
namespace effil {
@ -12,7 +13,7 @@ void sleep(const sol::stack_object& duration, const sol::stack_object& metric);
class ThreadHandle;
class Thread {
class Thread : public GCObject {
public:
Thread(const std::string& path,
const std::string& cpath,
@ -37,11 +38,9 @@ public:
void resume();
private:
std::shared_ptr<ThreadHandle> handle_;
static void runThread(Thread, FunctionObject, effil::StoredArray);
private:
Thread(const Thread&) = delete;
Thread& operator=(const Thread&) = delete;
std::shared_ptr<ThreadHandle> handle_;
};
} // effil

View File

@ -15,7 +15,8 @@ local api = {
channel = capi.channel,
type = capi.type,
pairs = capi.pairs,
ipairs = capi.ipairs
ipairs = capi.ipairs,
allow_table_upvalues = capi.allow_table_upvalues
}
api.size = function (something)

View File

@ -9,6 +9,13 @@ function default_tear_down()
effil.gc.collect()
-- effil.G is always present
-- thus, gc has one object
if effil.gc.count() ~= 1 then
print "Not all bojects were removed, gonna sleep for 2 seconds"
effil.sleep(2)
collectgarbage()
effil.gc.collect()
end
test.equal(effil.gc.count(), 1)
end

View File

@ -34,6 +34,4 @@ test.gc_stress.create_and_collect_in_parallel = function ()
for i = 1, thread_num do
test.equal(threads[i]:wait(), "completed")
end
test.equal(effil.gc.count(), 1)
end

View File

@ -9,7 +9,7 @@ test.metatable.tear_down = function (metatable)
if type(metatable) == "table" then
test.equal(effil.gc.count(), 1)
else
test.equal(effil.gc.count(), 2)
test.equal(effil.gc.count(), 3)
end
end

View File

@ -14,6 +14,7 @@ require "thread"
require "shared-table"
require "metatable"
require "type_mismatch"
require "upvalues"
if os.getenv("STRESS") then
require "channel-stress"

View File

@ -98,6 +98,7 @@ test.thread.cancel = function ()
test.is_true(thread:cancel())
test.equal(thread:status(), "canceled")
end
test.thread.async_cancel = function ()
local thread_runner = effil.thread(
function()
@ -213,9 +214,6 @@ test.thread.returns = function ()
test.is_function(returns[5])
test.equal(returns[5](11, 89), 100)
-- Workaround to get child thread free all return values
effil.sleep(2)
end
test.thread.timed_cancel = function ()

View File

@ -1,9 +1,16 @@
require "bootstrap-tests"
test.type = function()
test.type.tear_down = default_tear_down
test.type.check_all_types = function()
test.equal(effil.type(1), "number")
test.equal(effil.type("string"), "string")
test.equal(effil.type(true), "boolean")
test.equal(effil.type(nil), "nil")
test.equal(effil.type(function()end), "function")
test.equal(effil.type(effil.table()), "effil.table")
test.equal(effil.type(effil.channel()), "effil.channel")
test.equal(effil.type(effil.thread(function() end)()), "effil.thread")
local thr = effil.thread(function() end)()
test.equal(effil.type(thr), "effil.thread")
thr:wait()
end

View File

@ -122,6 +122,11 @@ local function generate_tests()
-- effil.gc.step
test.type_mismatch.input_types_mismatch_p(1, "number", "gc.step", type_instance)
end
if typename ~= "boolean" then
-- effil.allow_table_upvalue
test.type_mismatch.input_types_mismatch_p(1, "boolean", "allow_table_upvalues", type_instance)
end
end
-- Below presented tests which support everything except coroutines
@ -149,8 +154,7 @@ end
-- Put it to function to limit the lifetime of objects
generate_tests()
test.type_mismatch.gc_checks_after_tests = function ()
collectgarbage()
effil.gc.collect()
test.equal(effil.gc.count(), 1)
test.type_mismatch.gc_checks_after_tests = function()
effil.allow_table_upvalues(true)
default_tear_down()
end

119
tests/lua/upvalues.lua Normal file
View File

@ -0,0 +1,119 @@
require "bootstrap-tests"
test.upvalues.tear_down = default_tear_down
test.upvalues.check_single_upvalue_p = function(type_creator, type_checker)
local obj = type_creator()
local thread_worker = function(checker) return require("effil").type(obj) .. ": " .. checker(obj) end
local ret = effil.thread(thread_worker)(type_checker):get()
print("Returned: " .. ret)
test.equal(ret, effil.type(obj) .. ": " .. type_checker(obj))
end
local foo = function() return 22 end
test.upvalues.check_single_upvalue_p(function() return 1488 end,
function() return "1488" end)
test.upvalues.check_single_upvalue_p(function() return "awesome" end,
function() return "awesome" end)
test.upvalues.check_single_upvalue_p(function() return true end,
function() return "true" end)
test.upvalues.check_single_upvalue_p(function() return nil end,
function() return "nil" end)
test.upvalues.check_single_upvalue_p(function() return foo end,
function(f) return f() end)
test.upvalues.check_single_upvalue_p(function() return effil.table({key = 44}) end,
function(t) return t.key end)
test.upvalues.check_single_upvalue_p(function() local c = effil.channel() c:push(33) c:push(33) return c end,
function(c) return c:pop() end)
test.upvalues.check_single_upvalue_p(function() return effil.thread(foo)() end,
function(t) return t:get() end)
test.upvalues.check_invalid_coroutine = function()
local obj = coroutine.create(foo)
local thread_worker = function() return tostring(obj) end
local ret, err = pcall(effil.thread(thread_worker))
if ret then
ret:wait()
end
test.is_false(ret)
print("Returned: " .. err)
upvalue_num = LUA_VERSION > 51 and 2 or 1
test.equal(err, "effil.thread: bad function upvalue #" .. upvalue_num ..
" (unable to store object of thread type)")
end
test.upvalues.check_table = function()
local obj = { key = "value" }
local thread_worker = function() return require("effil").type(obj) .. ": " .. obj.key end
local ret = effil.thread(thread_worker)():get()
print("Returned: " .. ret)
test.equal(ret, "effil.table: value")
end
test.upvalues.check_env = function()
local obj1 = 13 -- local
obj2 = { key = "origin" } -- global
local obj3 = 79 -- local
local function foo() -- _ENV is 2nd upvalue
return obj1, obj2.key, obj3
end
local function thread_worker(func)
obj1 = 31 -- global
obj2 = { key = "local" } -- global
obj3 = 97 -- global
return table.concat({func()}, ", ")
end
local ret = effil.thread(thread_worker)(foo):get()
print("Returned: " .. ret)
test.equal(ret, "13, local, 79")
end
local function check_works(should_work)
local obj = { key = "value"}
local function worker()
return obj.key
end
local ret, err = pcall(effil.thread(worker))
if ret then
err:wait()
end
test.equal(ret, should_work)
if not should_work then
test.equal(err, "effil.thread: bad function upvalue #1 (table is disabled by effil.allow_table_upvalues)")
end
end
test.upvalues_table.tear_down = function()
effil.allow_table_upvalues(true)
default_tear_down()
end
test.upvalues_table.disabling_table_upvalues = function()
test.equal(effil.allow_table_upvalues(), true)
-- works by default
check_works(true)
-- disable
test.equal(effil.allow_table_upvalues(false), true)
check_works(false)
test.equal(effil.allow_table_upvalues(), false)
-- enable back
test.equal(effil.allow_table_upvalues(true), false)
check_works(true)
test.equal(effil.allow_table_upvalues(), true)
end