diff --git a/ci/test_all.sh b/ci/test_all.sh index 49a63f5..b705d35 100755 --- a/ci/test_all.sh +++ b/ci/test_all.sh @@ -4,5 +4,5 @@ set -e for build_type in debug release; do mkdir -p $build_type (cd $build_type && cmake -DCMAKE_BUILD_TYPE=$build_type $@ .. && make -j4 install) - (cd $build_type && ./tests && lua run_tests.lua) + (cd $build_type && ./tests && lua run_tests.lua --extra-checks) done diff --git a/src/cpp/channel.cpp b/src/cpp/channel.cpp new file mode 100644 index 0000000..0302619 --- /dev/null +++ b/src/cpp/channel.cpp @@ -0,0 +1,68 @@ +#include "channel.h" + +#include "sol.hpp" + +namespace effil { + +void Channel::getUserType(sol::state_view& lua) { + sol::usertype type("new", sol::no_constructor, + "push", &Channel::push, + "pop", &Channel::pop + ); + sol::stack::push(lua, type); + sol::stack::pop(lua); +} + +Channel::Channel(sol::optional capacity) : data_(std::make_shared()){ + if (capacity) { + REQUIRE(capacity.value() >= 0) << "Invalid capacity value = " << capacity.value(); + data_->capacity_ = static_cast(capacity.value()); + } + else { + data_->capacity_ = 0; + } +} + +bool Channel::push(const sol::variadic_args& args) { + if (!args.leftover_count()) + return false; + + std::unique_lock lock(data_->lock_); + if (data_->capacity_ && data_->channel_.size() >= data_->capacity_) + return false; + effil::StoredArray array; + for (const auto& arg : args) { + auto obj = createStoredObject(arg.get()); + if (obj->gcHandle()) + refs_->insert(obj->gcHandle()); + array.emplace_back(obj); + } + if (data_->channel_.empty()) + data_->cv_.notify_one(); + data_->channel_.emplace(array); + return true; +} + +StoredArray Channel::pop(const sol::optional& duration, + const sol::optional& period) { + std::unique_lock lock(data_->lock_); + while (data_->channel_.empty()) { + if (duration) { + if (data_->cv_.wait_for(lock, fromLuaTime(duration.value(), period)) == std::cv_status::timeout) + return StoredArray(); + } + else { // No time limit + data_->cv_.wait(lock); + } + } + + auto ret = data_->channel_.front(); + for (const auto& obj: ret) { + if (obj->gcHandle()) + refs_->erase(obj->gcHandle()); + } + data_->channel_.pop(); + return ret; +} + +} // namespace effil diff --git a/src/cpp/channel.h b/src/cpp/channel.h new file mode 100644 index 0000000..cfa3a67 --- /dev/null +++ b/src/cpp/channel.h @@ -0,0 +1,30 @@ +#pragma once + +#include "notifier.h" +#include "stored-object.h" +#include "lua-helpers.h" +#include "queue" + +namespace effil { + +class Channel : public GCObject { +public: + Channel(sol::optional capacity); + static void getUserType(sol::state_view& lua); + + bool push(const sol::variadic_args& args); + StoredArray pop(const sol::optional& duration, + const sol::optional& period); + +protected: + struct SharedData { + std::mutex lock_; + std::condition_variable cv_; + size_t capacity_; + std::queue channel_; + }; + + std::shared_ptr data_; +}; + +} // namespace effil diff --git a/src/cpp/garbage-collector.h b/src/cpp/garbage-collector.h index 1596b56..268e2b3 100644 --- a/src/cpp/garbage-collector.h +++ b/src/cpp/garbage-collector.h @@ -1,9 +1,7 @@ #pragma once #include "spin-mutex.h" - #include - #include #include #include @@ -87,4 +85,4 @@ private: GC(const GC&) = delete; }; -} // effil \ No newline at end of file +} // effil diff --git a/src/cpp/lua-helpers.cpp b/src/cpp/lua-helpers.cpp new file mode 100644 index 0000000..4a95696 --- /dev/null +++ b/src/cpp/lua-helpers.cpp @@ -0,0 +1,17 @@ +#include "lua-helpers.h" + +namespace effil { + +std::chrono::milliseconds fromLuaTime(int duration, const sol::optional& period) { + using namespace std::chrono; + + REQUIRE(duration >= 0) << "Invalid duration interval: " << duration; + + std::string metric = period ? period.value() : "s"; + if (metric == "ms") return milliseconds(duration); + else if (metric == "s") return seconds(duration); + else if (metric == "m") return minutes(duration); + else throw sol::error("invalid time identification: " + metric); +} + +} // namespace effil diff --git a/src/cpp/lua-helpers.h b/src/cpp/lua-helpers.h index 7c7b7aa..2862849 100644 --- a/src/cpp/lua-helpers.h +++ b/src/cpp/lua-helpers.h @@ -27,6 +27,8 @@ inline sol::function loadString(const sol::state_view& lua, const std::string& s return loader(str); } +std::chrono::milliseconds fromLuaTime(int duration, const sol::optional& period); + typedef std::vector StoredArray; } // namespace effil diff --git a/src/cpp/lua-module.cpp b/src/cpp/lua-module.cpp index 3b951bd..bfb0bf9 100644 --- a/src/cpp/lua-module.cpp +++ b/src/cpp/lua-module.cpp @@ -1,6 +1,7 @@ #include "threading.h" #include "shared-table.h" #include "garbage-collector.h" +#include "channel.h" #include @@ -22,6 +23,10 @@ sol::object createTable(sol::this_state lua) { return sol::make_object(lua, GC::instance().create()); } +sol::object createChannel(sol::optional capacity, sol::this_state lua) { + return sol::make_object(lua, GC::instance().create(capacity)); +} + SharedTable globalTable = GC::instance().create(); } // namespace @@ -30,6 +35,7 @@ extern "C" int luaopen_libeffil(lua_State* L) { sol::state_view lua(L); Thread::getUserType(lua); SharedTable::getUserType(lua); + Channel::getUserType(lua); sol::table publicApi = lua.create_table_with( "thread", createThread, "thread_id", threadId, @@ -43,7 +49,8 @@ extern "C" int luaopen_libeffil(lua_State* L) { "getmetatable", SharedTable::luaGetMetatable, "G", sol::make_object(lua, globalTable), "getmetatable", SharedTable::luaGetMetatable, - "gc", GC::getLuaApi(lua) + "gc", GC::getLuaApi(lua), + "channel", createChannel ); sol::stack::push(lua, publicApi); return 1; diff --git a/src/cpp/notifier.h b/src/cpp/notifier.h index 997867e..a02db77 100644 --- a/src/cpp/notifier.h +++ b/src/cpp/notifier.h @@ -46,4 +46,4 @@ private: Notifier(Notifier& ) = delete; }; -} // namespace effil \ No newline at end of file +} // namespace effil diff --git a/src/cpp/stored-object.cpp b/src/cpp/stored-object.cpp index 9a3dfc6..387cbbf 100644 --- a/src/cpp/stored-object.cpp +++ b/src/cpp/stored-object.cpp @@ -1,4 +1,5 @@ #include "stored-object.h" +#include "channel.h" #include "threading.h" #include "shared-table.h" @@ -73,24 +74,25 @@ private: std::string function_; }; -class TableHolder : public BaseHolder { +template +class GCObjectHolder : public BaseHolder { public: template - TableHolder(const SolType& luaObject) { - assert(luaObject.template is()); - handle_ = luaObject.template as().handle(); + GCObjectHolder(const SolType& luaObject) { + assert(luaObject.template is()); + handle_ = luaObject.template as().handle(); assert(GC::instance().has(handle_)); } - TableHolder(GCObjectHandle handle) + GCObjectHolder(GCObjectHandle handle) : handle_(handle) {} bool rawCompare(const BaseHolder* other) const final { - return handle_ < static_cast(other)->handle_; + return handle_ < static_cast*>(other)->handle_; } sol::object unpack(sol::this_state state) const final { - return sol::make_object(state, GC::instance().get(handle_)); + return sol::make_object(state, GC::instance().get(handle_)); } GCObjectHandle gcHandle() const override { return handle_; } @@ -118,9 +120,9 @@ StoredObject makeStoredObject(sol::object luaObject, SolTableToShared& visited) SharedTable table = GC::instance().create(); visited.emplace_back(std::make_pair(luaTable, table.handle())); dumpTable(&table, luaTable, visited); - return createStoredObject(table.handle()); + return std::make_unique>(table.handle()); } else { - return createStoredObject(st->second); + return std::make_unique>(st->second); } } else { return createStoredObject(luaObject); @@ -149,7 +151,9 @@ StoredObject fromSolObject(const SolObject& luaObject) { return std::make_unique>(luaObject); case sol::type::userdata: if (luaObject.template is()) - return std::make_unique(luaObject); + return std::make_unique>(luaObject); + else if (luaObject.template is()) + return std::make_unique>(luaObject); else if (luaObject.template is>()) return std::make_unique>>(luaObject); else @@ -167,7 +171,7 @@ StoredObject fromSolObject(const SolObject& luaObject) { // SolTableToShared is used to prevent from infinity recursion // in recursive tables dumpTable(&table, luaTable, visited); - return std::make_unique(table.handle()); + return std::make_unique>(table.handle()); } default: throw Exception() << "Unable to store object of that type: " << (int)luaObject.get_type() << "\n"; @@ -193,8 +197,6 @@ StoredObject createStoredObject(const sol::object& object) { return fromSolObjec StoredObject createStoredObject(const sol::stack_object& object) { return fromSolObject(object); } -StoredObject createStoredObject(GCObjectHandle handle) { return std::make_unique(handle); } - template sol::optional getPrimitiveHolderData(const StoredObject& sobj) { auto ptr = dynamic_cast*>(sobj.get()); diff --git a/src/cpp/stored-object.h b/src/cpp/stored-object.h index 17f3b0a..6768051 100644 --- a/src/cpp/stored-object.h +++ b/src/cpp/stored-object.h @@ -37,7 +37,6 @@ StoredObject createStoredObject(bool); StoredObject createStoredObject(double); StoredObject createStoredObject(const std::string&); StoredObject createStoredObject(const char*); -StoredObject createStoredObject(GCObjectHandle); StoredObject createStoredObject(const sol::object&); StoredObject createStoredObject(const sol::stack_object&); diff --git a/src/cpp/threading.cpp b/src/cpp/threading.cpp index 3cfa80d..fed4d2a 100644 --- a/src/cpp/threading.cpp +++ b/src/cpp/threading.cpp @@ -162,18 +162,6 @@ void runThread(std::shared_ptr handle, } } -std::chrono::milliseconds fromLuaTime(int duration, const sol::optional& period) { - using namespace std::chrono; - - REQUIRE(duration >= 0) << "Invalid duration interval: " << duration; - - std::string metric = period ? period.value() : "s"; - if (metric == "ms") return milliseconds(duration); - else if (metric == "s") return seconds(duration); - else if (metric == "m") return minutes(duration); - else throw sol::error("invalid time identification: " + metric); -} - } // namespace diff --git a/src/lua/effil.lua b/src/lua/effil.lua index 94d2934..d3360e7 100644 --- a/src/lua/effil.lua +++ b/src/lua/effil.lua @@ -23,7 +23,8 @@ local api = { setmetatable = capi.setmetatable, getmetatable = capi.getmetatable, G = capi.G, - gc = capi.gc + gc = capi.gc, + channel = capi.channel } local function run_thread(config, f, ...) diff --git a/tests/lua/channel.lua b/tests/lua/channel.lua new file mode 100644 index 0000000..87a2c8f --- /dev/null +++ b/tests/lua/channel.lua @@ -0,0 +1,131 @@ +TestChannels = {tearDown = tearDown} + +function TestChannels:testCapacityUsage() + local chan = effil.channel(2) + + test.assertTrue(chan:push(14)) + test.assertTrue(chan:push(88)) + test.assertFalse(chan:push(1488)) + + test.assertEquals(chan:pop(), 14) + test.assertEquals(chan:pop(), 88) + test.assertIsNil(chan:pop(0)) + + test.assertTrue(chan:push(14, 88), true) + local ret1, ret2 = chan:pop() + test.assertEquals(ret1, 14) + test.assertEquals(ret2, 88) +end + +function TestChannels:testRecursiveChannels() + local chan1 = effil.channel() + local chan2 = effil.channel() + local msg1, msg2 = "first channel", "second channel" + test.assertTrue(chan1:push(msg1, chan2)) + test.assertTrue(chan2:push(msg2, chan1)) + + local ret1 = { chan1:pop() } + test.assertEquals(ret1[1], msg1) + test.assertEquals(type(ret1[2]), "userdata") + local ret2 = { ret1[2]:pop() } + test.assertEquals(ret2[1], msg2) + test.assertEquals(type(ret2[2]), "userdata") +end + +function TestChannels:testWithThread() + local chan = effil.channel() + local thread = effil.thread(function(chan) + chan:push("message1") + chan:push("message2") + chan:push("message3") + chan:push("message4") + end + )(chan) + + local start_time = os.time() + test.assertEquals(chan:pop(), "message1") + thread:wait() + test.assertEquals(chan:pop(0), "message2") + test.assertEquals(chan:pop(1), "message3") + test.assertEquals(chan:pop(1, 'm'), "message4") + test.assertTrue(os.time() < start_time + 1) +end + +function TestChannels:testWithSharedTable() + local chan = effil.channel() + local table = effil.table() + + local test_value = "i'm value" + table.test_key = test_value + + chan:push(table) + test.assertEquals(chan:pop().test_key, test_value) + + table.channel = chan + table.channel:push(test_value) + test.assertEquals(table.channel:pop(), test_value) +end + +if WITH_EXTRA_CHECKS then + +function TestChannels:testStressLoadWithMultipleThreads() + local exchange_channel, result_channel = effil.channel(), effil.channel() + + local threads_number = 1000 + for i = 1, threads_number do + effil.thread(function(exchange_channel, result_channel, indx) + if indx % 2 == 0 then + for i = 1, 10000 do + exchange_channel:push(indx .. "_".. i) + end + else + repeat + local ret = exchange_channel:pop(10) + if ret then + result_channel:push(ret) + end + until ret == nil + end + end + )(exchange_channel, result_channel, i) + end + + local data = {} + for i = 1, (threads_number / 2) * 10000 do + local ret = result_channel:pop(10) + test.assertNotIsNil(ret) + test.assertIsString(ret) + test.assertIsNil(data[ret]) + data[ret] = true + end + + for thr_id = 2, threads_number, 2 do + for iter = 1, 10000 do + test.assertTrue(data[thr_id .. "_".. iter]) + end + end +end + +function TestChannels:testTimedRead() + local chan = effil.channel() + local delayedWriter = function(channel, delay) + require("effil").sleep(delay) + channel:push("hello!") + end + effil.thread(delayedWriter)(chan, 70) + + local function check_time(real_time, use_time, metric, result) + local start_time = os.time() + test.assertEquals(chan:pop(use_time, metric), result) + test.assertAlmostEquals(os.time(), start_time + real_time, 1) + end + check_time(2, 2, nil, nil) -- second by default + check_time(2, 2, 's', nil) + check_time(60, 1, 'm', nil) + + local start_time = os.time() + test.assertEquals(chan:pop(10), "hello!") + test.assertTrue(os.time() < start_time + 10) +end + +end -- WITH_EXTRA_CHECKS diff --git a/tests/lua/run_tests.lua b/tests/lua/run_tests.lua index b2292d7..2d6bc18 100755 --- a/tests/lua/run_tests.lua +++ b/tests/lua/run_tests.lua @@ -9,14 +9,17 @@ print("---------------") do -- Hack input arguments to make tests verbose by default - local found = false - for _, v in ipairs(arg) do + local make_verbose = true + for i, v in ipairs(arg) do if v == '-o' or v == '--output' then - found = true - break + make_verbose = false + elseif v == "--extra-checks" then + table.remove(arg, i) + WITH_EXTRA_CHECKS = true + print "# RUN TESTS WITH EXTRA CHECKS" end end - if not found then + if make_verbose then table.insert(arg, '-o') table.insert(arg, 'TAP') end @@ -32,6 +35,7 @@ require 'test_utils' require 'thread' require 'shared_table' require 'gc' +require 'channel' -- Hack tests functions to print when test starts for suite_name, suite in pairs(_G) do diff --git a/tests/lua/thread.lua b/tests/lua/thread.lua index 1237b8d..2a9b859 100644 --- a/tests/lua/thread.lua +++ b/tests/lua/thread.lua @@ -284,6 +284,8 @@ function TestThisThread:testThisThreadFunctions() effil.yield() -- just call it end +if WITH_EXTRA_CHECKS then + function TestThisThread:testTime() local function check_time(real_time, use_time, metric) local start_time = os.time() @@ -295,3 +297,5 @@ function TestThisThread:testTime() check_time(4, 4000, 'ms') check_time(60, 1, 'm') end + +end -- WITH_EXTRA_CHECKS