#pragma once
#include <trantor/exports.h>
#include <trantor/net/EventLoop.h>
#include <trantor/net/InetAddress.h>
#include <trantor/utils/NonCopyable.h>
#include <trantor/utils/MsgBuffer.h>
#include <trantor/net/callbacks.h>
#include <trantor/net/Certificate.h>
#include <trantor/net/TLSPolicy.h>
#include <trantor/net/AsyncStream.h>
#include <memory>
#include <functional>
#include <string>
namespace trantor
{
class TimingWheel;
struct SSLContext;
using SSLContextPtr = std::shared_ptr<SSLContext>;
class TRANTOR_EXPORT TcpConnection
{
  public:
    friend class TcpServer;
    friend class TcpConnectionImpl;
    friend class TcpClient;
    TcpConnection() = default;
    virtual ~TcpConnection(){};
    virtual void send(const char *msg, size_t len) = 0;
    virtual void send(const void *msg, size_t len) = 0;
    virtual void send(const std::string &msg) = 0;
    virtual void send(std::string &&msg) = 0;
    virtual void send(const MsgBuffer &buffer) = 0;
    virtual void send(MsgBuffer &&buffer) = 0;
    virtual void send(const std::shared_ptr<std::string> &msgPtr) = 0;
    virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) = 0;
    virtual void sendFile(const char *fileName,
                          long long offset = 0,
                          long long length = 0) = 0;
    virtual void sendFile(const wchar_t *fileName,
                          long long offset = 0,
                          long long length = 0) = 0;
    virtual void sendStream(std::function<std::size_t(char *, std::size_t)>
                                callback) = 0;  
    virtual AsyncStreamPtr sendAsyncStream(bool disableKickoff = false) = 0;
    virtual const InetAddress &localAddr() const = 0;
    virtual const InetAddress &peerAddr() const = 0;
    virtual bool connected() const = 0;
    virtual bool disconnected() const = 0;
    virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
                                          size_t markLen) = 0;
    virtual void setTcpNoDelay(bool on) = 0;
    virtual void shutdown() = 0;
    virtual void forceClose() = 0;
    virtual EventLoop *getLoop() = 0;
    void setContext(const std::shared_ptr<void> &context)
    {
        contextPtr_ = context;
    }
    void setContext(std::shared_ptr<void> &&context)
    {
        contextPtr_ = std::move(context);
    }
    virtual std::string applicationProtocol() const = 0;
    template <typename T>
    std::shared_ptr<T> getContext() const
    {
        return std::static_pointer_cast<T>(contextPtr_);
    }
    bool hasContext() const
    {
        return (bool)contextPtr_;
    }
    void clearContext()
    {
        contextPtr_.reset();
    }
    virtual void keepAlive() = 0;
    virtual bool isKeepAlive() = 0;
    virtual size_t bytesSent() const = 0;
    virtual size_t bytesReceived() const = 0;
    virtual bool isSSLConnection() const = 0;
    virtual MsgBuffer *getRecvBuffer() = 0;
    virtual CertificatePtr peerCertificate() const = 0;
    virtual std::string sniName() const = 0;
    virtual void startEncryption(TLSPolicyPtr policy,
                                 bool isServer,
                                 std::function<void(const TcpConnectionPtr &)>
                                     upgradeCallback = nullptr) = 0;
    [[deprecated("Use startEncryption(TLSPolicyPtr) instead")]] void
    startClientEncryption(
        std::function<void(const TcpConnectionPtr &)> &&callback,
        bool useOldTLS = false,
        bool validateCert = true,
        const std::string &hostname = "",
        const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
            {})
    {
        auto policy = TLSPolicy::defaultClientPolicy();
        policy->setUseOldTLS(useOldTLS)
            .setValidate(validateCert)
            .setHostname(hostname)
            .setConfCmds(sslConfCmds);
        startEncryption(std::move(policy), false, std::move(callback));
    }
    void setValidationPolicy(TLSPolicy &&policy)
    {
        tlsPolicy_ = std::move(policy);
    }
    void setRecvMsgCallback(const RecvMessageCallback &cb)
    {
        recvMsgCallback_ = cb;
    }
    void setRecvMsgCallback(RecvMessageCallback &&cb)
    {
        recvMsgCallback_ = std::move(cb);
    }
    void setConnectionCallback(const ConnectionCallback &cb)
    {
        connectionCallback_ = cb;
    }
    void setConnectionCallback(ConnectionCallback &&cb)
    {
        connectionCallback_ = std::move(cb);
    }
    void setWriteCompleteCallback(const WriteCompleteCallback &cb)
    {
        writeCompleteCallback_ = cb;
    }
    void setWriteCompleteCallback(WriteCompleteCallback &&cb)
    {
        writeCompleteCallback_ = std::move(cb);
    }
    void setCloseCallback(const CloseCallback &cb)
    {
        closeCallback_ = cb;
    }
    void setCloseCallback(CloseCallback &&cb)
    {
        closeCallback_ = std::move(cb);
    }
    void setSSLErrorCallback(const SSLErrorCallback &cb)
    {
        sslErrorCallback_ = cb;
    }
    void setSSLErrorCallback(SSLErrorCallback &&cb)
    {
        sslErrorCallback_ = std::move(cb);
    }
    virtual void connectEstablished() = 0;
    virtual void connectDestroyed() = 0;
    virtual void enableKickingOff(
        size_t timeout,
        const std::shared_ptr<TimingWheel> &timingWheel) = 0;
  protected:
    RecvMessageCallback recvMsgCallback_;
    ConnectionCallback connectionCallback_;
    CloseCallback closeCallback_;
    WriteCompleteCallback writeCompleteCallback_;
    HighWaterMarkCallback highWaterMarkCallback_;
    SSLErrorCallback sslErrorCallback_;
    TLSPolicy tlsPolicy_;
  private:
    std::shared_ptr<void> contextPtr_;
};
TRANTOR_EXPORT SSLContextPtr newSSLContext(const TLSPolicy &policy,
                                           bool server);
}  