parent
64628c1757
commit
5057116136
@ -4,5 +4,5 @@ set -e
|
||||
for build_type in debug release; do
|
||||
mkdir -p $build_type
|
||||
(cd $build_type && cmake -DCMAKE_BUILD_TYPE=$build_type $@ .. && make -j4 install)
|
||||
(cd $build_type && ./tests && lua run_tests.lua)
|
||||
(cd $build_type && ./tests && lua run_tests.lua --extra-checks)
|
||||
done
|
||||
|
||||
68
src/cpp/channel.cpp
Normal file
68
src/cpp/channel.cpp
Normal file
@ -0,0 +1,68 @@
|
||||
#include "channel.h"
|
||||
|
||||
#include "sol.hpp"
|
||||
|
||||
namespace effil {
|
||||
|
||||
void Channel::getUserType(sol::state_view& lua) {
|
||||
sol::usertype<Channel> type("new", sol::no_constructor,
|
||||
"push", &Channel::push,
|
||||
"pop", &Channel::pop
|
||||
);
|
||||
sol::stack::push(lua, type);
|
||||
sol::stack::pop<sol::object>(lua);
|
||||
}
|
||||
|
||||
Channel::Channel(sol::optional<int> capacity) : data_(std::make_shared<SharedData>()){
|
||||
if (capacity) {
|
||||
REQUIRE(capacity.value() >= 0) << "Invalid capacity value = " << capacity.value();
|
||||
data_->capacity_ = static_cast<size_t>(capacity.value());
|
||||
}
|
||||
else {
|
||||
data_->capacity_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool Channel::push(const sol::variadic_args& args) {
|
||||
if (!args.leftover_count())
|
||||
return false;
|
||||
|
||||
std::unique_lock<std::mutex> lock(data_->lock_);
|
||||
if (data_->capacity_ && data_->channel_.size() >= data_->capacity_)
|
||||
return false;
|
||||
effil::StoredArray array;
|
||||
for (const auto& arg : args) {
|
||||
auto obj = createStoredObject(arg.get<sol::object>());
|
||||
if (obj->gcHandle())
|
||||
refs_->insert(obj->gcHandle());
|
||||
array.emplace_back(obj);
|
||||
}
|
||||
if (data_->channel_.empty())
|
||||
data_->cv_.notify_one();
|
||||
data_->channel_.emplace(array);
|
||||
return true;
|
||||
}
|
||||
|
||||
StoredArray Channel::pop(const sol::optional<int>& duration,
|
||||
const sol::optional<std::string>& period) {
|
||||
std::unique_lock<std::mutex> lock(data_->lock_);
|
||||
while (data_->channel_.empty()) {
|
||||
if (duration) {
|
||||
if (data_->cv_.wait_for(lock, fromLuaTime(duration.value(), period)) == std::cv_status::timeout)
|
||||
return StoredArray();
|
||||
}
|
||||
else { // No time limit
|
||||
data_->cv_.wait(lock);
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = data_->channel_.front();
|
||||
for (const auto& obj: ret) {
|
||||
if (obj->gcHandle())
|
||||
refs_->erase(obj->gcHandle());
|
||||
}
|
||||
data_->channel_.pop();
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace effil
|
||||
30
src/cpp/channel.h
Normal file
30
src/cpp/channel.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "notifier.h"
|
||||
#include "stored-object.h"
|
||||
#include "lua-helpers.h"
|
||||
#include "queue"
|
||||
|
||||
namespace effil {
|
||||
|
||||
class Channel : public GCObject {
|
||||
public:
|
||||
Channel(sol::optional<int> capacity);
|
||||
static void getUserType(sol::state_view& lua);
|
||||
|
||||
bool push(const sol::variadic_args& args);
|
||||
StoredArray pop(const sol::optional<int>& duration,
|
||||
const sol::optional<std::string>& period);
|
||||
|
||||
protected:
|
||||
struct SharedData {
|
||||
std::mutex lock_;
|
||||
std::condition_variable cv_;
|
||||
size_t capacity_;
|
||||
std::queue<StoredArray> channel_;
|
||||
};
|
||||
|
||||
std::shared_ptr<SharedData> data_;
|
||||
};
|
||||
|
||||
} // namespace effil
|
||||
@ -1,9 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "spin-mutex.h"
|
||||
|
||||
#include <sol.hpp>
|
||||
|
||||
#include <mutex>
|
||||
#include <map>
|
||||
#include <set>
|
||||
@ -87,4 +85,4 @@ private:
|
||||
GC(const GC&) = delete;
|
||||
};
|
||||
|
||||
} // effil
|
||||
} // effil
|
||||
|
||||
17
src/cpp/lua-helpers.cpp
Normal file
17
src/cpp/lua-helpers.cpp
Normal file
@ -0,0 +1,17 @@
|
||||
#include "lua-helpers.h"
|
||||
|
||||
namespace effil {
|
||||
|
||||
std::chrono::milliseconds fromLuaTime(int duration, const sol::optional<std::string>& period) {
|
||||
using namespace std::chrono;
|
||||
|
||||
REQUIRE(duration >= 0) << "Invalid duration interval: " << duration;
|
||||
|
||||
std::string metric = period ? period.value() : "s";
|
||||
if (metric == "ms") return milliseconds(duration);
|
||||
else if (metric == "s") return seconds(duration);
|
||||
else if (metric == "m") return minutes(duration);
|
||||
else throw sol::error("invalid time identification: " + metric);
|
||||
}
|
||||
|
||||
} // namespace effil
|
||||
@ -27,6 +27,8 @@ inline sol::function loadString(const sol::state_view& lua, const std::string& s
|
||||
return loader(str);
|
||||
}
|
||||
|
||||
std::chrono::milliseconds fromLuaTime(int duration, const sol::optional<std::string>& period);
|
||||
|
||||
typedef std::vector<effil::StoredObject> StoredArray;
|
||||
|
||||
} // namespace effil
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include "threading.h"
|
||||
#include "shared-table.h"
|
||||
#include "garbage-collector.h"
|
||||
#include "channel.h"
|
||||
|
||||
#include <lua.hpp>
|
||||
|
||||
@ -22,6 +23,10 @@ sol::object createTable(sol::this_state lua) {
|
||||
return sol::make_object(lua, GC::instance().create<SharedTable>());
|
||||
}
|
||||
|
||||
sol::object createChannel(sol::optional<int> capacity, sol::this_state lua) {
|
||||
return sol::make_object(lua, GC::instance().create<Channel>(capacity));
|
||||
}
|
||||
|
||||
SharedTable globalTable = GC::instance().create<SharedTable>();
|
||||
|
||||
} // namespace
|
||||
@ -30,6 +35,7 @@ extern "C" int luaopen_libeffil(lua_State* L) {
|
||||
sol::state_view lua(L);
|
||||
Thread::getUserType(lua);
|
||||
SharedTable::getUserType(lua);
|
||||
Channel::getUserType(lua);
|
||||
sol::table publicApi = lua.create_table_with(
|
||||
"thread", createThread,
|
||||
"thread_id", threadId,
|
||||
@ -43,7 +49,8 @@ extern "C" int luaopen_libeffil(lua_State* L) {
|
||||
"getmetatable", SharedTable::luaGetMetatable,
|
||||
"G", sol::make_object(lua, globalTable),
|
||||
"getmetatable", SharedTable::luaGetMetatable,
|
||||
"gc", GC::getLuaApi(lua)
|
||||
"gc", GC::getLuaApi(lua),
|
||||
"channel", createChannel
|
||||
);
|
||||
sol::stack::push(lua, publicApi);
|
||||
return 1;
|
||||
|
||||
@ -46,4 +46,4 @@ private:
|
||||
Notifier(Notifier& ) = delete;
|
||||
};
|
||||
|
||||
} // namespace effil
|
||||
} // namespace effil
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include "stored-object.h"
|
||||
#include "channel.h"
|
||||
|
||||
#include "threading.h"
|
||||
#include "shared-table.h"
|
||||
@ -73,24 +74,25 @@ private:
|
||||
std::string function_;
|
||||
};
|
||||
|
||||
class TableHolder : public BaseHolder {
|
||||
template<typename T>
|
||||
class GCObjectHolder : public BaseHolder {
|
||||
public:
|
||||
template <typename SolType>
|
||||
TableHolder(const SolType& luaObject) {
|
||||
assert(luaObject.template is<SharedTable>());
|
||||
handle_ = luaObject.template as<SharedTable>().handle();
|
||||
GCObjectHolder(const SolType& luaObject) {
|
||||
assert(luaObject.template is<T>());
|
||||
handle_ = luaObject.template as<T>().handle();
|
||||
assert(GC::instance().has(handle_));
|
||||
}
|
||||
|
||||
TableHolder(GCObjectHandle handle)
|
||||
GCObjectHolder(GCObjectHandle handle)
|
||||
: handle_(handle) {}
|
||||
|
||||
bool rawCompare(const BaseHolder* other) const final {
|
||||
return handle_ < static_cast<const TableHolder*>(other)->handle_;
|
||||
return handle_ < static_cast<const GCObjectHolder<T>*>(other)->handle_;
|
||||
}
|
||||
|
||||
sol::object unpack(sol::this_state state) const final {
|
||||
return sol::make_object(state, GC::instance().get<SharedTable>(handle_));
|
||||
return sol::make_object(state, GC::instance().get<T>(handle_));
|
||||
}
|
||||
|
||||
GCObjectHandle gcHandle() const override { return handle_; }
|
||||
@ -118,9 +120,9 @@ StoredObject makeStoredObject(sol::object luaObject, SolTableToShared& visited)
|
||||
SharedTable table = GC::instance().create<SharedTable>();
|
||||
visited.emplace_back(std::make_pair(luaTable, table.handle()));
|
||||
dumpTable(&table, luaTable, visited);
|
||||
return createStoredObject(table.handle());
|
||||
return std::make_unique<GCObjectHolder<SharedTable>>(table.handle());
|
||||
} else {
|
||||
return createStoredObject(st->second);
|
||||
return std::make_unique<GCObjectHolder<SharedTable>>(st->second);
|
||||
}
|
||||
} else {
|
||||
return createStoredObject(luaObject);
|
||||
@ -149,7 +151,9 @@ StoredObject fromSolObject(const SolObject& luaObject) {
|
||||
return std::make_unique<PrimitiveHolder<void*>>(luaObject);
|
||||
case sol::type::userdata:
|
||||
if (luaObject.template is<SharedTable>())
|
||||
return std::make_unique<TableHolder>(luaObject);
|
||||
return std::make_unique<GCObjectHolder<SharedTable>>(luaObject);
|
||||
else if (luaObject.template is<Channel>())
|
||||
return std::make_unique<GCObjectHolder<Channel>>(luaObject);
|
||||
else if (luaObject.template is<std::shared_ptr<Thread>>())
|
||||
return std::make_unique<PrimitiveHolder<std::shared_ptr<Thread>>>(luaObject);
|
||||
else
|
||||
@ -167,7 +171,7 @@ StoredObject fromSolObject(const SolObject& luaObject) {
|
||||
// SolTableToShared is used to prevent from infinity recursion
|
||||
// in recursive tables
|
||||
dumpTable(&table, luaTable, visited);
|
||||
return std::make_unique<TableHolder>(table.handle());
|
||||
return std::make_unique<GCObjectHolder<SharedTable>>(table.handle());
|
||||
}
|
||||
default:
|
||||
throw Exception() << "Unable to store object of that type: " << (int)luaObject.get_type() << "\n";
|
||||
@ -193,8 +197,6 @@ StoredObject createStoredObject(const sol::object& object) { return fromSolObjec
|
||||
|
||||
StoredObject createStoredObject(const sol::stack_object& object) { return fromSolObject(object); }
|
||||
|
||||
StoredObject createStoredObject(GCObjectHandle handle) { return std::make_unique<TableHolder>(handle); }
|
||||
|
||||
template <typename DataType>
|
||||
sol::optional<DataType> getPrimitiveHolderData(const StoredObject& sobj) {
|
||||
auto ptr = dynamic_cast<PrimitiveHolder<DataType>*>(sobj.get());
|
||||
|
||||
@ -37,7 +37,6 @@ StoredObject createStoredObject(bool);
|
||||
StoredObject createStoredObject(double);
|
||||
StoredObject createStoredObject(const std::string&);
|
||||
StoredObject createStoredObject(const char*);
|
||||
StoredObject createStoredObject(GCObjectHandle);
|
||||
StoredObject createStoredObject(const sol::object&);
|
||||
StoredObject createStoredObject(const sol::stack_object&);
|
||||
|
||||
|
||||
@ -162,18 +162,6 @@ void runThread(std::shared_ptr<ThreadHandle> handle,
|
||||
}
|
||||
}
|
||||
|
||||
std::chrono::milliseconds fromLuaTime(int duration, const sol::optional<std::string>& period) {
|
||||
using namespace std::chrono;
|
||||
|
||||
REQUIRE(duration >= 0) << "Invalid duration interval: " << duration;
|
||||
|
||||
std::string metric = period ? period.value() : "s";
|
||||
if (metric == "ms") return milliseconds(duration);
|
||||
else if (metric == "s") return seconds(duration);
|
||||
else if (metric == "m") return minutes(duration);
|
||||
else throw sol::error("invalid time identification: " + metric);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
|
||||
@ -23,7 +23,8 @@ local api = {
|
||||
setmetatable = capi.setmetatable,
|
||||
getmetatable = capi.getmetatable,
|
||||
G = capi.G,
|
||||
gc = capi.gc
|
||||
gc = capi.gc,
|
||||
channel = capi.channel
|
||||
}
|
||||
|
||||
local function run_thread(config, f, ...)
|
||||
|
||||
131
tests/lua/channel.lua
Normal file
131
tests/lua/channel.lua
Normal file
@ -0,0 +1,131 @@
|
||||
TestChannels = {tearDown = tearDown}
|
||||
|
||||
function TestChannels:testCapacityUsage()
|
||||
local chan = effil.channel(2)
|
||||
|
||||
test.assertTrue(chan:push(14))
|
||||
test.assertTrue(chan:push(88))
|
||||
test.assertFalse(chan:push(1488))
|
||||
|
||||
test.assertEquals(chan:pop(), 14)
|
||||
test.assertEquals(chan:pop(), 88)
|
||||
test.assertIsNil(chan:pop(0))
|
||||
|
||||
test.assertTrue(chan:push(14, 88), true)
|
||||
local ret1, ret2 = chan:pop()
|
||||
test.assertEquals(ret1, 14)
|
||||
test.assertEquals(ret2, 88)
|
||||
end
|
||||
|
||||
function TestChannels:testRecursiveChannels()
|
||||
local chan1 = effil.channel()
|
||||
local chan2 = effil.channel()
|
||||
local msg1, msg2 = "first channel", "second channel"
|
||||
test.assertTrue(chan1:push(msg1, chan2))
|
||||
test.assertTrue(chan2:push(msg2, chan1))
|
||||
|
||||
local ret1 = { chan1:pop() }
|
||||
test.assertEquals(ret1[1], msg1)
|
||||
test.assertEquals(type(ret1[2]), "userdata")
|
||||
local ret2 = { ret1[2]:pop() }
|
||||
test.assertEquals(ret2[1], msg2)
|
||||
test.assertEquals(type(ret2[2]), "userdata")
|
||||
end
|
||||
|
||||
function TestChannels:testWithThread()
|
||||
local chan = effil.channel()
|
||||
local thread = effil.thread(function(chan)
|
||||
chan:push("message1")
|
||||
chan:push("message2")
|
||||
chan:push("message3")
|
||||
chan:push("message4")
|
||||
end
|
||||
)(chan)
|
||||
|
||||
local start_time = os.time()
|
||||
test.assertEquals(chan:pop(), "message1")
|
||||
thread:wait()
|
||||
test.assertEquals(chan:pop(0), "message2")
|
||||
test.assertEquals(chan:pop(1), "message3")
|
||||
test.assertEquals(chan:pop(1, 'm'), "message4")
|
||||
test.assertTrue(os.time() < start_time + 1)
|
||||
end
|
||||
|
||||
function TestChannels:testWithSharedTable()
|
||||
local chan = effil.channel()
|
||||
local table = effil.table()
|
||||
|
||||
local test_value = "i'm value"
|
||||
table.test_key = test_value
|
||||
|
||||
chan:push(table)
|
||||
test.assertEquals(chan:pop().test_key, test_value)
|
||||
|
||||
table.channel = chan
|
||||
table.channel:push(test_value)
|
||||
test.assertEquals(table.channel:pop(), test_value)
|
||||
end
|
||||
|
||||
if WITH_EXTRA_CHECKS then
|
||||
|
||||
function TestChannels:testStressLoadWithMultipleThreads()
|
||||
local exchange_channel, result_channel = effil.channel(), effil.channel()
|
||||
|
||||
local threads_number = 1000
|
||||
for i = 1, threads_number do
|
||||
effil.thread(function(exchange_channel, result_channel, indx)
|
||||
if indx % 2 == 0 then
|
||||
for i = 1, 10000 do
|
||||
exchange_channel:push(indx .. "_".. i)
|
||||
end
|
||||
else
|
||||
repeat
|
||||
local ret = exchange_channel:pop(10)
|
||||
if ret then
|
||||
result_channel:push(ret)
|
||||
end
|
||||
until ret == nil
|
||||
end
|
||||
end
|
||||
)(exchange_channel, result_channel, i)
|
||||
end
|
||||
|
||||
local data = {}
|
||||
for i = 1, (threads_number / 2) * 10000 do
|
||||
local ret = result_channel:pop(10)
|
||||
test.assertNotIsNil(ret)
|
||||
test.assertIsString(ret)
|
||||
test.assertIsNil(data[ret])
|
||||
data[ret] = true
|
||||
end
|
||||
|
||||
for thr_id = 2, threads_number, 2 do
|
||||
for iter = 1, 10000 do
|
||||
test.assertTrue(data[thr_id .. "_".. iter])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function TestChannels:testTimedRead()
|
||||
local chan = effil.channel()
|
||||
local delayedWriter = function(channel, delay)
|
||||
require("effil").sleep(delay)
|
||||
channel:push("hello!")
|
||||
end
|
||||
effil.thread(delayedWriter)(chan, 70)
|
||||
|
||||
local function check_time(real_time, use_time, metric, result)
|
||||
local start_time = os.time()
|
||||
test.assertEquals(chan:pop(use_time, metric), result)
|
||||
test.assertAlmostEquals(os.time(), start_time + real_time, 1)
|
||||
end
|
||||
check_time(2, 2, nil, nil) -- second by default
|
||||
check_time(2, 2, 's', nil)
|
||||
check_time(60, 1, 'm', nil)
|
||||
|
||||
local start_time = os.time()
|
||||
test.assertEquals(chan:pop(10), "hello!")
|
||||
test.assertTrue(os.time() < start_time + 10)
|
||||
end
|
||||
|
||||
end -- WITH_EXTRA_CHECKS
|
||||
@ -9,14 +9,17 @@ print("---------------")
|
||||
|
||||
do
|
||||
-- Hack input arguments to make tests verbose by default
|
||||
local found = false
|
||||
for _, v in ipairs(arg) do
|
||||
local make_verbose = true
|
||||
for i, v in ipairs(arg) do
|
||||
if v == '-o' or v == '--output' then
|
||||
found = true
|
||||
break
|
||||
make_verbose = false
|
||||
elseif v == "--extra-checks" then
|
||||
table.remove(arg, i)
|
||||
WITH_EXTRA_CHECKS = true
|
||||
print "# RUN TESTS WITH EXTRA CHECKS"
|
||||
end
|
||||
end
|
||||
if not found then
|
||||
if make_verbose then
|
||||
table.insert(arg, '-o')
|
||||
table.insert(arg, 'TAP')
|
||||
end
|
||||
@ -32,6 +35,7 @@ require 'test_utils'
|
||||
require 'thread'
|
||||
require 'shared_table'
|
||||
require 'gc'
|
||||
require 'channel'
|
||||
|
||||
-- Hack tests functions to print when test starts
|
||||
for suite_name, suite in pairs(_G) do
|
||||
|
||||
@ -284,6 +284,8 @@ function TestThisThread:testThisThreadFunctions()
|
||||
effil.yield() -- just call it
|
||||
end
|
||||
|
||||
if WITH_EXTRA_CHECKS then
|
||||
|
||||
function TestThisThread:testTime()
|
||||
local function check_time(real_time, use_time, metric)
|
||||
local start_time = os.time()
|
||||
@ -295,3 +297,5 @@ function TestThisThread:testTime()
|
||||
check_time(4, 4000, 'ms')
|
||||
check_time(60, 1, 'm')
|
||||
end
|
||||
|
||||
end -- WITH_EXTRA_CHECKS
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user