diff --git a/modules/softwareintegration/softwareintegrationmodule.cpp b/modules/softwareintegration/softwareintegrationmodule.cpp index c5a40fceb7..e8721e202c 100644 --- a/modules/softwareintegration/softwareintegrationmodule.cpp +++ b/modules/softwareintegration/softwareintegrationmodule.cpp @@ -47,14 +47,24 @@ namespace { namespace openspace { - const unsigned int SoftwareIntegrationModule::ProtocolVersion = 1; + const unsigned int Connection::ProtocolVersion = 1; SoftwareIntegrationModule::SoftwareIntegrationModule() : OpenSpaceModule(Name) {} + Connection::Message::Message(MessageType t, std::vector c) + : type(t) + , content(std::move(c)) + {} + + Connection::ConnectionLostError::ConnectionLostError() + : ghoul::RuntimeError("Connection lost", "Connection") + {} + Connection::Connection(std::unique_ptr socket) : _socket(std::move(socket)) {} + void SoftwareIntegrationModule::internalInitialize(const ghoul::Dictionary&) { auto fRenderable = FactoryManager::ref().factory(); ghoul_assert(fRenderable, "No renderable factory existed"); @@ -85,24 +95,71 @@ namespace openspace { return _socket.get(); } + // Connection + Connection::Message Connection::receiveMessage() { + // Header consists of... + /*constexpr size_t HeaderSize = + 2 * sizeof(char) + // OS + 3 * sizeof(uint32_t); // Protocol version, message type and message size*/ + + // Create basic buffer for receiving first part of messages + // std::vector headerBuffer(HeaderSize); + std::vector messageBuffer; + + // Receive the header data + /*if (!_socket->get(headerBuffer.data(), HeaderSize)) { + LERROR("Failed to read header from socket. Disconnecting."); + throw ConnectionLostError(); + } + + // Make sure that header matches this version of OpenSpace + if (!(headerBuffer[0] == 'O' && headerBuffer[1] && 'S')) { + LERROR("Expected to read message header 'OS' from socket."); + throw ConnectionLostError(); + } + + size_t offset = 2; + const uint32_t protocolVersionIn = + *reinterpret_cast(headerBuffer.data() + offset); + offset += sizeof(uint32_t); + + if (protocolVersionIn != ProtocolVersion) { + LERROR(fmt::format( + "Protocol versions do not match. Remote version: {}, Local version: {}", + protocolVersionIn, + ProtocolVersion + )); + throw ConnectionLostError(); + } + + const uint32_t messageTypeIn = + *reinterpret_cast(headerBuffer.data() + offset); + offset += sizeof(uint32_t); + + const uint32_t messageSizeIn = + *reinterpret_cast(headerBuffer.data() + offset); + offset += sizeof(uint32_t);*/ + + const size_t messageSize = 2; + + // Receive the payload + messageBuffer.resize(messageSize); + if (!_socket->get(messageBuffer.data(), messageSize)) { + LERROR("Failed to read message from socket. Disconnecting."); + throw ConnectionLostError(); + } + + // And delegate decoding depending on type + return Message(MessageType::Data, messageBuffer); + } + // Server void SoftwareIntegrationModule::start(int port) { _socketServer.listen(port); - _socketServer.awaitPendingTcpSocket(); - //_serverThread = std::thread([this]() { handleNewPeers(); }); - } - // Server - void SoftwareIntegrationModule::setDefaultHostAddress(std::string defaultHostAddress) { - std::lock_guard lock(_hostInfoMutex); - _defaultHostAddress = std::move(defaultHostAddress); - } - - // Server - std::string SoftwareIntegrationModule::defaultHostAddress() const { - std::lock_guard lock(_hostInfoMutex); - return _defaultHostAddress; + _serverThread = std::thread([this]() { handleNewPeers(); }); + _eventLoopThread = std::thread([this]() { eventLoop(); }); } // Server @@ -129,8 +186,8 @@ namespace openspace { }); auto it = _peers.emplace(p->id, p); it.first->second->thread = std::thread([this, id]() { - // handlePeer(id); - }); + handlePeer(id); + }); } } @@ -144,6 +201,68 @@ namespace openspace { return it->second; } + void SoftwareIntegrationModule::handlePeer(size_t id) { + while (!_shouldStop) { + std::shared_ptr p = peer(id); + if (!p) { + return; + } + + if (!p->connection.isConnectedOrConnecting()) { + return; + } + try { + Connection::Message m = p->connection.receiveMessage(); + _incomingMessages.push({ id, m }); + } + catch (const Connection::ConnectionLostError&) { + LERROR(fmt::format("Connection lost to {}", p->id)); + _incomingMessages.push({ + id, + Connection::Message( + Connection::MessageType::Disconnection, std::vector() + ) + }); + return; + } + } + } + + void SoftwareIntegrationModule::eventLoop() { + while (!_shouldStop) { + PeerMessage pm = _incomingMessages.pop(); + handlePeerMessage(std::move(pm)); + } + } + + void SoftwareIntegrationModule::handlePeerMessage(PeerMessage peerMessage) { + const size_t peerId = peerMessage.peerId; + auto it = _peers.find(peerId); + if (it == _peers.end()) { + return; + } + + std::shared_ptr& peer = it->second; + + const Connection::MessageType messageType = peerMessage.message.type; + std::vector& data = peerMessage.message.content; + std::string sData(data.begin(), data.end()); + switch (messageType) { + case Connection::MessageType::Data: + //handleData(*peer, std::move(data)); + LERROR(fmt::format("Peer message: {}", sData)); + break; + case Connection::MessageType::Disconnection: + disconnect(*peer); + break; + default: + LERROR(fmt::format( + "Unsupported message type: {}", static_cast(messageType) + )); + break; + } + } + // Server bool SoftwareIntegrationModule::isConnected(const Peer& peer) const { return peer.status != Connection::Status::Connecting && @@ -153,19 +272,7 @@ namespace openspace { // Server void SoftwareIntegrationModule::disconnect(Peer& peer) { if (isConnected(peer)) { - //nConnections() - 1; - } - - size_t hostPeerId = 0; - { - std::lock_guard lock(_hostInfoMutex); - hostPeerId = _hostPeerId; - } - - // Make sure any disconnecting host is first degraded to client, - // in order to notify other clients about host disconnection. - if (peer.id == hostPeerId) { - //setToClient(peer); + _nConnections = nConnections() - 1; } peer.connection.disconnect(); @@ -173,6 +280,10 @@ namespace openspace { _peers.erase(peer.id); } + size_t SoftwareIntegrationModule::nConnections() const { + return _nConnections; + } + std::vector SoftwareIntegrationModule::documentations() const { return { RenderablePointsCloud::Documentation(), diff --git a/modules/softwareintegration/softwareintegrationmodule.h b/modules/softwareintegration/softwareintegrationmodule.h index f1662c502f..b5574fa024 100644 --- a/modules/softwareintegration/softwareintegrationmodule.h +++ b/modules/softwareintegration/softwareintegrationmodule.h @@ -47,6 +47,27 @@ public: Connecting }; + enum class MessageType : uint32_t { + Authentication = 0, + Data, + ConnectionStatus, + NConnections, + Disconnection + }; + + struct Message { + Message() = default; + Message(MessageType t, std::vector c); + + MessageType type; + std::vector content; + }; + + class ConnectionLostError : public ghoul::RuntimeError { + public: + explicit ConnectionLostError(); + }; + Connection(std::unique_ptr socket); // Connection @@ -54,6 +75,10 @@ public: void disconnect(); ghoul::io::TcpSocket* socket(); + Connection::Message receiveMessage(); + + static const unsigned int ProtocolVersion; + private: // Connection std::unique_ptr _socket; @@ -69,15 +94,12 @@ public: // Server void start(int port); - void setDefaultHostAddress(std::string defaultHostAddress); - std::string defaultHostAddress() const; void stop(); - //size_t nConnections() const; + size_t nConnections() const; std::vector documentations() const override; scripting::LuaLibrary luaLibrary() const override; - static const unsigned int ProtocolVersion; private: // Server struct Peer { @@ -88,26 +110,28 @@ private: std::thread thread; }; + struct PeerMessage { + size_t peerId; + Connection::Message message; + }; + // Server bool isConnected(const Peer& peer) const; void disconnect(Peer& peer); - //void setName(Peer& peer, std::string name); - //void assignHost(std::shared_ptr newHost); - //void handleDisconnection(std::shared_ptr peer); void handleNewPeers(); + void eventLoop(); std::shared_ptr peer(size_t id); - //void handlePeer(size_t id); + void handlePeer(size_t id); + void handlePeerMessage(PeerMessage peerMessage); std::unordered_map> _peers; mutable std::mutex _peerListMutex; std::thread _serverThread; + std::thread _eventLoopThread; ghoul::io::TcpSocketServer _socketServer; size_t _nextConnectionId = 1; std::atomic_bool _shouldStop = false; std::atomic_size_t _nConnections = 0; - std::atomic_size_t _hostPeerId = 0; - mutable std::mutex _hostInfoMutex; - std::string _hostName; - std::string _defaultHostAddress; + ConcurrentQueue _incomingMessages; void internalInitialize(const ghoul::Dictionary&) override; void internalDeinitializeGL() override;