diff --git a/lib/libimhex/include/hex/mcp/server.hpp b/lib/libimhex/include/hex/mcp/server.hpp index 959a5d931..7599449c6 100644 --- a/lib/libimhex/include/hex/mcp/server.hpp +++ b/lib/libimhex/include/hex/mcp/server.hpp @@ -1,11 +1,51 @@ #pragma once +#include + #include #include #include namespace hex::mcp { + class JsonRpc { + public: + explicit JsonRpc(std::string request) : m_request(std::move(request)){ } + + struct MethodNotFoundException : std::exception {}; + struct InvalidParametersException : std::exception {}; + + enum class ErrorCode: i16 { + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + }; + + using Callback = std::function; + std::optional execute(const Callback &callback); + void setError(ErrorCode code, std::string message); + + private: + std::optional handleMessage(const nlohmann::json &request, const Callback &callback); + std::optional handleBatchedMessages(const nlohmann::json &request, const Callback &callback); + + nlohmann::json createDefaultMessage(); + nlohmann::json createErrorMessage(ErrorCode code, const std::string &message); + nlohmann::json createResponseMessage(const nlohmann::json &result); + + private: + std::string m_request; + std::optional m_id; + + struct Error { + ErrorCode code; + std::string message; + }; + std::optional m_error; + }; + struct TextContent { std::string text; diff --git a/lib/libimhex/source/mcp/client.cpp b/lib/libimhex/source/mcp/client.cpp index 41c34fcaf..e09b3830f 100644 --- a/lib/libimhex/source/mcp/client.cpp +++ b/lib/libimhex/source/mcp/client.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -17,20 +18,26 @@ namespace hex::mcp { wolv::net::SocketClient client(wolv::net::SocketClient::Type::TCP, true); client.connect("127.0.0.1", Server::McpInternalPort); - if (!client.isConnected()) { - log::resumeLogging(); - log::error("Cannot connect to ImHex. Do you have an instance running and is the MCP server enabled?"); - return EXIT_FAILURE; - } + fprintf(stderr, "Established connection to main ImHex instance!\n"); while (true) { std::string request; std::getline(input, request); + if (ImHexApi::System::isMainInstance()) { + JsonRpc response(request); + response.setError(JsonRpc::ErrorCode::InternalError, "No other instance of ImHex is running. Make sure that you have ImHex open already."); + output << response.execute([](auto, auto){ return nlohmann::json::object(); }).value_or("") << '\n'; + continue; + } + client.writeString(request); auto response = client.readString(); if (!response.empty() && response.front() != 0x00) output << response << '\n'; + + if (!client.isConnected()) + break; } return EXIT_SUCCESS; diff --git a/lib/libimhex/source/mcp/server.cpp b/lib/libimhex/source/mcp/server.cpp index 8d9a0e0da..c6151627b 100644 --- a/lib/libimhex/source/mcp/server.cpp +++ b/lib/libimhex/source/mcp/server.cpp @@ -12,45 +12,41 @@ namespace hex::mcp { - class JsonRpc { - public: - explicit JsonRpc(std::string request) : m_request(std::move(request)) { } + std::optional JsonRpc::execute(const Callback &callback) { + try { + auto requestJson = nlohmann::json::parse(m_request); - struct MethodNotFoundException : std::exception {}; - struct InvalidParametersException : std::exception {}; - - std::optional execute(auto callback) { - try { - auto requestJson = nlohmann::json::parse(m_request); - - if (requestJson.is_array()) { - return handleBatchedMessages(requestJson, callback).transform([](const auto &response) { return response.dump(); }); - } else { - return handleMessage(requestJson, callback).transform([](const auto &response) { return response.dump(); }); - } - } catch (const MethodNotFoundException &) { - return createErrorMessage(ErrorCode::MethodNotFound, "Method not found").dump(); - } catch (const InvalidParametersException &) { - return createErrorMessage(ErrorCode::InvalidParams, "Invalid params").dump(); - } catch (const nlohmann::json::parse_error &) { - return createErrorMessage(ErrorCode::ParseError, "Parse error").dump(); - } catch (const std::exception &e) { - return createErrorMessage(ErrorCode::InternalError, e.what()).dump(); + if (requestJson.is_array()) { + return handleBatchedMessages(requestJson, callback).transform([](const auto &response) { return response.dump(); }); + } else { + return handleMessage(requestJson, callback).transform([](const auto &response) { return response.dump(); }); } + } catch (const nlohmann::json::exception &) { + return createErrorMessage(ErrorCode::ParseError, "Parse error").dump(); } + } - private: - std::optional handleMessage(const nlohmann::json &request, auto callback) { + void JsonRpc::setError(ErrorCode code, std::string message) { + m_error = Error{ code, std::move(message) }; + } + + std::optional JsonRpc::handleMessage(const nlohmann::json &request, const Callback &callback) { + try { // Validate JSON-RPC request if (!request.contains("jsonrpc") || request["jsonrpc"] != "2.0" || !request.contains("method") || !request["method"].is_string()) { m_id = request.contains("id") ? std::optional(request["id"].get()) : std::nullopt; - return createErrorMessage(ErrorCode::InvalidRequest, "Invalid Request").dump(); + return createErrorMessage(ErrorCode::InvalidRequest, "Invalid Request"); } m_id = request.contains("id") ? std::optional(request["id"].get()) : std::nullopt; + // Return a user-specified error if set + if (m_error.has_value()) { + return createErrorMessage(m_error->code, m_error->message); + } + // Execute the method auto result = callback(request["method"].get(), request.value("params", nlohmann::json::object())); @@ -58,64 +54,58 @@ namespace hex::mcp { return std::nullopt; return createResponseMessage(result.is_null() ? nlohmann::json::object() : result); + } catch (const MethodNotFoundException &) { + return createErrorMessage(ErrorCode::MethodNotFound, "Method not found"); + } catch (const InvalidParametersException &) { + return createErrorMessage(ErrorCode::InvalidParams, "Invalid params"); + } catch (const std::exception &e) { + return createErrorMessage(ErrorCode::InternalError, e.what()); + } + } + + std::optional JsonRpc::handleBatchedMessages(const nlohmann::json &request, const Callback &callback) { + if (!request.is_array()) { + return createErrorMessage(ErrorCode::InvalidRequest, "Invalid Request"); } - std::optional handleBatchedMessages(const nlohmann::json &request, auto callback) { - if (!request.is_array()) { - return createErrorMessage(ErrorCode::InvalidRequest, "Invalid Request").dump(); - } - - nlohmann::json responses = nlohmann::json::array(); - for (const auto &message : request) { - auto response = handleMessage(message, callback); - if (response.has_value()) - responses.push_back(*response); - } - - if (responses.empty()) - return std::nullopt; - - return responses.dump(); + nlohmann::json responses = nlohmann::json::array(); + for (const auto &message : request) { + auto response = handleMessage(message, callback); + if (response.has_value()) + responses.push_back(*response); } - enum class ErrorCode: i16 { - ParseError = -32700, - InvalidRequest = -32600, - MethodNotFound = -32601, - InvalidParams = -32602, - InternalError = -32603, + if (responses.empty()) + return std::nullopt; + + return responses; + } + + nlohmann::json JsonRpc::createDefaultMessage() { + nlohmann::json message; + message["jsonrpc"] = "2.0"; + if (m_id.has_value()) + message["id"] = m_id.value(); + else + message["id"] = nullptr; + + return message; + } + + nlohmann::json JsonRpc::createErrorMessage(ErrorCode code, const std::string &message) { + auto json = createDefaultMessage(); + json["error"] = { + { "code", int(code) }, + { "message", message } }; + return json; + } - nlohmann::json createDefaultMessage() { - nlohmann::json message; - message["jsonrpc"] = "2.0"; - if (m_id.has_value()) - message["id"] = m_id.value(); - else - message["id"] = nullptr; - - return message; - } - - nlohmann::json createErrorMessage(ErrorCode code, const std::string &message) { - auto json = createDefaultMessage(); - json["error"] = { - { "code", int(code) }, - { "message", message } - }; - return json; - } - - nlohmann::json createResponseMessage(const nlohmann::json &result) { - auto json = createDefaultMessage(); - json["result"] = result; - return json; - } - - private: - std::string m_request; - std::optional m_id; - }; + nlohmann::json JsonRpc::createResponseMessage(const nlohmann::json &result) { + auto json = createDefaultMessage(); + json["result"] = result; + return json; + } Server::Server() : m_server(McpInternalPort, 1024, 1, true) {