#include "HttpRequestParser.h"
#include <drogon/HttpTypes.h>
#include <trantor/utils/Logger.h>
#include <trantor/utils/MsgBuffer.h>
#include <iostream>
#include "HttpAppFrameworkImpl.h"
#include "HttpRequestImpl.h"
#include "HttpResponseImpl.h"
#include "HttpUtils.h"
using namespace trantor;
using namespace drogon;
static constexpr size_t CRLF_LEN = 2;            
static constexpr size_t METHOD_MAX_LEN = 7;      
static constexpr size_t TRUNK_LEN_MAX_LEN = 16;  
HttpRequestParser::HttpRequestParser(const trantor::TcpConnectionPtr &connPtr)
    : status_(HttpRequestParseStatus::kExpectMethod),
      loop_(connPtr->getLoop()),
      conn_(connPtr)
{
}
bool HttpRequestParser::processRequestLine(const char *begin, const char *end)
{
    bool succeed = false;
    const char *start = begin;
    const char *space = std::find(start, end, ' ');
    if (space != end)
    {
        const char *slash = std::find(start, space, '/');
        if (slash != start && slash + 1 < space && *(slash + 1) == '/')
        {
            slash = std::find(slash + 2, space, '/');
        }
        const char *question = std::find(slash, space, '?');
        if (slash != space)
        {
            request_->setPath(slash, question);
        }
        else
        {
            request_->setPath("/");
        }
        if (question != space)
        {
            request_->setQuery(question + 1, space);
        }
        start = space + 1;
        succeed = end - start == 8 && std::equal(start, end - 1, "HTTP/1.");
        if (succeed)
        {
            if (*(end - 1) == '1')
            {
                request_->setVersion(Version::kHttp11);
            }
            else if (*(end - 1) == '0')
            {
                request_->setVersion(Version::kHttp10);
            }
            else
            {
                succeed = false;
            }
        }
    }
    return succeed;
}
HttpRequestImplPtr HttpRequestParser::makeRequestForPool(HttpRequestImpl *ptr)
{
    return std::shared_ptr<HttpRequestImpl>(
        ptr, [weakPtr = weak_from_this()](HttpRequestImpl *p) {
            auto thisPtr = weakPtr.lock();
            if (thisPtr)
            {
                if (thisPtr->loop_->isInLoopThread())
                {
                    p->reset();
                    thisPtr->requestsPool_.emplace_back(
                        thisPtr->makeRequestForPool(p));
                }
                else
                {
                    auto &loop = thisPtr->loop_;
                    loop->queueInLoop([thisPtr = std::move(thisPtr), p]() {
                        p->reset();
                        thisPtr->requestsPool_.emplace_back(
                            thisPtr->makeRequestForPool(p));
                    });
                }
            }
            else
            {
                delete p;
            }
        });
}
void HttpRequestParser::reset()
{
    assert(loop_->isInLoopThread());
    remainContentLength_ = 0;
    status_ = HttpRequestParseStatus::kExpectMethod;
    if (requestsPool_.empty())
    {
        request_ = makeRequestForPool(new HttpRequestImpl(loop_));
    }
    else
    {
        auto req = std::move(requestsPool_.back());
        requestsPool_.pop_back();
        request_ = std::move(req);
        request_->setCreationDate(trantor::Date::now());
    }
}
int HttpRequestParser::parseRequest(MsgBuffer *buf)
{
    while (true)
    {
        switch (status_)
        {
            case (HttpRequestParseStatus::kExpectMethod):
            {
                auto *space = std::find(buf->peek(),
                                        (const char *)buf->beginWrite(),
                                        ' ');
                if (space == buf->beginWrite())
                {
                    if (buf->readableBytes() > METHOD_MAX_LEN)
                    {
                        return -k400BadRequest;
                    }
                    return 0;
                }
                if (!request_->setMethod(buf->peek(), space))
                {
                    return -k405MethodNotAllowed;
                }
                status_ = HttpRequestParseStatus::kExpectRequestLine;
                buf->retrieveUntil(space + 1);
                continue;
            }
            case HttpRequestParseStatus::kExpectRequestLine:
            {
                const char *crlf = buf->findCRLF();
                if (!crlf)
                {
                    if (buf->readableBytes() >= 64 * 1024)
                    {
                        return -k414RequestURITooLarge;
                    }
                    return 0;
                }
                if (!processRequestLine(buf->peek(), crlf))
                {
                    return -k400BadRequest;
                }
                buf->retrieveUntil(crlf + CRLF_LEN);
                status_ = HttpRequestParseStatus::kExpectHeaders;
                continue;
            }
            case HttpRequestParseStatus::kExpectHeaders:
            {
                const char *crlf = buf->findCRLF();
                if (!crlf)
                {
                    if (buf->readableBytes() >= 64 * 1024)
                    {
                        return -k400BadRequest;
                    }
                    return 0;
                }
                const char *colon = std::find(buf->peek(), crlf, ':');
                if (colon != crlf)
                {
                    request_->addHeader(buf->peek(), colon, crlf);
                    buf->retrieveUntil(crlf + CRLF_LEN);
                    continue;
                }
                buf->retrieveUntil(crlf + CRLF_LEN);
                auto &len = request_->getHeaderBy("content-length");
                if (!len.empty())
                {
                    try
                    {
                        remainContentLength_ =
                            static_cast<size_t>(std::stoull(len));
                    }
                    catch (...)
                    {
                        return -k400BadRequest;
                    }
                    request_->contentLengthHeaderValue_ = remainContentLength_;
                    if (remainContentLength_ == 0)
                    {
                        status_ = HttpRequestParseStatus::kGotAll;
                    }
                    else
                    {
                        status_ = HttpRequestParseStatus::kExpectBody;
                    }
                }
                else
                {
                    const std::string &encode =
                        request_->getHeaderBy("transfer-encoding");
                    if (encode.empty())
                    {
                        status_ = HttpRequestParseStatus::kGotAll;
                    }
                    else if (encode == "chunked")
                    {
                        status_ = HttpRequestParseStatus::kExpectChunkLen;
                    }
                    else
                    {
                        return -k501NotImplemented;
                    }
                }
                if (remainContentLength_ >
                    HttpAppFrameworkImpl::instance().getClientMaxBodySize())
                {
                    return -k413RequestEntityTooLarge;
                }
                auto &expect = request_->expect();
                if (expect == "100-continue" &&
                    request_->getVersion() >= Version::kHttp11)
                {
                    if (remainContentLength_ == 0)
                    {
                        return -k400BadRequest;
                    }
                    else
                    {
                        auto connPtr = conn_.lock();  
                        if (!connPtr)
                        {
                            return -1;
                        }
                        auto resp = HttpResponse::newHttpResponse();
                        resp->setStatusCode(k100Continue);
                        auto httpString =
                            static_cast<HttpResponseImpl *>(resp.get())
                                ->renderToBuffer();
                        connPtr->send(std::move(*httpString));
                    }
                }
                else if (!expect.empty())
                {
                    LOG_WARN << "417ExpectationFailed for \"" << expect << "\"";
                    return -k417ExpectationFailed;
                }
                assert(status_ == HttpRequestParseStatus::kGotAll ||
                       status_ == HttpRequestParseStatus::kExpectBody ||
                       status_ == HttpRequestParseStatus::kExpectChunkLen);
                if (app().isRequestStreamEnabled())
                {
                    request_->streamStart();
                    if (status_ == HttpRequestParseStatus::kGotAll)
                    {
                        ++requestsCounter_;
                        return 2;
                    }
                    else
                    {
                        return 3;
                    }
                }
                if (remainContentLength_)
                {
                    request_->reserveBodySize(remainContentLength_);
                }
                continue;
            }
            case HttpRequestParseStatus::kExpectBody:
            {
                size_t bytesToConsume =
                    remainContentLength_ <= buf->readableBytes()
                        ? remainContentLength_
                        : buf->readableBytes();
                if (bytesToConsume)
                {
                    request_->appendToBody(buf->peek(), bytesToConsume);
                    buf->retrieve(bytesToConsume);
                    remainContentLength_ -= bytesToConsume;
                }
                if (remainContentLength_ == 0)
                {
                    status_ = HttpRequestParseStatus::kGotAll;
                    ++requestsCounter_;
                    return 1;
                }
                return 0;
            }
            case HttpRequestParseStatus::kExpectChunkLen:
            {
                const char *crlf = buf->findCRLF();
                if (!crlf)
                {
                    if (buf->readableBytes() > TRUNK_LEN_MAX_LEN + CRLF_LEN)
                    {
                        return -k400BadRequest;
                    }
                    return 0;
                }
                std::string len(buf->peek(), crlf - buf->peek());
                char *end;
                currentChunkLength_ = strtol(len.c_str(), &end, 16);
                if (currentChunkLength_ != 0)
                {
                    if (currentChunkLength_ + remainContentLength_ >
                        HttpAppFrameworkImpl::instance().getClientMaxBodySize())
                    {
                        return -k413RequestEntityTooLarge;
                    }
                    status_ = HttpRequestParseStatus::kExpectChunkBody;
                }
                else
                {
                    status_ = HttpRequestParseStatus::kExpectLastEmptyChunk;
                }
                buf->retrieveUntil(crlf + CRLF_LEN);
                continue;
            }
            case HttpRequestParseStatus::kExpectChunkBody:
            {
                if (buf->readableBytes() < (currentChunkLength_ + CRLF_LEN))
                {
                    return 0;
                }
                if (*(buf->peek() + currentChunkLength_) != '\r' ||
                    *(buf->peek() + currentChunkLength_ + 1) != '\n')
                {
                    return -k400BadRequest;
                }
                request_->appendToBody(buf->peek(), currentChunkLength_);
                buf->retrieve(currentChunkLength_ + CRLF_LEN);
                remainContentLength_ += currentChunkLength_;
                currentChunkLength_ = 0;
                status_ = HttpRequestParseStatus::kExpectChunkLen;
                continue;
            }
            case HttpRequestParseStatus::kExpectLastEmptyChunk:
            {
                if (buf->readableBytes() < CRLF_LEN)
                {
                    return 0;
                }
                if (*(buf->peek()) != '\r' || *(buf->peek() + 1) != '\n')
                {
                    return -k400BadRequest;
                }
                buf->retrieve(CRLF_LEN);
                if (!request_->isStreamMode())
                {
                    request_->addHeader("content-length",
                                        std::to_string(
                                            request_->realContentLength()));
                    request_->removeHeaderBy("transfer-encoding");
                }
                status_ = HttpRequestParseStatus::kGotAll;
                ++requestsCounter_;
                return 1;
            }
            case HttpRequestParseStatus::kGotAll:
            {
                ++requestsCounter_;
                return 1;
            }
        }
    }
    return -1;  
}
void HttpRequestParser::pushRequestToPipelining(const HttpRequestPtr &req,
                                                bool isHeadMethod)
{
    assert(loop_->isInLoopThread());
    requestPipelining_.push_back({req, {nullptr, isHeadMethod}});
}
bool HttpRequestParser::pushResponseToPipelining(const HttpRequestPtr &req,
                                                 HttpResponsePtr resp)
{
    assert(loop_->isInLoopThread());
    for (size_t i = 0; i != requestPipelining_.size(); ++i)
    {
        if (requestPipelining_[i].first == req)
        {
            requestPipelining_[i].second.first = std::move(resp);
            return i == 0;
        }
    }
    assert(false);  
    return false;
}
void HttpRequestParser::popReadyResponses(
    std::vector<std::pair<HttpResponsePtr, bool>> &buffer)
{
    while (!requestPipelining_.empty() &&
           requestPipelining_.front().second.first)
    {
        buffer.push_back(std::move(requestPipelining_.front().second));
        requestPipelining_.pop_front();
    }
}