#pragma once
#include <trantor/net/EventLoop.h>
#include <trantor/utils/Logger.h>
#include <atomic>
#include <deque>
#include <map>
#include <mutex>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <future>
#include <assert.h>
#define WHEELS_NUM 4
#define BUCKET_NUM_PER_WHEEL 200
#define TICK_INTERVAL 1.0
namespace drogon
{
class CallbackEntry
{
  public:
    CallbackEntry(std::function<void()> cb) : cb_(std::move(cb))
    {
    }
    ~CallbackEntry()
    {
        cb_();
    }
  private:
    std::function<void()> cb_;
};
using CallbackEntryPtr = std::shared_ptr<CallbackEntry>;
using WeakCallbackEntryPtr = std::weak_ptr<CallbackEntry>;
using CallbackBucket = std::unordered_set<CallbackEntryPtr>;
using CallbackBucketQueue = std::deque<CallbackBucket>;
template <typename T1, typename T2>
class CacheMap
{
  public:
    CacheMap(trantor::EventLoop *loop,
             float tickInterval = TICK_INTERVAL,
             size_t wheelsNum = WHEELS_NUM,
             size_t bucketsNumPerWheel = BUCKET_NUM_PER_WHEEL,
             std::function<void(const T1 &)> fnOnInsert = nullptr,
             std::function<void(const T1 &)> fnOnErase = nullptr)
        : loop_(loop),
          tickInterval_(tickInterval),
          wheelsNumber_(wheelsNum),
          bucketsNumPerWheel_(bucketsNumPerWheel),
          ctrlBlockPtr_(std::make_shared<ControlBlock>()),
          fnOnInsert_(fnOnInsert),
          fnOnErase_(fnOnErase)
    {
        wheels_.resize(wheelsNumber_);
        for (size_t i = 0; i < wheelsNumber_; ++i)
        {
            wheels_[i].resize(bucketsNumPerWheel_);
        }
        if (tickInterval_ > 0 && wheelsNumber_ > 0 && bucketsNumPerWheel_ > 0)
        {
            timerId_ = loop_->runEvery(
                tickInterval_, [this, ctrlBlockPtr = ctrlBlockPtr_]() {
                    std::lock_guard<std::mutex> lock(ctrlBlockPtr->mtx);
                    if (ctrlBlockPtr->destructed)
                        return;
                    size_t t = ++ticksCounter_;
                    size_t pow = 1;
                    for (size_t i = 0; i < wheelsNumber_; ++i)
                    {
                        if ((t % pow) == 0)
                        {
                            CallbackBucket tmp;
                            {
                                std::lock_guard<std::mutex> lock(bucketMutex_);
                                wheels_[i].front().swap(tmp);
                                wheels_[i].pop_front();
                                wheels_[i].push_back(CallbackBucket());
                            }
                        }
                        pow = pow * bucketsNumPerWheel_;
                    }
                });
            loop_->runOnQuit([ctrlBlockPtr = ctrlBlockPtr_] {
                std::lock_guard<std::mutex> lock(ctrlBlockPtr->mtx);
                ctrlBlockPtr->loopEnded = true;
            });
        }
        else
        {
            noWheels_ = true;
        }
    };
    ~CacheMap()
    {
        std::lock_guard<std::mutex> lock(ctrlBlockPtr_->mtx);
        ctrlBlockPtr_->destructed = true;
        map_.clear();
        if (!ctrlBlockPtr_->loopEnded)
        {
            loop_->invalidateTimer(timerId_);
        }
        for (auto iter = wheels_.rbegin(); iter != wheels_.rend(); ++iter)
        {
            iter->clear();
        }
        LOG_TRACE << "CacheMap destruct!";
    }
    struct MapValue
    {
        MapValue(const T2 &value,
                 size_t timeout,
                 std::function<void()> &&callback)
            : value_(value),
              timeout_(timeout),
              timeoutCallback_(std::move(callback))
        {
        }
        MapValue(T2 &&value, size_t timeout, std::function<void()> &&callback)
            : value_(std::move(value)),
              timeout_(timeout),
              timeoutCallback_(std::move(callback))
        {
        }
        MapValue(T2 &&value, size_t timeout)
            : value_(std::move(value)), timeout_(timeout)
        {
        }
        MapValue(const T2 &value, size_t timeout)
            : value_(value), timeout_(timeout)
        {
        }
        MapValue(T2 &&value) : value_(std::move(value))
        {
        }
        MapValue(const T2 &value) : value_(value)
        {
        }
        MapValue() = default;
        T2 value_;
        size_t timeout_{0};
        std::function<void()> timeoutCallback_;
        WeakCallbackEntryPtr weakEntryPtr_;
    };
    void insert(const T1 &key,
                T2 &&value,
                size_t timeout = 0,
                std::function<void()> timeoutCallback = std::function<void()>())
    {
        if (timeout > 0)
        {
            MapValue v{std::move(value), timeout, std::move(timeoutCallback)};
            std::lock_guard<std::mutex> lock(mtx_);
            map_.insert(std::make_pair(key, std::move(v)));
            eraseAfter(timeout, key);
        }
        else
        {
            MapValue v{std::move(value)};
            std::lock_guard<std::mutex> lock(mtx_);
            map_.insert(std::make_pair(key, std::move(v)));
        }
        if (fnOnInsert_)
            fnOnInsert_(key);
    }
    void insert(const T1 &key,
                const T2 &value,
                size_t timeout = 0,
                std::function<void()> timeoutCallback = std::function<void()>())
    {
        if (timeout > 0)
        {
            MapValue v{value, timeout, std::move(timeoutCallback)};
            std::lock_guard<std::mutex> lock(mtx_);
            map_.insert(std::make_pair(key, std::move(v)));
            eraseAfter(timeout, key);
        }
        else
        {
            MapValue v{value};
            std::lock_guard<std::mutex> lock(mtx_);
            map_.insert(std::make_pair(key, std::move(v)));
        }
        if (fnOnInsert_)
            fnOnInsert_(key);
    }
    T2 operator[](const T1 &key)
    {
        size_t timeout = 0;
        std::lock_guard<std::mutex> lock(mtx_);
        auto iter = map_.find(key);
        if (iter != map_.end())
        {
            timeout = iter->second.timeout_;
            if (timeout > 0)
                eraseAfter(timeout, key);
            return iter->second.value_;
        }
        return T2();
    }
    template <typename Callable>
    void modify(const T1 &key, Callable &&handler, size_t timeout = 0)
    {
        {
            std::lock_guard<std::mutex> lock(mtx_);
            auto iter = map_.find(key);
            if (iter != map_.end())
            {
                timeout = iter->second.timeout_;
                handler(iter->second.value_);
                if (timeout > 0)
                    eraseAfter(timeout, key);
                return;
            }
            MapValue v{T2(), timeout};
            handler(v.value_);
            map_.insert(std::make_pair(key, std::move(v)));
            if (timeout > 0)
            {
                eraseAfter(timeout, key);
            }
        }
        if (fnOnInsert_)
            fnOnInsert_(key);
    }
    bool find(const T1 &key)
    {
        size_t timeout = 0;
        bool flag = false;
        std::lock_guard<std::mutex> lock(mtx_);
        auto iter = map_.find(key);
        if (iter != map_.end())
        {
            timeout = iter->second.timeout_;
            flag = true;
        }
        if (timeout > 0)
            eraseAfter(timeout, key);
        return flag;
    }
    bool findAndFetch(const T1 &key, T2 &value)
    {
        size_t timeout = 0;
        bool flag = false;
        std::lock_guard<std::mutex> lock(mtx_);
        auto iter = map_.find(key);
        if (iter != map_.end())
        {
            timeout = iter->second.timeout_;
            flag = true;
            value = iter->second.value_;
        }
        if (timeout > 0)
            eraseAfter(timeout, key);
        return flag;
    }
    void erase(const T1 &key)
    {
        {
            std::lock_guard<std::mutex> lock(mtx_);
            map_.erase(key);
        }
        if (fnOnErase_)
            fnOnErase_(key);
    }
    trantor::EventLoop *getLoop()
    {
        return loop_;
    }
    void runAfter(size_t delay, std::function<void()> &&task)
    {
        std::lock_guard<std::mutex> lock(bucketMutex_);
        insertEntry(delay, std::make_shared<CallbackEntry>(std::move(task)));
    }
    void runAfter(size_t delay, const std::function<void()> &task)
    {
        std::lock_guard<std::mutex> lock(bucketMutex_);
        insertEntry(delay, std::make_shared<CallbackEntry>(task));
    }
  private:
    struct ControlBlock
    {
        ControlBlock() : destructed(false), loopEnded(false)
        {
        }
        bool destructed;
        bool loopEnded;
        std::mutex mtx;
    };
    std::unordered_map<T1, MapValue> map_;
    std::vector<CallbackBucketQueue> wheels_;
    std::atomic<size_t> ticksCounter_{0};
    std::mutex mtx_;
    std::mutex bucketMutex_;
    trantor::TimerId timerId_;
    trantor::EventLoop *loop_;
    float tickInterval_;
    size_t wheelsNumber_;
    size_t bucketsNumPerWheel_;
    std::shared_ptr<ControlBlock> ctrlBlockPtr_;
    std::function<void(const T1 &)> fnOnInsert_;
    std::function<void(const T1 &)> fnOnErase_;
    bool noWheels_{false};
    void insertEntry(size_t delay, CallbackEntryPtr entryPtr)
    {
        if (delay <= 0)
            return;
        delay = static_cast<size_t>(delay / tickInterval_ + 1);
        size_t t = ticksCounter_;
        for (size_t i = 0; i < wheelsNumber_; ++i)
        {
            if (delay <= bucketsNumPerWheel_)
            {
                wheels_[i][delay - 1].insert(entryPtr);
                break;
            }
            if (i < (wheelsNumber_ - 1))
            {
                entryPtr = std::make_shared<CallbackEntry>(
                    [this, delay, i, t, entryPtr]() {
                        if (delay > 0)
                        {
                            std::lock_guard<std::mutex> lock(bucketMutex_);
                            wheels_[i][(delay + (t % bucketsNumPerWheel_) - 1) %
                                       bucketsNumPerWheel_]
                                .insert(entryPtr);
                        }
                    });
            }
            else
            {
                wheels_[i][bucketsNumPerWheel_ - 1].insert(entryPtr);
            }
            delay =
                (delay + (t % bucketsNumPerWheel_) - 1) / bucketsNumPerWheel_;
            t = t / bucketsNumPerWheel_;
        }
    }
    void eraseAfter(size_t delay, const T1 &key)
    {
        if (noWheels_)
            return;
        assert(map_.find(key) != map_.end());
        CallbackEntryPtr entryPtr;
        if (map_.find(key) != map_.end())
        {
            entryPtr = map_[key].weakEntryPtr_.lock();
        }
        if (entryPtr)
        {
            std::lock_guard<std::mutex> lock(bucketMutex_);
            insertEntry(delay, entryPtr);
        }
        else
        {
            std::function<void()> cb = [this, key]() {
                bool erased{false};
                std::function<void()> timeoutCallback;
                {
                    std::lock_guard<std::mutex> lock(mtx_);
                    auto iter = map_.find(key);
                    if (iter != map_.end())
                    {
                        auto &value = iter->second;
                        auto entryPtr = value.weakEntryPtr_.lock();
                        if (value.timeout_ > 0 && !entryPtr)
                        {
                            erased = true;
                            timeoutCallback = std::move(value.timeoutCallback_);
                            map_.erase(key);
                        }
                    }
                }
                if (erased && fnOnErase_)
                    fnOnErase_(key);
                if (erased && timeoutCallback)
                    timeoutCallback();
            };
            entryPtr = std::make_shared<CallbackEntry>(std::move(cb));
            map_[key].weakEntryPtr_ = entryPtr;
            {
                std::lock_guard<std::mutex> lock(bucketMutex_);
                insertEntry(delay, entryPtr);
            }
        }
    }
};
}  