#pragma once
#include <drogon/HttpRequest.h>
#include <drogon/HttpResponse.h>
#include <drogon/WebSocketConnection.h>
#include <drogon/HttpTypes.h>
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>
#endif
#include <functional>
#include <memory>
#include <string>
#include <trantor/net/EventLoop.h>
namespace drogon
{
class WebSocketClient;
using WebSocketClientPtr = std::shared_ptr<WebSocketClient>;
using WebSocketRequestCallback = std::function<
    void(ReqResult, const HttpResponsePtr &, const WebSocketClientPtr &)>;
#ifdef __cpp_impl_coroutine
namespace internal
{
struct [[nodiscard]] WebSocketConnectionAwaiter
    : public CallbackAwaiter<HttpResponsePtr>
{
    WebSocketConnectionAwaiter(WebSocketClient *client, HttpRequestPtr req)
        : client_(client), req_(std::move(req))
    {
    }
    void await_suspend(std::coroutine_handle<> handle);
  private:
    WebSocketClient *client_;
    HttpRequestPtr req_;
};
}  
#endif
class DROGON_EXPORT WebSocketClient
{
  public:
    virtual WebSocketConnectionPtr getConnection() = 0;
    virtual void setMessageHandler(
        const std::function<void(std::string &&message,
                                 const WebSocketClientPtr &,
                                 const WebSocketMessageType &)> &callback) = 0;
    virtual void setConnectionClosedHandler(
        const std::function<void(const WebSocketClientPtr &)> &callback) = 0;
    virtual void connectToServer(const HttpRequestPtr &request,
                                 const WebSocketRequestCallback &callback) = 0;
    virtual void setCertPath(const std::string &cert,
                             const std::string &key) = 0;
    virtual void addSSLConfigs(
        const std::vector<std::pair<std::string, std::string>>
            &sslConfCmds) = 0;
#ifdef __cpp_impl_coroutine
    void setAsyncMessageHandler(
        const std::function<Task<>(std::string &&message,
                                   const WebSocketClientPtr &,
                                   const WebSocketMessageType &)> &callback)
    {
        setMessageHandler([callback](std::string &&message,
                                     const WebSocketClientPtr &client,
                                     const WebSocketMessageType &type) -> void {
            [callback](std::string &&message,
                       const WebSocketClientPtr client,
                       const WebSocketMessageType type) -> AsyncTask {
                co_await callback(std::move(message), client, type);
            }(std::move(message), client, type);
        });
    }
    void setAsyncConnectionClosedHandler(
        const std::function<Task<>(const WebSocketClientPtr &)> &callback)
    {
        setConnectionClosedHandler(
            [callback](const WebSocketClientPtr &client) {
                [=]() -> AsyncTask { co_await callback(client); }();
            });
    }
    internal::WebSocketConnectionAwaiter connectToServerCoro(
        const HttpRequestPtr &request)
    {
        return internal::WebSocketConnectionAwaiter(this, request);
    }
#endif
    virtual trantor::EventLoop *getLoop() = 0;
    virtual void stop() = 0;
    static WebSocketClientPtr newWebSocketClient(
        const std::string &ip,
        uint16_t port,
        bool useSSL = false,
        trantor::EventLoop *loop = nullptr,
        bool useOldTLS = false,
        bool validateCert = true);
    static WebSocketClientPtr newWebSocketClient(
        const std::string &hostString,
        trantor::EventLoop *loop = nullptr,
        bool useOldTLS = false,
        bool validateCert = true);
    virtual ~WebSocketClient() = default;
};
#ifdef __cpp_impl_coroutine
inline void internal::WebSocketConnectionAwaiter::await_suspend(
    std::coroutine_handle<> handle)
{
    client_->connectToServer(req_,
                             [this, handle](ReqResult result,
                                            const HttpResponsePtr &resp,
                                            const WebSocketClientPtr &) {
                                 if (result == ReqResult::Ok)
                                     setValue(resp);
                                 else
                                 {
                                     std::string reason;
                                     if (result == ReqResult::BadResponse)
                                         reason = "BadResponse";
                                     else if (result ==
                                              ReqResult::NetworkFailure)
                                         reason = "NetworkFailure";
                                     else if (result ==
                                              ReqResult::BadServerAddress)
                                         reason = "BadServerAddress";
                                     else if (result == ReqResult::Timeout)
                                         reason = "Timeout";
                                     setException(std::make_exception_ptr(
                                         std::runtime_error(reason)));
                                 }
                                 handle.resume();
                             });
}
#endif
}  