diff --git a/src/cpp/shared-table.cpp b/src/cpp/shared-table.cpp index 16622c5..914ca8b 100644 --- a/src/cpp/shared-table.cpp +++ b/src/cpp/shared-table.cpp @@ -22,7 +22,9 @@ sol::object SharedTable::getUserType(sol::state_view &lua) { "new", sol::no_constructor, sol::meta_function::new_index, &SharedTable::luaSet, sol::meta_function::index, &SharedTable::luaGet, - sol::meta_function::length, &SharedTable::size + sol::meta_function::length, &SharedTable::length, + "__pairs", &SharedTable::pairs, + "__ipairs", &SharedTable::ipairs ); sol::stack::push(lua, type); return sol::stack::pop(lua); @@ -37,6 +39,16 @@ void SharedTable::set(StoredObject&& key, StoredObject&& value) { data_->entries[std::move(key)] = std::move(value); } +sol::object SharedTable::get(const StoredObject& key, const sol::this_state& state) const { + std::lock_guard g(data_->lock); + auto val = data_->entries.find(key); + if (val == data_->entries.end()) { + return sol::nil; + } else { + return val->second->unpack(state); + } +} + void SharedTable::luaSet(const sol::stack_object& luaKey, const sol::stack_object& luaValue) { REQUIRE(luaKey.valid()) << "Indexing by nil"; @@ -59,15 +71,8 @@ void SharedTable::luaSet(const sol::stack_object& luaKey, const sol::stack_objec sol::object SharedTable::luaGet(const sol::stack_object& luaKey, const sol::this_state& state) const { REQUIRE(luaKey.valid()) << "Indexing by nil"; - StoredObject key = createStoredObject(luaKey); - std::lock_guard g(data_->lock); - auto val = data_->entries.find(key); - if (val == data_->entries.end()) { - return sol::nil; - } else { - return val->second->unpack(state); - } + return get(key, state); } size_t SharedTable::size() const { @@ -75,4 +80,61 @@ size_t SharedTable::size() const { return data_->entries.size(); } +size_t SharedTable::length() const { + std::lock_guard g(data_->lock); + + DataEntries::const_iterator iter; + size_t l = 0u; + while((iter = data_->entries.find(createStoredObject(static_cast(l + 1)))) != data_->entries.end()) { + l++; + }; + return l; +} + +SharedTable::PairsIterator SharedTable::getNext(const sol::object& key, sol::this_state lua) +{ + std::lock_guard g(data_->lock); + if (key) + { + auto obj = createStoredObject(key); + auto upper = data_->entries.upper_bound(obj); + if (upper != data_->entries.end()) + return std::tuple(upper->first->unpack(lua), upper->second->unpack(lua)); + } + else + { + if (!data_->entries.empty()) + { + const auto& begin = data_->entries.begin(); + return std::tuple(begin->first->unpack(lua), begin->second->unpack(lua)); + } + } + return std::tuple(sol::nil, sol::nil); +} + +SharedTable::PairsIterator SharedTable::pairs(sol::this_state lua) const { + auto next = [](sol::this_state lua, SharedTable table, sol::stack_object key){ return table.getNext(key, lua); }; + return std::tuple( + sol::make_object(lua, std::function(next)).as(), + sol::make_object(lua, *this) + ); +} + +std::tuple ipairsNext(sol::this_state lua, SharedTable table, sol::optional key) +{ + unsigned long index = key ? key.value() + 1 : 1ul; + auto objKey = createStoredObject(static_cast(index)); + sol::object value = table.get(objKey, lua); + if (!value.valid()) + return std::tuple(sol::nil, sol::nil); + return std::tuple(objKey->unpack(lua), value); +} + +std::tuple SharedTable::ipairs(sol::this_state lua) const { + return std::tuple( + sol::make_object(lua, ipairsNext).as(), + sol::make_object(lua, *this) + ); +} + } // effil diff --git a/src/cpp/shared-table.h b/src/cpp/shared-table.h index bdd0bac..5f4458e 100644 --- a/src/cpp/shared-table.h +++ b/src/cpp/shared-table.h @@ -12,6 +12,10 @@ namespace effil { class SharedTable : public GCObject { +private: + typedef std::tuple PairsIterator; + typedef std::map DataEntries; + public: SharedTable(); SharedTable(SharedTable&&) = default; @@ -20,20 +24,24 @@ public: static sol::object getUserType(sol::state_view &lua); void set(StoredObject&&, StoredObject&&); + sol::object get(const StoredObject& key, const sol::this_state& state) const; + PairsIterator getNext(const sol::object& key, sol::this_state lua); // These functions could be invoked from lua scripts void luaSet(const sol::stack_object& luaKey, const sol::stack_object& luaValue); - sol::object luaGet(const sol::stack_object& luaKey, const sol::this_state& state) const; + sol::object luaGet(const sol::stack_object& key, const sol::this_state& state) const; size_t size() const; + size_t length() const; + PairsIterator pairs(sol::this_state) const; + PairsIterator ipairs(sol::this_state) const; private: - typedef std::unique_ptr StoredObject; + struct SharedData { SpinMutex lock; - std::unordered_map entries; + DataEntries entries; }; -private: std::shared_ptr data_; }; diff --git a/src/cpp/spin-mutex.h b/src/cpp/spin-mutex.h index 0f5d2a4..265707a 100644 --- a/src/cpp/spin-mutex.h +++ b/src/cpp/spin-mutex.h @@ -2,6 +2,7 @@ #include #include +#include namespace effil { @@ -15,6 +16,7 @@ public: void unlock() noexcept { lock_.clear(std::memory_order_release); + } private: diff --git a/src/cpp/stored-object.cpp b/src/cpp/stored-object.cpp index afae288..ffbb014 100644 --- a/src/cpp/stored-object.cpp +++ b/src/cpp/stored-object.cpp @@ -26,11 +26,7 @@ public: : data_(init) {} bool rawCompare(const BaseHolder* other) const noexcept final { - return static_cast*>(other)->data_ == data_; - } - - std::size_t hash() const noexcept final { - return std::hash()(data_); + return static_cast*>(other)->data_ < data_; } sol::object unpack(sol::this_state state) const final { @@ -52,11 +48,7 @@ public: } bool rawCompare(const BaseHolder* other) const noexcept final { - return static_cast(other)->function_ == function_; - } - - std::size_t hash() const noexcept final { - return std::hash()(function_); + return static_cast(other)->function_ < function_; } sol::object unpack(sol::this_state state) const final { @@ -86,11 +78,7 @@ public: : handle_(handle) {} bool rawCompare(const BaseHolder *other) const final { - return static_cast(other)->handle_ == handle_; - } - - std::size_t hash() const final { - return std::hash()(handle_); + return static_cast(other)->handle_ < handle_; } sol::object unpack(sol::this_state state) const final { diff --git a/src/cpp/stored-object.h b/src/cpp/stored-object.h index 54ed1ed..ae7e4b7 100644 --- a/src/cpp/stored-object.h +++ b/src/cpp/stored-object.h @@ -12,14 +12,14 @@ public: virtual ~BaseHolder() = default; bool compare(const BaseHolder* other) const { - return typeid(*this) == typeid(*other) && rawCompare(other); + if (typeid(*this) == typeid(*other)) + return rawCompare(other); + return typeid(*this).before(typeid(*other)); } + virtual bool rawCompare(const BaseHolder* other) const = 0; virtual const std::type_info& type() { return typeid(*this); } - - virtual std::size_t hash() const = 0; virtual sol::object unpack(sol::this_state state) const = 0; - virtual GCObjectHandle gcHandle() const { return GCNull; } private: @@ -29,13 +29,7 @@ private: typedef std::unique_ptr StoredObject; -struct StoredObjectHash { - size_t operator()(const StoredObject& o) const { - return o->hash(); - } -}; - -struct StoredObjectEqual { +struct StoredObjectLess { bool operator()(const StoredObject& lhs, const StoredObject& rhs) const { return lhs->compare(rhs.get()); } diff --git a/tests/lua/smoke_test.lua b/tests/lua/smoke_test.lua index 7c1172f..90e6926 100644 --- a/tests/lua/smoke_test.lua +++ b/tests/lua/smoke_test.lua @@ -218,3 +218,49 @@ function TestSmoke:testCheckThreadReturns() test.assertEquals(returns[5](11, 89), 100) end +function TestSmoke:testCheckPairsInterating() + local effil = require('libeffil') + local share = effil.share() + local data = { 0, 0, 0, ["key1"] = 0, ["key2"] = 0, ["key3"] = 0 } + + for k, _ in pairs(data) do + share[k] = k .. "-value" + end + + for k,v in pairs(share) do + test.assertEquals(data[k], 0) + data[k] = 1 + test.assertEquals(v, k .. "-value") + end + + for k,v in pairs(data) do + log("Check: " .. k) + test.assertEquals(v, 1) + end + + for k,v in ipairs(share) do + test.assertEquals(data[k], 1) + data[k] = 2 + test.assertEquals(v, k .. "-value") + end + + for k,v in ipairs(data) do + log("Check: " .. k) + test.assertEquals(v, 2) + end +end + +function TestSmoke:testCheckLengthOperator() + local effil = require('libeffil') + local share = effil.share() + share[1] = 10 + share[2] = 20 + share[3] = 30 + share[4] = 40 + log "Check values" + test.assertEquals(#share, 4) + share[3] = nil + test.assertEquals(#share, 2) + share[1] = nil + test.assertEquals(#share, 0) +end