Channels implementation and tests (#42)

Channel implementation
This commit is contained in:
mihacooper 2017-04-16 22:58:21 +03:00 committed by Ilia
parent 64628c1757
commit 5057116136
15 changed files with 289 additions and 38 deletions

View File

@ -4,5 +4,5 @@ set -e
for build_type in debug release; do for build_type in debug release; do
mkdir -p $build_type mkdir -p $build_type
(cd $build_type && cmake -DCMAKE_BUILD_TYPE=$build_type $@ .. && make -j4 install) (cd $build_type && cmake -DCMAKE_BUILD_TYPE=$build_type $@ .. && make -j4 install)
(cd $build_type && ./tests && lua run_tests.lua) (cd $build_type && ./tests && lua run_tests.lua --extra-checks)
done done

68
src/cpp/channel.cpp Normal file
View File

@ -0,0 +1,68 @@
#include "channel.h"
#include "sol.hpp"
namespace effil {
void Channel::getUserType(sol::state_view& lua) {
sol::usertype<Channel> type("new", sol::no_constructor,
"push", &Channel::push,
"pop", &Channel::pop
);
sol::stack::push(lua, type);
sol::stack::pop<sol::object>(lua);
}
Channel::Channel(sol::optional<int> capacity) : data_(std::make_shared<SharedData>()){
if (capacity) {
REQUIRE(capacity.value() >= 0) << "Invalid capacity value = " << capacity.value();
data_->capacity_ = static_cast<size_t>(capacity.value());
}
else {
data_->capacity_ = 0;
}
}
bool Channel::push(const sol::variadic_args& args) {
if (!args.leftover_count())
return false;
std::unique_lock<std::mutex> lock(data_->lock_);
if (data_->capacity_ && data_->channel_.size() >= data_->capacity_)
return false;
effil::StoredArray array;
for (const auto& arg : args) {
auto obj = createStoredObject(arg.get<sol::object>());
if (obj->gcHandle())
refs_->insert(obj->gcHandle());
array.emplace_back(obj);
}
if (data_->channel_.empty())
data_->cv_.notify_one();
data_->channel_.emplace(array);
return true;
}
StoredArray Channel::pop(const sol::optional<int>& duration,
const sol::optional<std::string>& period) {
std::unique_lock<std::mutex> lock(data_->lock_);
while (data_->channel_.empty()) {
if (duration) {
if (data_->cv_.wait_for(lock, fromLuaTime(duration.value(), period)) == std::cv_status::timeout)
return StoredArray();
}
else { // No time limit
data_->cv_.wait(lock);
}
}
auto ret = data_->channel_.front();
for (const auto& obj: ret) {
if (obj->gcHandle())
refs_->erase(obj->gcHandle());
}
data_->channel_.pop();
return ret;
}
} // namespace effil

30
src/cpp/channel.h Normal file
View File

@ -0,0 +1,30 @@
#pragma once
#include "notifier.h"
#include "stored-object.h"
#include "lua-helpers.h"
#include "queue"
namespace effil {
class Channel : public GCObject {
public:
Channel(sol::optional<int> capacity);
static void getUserType(sol::state_view& lua);
bool push(const sol::variadic_args& args);
StoredArray pop(const sol::optional<int>& duration,
const sol::optional<std::string>& period);
protected:
struct SharedData {
std::mutex lock_;
std::condition_variable cv_;
size_t capacity_;
std::queue<StoredArray> channel_;
};
std::shared_ptr<SharedData> data_;
};
} // namespace effil

View File

@ -1,9 +1,7 @@
#pragma once #pragma once
#include "spin-mutex.h" #include "spin-mutex.h"
#include <sol.hpp> #include <sol.hpp>
#include <mutex> #include <mutex>
#include <map> #include <map>
#include <set> #include <set>

17
src/cpp/lua-helpers.cpp Normal file
View File

@ -0,0 +1,17 @@
#include "lua-helpers.h"
namespace effil {
std::chrono::milliseconds fromLuaTime(int duration, const sol::optional<std::string>& period) {
using namespace std::chrono;
REQUIRE(duration >= 0) << "Invalid duration interval: " << duration;
std::string metric = period ? period.value() : "s";
if (metric == "ms") return milliseconds(duration);
else if (metric == "s") return seconds(duration);
else if (metric == "m") return minutes(duration);
else throw sol::error("invalid time identification: " + metric);
}
} // namespace effil

View File

@ -27,6 +27,8 @@ inline sol::function loadString(const sol::state_view& lua, const std::string& s
return loader(str); return loader(str);
} }
std::chrono::milliseconds fromLuaTime(int duration, const sol::optional<std::string>& period);
typedef std::vector<effil::StoredObject> StoredArray; typedef std::vector<effil::StoredObject> StoredArray;
} // namespace effil } // namespace effil

View File

@ -1,6 +1,7 @@
#include "threading.h" #include "threading.h"
#include "shared-table.h" #include "shared-table.h"
#include "garbage-collector.h" #include "garbage-collector.h"
#include "channel.h"
#include <lua.hpp> #include <lua.hpp>
@ -22,6 +23,10 @@ sol::object createTable(sol::this_state lua) {
return sol::make_object(lua, GC::instance().create<SharedTable>()); return sol::make_object(lua, GC::instance().create<SharedTable>());
} }
sol::object createChannel(sol::optional<int> capacity, sol::this_state lua) {
return sol::make_object(lua, GC::instance().create<Channel>(capacity));
}
SharedTable globalTable = GC::instance().create<SharedTable>(); SharedTable globalTable = GC::instance().create<SharedTable>();
} // namespace } // namespace
@ -30,6 +35,7 @@ extern "C" int luaopen_libeffil(lua_State* L) {
sol::state_view lua(L); sol::state_view lua(L);
Thread::getUserType(lua); Thread::getUserType(lua);
SharedTable::getUserType(lua); SharedTable::getUserType(lua);
Channel::getUserType(lua);
sol::table publicApi = lua.create_table_with( sol::table publicApi = lua.create_table_with(
"thread", createThread, "thread", createThread,
"thread_id", threadId, "thread_id", threadId,
@ -43,7 +49,8 @@ extern "C" int luaopen_libeffil(lua_State* L) {
"getmetatable", SharedTable::luaGetMetatable, "getmetatable", SharedTable::luaGetMetatable,
"G", sol::make_object(lua, globalTable), "G", sol::make_object(lua, globalTable),
"getmetatable", SharedTable::luaGetMetatable, "getmetatable", SharedTable::luaGetMetatable,
"gc", GC::getLuaApi(lua) "gc", GC::getLuaApi(lua),
"channel", createChannel
); );
sol::stack::push(lua, publicApi); sol::stack::push(lua, publicApi);
return 1; return 1;

View File

@ -1,4 +1,5 @@
#include "stored-object.h" #include "stored-object.h"
#include "channel.h"
#include "threading.h" #include "threading.h"
#include "shared-table.h" #include "shared-table.h"
@ -73,24 +74,25 @@ private:
std::string function_; std::string function_;
}; };
class TableHolder : public BaseHolder { template<typename T>
class GCObjectHolder : public BaseHolder {
public: public:
template <typename SolType> template <typename SolType>
TableHolder(const SolType& luaObject) { GCObjectHolder(const SolType& luaObject) {
assert(luaObject.template is<SharedTable>()); assert(luaObject.template is<T>());
handle_ = luaObject.template as<SharedTable>().handle(); handle_ = luaObject.template as<T>().handle();
assert(GC::instance().has(handle_)); assert(GC::instance().has(handle_));
} }
TableHolder(GCObjectHandle handle) GCObjectHolder(GCObjectHandle handle)
: handle_(handle) {} : handle_(handle) {}
bool rawCompare(const BaseHolder* other) const final { bool rawCompare(const BaseHolder* other) const final {
return handle_ < static_cast<const TableHolder*>(other)->handle_; return handle_ < static_cast<const GCObjectHolder<T>*>(other)->handle_;
} }
sol::object unpack(sol::this_state state) const final { sol::object unpack(sol::this_state state) const final {
return sol::make_object(state, GC::instance().get<SharedTable>(handle_)); return sol::make_object(state, GC::instance().get<T>(handle_));
} }
GCObjectHandle gcHandle() const override { return handle_; } GCObjectHandle gcHandle() const override { return handle_; }
@ -118,9 +120,9 @@ StoredObject makeStoredObject(sol::object luaObject, SolTableToShared& visited)
SharedTable table = GC::instance().create<SharedTable>(); SharedTable table = GC::instance().create<SharedTable>();
visited.emplace_back(std::make_pair(luaTable, table.handle())); visited.emplace_back(std::make_pair(luaTable, table.handle()));
dumpTable(&table, luaTable, visited); dumpTable(&table, luaTable, visited);
return createStoredObject(table.handle()); return std::make_unique<GCObjectHolder<SharedTable>>(table.handle());
} else { } else {
return createStoredObject(st->second); return std::make_unique<GCObjectHolder<SharedTable>>(st->second);
} }
} else { } else {
return createStoredObject(luaObject); return createStoredObject(luaObject);
@ -149,7 +151,9 @@ StoredObject fromSolObject(const SolObject& luaObject) {
return std::make_unique<PrimitiveHolder<void*>>(luaObject); return std::make_unique<PrimitiveHolder<void*>>(luaObject);
case sol::type::userdata: case sol::type::userdata:
if (luaObject.template is<SharedTable>()) if (luaObject.template is<SharedTable>())
return std::make_unique<TableHolder>(luaObject); return std::make_unique<GCObjectHolder<SharedTable>>(luaObject);
else if (luaObject.template is<Channel>())
return std::make_unique<GCObjectHolder<Channel>>(luaObject);
else if (luaObject.template is<std::shared_ptr<Thread>>()) else if (luaObject.template is<std::shared_ptr<Thread>>())
return std::make_unique<PrimitiveHolder<std::shared_ptr<Thread>>>(luaObject); return std::make_unique<PrimitiveHolder<std::shared_ptr<Thread>>>(luaObject);
else else
@ -167,7 +171,7 @@ StoredObject fromSolObject(const SolObject& luaObject) {
// SolTableToShared is used to prevent from infinity recursion // SolTableToShared is used to prevent from infinity recursion
// in recursive tables // in recursive tables
dumpTable(&table, luaTable, visited); dumpTable(&table, luaTable, visited);
return std::make_unique<TableHolder>(table.handle()); return std::make_unique<GCObjectHolder<SharedTable>>(table.handle());
} }
default: default:
throw Exception() << "Unable to store object of that type: " << (int)luaObject.get_type() << "\n"; throw Exception() << "Unable to store object of that type: " << (int)luaObject.get_type() << "\n";
@ -193,8 +197,6 @@ StoredObject createStoredObject(const sol::object& object) { return fromSolObjec
StoredObject createStoredObject(const sol::stack_object& object) { return fromSolObject(object); } StoredObject createStoredObject(const sol::stack_object& object) { return fromSolObject(object); }
StoredObject createStoredObject(GCObjectHandle handle) { return std::make_unique<TableHolder>(handle); }
template <typename DataType> template <typename DataType>
sol::optional<DataType> getPrimitiveHolderData(const StoredObject& sobj) { sol::optional<DataType> getPrimitiveHolderData(const StoredObject& sobj) {
auto ptr = dynamic_cast<PrimitiveHolder<DataType>*>(sobj.get()); auto ptr = dynamic_cast<PrimitiveHolder<DataType>*>(sobj.get());

View File

@ -37,7 +37,6 @@ StoredObject createStoredObject(bool);
StoredObject createStoredObject(double); StoredObject createStoredObject(double);
StoredObject createStoredObject(const std::string&); StoredObject createStoredObject(const std::string&);
StoredObject createStoredObject(const char*); StoredObject createStoredObject(const char*);
StoredObject createStoredObject(GCObjectHandle);
StoredObject createStoredObject(const sol::object&); StoredObject createStoredObject(const sol::object&);
StoredObject createStoredObject(const sol::stack_object&); StoredObject createStoredObject(const sol::stack_object&);

View File

@ -162,18 +162,6 @@ void runThread(std::shared_ptr<ThreadHandle> handle,
} }
} }
std::chrono::milliseconds fromLuaTime(int duration, const sol::optional<std::string>& period) {
using namespace std::chrono;
REQUIRE(duration >= 0) << "Invalid duration interval: " << duration;
std::string metric = period ? period.value() : "s";
if (metric == "ms") return milliseconds(duration);
else if (metric == "s") return seconds(duration);
else if (metric == "m") return minutes(duration);
else throw sol::error("invalid time identification: " + metric);
}
} // namespace } // namespace

View File

@ -23,7 +23,8 @@ local api = {
setmetatable = capi.setmetatable, setmetatable = capi.setmetatable,
getmetatable = capi.getmetatable, getmetatable = capi.getmetatable,
G = capi.G, G = capi.G,
gc = capi.gc gc = capi.gc,
channel = capi.channel
} }
local function run_thread(config, f, ...) local function run_thread(config, f, ...)

131
tests/lua/channel.lua Normal file
View File

@ -0,0 +1,131 @@
TestChannels = {tearDown = tearDown}
function TestChannels:testCapacityUsage()
local chan = effil.channel(2)
test.assertTrue(chan:push(14))
test.assertTrue(chan:push(88))
test.assertFalse(chan:push(1488))
test.assertEquals(chan:pop(), 14)
test.assertEquals(chan:pop(), 88)
test.assertIsNil(chan:pop(0))
test.assertTrue(chan:push(14, 88), true)
local ret1, ret2 = chan:pop()
test.assertEquals(ret1, 14)
test.assertEquals(ret2, 88)
end
function TestChannels:testRecursiveChannels()
local chan1 = effil.channel()
local chan2 = effil.channel()
local msg1, msg2 = "first channel", "second channel"
test.assertTrue(chan1:push(msg1, chan2))
test.assertTrue(chan2:push(msg2, chan1))
local ret1 = { chan1:pop() }
test.assertEquals(ret1[1], msg1)
test.assertEquals(type(ret1[2]), "userdata")
local ret2 = { ret1[2]:pop() }
test.assertEquals(ret2[1], msg2)
test.assertEquals(type(ret2[2]), "userdata")
end
function TestChannels:testWithThread()
local chan = effil.channel()
local thread = effil.thread(function(chan)
chan:push("message1")
chan:push("message2")
chan:push("message3")
chan:push("message4")
end
)(chan)
local start_time = os.time()
test.assertEquals(chan:pop(), "message1")
thread:wait()
test.assertEquals(chan:pop(0), "message2")
test.assertEquals(chan:pop(1), "message3")
test.assertEquals(chan:pop(1, 'm'), "message4")
test.assertTrue(os.time() < start_time + 1)
end
function TestChannels:testWithSharedTable()
local chan = effil.channel()
local table = effil.table()
local test_value = "i'm value"
table.test_key = test_value
chan:push(table)
test.assertEquals(chan:pop().test_key, test_value)
table.channel = chan
table.channel:push(test_value)
test.assertEquals(table.channel:pop(), test_value)
end
if WITH_EXTRA_CHECKS then
function TestChannels:testStressLoadWithMultipleThreads()
local exchange_channel, result_channel = effil.channel(), effil.channel()
local threads_number = 1000
for i = 1, threads_number do
effil.thread(function(exchange_channel, result_channel, indx)
if indx % 2 == 0 then
for i = 1, 10000 do
exchange_channel:push(indx .. "_".. i)
end
else
repeat
local ret = exchange_channel:pop(10)
if ret then
result_channel:push(ret)
end
until ret == nil
end
end
)(exchange_channel, result_channel, i)
end
local data = {}
for i = 1, (threads_number / 2) * 10000 do
local ret = result_channel:pop(10)
test.assertNotIsNil(ret)
test.assertIsString(ret)
test.assertIsNil(data[ret])
data[ret] = true
end
for thr_id = 2, threads_number, 2 do
for iter = 1, 10000 do
test.assertTrue(data[thr_id .. "_".. iter])
end
end
end
function TestChannels:testTimedRead()
local chan = effil.channel()
local delayedWriter = function(channel, delay)
require("effil").sleep(delay)
channel:push("hello!")
end
effil.thread(delayedWriter)(chan, 70)
local function check_time(real_time, use_time, metric, result)
local start_time = os.time()
test.assertEquals(chan:pop(use_time, metric), result)
test.assertAlmostEquals(os.time(), start_time + real_time, 1)
end
check_time(2, 2, nil, nil) -- second by default
check_time(2, 2, 's', nil)
check_time(60, 1, 'm', nil)
local start_time = os.time()
test.assertEquals(chan:pop(10), "hello!")
test.assertTrue(os.time() < start_time + 10)
end
end -- WITH_EXTRA_CHECKS

View File

@ -9,14 +9,17 @@ print("---------------")
do do
-- Hack input arguments to make tests verbose by default -- Hack input arguments to make tests verbose by default
local found = false local make_verbose = true
for _, v in ipairs(arg) do for i, v in ipairs(arg) do
if v == '-o' or v == '--output' then if v == '-o' or v == '--output' then
found = true make_verbose = false
break elseif v == "--extra-checks" then
table.remove(arg, i)
WITH_EXTRA_CHECKS = true
print "# RUN TESTS WITH EXTRA CHECKS"
end end
end end
if not found then if make_verbose then
table.insert(arg, '-o') table.insert(arg, '-o')
table.insert(arg, 'TAP') table.insert(arg, 'TAP')
end end
@ -32,6 +35,7 @@ require 'test_utils'
require 'thread' require 'thread'
require 'shared_table' require 'shared_table'
require 'gc' require 'gc'
require 'channel'
-- Hack tests functions to print when test starts -- Hack tests functions to print when test starts
for suite_name, suite in pairs(_G) do for suite_name, suite in pairs(_G) do

View File

@ -284,6 +284,8 @@ function TestThisThread:testThisThreadFunctions()
effil.yield() -- just call it effil.yield() -- just call it
end end
if WITH_EXTRA_CHECKS then
function TestThisThread:testTime() function TestThisThread:testTime()
local function check_time(real_time, use_time, metric) local function check_time(real_time, use_time, metric)
local start_time = os.time() local start_time = os.time()
@ -295,3 +297,5 @@ function TestThisThread:testTime()
check_time(4, 4000, 'ms') check_time(4, 4000, 'ms')
check_time(60, 1, 'm') check_time(60, 1, 'm')
end end
end -- WITH_EXTRA_CHECKS