From 2d4fa917ada2dd24ecfb213703a7d0bb4db1a918 Mon Sep 17 00:00:00 2001 From: Denis Slynko Date: Fri, 6 Oct 2023 16:47:27 +0300 Subject: [PATCH] Refactor websockets implementation Change-Id: Ie4c159938456d5cf2d0c88c97bc93bc4e9dba1e4 Signed-off-by: Denis Slynko --- .vscode/c_cpp_properties.json | 16 + .vscode/settings.json | 55 ++ inspector/BUILD.gn | 4 +- inspector/connect_server.cpp | 4 +- inspector/connect_server.h | 5 +- inspector/ws_server.cpp | 10 +- inspector/ws_server.h | 5 +- tooling/client/BUILD.gn | 2 +- tooling/client/session/session.h | 8 +- tooling/client/utils/cli_command.h | 1 - tooling/client/websocket/websocket_client.cpp | 442 ------------- tooling/test/client_utils/test_util.cpp | 2 +- tooling/test/client_utils/test_util.h | 2 +- websocket/BUILD.gn | 94 ++- websocket/client/websocket_client.cpp | 272 ++++++++ .../client}/websocket_client.h | 84 ++- websocket/define.h | 97 ++- websocket/frame_builder.cpp | 148 +++++ websocket/frame_builder.h | 80 +++ websocket/handshake_helper.cpp | 58 ++ websocket/handshake_helper.h | 121 ++++ websocket/http.cpp | 122 ++++ websocket/http.h | 62 ++ websocket/network.cpp | 82 +++ websocket/network.h | 43 ++ websocket/server/websocket_server.cpp | 310 +++++++++ websocket/server/websocket_server.h | 78 +++ websocket/string_utils.h | 49 ++ websocket/test/BUILD.gn | 6 +- websocket/test/frame_builder_test.cpp | 176 +++++ websocket/test/http_decoder_test.cpp | 84 +++ websocket/test/web_socket_frame_test.cpp | 45 ++ websocket/test/websocket_test.cpp | 338 +--------- websocket/web_socket_frame.h | 73 +++ websocket/websocket.cpp | 612 ------------------ websocket/websocket.h | 106 --- websocket/websocket_base.cpp | 258 ++++++++ websocket/websocket_base.h | 136 ++++ 38 files changed, 2471 insertions(+), 1619 deletions(-) create mode 100644 .vscode/c_cpp_properties.json create mode 100644 .vscode/settings.json delete mode 100644 tooling/client/websocket/websocket_client.cpp create mode 100644 websocket/client/websocket_client.cpp rename {tooling/client/websocket => websocket/client}/websocket_client.h (49%) create mode 100644 websocket/frame_builder.cpp create mode 100644 websocket/frame_builder.h create mode 100644 websocket/handshake_helper.cpp create mode 100644 websocket/handshake_helper.h create mode 100644 websocket/http.cpp create mode 100644 websocket/http.h create mode 100644 websocket/network.cpp create mode 100644 websocket/network.h create mode 100644 websocket/server/websocket_server.cpp create mode 100644 websocket/server/websocket_server.h create mode 100644 websocket/string_utils.h create mode 100644 websocket/test/frame_builder_test.cpp create mode 100644 websocket/test/http_decoder_test.cpp create mode 100644 websocket/test/web_socket_frame_test.cpp create mode 100644 websocket/web_socket_frame.h delete mode 100644 websocket/websocket.cpp delete mode 100644 websocket/websocket.h create mode 100644 websocket/websocket_base.cpp create mode 100644 websocket/websocket_base.h diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 00000000..07c08df7 --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,16 @@ +{ + "configurations": [ + { + "name": "Linux", + "includePath": [ + "${workspaceFolder}/**" + ], + "defines": [], + "compilerPath": "/usr/bin/clang", + "cStandard": "c17", + "cppStandard": "c++17", + "intelliSenseMode": "linux-clang-x64" + } + ], + "version": 4 +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..bb4ace7e --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,55 @@ +{ + "files.associations": { + "array": "cpp", + "bitset": "cpp", + "string_view": "cpp", + "initializer_list": "cpp", + "ranges": "cpp", + "span": "cpp", + "utility": "cpp", + "shared_mutex": "cpp", + "map": "cpp", + "atomic": "cpp", + "bit": "cpp", + "*.tcc": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "clocale": "cpp", + "compare": "cpp", + "concepts": "cpp", + "condition_variable": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdlib": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "unordered_map": "cpp", + "exception": "cpp", + "algorithm": "cpp", + "iterator": "cpp", + "memory": "cpp", + "memory_resource": "cpp", + "optional": "cpp", + "ratio": "cpp", + "string": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "iosfwd": "cpp", + "istream": "cpp", + "limits": "cpp", + "new": "cpp", + "ostream": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "stop_token": "cpp", + "streambuf": "cpp", + "thread": "cpp", + "typeinfo": "cpp", + "semaphore": "cpp", + "cmath": "cpp", + "cstdio": "cpp", + "numbers": "cpp" + } +} \ No newline at end of file diff --git a/inspector/BUILD.gn b/inspector/BUILD.gn index e30db975..caa22453 100644 --- a/inspector/BUILD.gn +++ b/inspector/BUILD.gn @@ -32,7 +32,7 @@ ohos_source_set("ark_debugger_static") { libs = [ "log" ] } - deps += [ "../websocket:websocket" ] + deps += [ "../websocket:websocket_server" ] sources = [ "../common/log_wrapper.cpp", "inspector.cpp", @@ -81,7 +81,7 @@ ohos_source_set("connectserver_debugger_static") { include_dirs = [ "$toolchain_root" ] - deps += [ "../websocket:websocket" ] + deps += [ "../websocket:websocket_server" ] sources = [ "../common/log_wrapper.cpp", "connect_inspector.cpp", diff --git a/inspector/connect_server.cpp b/inspector/connect_server.cpp index 3e40c00f..a8a17c7a 100644 --- a/inspector/connect_server.cpp +++ b/inspector/connect_server.cpp @@ -25,7 +25,7 @@ std::shared_mutex g_sendMutex; void ConnectServer::RunServer() { terminateExecution_ = false; - webSocket_ = std::make_unique(); + webSocket_ = std::make_unique(); tid_ = pthread_self(); int appPid = getpid(); std::string pidStr = std::to_string(appPid); @@ -37,7 +37,7 @@ void ConnectServer::RunServer() #endif while (!terminateExecution_) { #if defined(OHOS_PLATFORM) - if (!webSocket_->ConnectUnixWebSocket()) { + if (!webSocket_->AcceptNewConnection()) { return; } #endif diff --git a/inspector/connect_server.h b/inspector/connect_server.h index bd94fe07..6f96be09 100644 --- a/inspector/connect_server.h +++ b/inspector/connect_server.h @@ -18,7 +18,8 @@ #include #include -#include "websocket/websocket.h" +#include +#include "websocket/server/websocket_server.h" #ifdef WINDOWS_PLATFORM #include #endif @@ -39,7 +40,7 @@ private: std::string bundleName_; pthread_t tid_ {0}; std::function wsOnMessage_ {}; - std::unique_ptr webSocket_ { nullptr }; + std::unique_ptr webSocket_ { nullptr }; }; } // namespace OHOS::ArkCompiler::Toolchain diff --git a/inspector/ws_server.cpp b/inspector/ws_server.cpp index 565bf707..24fca952 100644 --- a/inspector/ws_server.cpp +++ b/inspector/ws_server.cpp @@ -34,7 +34,7 @@ void WsServer::RunServer() LOGE("WsServer has been terminated unexpectedly"); return; } - webSocket_ = std::make_unique(); + webSocket_ = std::make_unique(); #if !defined(OHOS_PLATFORM) LOGI("WsSever Runsever: Init tcp websocket %{public}d", port_); if (!webSocket_->InitTcpWebSocket(port_)) { @@ -57,15 +57,9 @@ void WsServer::RunServer() #endif } while (!terminateExecution_) { -#if !defined(OHOS_PLATFORM) - if (!webSocket_->ConnectTcpWebSocket()) { + if (!webSocket_->AcceptNewConnection()) { return; } -#else - if (!webSocket_->ConnectUnixWebSocket()) { - return; - } -#endif while (webSocket_->IsConnected()) { std::string message = webSocket_->Decode(); if (!message.empty() && webSocket_->IsDecodeDisconnectMsg(message)) { diff --git a/inspector/ws_server.h b/inspector/ws_server.h index ed76fabd..44d07ca1 100644 --- a/inspector/ws_server.h +++ b/inspector/ws_server.h @@ -18,12 +18,13 @@ #include #include +#include #include #ifdef WINDOWS_PLATFORM #include #endif -#include "websocket/websocket.h" +#include "websocket/server/websocket_server.h" namespace OHOS::ArkCompiler::Toolchain { class WsServer { @@ -46,7 +47,7 @@ private: std::mutex wsMutex_; std::string componentName_ {}; std::function wsOnMessage_ {}; - std::unique_ptr webSocket_ { nullptr }; + std::unique_ptr webSocket_ { nullptr }; [[maybe_unused]] int port_ = -1; }; } // namespace OHOS::ArkCompiler::Toolchain diff --git a/tooling/client/BUILD.gn b/tooling/client/BUILD.gn index c01354cc..6f0c6a83 100644 --- a/tooling/client/BUILD.gn +++ b/tooling/client/BUILD.gn @@ -43,13 +43,13 @@ ohos_source_set("libark_client_set") { "session/session.cpp", "utils/cli_command.cpp", "utils/utils.cpp", - "websocket/websocket_client.cpp", ] deps += [ "$ark_third_party_root/libuv:uv", "$ark_third_party_root/openssl:libcrypto_shared", "..:libark_ecma_debugger", + "$toolchain_root/websocket:websocket_client", sdk_libc_secshared_dep, ] diff --git a/tooling/client/session/session.h b/tooling/client/session/session.h index aeaba1ac..267bbe80 100755 --- a/tooling/client/session/session.h +++ b/tooling/client/session/session.h @@ -30,7 +30,7 @@ #include "tooling/client/manager/stack_manager.h" #include "tooling/client/manager/variable_manager.h" #include "tooling/client/manager/watch_manager.h" -#include "tooling/client/websocket/websocket_client.h" +#include "websocket/server/websocket_server.h" namespace OHOS::ArkCompiler::Toolchain { using CmdForAllCB = std::function; @@ -61,7 +61,7 @@ public: bool ClientSendReq(const std::string &message) { - return cliSocket_.ClientSendReq(message); + return cliSocket_.SendReply(message); } DomainManager& GetDomainManager() @@ -84,7 +84,7 @@ public: return variableManager_; } - WebsocketClient& GetWebsocketClient() + WebSocketClient& GetWebSocketClient() { return cliSocket_; } @@ -118,7 +118,7 @@ private: uint32_t sessionId_; std::string sockInfo_; DomainManager domainManager_; - WebsocketClient cliSocket_; + WebSocketClient cliSocket_; uv_thread_t socketTid_; std::atomic messageId_ {1}; BreakPointManager breakpoint_; diff --git a/tooling/client/utils/cli_command.h b/tooling/client/utils/cli_command.h index 5410695e..f184a06f 100644 --- a/tooling/client/utils/cli_command.h +++ b/tooling/client/utils/cli_command.h @@ -26,7 +26,6 @@ #include "tooling/client/domain/profiler_client.h" #include "tooling/client/manager/domain_manager.h" #include "tooling/client/utils/utils.h" -#include "tooling/client/websocket/websocket_client.h" #include "tooling/client/session/session.h" namespace OHOS::ArkCompiler::Toolchain { diff --git a/tooling/client/websocket/websocket_client.cpp b/tooling/client/websocket/websocket_client.cpp deleted file mode 100644 index 9cc9f21d..00000000 --- a/tooling/client/websocket/websocket_client.cpp +++ /dev/null @@ -1,442 +0,0 @@ -/* - * Copyright (c) 2023 Huawei Device Co., Ltd. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include - -#include "common/log_wrapper.h" -#include "websocket_client.h" - - -namespace OHOS::ArkCompiler::Toolchain { -bool WebsocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit) -{ - if (socketState_ != ToolchainSocketState::UNINITED) { - LOGE("InitToolchainWebSocketForPort::client has inited."); - return true; - } - - client_ = socket(AF_INET, SOCK_STREAM, 0); - if (client_ < SOCKET_SUCCESS) { - LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s", - errno, strerror(errno)); - return false; - } - - // set send and recv timeout limit - if (!SetWebSocketTimeOut(client_, timeoutLimit)) { - LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s", - errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - - sockaddr_in clientAddr; - if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) { - LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s", - errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - clientAddr.sin_family = AF_INET; - clientAddr.sin_port = htons(port); - if (int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr) < NET_SUCCESS) { - LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s", - errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - - int ret = connect(client_, reinterpret_cast(&clientAddr), sizeof(clientAddr)); - if (ret != SOCKET_SUCCESS) { - LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s", - errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - socketState_ = ToolchainSocketState::INITED; - LOGE("InitToolchainWebSocketForPort::client connect success."); - return true; -} - -bool WebsocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit) -{ - if (socketState_ != ToolchainSocketState::UNINITED) { - LOGE("InitToolchainWebSocketForSockName::client has inited."); - return true; - } - - client_ = socket(AF_UNIX, SOCK_STREAM, 0); - if (client_ < SOCKET_SUCCESS) { - LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s", - errno, strerror(errno)); - return false; - } - - // set send and recv timeout limit - if (!SetWebSocketTimeOut(client_, timeoutLimit)) { - LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, error = %{public}d ,\ - desc = %{public}s", errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - - struct sockaddr_un serverAddr; - if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) { - LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, error = %{public}d,\ - desc = %{public}s", errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - serverAddr.sun_family = AF_UNIX; - if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) { - LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, error = %{public}d,\ - desc = %{public}s", errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - serverAddr.sun_path[0] = '\0'; - - uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1; - int ret = connect(client_, reinterpret_cast(&serverAddr), static_cast(len)); - if (ret != SOCKET_SUCCESS) { - LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s", - errno, strerror(errno)); - close(client_); - client_ = -1; - return false; - } - socketState_ = ToolchainSocketState::INITED; - LOGE("InitToolchainWebSocketForSockName::client connect success."); - return true; -} - -bool WebsocketClient::ClientSendWSUpgradeReq() -{ - if (socketState_ == ToolchainSocketState::UNINITED) { - LOGE("ClientSendWSUpgradeReq::client has not inited."); - return false; - } - if (socketState_ == ToolchainSocketState::CONNECTED) { - LOGE("ClientSendWSUpgradeReq::client has connected."); - return true; - } - - int msgLen = strlen(CLIENT_WEBSOCKET_UPGRADE_REQ); - int32_t sendLen = send(client_, CLIENT_WEBSOCKET_UPGRADE_REQ, msgLen, 0); - if (sendLen != msgLen) { - LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn", - errno, strerror(errno)); - socketState_ = ToolchainSocketState::UNINITED; - close(client_); - client_ = -1; - return false; - } - LOGE("ClientSendWSUpgradeReq::client send wsupgrade req success."); - return true; -} - -bool WebsocketClient::ClientRecvWSUpgradeRsp() -{ - if (socketState_ == ToolchainSocketState::UNINITED) { - LOGE("ClientRecvWSUpgradeRsp::client has not inited."); - return false; - } - if (socketState_ == ToolchainSocketState::CONNECTED) { - LOGE("ClientRecvWSUpgradeRsp::client has connected."); - return true; - } - - char recvBuf[CLIENT_WEBSOCKET_UPGRADE_RSP_LEN + 1] = {0}; - int32_t bufLen = recv(client_, recvBuf, CLIENT_WEBSOCKET_UPGRADE_RSP_LEN, 0); - if (bufLen != CLIENT_WEBSOCKET_UPGRADE_RSP_LEN) { - LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn", - errno, strerror(errno)); - socketState_ = ToolchainSocketState::UNINITED; - close(client_); - client_ = -1; - return false; - } - socketState_ = ToolchainSocketState::CONNECTED; - LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success."); - return true; -} - -bool WebsocketClient::ClientSendReq(const std::string &message) -{ - if (socketState_ != ToolchainSocketState::CONNECTED) { - LOGE("ClientSendReq::client has not connected."); - return false; - } - - uint32_t msgLen = message.length(); - std::unique_ptr msgBuf = std::make_unique(msgLen + 15); // 15: the maximum expand length - char *sendBuf = msgBuf.get(); - uint32_t sendMsgLen = 0; - sendBuf[0] = 0x81; // 0x81: the text message sent by the server should start with '0x81'. - uint32_t mask = 1; - // Depending on the length of the messages, client will use shift operation to get the res - // and store them in the buffer. - if (msgLen <= 125) { // 125: situation 1 when message's length <= 125 - sendBuf[1] = msgLen | (mask << 7); // 7: mask need shift left by 7 bits - sendMsgLen = 2; // 2: the length of header frame is 2; - } else if (msgLen < 65536) { // 65536: message's length - sendBuf[1] = 126 | (mask << 7); // 126: payloadLen according to the spec; 7: mask shift left by 7 bits - sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1) - sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0) - sendMsgLen = 4; // 4: the length of header frame is 4 - } else { - sendBuf[1] = 127 | (mask << 7); // 127: payloadLen according to the spec; 7: mask shift left by 7 bits - for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits - sendBuf[i] = 0; - } - sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3) - sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2) - sendBuf[8] = ((msgLen & 0x0000ff00) >> 8); // 8: shift 8 bits => res * (256^1) - sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0) - sendMsgLen = 10; // 10: the length of header frame is 10 - } - - if (memcpy_s(sendBuf + sendMsgLen, SOCKET_MASK_LEN, MASK_KEY, SOCKET_MASK_LEN) != EOK) { - LOGE("ClientSendReq::client memcpy_s MASK_KEY failed, error = %{public}d, desc = %{public}s", - errno, strerror(errno)); - return false; - } - sendMsgLen += SOCKET_MASK_LEN; - - std::string maskMessage; - for (uint64_t i = 0; i < msgLen; i++) { - uint64_t j = i % SOCKET_MASK_LEN; - maskMessage.push_back(message[i] ^ MASK_KEY[j]); - } - if (memcpy_s(sendBuf + sendMsgLen, msgLen, maskMessage.c_str(), msgLen) != EOK) { - LOGE("ClientSendReq::client memcpy_s maskMessage failed, error = %{public}d, desc = %{public}s", - errno, strerror(errno)); - return false; - } - msgBuf[sendMsgLen + msgLen] = '\0'; - - if (send(client_, sendBuf, sendMsgLen + msgLen, 0) != static_cast(sendMsgLen + msgLen)) { - LOGE("ClientSendReq::client send msg req failed, error = %{public}d, desc = %{public}s", - errno, strerror(errno)); - return false; - } - LOGE("ClientRecvWSUpgradeRsp::client send msg req success."); - return true; -} - -std::string WebsocketClient::Decode() -{ - if (socketState_ != ToolchainSocketState::CONNECTED) { - LOGE("WebsocketClient:Decode failed, websocket not connected!"); - return ""; - } - char recvbuf[SOCKET_HEADER_LEN + 1]; - errno = 0; - if (!Recv(client_, recvbuf, SOCKET_HEADER_LEN, 0)) { - if (errno != EAGAIN) { - LOGE("WebsocketClient:Decode failed, client websocket disconnect"); - socketState_ = ToolchainSocketState::INITED; - close(client_); - client_ = -1; - } - return ""; - } - ToolchainWebSocketFrame wsFrame; - int32_t index = 0; - wsFrame.fin = static_cast(recvbuf[index] >> 7); // 7: shift right by 7 bits to get the fin - wsFrame.opcode = static_cast(recvbuf[index] & 0xf); - if (wsFrame.opcode == 0x1) { // 0x1: 0x1 means a text frame - index++; - wsFrame.mask = static_cast((recvbuf[index] >> 7) & 0x1); // 7: to get the mask - wsFrame.payloadLen = recvbuf[index] & 0x7f; - if (HandleFrame(wsFrame)) { - return wsFrame.payload.get(); - } - return ""; - } else if (wsFrame.opcode == 0x9) { // 0x9: 0x9 means a ping frame - // send pong frame - char pongFrame[SOCKET_HEADER_LEN] = {0}; - pongFrame[0] = 0x8a; // 0x8a: 0x8a means a pong frame - pongFrame[1] = 0x0; - if (!Send(client_, pongFrame, SOCKET_HEADER_LEN, 0)) { - LOGE("WebsocketClient Decode: Send pong frame failed"); - return ""; - } - } - return ""; -} - -bool WebsocketClient::HandleFrame(ToolchainWebSocketFrame& wsFrame) -{ - if (wsFrame.payloadLen == 126) { // 126: the payloadLen read from frame - char recvbuf[PAYLOAD_LEN + 1] = {0}; - if (!Recv(client_, recvbuf, PAYLOAD_LEN, 0)) { - LOGE("WebsocketClient HandleFrame: Recv payloadLen == 126 failed"); - return false; - } - - uint16_t msgLen = 0; - if (memcpy_s(&msgLen, sizeof(recvbuf), recvbuf, sizeof(recvbuf) - 1) != EOK) { - LOGE("WebsocketClient HandleFrame: memcpy_s failed"); - return false; - } - wsFrame.payloadLen = ntohs(msgLen); - } else if (wsFrame.payloadLen > 126) { // 126: the payloadLen read from frame - char recvbuf[EXTEND_PAYLOAD_LEN + 1] = {0}; - if (!Recv(client_, recvbuf, EXTEND_PAYLOAD_LEN, 0)) { - LOGE("WebsocketClient HandleFrame: Recv payloadLen > 127 failed"); - return false; - } - wsFrame.payloadLen = NetToHostLongLong(recvbuf, EXTEND_PAYLOAD_LEN); - } - return DecodeMessage(wsFrame); -} - -bool WebsocketClient::DecodeMessage(ToolchainWebSocketFrame& wsFrame) -{ - if (wsFrame.payloadLen == 0 || wsFrame.payloadLen > UINT64_MAX) { - LOGE("WebsocketClient:ReadMsg length error, expected greater than zero and less than UINT64_MAX"); - return false; - } - uint64_t msgLen = wsFrame.payloadLen; - wsFrame.payload = std::make_unique(msgLen + 1); - if (wsFrame.mask == 1) { - char buf[msgLen + 1]; - if (!Recv(client_, wsFrame.maskingkey, SOCKET_MASK_LEN, 0)) { - LOGE("WebsocketClient DecodeMessage: Recv maskingkey failed"); - return false; - } - - if (!Recv(client_, buf, msgLen, 0)) { - LOGE("WebsocketClient DecodeMessage: Recv message with mask failed"); - return false; - } - - for (uint64_t i = 0; i < msgLen; i++) { - uint64_t j = i % SOCKET_MASK_LEN; - wsFrame.payload.get()[i] = buf[i] ^ wsFrame.maskingkey[j]; - } - } else { - char buf[msgLen + 1]; - if (!Recv(client_, buf, msgLen, 0)) { - LOGE("WebsocketClient DecodeMessage: Recv message without mask failed"); - return false; - } - - if (memcpy_s(wsFrame.payload.get(), msgLen, buf, msgLen) != EOK) { - LOGE("WebsocketClient DecodeMessage: memcpy_s failed"); - return false; - } - } - wsFrame.payload.get()[msgLen] = '\0'; - return true; -} - -uint64_t WebsocketClient::NetToHostLongLong(char* buf, uint32_t len) -{ - uint64_t result = 0; - for (uint32_t i = 0; i < len; i++) { - result |= static_cast(buf[i]); - if ((i + 1) < len) { - result <<= 8; // 8: result need shift left 8 bits in order to big endian convert to int - } - } - return result; -} - -bool WebsocketClient::Send(int32_t fd, const char* buf, size_t totalLen, int32_t flags) const -{ - size_t sendLen = 0; - while (sendLen < totalLen) { - ssize_t len = send(fd, buf + sendLen, totalLen - sendLen, flags); - if (len <= 0) { - LOGE("WebsocketClient Send Message in while failed, WebsocketClient disconnect"); - return false; - } - sendLen += static_cast(len); - } - return true; -} - -bool WebsocketClient::Recv(int32_t fd, char* buf, size_t totalLen, int32_t flags) const -{ - size_t recvLen = 0; - while (recvLen < totalLen) { - ssize_t len = recv(fd, buf + recvLen, totalLen - recvLen, flags); - if (len <= 0) { - LOGE("WebsocketClient Recv payload in while failed, WebsocketClient disconnect"); - return false; - } - recvLen += static_cast(len); - } - buf[totalLen] = '\0'; - return true; -} - -void WebsocketClient::Close() -{ - if (socketState_ == ToolchainSocketState::UNINITED) { - return; - } - socketState_ = ToolchainSocketState::UNINITED; - close(client_); - client_ = -1; -} - -bool WebsocketClient::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit) -{ - if (timeoutLimit > 0) { - struct timeval timeout = {timeoutLimit, 0}; - if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { - LOGE("WebsocketClient:SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed"); - return false; - } - if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { - LOGE("WebsocketClient:SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed"); - return false; - } - } - return true; -} - -bool WebsocketClient::IsConnected() -{ - return socketState_ == ToolchainSocketState::CONNECTED; -} - -std::string WebsocketClient::GetSocketStateString() -{ - std::vector stateStr = { - "uninited", - "inited", - "connected" - }; - - return stateStr[socketState_]; -} -} // namespace OHOS::ArkCompiler::Toolchain \ No newline at end of file diff --git a/tooling/test/client_utils/test_util.cpp b/tooling/test/client_utils/test_util.cpp index db52114a..9c4d5fef 100644 --- a/tooling/test/client_utils/test_util.cpp +++ b/tooling/test/client_utils/test_util.cpp @@ -107,7 +107,7 @@ void TestUtil::ForkSocketClient([[maybe_unused]] int port, const std::string &na int ret = SessionManager::getInstance().CreateTestSession(sockInfo); LOG_ECMA_IF(ret, FATAL) << "CreateTestSession fail"; - WebsocketClient &client = SessionManager::getInstance().GetCurrentSession()->GetWebsocketClient(); + WebSocketClient &client = SessionManager::getInstance().GetCurrentSession()->GetWebSocketClient(); auto &testAction = TestUtil::GetTest(name)->testAction; for (const auto &action: testAction) { LOG_DEBUGGER(INFO) << "message: " << action.message; diff --git a/tooling/test/client_utils/test_util.h b/tooling/test/client_utils/test_util.h index f7839cce..1f6f2184 100644 --- a/tooling/test/client_utils/test_util.h +++ b/tooling/test/client_utils/test_util.h @@ -19,7 +19,7 @@ #include "tooling/test/client_utils/test_actions.h" #include "tooling/client/manager/domain_manager.h" -#include "tooling/client/websocket/websocket_client.h" +#include "websocket/server/websocket_server.h" #include "ecmascript/jspandafile/js_pandafile_manager.h" #include "ecmascript/debugger/js_debugger.h" #include "os/mutex.h" diff --git a/websocket/BUILD.gn b/websocket/BUILD.gn index 32d3c7b7..43b5995f 100644 --- a/websocket/BUILD.gn +++ b/websocket/BUILD.gn @@ -13,46 +13,98 @@ import("//arkcompiler/toolchain/toolchain.gni") -ohos_source_set("websocket") { - stack_protector_ret = false +config("websocket_config") { defines = [] - deps = [] - configs = [ sdk_libc_secshared_config ] + include_dirs = [ + "$toolchain_root/inspector", + "//utils/native/base/include", + "//third_party/openssl/include", + ] + if (is_mingw || is_mac) { cflags = [ "-std=c++17" ] } + cflags_cc = [ "-Wno-vla-extension" ] +} + +websocket_configs = [ + sdk_libc_secshared_config, + "..:ark_toolchain_common_config", + ":websocket_config", +] + +websocket_deps = [ sdk_libc_secshared_dep ] +websocket_deps += hiviewdfx_deps +if (is_arkui_x && target_os == "ios") { + websocket_deps += [ "$ark_third_party_root/openssl:libcrypto_static" ] +} else if (is_mingw) { + websocket_deps += [ "$ark_third_party_root/openssl:libcrypto_restool" ] +} else { + websocket_deps += [ "$ark_third_party_root/openssl:libcrypto_shared" ] +} + +ohos_source_set("websocket_base") { + stack_protector_ret = false + + configs = websocket_configs + # hiviewdfx libraries external_deps = hiviewdfx_ext_deps - deps += hiviewdfx_deps + deps = websocket_deps if (target_os == "android" && !ark_standalone_build) { libs = [ "log" ] } - include_dirs = [] - - include_dirs += [ - "$toolchain_root/inspector", - "//utils/native/base/include", - "//third_party/openssl/include", + sources = [ + "frame_builder.cpp", + "handshake_helper.cpp", + "http.cpp", + "network.cpp", + "websocket_base.cpp", ] - sources = [ "websocket.cpp" ] + subsystem_name = "arkcompiler" + part_name = "toolchain" +} - deps += [ sdk_libc_secshared_dep ] - if (is_arkui_x && target_os == "ios") { - deps += [ "$ark_third_party_root/openssl:libcrypto_static" ] - } else if (is_mingw) { - deps += [ "$ark_third_party_root/openssl:libcrypto_restool" ] - } else { - deps += [ "$ark_third_party_root/openssl:libcrypto_shared" ] +ohos_source_set("websocket_server") { + stack_protector_ret = false + + configs = websocket_configs + + # hiviewdfx libraries + external_deps = hiviewdfx_ext_deps + deps = websocket_deps + deps += [ ":websocket_base" ] + + if (target_os == "android" && !ark_standalone_build) { + libs = [ "log" ] } - configs += [ "..:ark_toolchain_common_config" ] + sources = [ "server/websocket_server.cpp" ] - cflags_cc = [ "-Wno-vla-extension" ] + subsystem_name = "arkcompiler" + part_name = "toolchain" +} + +ohos_source_set("websocket_client") { + stack_protector_ret = false + + configs = websocket_configs + + # hiviewdfx libraries + external_deps = hiviewdfx_ext_deps + deps = websocket_deps + deps += [ ":websocket_base" ] + + if (target_os == "android" && !ark_standalone_build) { + libs = [ "log" ] + } + + sources = [ "client/websocket_client.cpp" ] subsystem_name = "arkcompiler" part_name = "toolchain" diff --git a/websocket/client/websocket_client.cpp b/websocket/client/websocket_client.cpp new file mode 100644 index 00000000..1a3530cd --- /dev/null +++ b/websocket/client/websocket_client.cpp @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2023 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/log_wrapper.h" +#include "websocket/client/websocket_client.h" +#include "websocket/frame_builder.h" +#include "websocket/handshake_helper.h" +#include "websocket/network.h" +#include "websocket/string_utils.h" + +#include +#include +#include +#include +#include + +namespace OHOS::ArkCompiler::Toolchain { +bool WebSocketClient::InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit) +{ + if (socketState_ != SocketState::UNINITED) { + LOGE("InitToolchainWebSocketForPort::client has inited."); + return true; + } + + connectionFd_ = socket(AF_INET, SOCK_STREAM, 0); + if (connectionFd_ < SOCKET_SUCCESS) { + LOGE("InitToolchainWebSocketForPort::client socket failed, error = %{public}d , desc = %{public}s", + errno, strerror(errno)); + return false; + } + + // set send and recv timeout limit + if (!SetWebSocketTimeOut(connectionFd_, timeoutLimit)) { + LOGE("InitToolchainWebSocketForPort::client SetWebSocketTimeOut failed, error = %{public}d , desc = %{public}s", + errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + + sockaddr_in clientAddr; + if (memset_s(&clientAddr, sizeof(clientAddr), 0, sizeof(clientAddr)) != EOK) { + LOGE("InitToolchainWebSocketForPort::client memset_s clientAddr failed, error = %{public}d, desc = %{public}s", + errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + clientAddr.sin_family = AF_INET; + clientAddr.sin_port = htons(port); + int ret = inet_pton(AF_INET, "127.0.0.1", &clientAddr.sin_addr); + if (ret != NET_SUCCESS) { + LOGE("InitToolchainWebSocketForPort::client inet_pton failed, error = %{public}d, desc = %{public}s", + errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + + ret = connect(connectionFd_, reinterpret_cast(&clientAddr), sizeof(clientAddr)); + if (ret != SOCKET_SUCCESS) { + LOGE("InitToolchainWebSocketForPort::client connect failed, error = %{public}d, desc = %{public}s", + errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + socketState_ = SocketState::INITED; + LOGI("InitToolchainWebSocketForPort::client connect success."); + return true; +} + +bool WebSocketClient::InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit) +{ + if (socketState_ != SocketState::UNINITED) { + LOGE("InitToolchainWebSocketForSockName::client has inited."); + return true; + } + + connectionFd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (connectionFd_ < SOCKET_SUCCESS) { + LOGE("InitToolchainWebSocketForSockName::client socket failed, error = %{public}d , desc = %{public}s", + errno, strerror(errno)); + return false; + } + + // set send and recv timeout limit + if (!SetWebSocketTimeOut(connectionFd_, timeoutLimit)) { + LOGE("InitToolchainWebSocketForSockName::client SetWebSocketTimeOut failed, error = %{public}d ,\ + desc = %{public}s", errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + + struct sockaddr_un serverAddr; + if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) { + LOGE("InitToolchainWebSocketForSockName::client memset_s clientAddr failed, error = %{public}d,\ + desc = %{public}s", errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + serverAddr.sun_family = AF_UNIX; + if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) { + LOGE("InitToolchainWebSocketForSockName::client strcpy_s serverAddr.sun_path failed, error = %{public}d,\ + desc = %{public}s", errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + serverAddr.sun_path[0] = '\0'; + + uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1; + int ret = connect(connectionFd_, reinterpret_cast(&serverAddr), static_cast(len)); + if (ret != SOCKET_SUCCESS) { + LOGE("InitToolchainWebSocketForSockName::client connect failed, error = %{public}d, desc = %{public}s", + errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + socketState_ = SocketState::INITED; + LOGI("InitToolchainWebSocketForSockName::client connect success."); + return true; +} + +bool WebSocketClient::ClientSendWSUpgradeReq() +{ + if (socketState_ == SocketState::UNINITED) { + LOGE("ClientSendWSUpgradeReq::client has not inited."); + return false; + } + if (socketState_ == SocketState::CONNECTED) { + LOGE("ClientSendWSUpgradeReq::client has connected."); + return true; + } + + // length without null-terminator + if (!Send(connectionFd_, CLIENT_WEBSOCKET_UPGRADE_REQ, sizeof(CLIENT_WEBSOCKET_UPGRADE_REQ) - 1, 0)) { + LOGE("ClientSendWSUpgradeReq::client send wsupgrade req failed, error = %{public}d, desc = %{public}sn", + errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + LOGI("ClientSendWSUpgradeReq::client send wsupgrade req success."); + return true; +} + +bool WebSocketClient::ClientRecvWSUpgradeRsp() +{ + if (socketState_ == SocketState::UNINITED) { + LOGE("ClientRecvWSUpgradeRsp::client has not inited."); + return false; + } + if (socketState_ == SocketState::CONNECTED) { + LOGE("ClientRecvWSUpgradeRsp::client has connected."); + return true; + } + + std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0); + auto msgLen = recv(connectionFd_, msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0); + if (msgLen <= 0) { + LOGE("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = %{public}d, desc = %{public}sn", + errno, strerror(errno)); + CloseConnectionSocketOnFail(); + return false; + } + // reduce to received size + msgBuf.resize(msgLen); + + HttpResponse response; + if (!HttpResponse::Decode(msgBuf, response) || !ValidateServerHandShake(response)) { + LOGE("ClientRecvWSUpgradeRsp::client server handshake response is invalid"); + CloseConnectionSocketOnFail(); + return false; + } + + socketState_ = SocketState::CONNECTED; + LOGI("ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success."); + return true; +} + +/* static */ +bool WebSocketClient::ValidateServerHandShake(HttpResponse& response) +{ + // in accordance to https://www.rfc-editor.org/rfc/rfc6455#section-4.1 + if (response.status != HTTP_SWITCHING_PROTOCOLS_STATUS_CODE) { + return false; + } + ToLowerCase(response.upgrade); + if (response.upgrade != HTTP_RESPONSE_REQUIRED_UPGRADE) { + return false; + } + ToLowerCase(response.connection); + if (response.connection != HTTP_RESPONSE_REQUIRED_CONNECTION) { + return false; + } + + unsigned char expectedAcceptEncoding[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1]; + // TODO: assert `EncodeKey` return value + WebSocketKeyEncoder::EncodeKey(DEFAULT_WEB_SOCKET_KEY, expectedAcceptEncoding); + Trim(response.secWebSocketAccept); + if (response.secWebSocketAccept.size() != WebSocketKeyEncoder::ENCODED_KEY_LEN || + response.secWebSocketAccept.compare(reinterpret_cast(expectedAcceptEncoding)) != 0) { + return false; + } + + // may support two remaining checks + return true; +} + +bool WebSocketClient::DecodeMessage(WebSocketFrame& wsFrame, bool &isRecvFail) const +{ + uint64_t msgLen = wsFrame.payloadLen; + if (msgLen == 0) { + // receiving empty data is OK + return true; + } + auto& buffer = wsFrame.payload; + buffer.resize(msgLen, 0); + + if (!Recv(connectionFd_, buffer, 0)) { + LOGE("DecodeMessage: Recv message without mask failed"); + SetSocketFail(isRecvFail); + return false; + } + + return true; +} + +void WebSocketClient::Close() +{ + if (socketState_ == SocketState::CONNECTED) { + CloseConnection(CloseStatusCode::SERVER_GO_AWAY, SocketState::UNINITED); + } +} + +void WebSocketClient::CloseConnectionSocketOnFail() +{ + CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::UNINITED); +} + +bool WebSocketClient::ValidateIncomingFrame(const WebSocketFrame& wsFrame) +{ + // "A server MUST NOT mask any frames that it sends to the client." + // https://www.rfc-editor.org/rfc/rfc6455#section-5.1 + return wsFrame.mask == 0; +} + +std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType) const +{ + ClientFrameBuilder builder(isLast, frameType, MASK_KEY); + return builder.Build(); +} + +std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const +{ + ClientFrameBuilder builder(isLast, frameType, MASK_KEY); + return builder.SetPayload(payload).Build(); +} + +std::string WebSocketClient::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const +{ + ClientFrameBuilder builder(isLast, frameType, MASK_KEY); + return builder.SetPayload(std::move(payload)).Build(); +} +} // namespace OHOS::ArkCompiler::Toolchain diff --git a/tooling/client/websocket/websocket_client.h b/websocket/client/websocket_client.h similarity index 49% rename from tooling/client/websocket/websocket_client.h rename to websocket/client/websocket_client.h index 3fb7f6bb..4d8bdfb3 100644 --- a/tooling/client/websocket/websocket_client.h +++ b/websocket/client/websocket_client.h @@ -13,72 +13,64 @@ * limitations under the License. */ -#ifndef ECMASCRIPT_TOOLING_CLIENT_WEBSOCKET_CLIENT_H -#define ECMASCRIPT_TOOLING_CLIENT_WEBSOCKET_CLIENT_H +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_CLIENT_WEBSOCKET_CLIENT_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_CLIENT_WEBSOCKET_CLIENT_H + +#include "websocket/http.h" +#include "websocket/websocket_base.h" #include #include -#include #include - -#include "ecmascript/log_wrapper.h" -#include "websocket/websocket.h" +#include namespace OHOS::ArkCompiler::Toolchain { -struct ToolchainWebSocketFrame { - uint8_t fin = 0; - uint8_t opcode = 0; - uint8_t mask = 0; - uint64_t payloadLen = 0; - char maskingkey[5] = {0}; - std::unique_ptr payload = nullptr; -}; -class WebsocketClient : public WebSocket { +class WebSocketClient final : public WebSocketBase { public: - enum ToolchainSocketState : uint8_t { - UNINITED, - INITED, - CONNECTED, - }; - WebsocketClient() = default; - ~WebsocketClient() = default; + ~WebSocketClient() noexcept override = default; + + bool DecodeMessage(WebSocketFrame& wsFrame, bool &isRecvFail) const override; + void Close() override; + bool InitToolchainWebSocketForPort(int port, uint32_t timeoutLimit = 5); bool InitToolchainWebSocketForSockName(const std::string &sockName, uint32_t timeoutLimit = 5); bool ClientSendWSUpgradeReq(); bool ClientRecvWSUpgradeRsp(); - bool ClientSendReq(const std::string &message); - std::string Decode(); - bool HandleFrame(ToolchainWebSocketFrame& wsFrame); - bool DecodeMessage(ToolchainWebSocketFrame& wsFrame); - uint64_t NetToHostLongLong(char* buf, uint32_t len); - bool Recv(int32_t fd, char* buf, size_t totalLen, int32_t flags) const; - bool Send(int32_t fd, const char* buf, size_t totalLen, int32_t flags) const; - void Close(); - bool SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit); - bool IsConnected(); - std::string GetSocketStateString(); private: - int32_t client_ {-1}; - std::atomic socketState_ {ToolchainSocketState::UNINITED}; - static constexpr int32_t CLIENT_WEBSOCKET_UPGRADE_RSP_LEN = 129; - static constexpr char CLIENT_WEBSOCKET_UPGRADE_REQ[] = "GET / HTTP/1.1\r\n" + static bool ValidateServerHandShake(HttpResponse& response); + + void CloseConnectionSocketOnFail(); + + bool ValidateIncomingFrame(const WebSocketFrame& wsFrame) override; + std::string CreateFrame(bool isLast, FrameType frameType) const override; + std::string CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const override; + std::string CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const override; + +private: + static constexpr std::string_view HTTP_SWITCHING_PROTOCOLS_STATUS_CODE = "101"; + static constexpr std::string_view HTTP_RESPONSE_REQUIRED_UPGRADE = "websocket"; + static constexpr std::string_view HTTP_RESPONSE_REQUIRED_CONNECTION = "upgrade"; + +#define WEB_SOCKET_KEY "64b4B+s5JDlgkdg7NekJ+g==" + static constexpr unsigned char DEFAULT_WEB_SOCKET_KEY[] = WEB_SOCKET_KEY; + static constexpr char CLIENT_WEBSOCKET_UPGRADE_REQ[] = "GET / HTTP/1.1\r\n" "Connection: Upgrade\r\n" "Pragma: no-cache\r\n" "Cache-Control: no-cache\r\n" "Upgrade: websocket\r\n" "Sec-WebSocket-Version: 13\r\n" "Accept-Encoding: gzip, deflate, br\r\n" - "Sec-WebSocket-Key: 64b4B+s5JDlgkdg7NekJ+g==\r\n" - "Sec-WebSocket-Extensions: permessage-deflate\r\n"; - static constexpr int32_t SOCKET_SUCCESS = 0; + "Sec-WebSocket-Key: " + WEB_SOCKET_KEY + "\r\n" + "Sec-WebSocket-Extensions: permessage-deflate\r\n" + "\r\n"; +#undef WEB_SOCKET_KEY + static constexpr int NET_SUCCESS = 1; - static constexpr int32_t SOCKET_MASK_LEN = 4; - static constexpr int32_t SOCKET_HEADER_LEN = 2; - static constexpr int32_t PAYLOAD_LEN = 2; - static constexpr int32_t EXTEND_PAYLOAD_LEN = 8; - static constexpr char MASK_KEY[SOCKET_MASK_LEN + 1] = "abcd"; + static constexpr uint8_t MASK_KEY[] = {0xa, 0xb, 0xc, 0xd}; }; } // namespace OHOS::ArkCompiler::Toolchain -#endif \ No newline at end of file +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_CLIENT_WEBSOCKET_CLIENT_H diff --git a/websocket/define.h b/websocket/define.h index a17f168a..ba40e3c0 100644 --- a/websocket/define.h +++ b/websocket/define.h @@ -1,61 +1,36 @@ -/* - * Copyright (c) 2022-2023 Huawei Device Co., Ltd. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_DEFINE_H -#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_DEFINE_H - -#include -#include -#include -#include -#include -#include -#if defined(WINDOWS_PLATFORM) -#include -#include -#ifdef ERROR -#undef ERROR -#endif -#else -#include -#include -#include -#include -#endif -#include -#include - -namespace OHOS::ArkCompiler::Toolchain { -std::vector ProtocolSplit(const std::string& str, const std::string& input) -{ - std::vector result; - size_t prev = 0; - size_t len = input.length(); - size_t cur = str.find(input); - while (cur != std::string::npos) { - std::string tmp = str.substr(prev, cur - prev); - result.push_back(tmp); - prev = cur + len; - cur = str.find(input, prev); - } - if (prev < str.size()) { - std::string tmp = str.substr(prev); - result.push_back(tmp); - } - return result; -} -} // namespace OHOS::ArkCompiler::Toolchain - -#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_DEFINE_H +/* + * Copyright (c) 2022-2023 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_DEFINE_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_DEFINE_H + +#include +#include +#include +#if defined(WINDOWS_PLATFORM) +#include +#include +#ifdef ERROR +#undef ERROR +#endif +#else +#include +#include +#include +#include +#endif +#include + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_DEFINE_H diff --git a/websocket/frame_builder.cpp b/websocket/frame_builder.cpp new file mode 100644 index 00000000..97ed10c4 --- /dev/null +++ b/websocket/frame_builder.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "websocket/frame_builder.h" + +#include + +namespace OHOS::ArkCompiler::Toolchain { +ServerFrameBuilder& ServerFrameBuilder::SetFinal(bool fin) +{ + fin_ = fin; + return *this; +} + +ServerFrameBuilder& ServerFrameBuilder::SetOpcode(FrameType opcode) +{ + opcode_ = opcode; + return *this; +} + +ServerFrameBuilder& ServerFrameBuilder::SetPayload(const std::string& payload) +{ + payload_ = payload; + return *this; +} + +ServerFrameBuilder& ServerFrameBuilder::SetPayload(std::string&& payload) +{ + payload_ = std::move(payload); + return *this; +} + +ServerFrameBuilder& ServerFrameBuilder::AppendPayload(const std::string& payload) +{ + payload_.append(payload); + return *this; +} + +std::string ServerFrameBuilder::Build() const +{ + std::string message; + PushFullHeader(message, 0); + PushPayload(message); + return message; +} + +void ServerFrameBuilder::PushFullHeader(std::string& message, size_t additionalReservedMem) const +{ + auto headerBytes = WebSocketFrame::HEADER_LEN; + auto payloadBytes = payload_.size(); + uint8_t payloadLenField = 0; + + if (payloadBytes <= WebSocketFrame::ONE_BYTE_LENTH_ENC_LIMIT) { + payloadLenField = static_cast(payloadBytes); + } else if (payloadBytes <= std::numeric_limits::max()) { // condition equals to `payloadBytes < 65536` + payloadLenField = WebSocketFrame::TWO_BYTES_LENTH_ENC; + headerBytes += WebSocketFrame::TWO_BYTES_LENTH; + } else { + payloadLenField = WebSocketFrame::EIGHT_BYTES_LENTH_ENC; + headerBytes += WebSocketFrame::EIGHT_BYTES_LENTH; + } + + message.reserve(headerBytes + payloadBytes + additionalReservedMem); + PushHeader(message, payloadLenField); + PushPayloadLength(message, payloadLenField); +} + +void ServerFrameBuilder::PushHeader(std::string& message, uint8_t payloadLenField) const +{ + uint8_t byte = EnumToNumber(opcode_); + if (fin_) { + byte |= 0x80; + } + message.push_back(byte); + + // A server MUST NOT mask any frames that it sends to the client, + // hence mask bit must be set to zero (see https://www.rfc-editor.org/rfc/rfc6455#section-5.1) + byte = payloadLenField & 0x7f; + message.push_back(byte); +} + +void ServerFrameBuilder::PushPayloadLength(std::string& message, uint8_t payloadLenField) const +{ + uint64_t payloadLen = payload_.size(); + if (payloadLenField == WebSocketFrame::TWO_BYTES_LENTH_ENC) { + PushNumberPerByte(message, static_cast(payloadLen)); + } else if (payloadLenField == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) { + PushNumberPerByte(message, payloadLen); + } +} + +void ServerFrameBuilder::PushPayload(std::string& message) const +{ + message.append(payload_); +} + +ClientFrameBuilder::ClientFrameBuilder(bool final, FrameType opcode, const uint8_t maskingKey[WebSocketFrame::MASK_LEN]) + : ServerFrameBuilder(final, opcode) +{ + SetMask(maskingKey); +} + +ClientFrameBuilder& ClientFrameBuilder::SetMask(const uint8_t maskingKey[WebSocketFrame::MASK_LEN]) +{ + for (size_t i = 0; i < WebSocketFrame::MASK_LEN; ++i) { + maskingKey_[i] = maskingKey[i]; + } + return *this; +} + +void ClientFrameBuilder::PushFullHeader(std::string& message, size_t additionalReservedMem) const +{ + // reserve additional 4 bytes for mask + ServerFrameBuilder::PushFullHeader(message, additionalReservedMem + WebSocketFrame::MASK_LEN); + // If the data is being sent by the client, the frame(s) MUST be masked + // (see https://www.rfc-editor.org/rfc/rfc6455#section-6.1) + message[1] |= 0x80; + PushMask(message); +} + +void ClientFrameBuilder::PushPayload(std::string& message) const +{ + // push masked payload + for (size_t i = 0, end = payload_.size(); i < end; ++i) { + char c = payload_[i] ^ maskingKey_[i % WebSocketFrame::MASK_LEN]; + message.push_back(c); + } +} + +void ClientFrameBuilder::PushMask(std::string& message) const +{ + for (size_t i = 0; i < WebSocketFrame::MASK_LEN; ++i) { + message.push_back(static_cast(maskingKey_[i])); + } +} +} // OHOS::ArkCompiler::Toolchain diff --git a/websocket/frame_builder.h b/websocket/frame_builder.h new file mode 100644 index 00000000..b932ae27 --- /dev/null +++ b/websocket/frame_builder.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_FRAME_BUILDER_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_FRAME_BUILDER_H + +#include "websocket/web_socket_frame.h" + +#include + +namespace OHOS::ArkCompiler::Toolchain { +template >> +inline void PushNumberPerByte(std::string& message, T number) +{ + constexpr size_t bytesCount = sizeof(T); + constexpr size_t bitsCount = 8; + size_t shiftCount = (bytesCount - 1) * bitsCount; + for (size_t i = 0; i < bytesCount; ++i, shiftCount -= bitsCount) { + message.push_back((number >> shiftCount) & 0xff); + } +} + +class ServerFrameBuilder { +public: + // force users to specify opcode and final bit + ServerFrameBuilder() = delete; + ServerFrameBuilder(bool final, FrameType opcode) : fin_(final), opcode_(opcode) + { + } + ~ServerFrameBuilder() noexcept = default; + + ServerFrameBuilder& SetFinal(bool fin); + ServerFrameBuilder& SetOpcode(FrameType opcode); + ServerFrameBuilder& SetPayload(const std::string& payload); + ServerFrameBuilder& SetPayload(std::string&& payload); + ServerFrameBuilder& AppendPayload(const std::string& payload); + + std::string Build() const; + +protected: + void PushHeader(std::string& message, uint8_t payloadLenField) const; + void PushPayloadLength(std::string& message, uint8_t payloadLenField) const; + virtual void PushFullHeader(std::string& message, size_t additionalReservedMem) const; + virtual void PushPayload(std::string& message) const; + +protected: + bool fin_; + FrameType opcode_; + std::string payload_; +}; + +class ClientFrameBuilder final : public ServerFrameBuilder { +public: + ClientFrameBuilder(bool final, FrameType opcode, const uint8_t maskingKey[WebSocketFrame::MASK_LEN]); + + ClientFrameBuilder& SetMask(const uint8_t maskingKey[WebSocketFrame::MASK_LEN]); + +private: + void PushMask(std::string& message) const; + void PushFullHeader(std::string& message, size_t additionalReservedMem) const override; + void PushPayload(std::string& message) const override; + +private: + uint8_t maskingKey_[WebSocketFrame::MASK_LEN] = {0}; +}; +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_FRAME_BUILDER_H diff --git a/websocket/handshake_helper.cpp b/websocket/handshake_helper.cpp new file mode 100644 index 00000000..cf1d30e6 --- /dev/null +++ b/websocket/handshake_helper.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/log_wrapper.h" +#include "websocket/handshake_helper.h" + +namespace OHOS::ArkCompiler::Toolchain { +/* static */ +bool WebSocketKeyEncoder::EncodeKey(std::string_view key, unsigned char (&destination)[ENCODED_KEY_LEN + 1]) +{ + std::string buffer(key.size() + WEB_SOCKET_GUID.size(), 0); + key.copy(buffer.data(), key.size()); + WEB_SOCKET_GUID.copy(buffer.data() + key.size(), WEB_SOCKET_GUID.size()); + + return EncodeKey(reinterpret_cast(buffer.data()), buffer.size(), destination); +} + +/* static */ +bool WebSocketKeyEncoder::EncodeKey(const unsigned char(&key)[KEY_LENGTH + 1], + unsigned char (&destination)[ENCODED_KEY_LEN + 1]) +{ + constexpr size_t bufferSize = KEY_LENGTH + WEB_SOCKET_GUID.size(); + unsigned char buffer[bufferSize]; + auto *guid = std::copy(key, key + KEY_LENGTH, buffer); + WEB_SOCKET_GUID.copy(reinterpret_cast(guid), WEB_SOCKET_GUID.size()); + + return EncodeKey(buffer, bufferSize, destination); +} + +/* static */ +bool WebSocketKeyEncoder::EncodeKey(const unsigned char *source, size_t length, + unsigned char (&destination)[ENCODED_KEY_LEN + 1]) +{ + unsigned char hash[SHA_DIGEST_LENGTH]; + SHA1(source, length, hash); + + // base64-encoding is done via EVP_EncodeBlock, which writes a null-terminated string. + int encodedBytes = EVP_EncodeBlock(destination, hash, SHA_DIGEST_LENGTH); + // "EVP_EncodeBlock() returns the number of bytes encoded excluding the NUL terminator." + if (encodedBytes != ENCODED_KEY_LEN) { + LOGE("EVP_EncodeBlock failed to encode all bytes, encodedBytes = %{public}d", encodedBytes); + return false; + } + return true; +} +} // namespace OHOS::ArkCompiler::Toolchain diff --git a/websocket/handshake_helper.h b/websocket/handshake_helper.h new file mode 100644 index 00000000..13e4b94b --- /dev/null +++ b/websocket/handshake_helper.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_HANDSHAKE_HELPER_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_HANDSHAKE_HELPER_H + +#include "websocket/define.h" +#include "websocket/http.h" +#include "websocket/network.h" + +#include +#include + +namespace OHOS::ArkCompiler::Toolchain { +class WebSocketKeyEncoder { +public: + // WebSocket Globally Unique Identifier + static constexpr std::string_view WEB_SOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + // The value of |Sec-WebSocket-Key| header field MUST be a nonce consisting of a randomly selected 16-byte value + static constexpr size_t KEY_LENGTH = GetBase64EncodingLength(16); + // SHA1 will write SHA_DIGEST_LENGTH == 20 bytes of output + static constexpr size_t ENCODED_KEY_LEN = GetBase64EncodingLength(SHA_DIGEST_LENGTH); + + static bool EncodeKey(std::string_view key, unsigned char (&destination)[ENCODED_KEY_LEN + 1]); + static bool EncodeKey(const unsigned char(&key)[KEY_LENGTH + 1], unsigned char (&destination)[ENCODED_KEY_LEN + 1]); + +private: + static bool EncodeKey(const unsigned char *source, size_t length, + unsigned char (&destination)[ENCODED_KEY_LEN + 1]); +}; + +class ProtocolUpgradeBuilder { +private: + constexpr size_t copyStringToBuffer(std::string_view source, size_t startIndex) + { + for (size_t i = 0, end = source.size(); i < end; ++i, ++startIndex) { + upgrade_buffer_[startIndex] = source[i]; + } + return startIndex; + } + + template + constexpr size_t copyStringToBuffer(const T (&source)[LENGTH], size_t startIndex) + { + for (size_t i = 0, end = LENGTH - 1; i < end; ++i, ++startIndex) { + upgrade_buffer_[startIndex] = source[i]; + } + return startIndex; + } + +public: + constexpr ProtocolUpgradeBuilder() + { + size_t index = copyStringToBuffer(SWITCHING_PROTOCOLS, 0); + index = copyStringToBuffer(HttpBase::EOL, index); + index = copyStringToBuffer(CONNECTION_UPGRADE, index); + index = copyStringToBuffer(HttpBase::EOL, index); + index = copyStringToBuffer(UPGRADE_WEBSOCKET, index); + index = copyStringToBuffer(HttpBase::EOL, index); + index = copyStringToBuffer(ACCEPT_KEY, index); + // will copy key without null terminator + index += WebSocketKeyEncoder::ENCODED_KEY_LEN; + index = copyStringToBuffer(HttpBase::EOL, index); + index = copyStringToBuffer(HttpBase::EOL, index); + } + + constexpr explicit ProtocolUpgradeBuilder( + const unsigned char (&encodedKey)[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1]) + : ProtocolUpgradeBuilder() + { + SetKey(encodedKey); + } + + constexpr void SetKey(const unsigned char (&encodedKey)[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1]) + { + copyStringToBuffer(encodedKey, KEY_START); + } + + constexpr const char *GetUpgradeMessage() + { + return upgrade_buffer_.data(); + } + + static constexpr size_t GetLength() + { + return MESSAGE_LENGTH; + } + +private: + static constexpr std::string_view SWITCHING_PROTOCOLS = "HTTP/1.1 101 Switching Protocols"; + static constexpr std::string_view CONNECTION_UPGRADE = "Connection: Upgrade"; + static constexpr std::string_view UPGRADE_WEBSOCKET = "Upgrade: websocket"; + static constexpr std::string_view ACCEPT_KEY = "Sec-WebSocket-Accept: "; + static constexpr size_t KEY_START = SWITCHING_PROTOCOLS.size() + + CONNECTION_UPGRADE.size() + + UPGRADE_WEBSOCKET.size() + + ACCEPT_KEY.size() + + 3 * HttpBase::EOL.size(); + static constexpr size_t MESSAGE_LENGTH = KEY_START + + WebSocketKeyEncoder::ENCODED_KEY_LEN + + 2 * HttpBase::EOL.size(); + +private: + // null-terminated string buffer + std::array upgrade_buffer_ {0}; +}; +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_HANDSHAKE_HELPER_H diff --git a/websocket/http.cpp b/websocket/http.cpp new file mode 100644 index 00000000..479970ab --- /dev/null +++ b/websocket/http.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/log_wrapper.h" +#include "websocket/http.h" + +namespace OHOS::ArkCompiler::Toolchain { +/* static */ +std::string HttpBase::DecodeHeader(const std::string& headersText, std::string_view headerName) +{ + auto startPos = headersText.find(headerName); + if (startPos != std::string::npos) { + auto endOfLinePos = headersText.find(EOL, startPos); + startPos += headerName.size(); + if (startPos < headersText.size() && startPos < endOfLinePos) { + return headersText.substr(startPos, endOfLinePos - startPos); + } + } + return ""; +} + +/* static */ +std::string HttpRequest::DecodeVersion(const std::string& request, std::string::size_type methodStartPos) +{ + if (methodStartPos >= request.size()) { + return ""; + } + + auto endOfLinePos = request.find(EOL, methodStartPos); + // the typical header is "GET /chat HTTP/1.1", where protocol version is located after the second space symbol + methodStartPos = request.find(' ', methodStartPos); + if (methodStartPos != std::string::npos) { + methodStartPos = request.find(' ', methodStartPos + 1); + } + if (methodStartPos != std::string::npos && + methodStartPos + 1 < request.size() && + methodStartPos + 1 < endOfLinePos) { + return request.substr(methodStartPos + 1, endOfLinePos - (methodStartPos + 1)); + } + return ""; +} + +// request example can be found at https://www.rfc-editor.org/rfc/rfc6455#section-1.3 +/* static */ +bool HttpRequest::Decode(const std::string& request, HttpRequest& parsed) +{ + auto pos = request.find(GET); + if (pos == std::string::npos) { + LOGW("Handshake failed: lack of necessary info"); + return false; + } + + parsed.version = DecodeVersion(request, pos); + parsed.connection = DecodeHeader(request, CONNECTION); + parsed.upgrade = DecodeHeader(request, UPGRADE); + parsed.secWebSocketKey = DecodeHeader(request, SEC_WEBSOCKET_KEY); + + return true; +} + +/* static */ +std::string HttpResponse::DecodeVersion(const std::string& response, std::string::size_type versionStartPos) +{ + // status-line example: "HTTP/1.1 404 Not Found" + if (versionStartPos < response.size()) { + auto versionEndPos = response.find(' ', versionStartPos); + if (versionEndPos != std::string::npos) { + return response.substr(versionStartPos, versionEndPos - versionStartPos); + } + } + return ""; +} + +/* static */ +std::string HttpResponse::DecodeStatus(const std::string& response, std::string::size_type versionEndPos) +{ + // status-line example: "HTTP/1.1 404 Not Found" + if (versionEndPos < response.size() && response[versionEndPos] == ' ') { + auto statusStartPos = response.find_first_not_of(' ', versionEndPos); + if (statusStartPos != std::string::npos) { + auto statusEndPos = response.find(' ', statusStartPos); + if (statusEndPos != std::string::npos || + (statusEndPos = response.find(EOL, statusStartPos)) != std::string::npos) { + return response.substr(statusStartPos, statusEndPos - statusStartPos); + } + } + } + return ""; +} + +// request example can be found at https://www.rfc-editor.org/rfc/rfc6455#section-1.2 +/* static */ +bool HttpResponse::Decode(const std::string& response, HttpResponse& parsed) +{ + // find start of status-line + auto versionStartPos = response.find("HTTP"); + if (versionStartPos == std::string::npos) { + LOGW("Handshake failed: lack of necessary info, no status-line found"); + return false; + } + + parsed.version = DecodeVersion(response, versionStartPos); + parsed.status = DecodeStatus(response, versionStartPos + parsed.version.size()); + parsed.connection = DecodeHeader(response, CONNECTION); + parsed.upgrade = DecodeHeader(response, UPGRADE); + parsed.secWebSocketAccept = DecodeHeader(response, SEC_WEBSOCKET_ACCEPT); + + return true; +} +} // namespace OHOS::ArkCompiler::Toolchain diff --git a/websocket/http.h b/websocket/http.h new file mode 100644 index 00000000..8955404f --- /dev/null +++ b/websocket/http.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_HTTP_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_HTTP_H + +#include + +namespace OHOS::ArkCompiler::Toolchain { +struct HttpBase { + static constexpr std::string_view EOL = "\r\n"; + static constexpr std::string_view GET = "GET"; + static constexpr std::string_view CONNECTION = "Connection: "; + static constexpr std::string_view UPGRADE = "Upgrade: "; + static constexpr std::string_view SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept: "; + static constexpr std::string_view SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key: "; + + static std::string DecodeHeader(const std::string& headersText, std::string_view headerName); +}; + + +struct HttpRequest : private HttpBase { + std::string version; + std::string connection; + std::string upgrade; + std::string secWebSocketKey; + + static bool Decode(const std::string& request, HttpRequest& parsed); + +private: + static std::string DecodeVersion(const std::string& request, std::string::size_type methodStartPos); +}; + + +struct HttpResponse : private HttpBase { + std::string version; + std::string status; + std::string connection; + std::string upgrade; + std::string secWebSocketAccept; + + static bool Decode(const std::string& response, HttpResponse& parsed); + +private: + static std::string DecodeVersion(const std::string& response, std::string::size_type versionStartPos); + static std::string DecodeStatus(const std::string& response, std::string::size_type versionEndPos); +}; +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_HTTP_H diff --git a/websocket/network.cpp b/websocket/network.cpp new file mode 100644 index 00000000..f960760e --- /dev/null +++ b/websocket/network.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/log_wrapper.h" +#include "websocket/define.h" +#include "websocket/network.h" + +namespace OHOS::ArkCompiler::Toolchain { +bool Recv(int32_t client, std::string& buffer, int32_t flags) +{ + if (buffer.empty()) { + return false; + } + auto succeeded = Recv(client, buffer.data(), buffer.size(), flags); + if (!succeeded) { + buffer.clear(); + } + return succeeded; +} + +bool Recv(int32_t client, char* buf, size_t totalLen, int32_t flags) +{ + size_t recvLen = 0; + while (recvLen < totalLen) { + ssize_t len = recv(client, buf + recvLen, totalLen - recvLen, flags); + if (len <= 0) { + LOGE("Recv payload in while failed, len = %{public}ld, errno = %{public}d", static_cast(len), errno); + return false; + } + recvLen += static_cast(len); + } + return true; +} + +bool Recv(int32_t client, uint8_t* buf, size_t totalLen, int32_t flags) +{ + return Recv(client, reinterpret_cast(buf), totalLen, flags); +} + +bool Send(int32_t client, const std::string& message, int32_t flags) +{ + return Send(client, message.c_str(), message.size(), flags); +} + +bool Send(int32_t client, const char* buf, size_t totalLen, int32_t flags) +{ + size_t sendLen = 0; + while (sendLen < totalLen) { + ssize_t len = send(client, buf + sendLen, totalLen - sendLen, flags); + if (len <= 0) { + LOGE("Send Message in while failed, len = %{public}ld, errno = %{public}d", static_cast(len), errno); + return false; + } + sendLen += static_cast(len); + } + return true; +} + +uint64_t NetToHostLongLong(uint8_t* buf, uint32_t len) +{ + uint64_t result = 0; + for (uint32_t i = 0; i < len; i++) { + result |= buf[i]; + if ((i + 1) < len) { + result <<= 8; // 8: result need shift left 8 bits in order to big endian convert to int + } + } + return result; +} +} // OHOS::ArkCompiler::Toolchain diff --git a/websocket/network.h b/websocket/network.h new file mode 100644 index 00000000..013c0c56 --- /dev/null +++ b/websocket/network.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_NETWORK_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_NETWORK_H + +#include + +namespace OHOS::ArkCompiler::Toolchain { +// Receives a message of size `buffer.size()`. Clears the string buffer on error. +bool Recv(int32_t client, std::string& buffer, int32_t flags); + +bool Recv(int32_t client, char* buf, size_t totalLen, int32_t flags); + +bool Recv(int32_t client, uint8_t* buf, size_t totalLen, int32_t flags); + +bool Send(int32_t client, const std::string& message, int32_t flags); + +bool Send(int32_t client, const char* buf, size_t totalLen, int32_t flags); + +uint64_t NetToHostLongLong(uint8_t* buf, uint32_t len); + +constexpr inline size_t GetBase64EncodingLength(size_t inputLength) +{ + size_t paddingOffset = 2; + // base64-encoding produces padded output of 4 characters for every 3 bytes input + return ((inputLength + paddingOffset) / 3) * 4; +} +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_NETWORK_H diff --git a/websocket/server/websocket_server.cpp b/websocket/server/websocket_server.cpp new file mode 100644 index 00000000..fea5868b --- /dev/null +++ b/websocket/server/websocket_server.cpp @@ -0,0 +1,310 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/log_wrapper.h" +#include "websocket/frame_builder.h" +#include "websocket/handshake_helper.h" +#include "websocket/network.h" +#include "websocket/server/websocket_server.h" + +namespace OHOS::ArkCompiler::Toolchain { +bool WebSocketServer::DecodeMessage(WebSocketFrame& wsFrame, bool &isRecvFail) const +{ + uint64_t msgLen = wsFrame.payloadLen; + if (msgLen == 0) { + // receiving empty data is OK + return true; + } + auto& buffer = wsFrame.payload; + buffer.resize(msgLen, 0); + + if (!Recv(connectionFd_, wsFrame.maskingKey, sizeof(wsFrame.maskingKey), 0)) { + LOGE("DecodeMessage: Recv maskingKey failed"); + SetSocketFail(isRecvFail); + return false; + } + + if (!Recv(connectionFd_, buffer, 0)) { + LOGE("DecodeMessage: Recv message with mask failed"); + SetSocketFail(isRecvFail); + return false; + } + + for (uint64_t i = 0; i < msgLen; i++) { + auto j = i % WebSocketFrame::MASK_LEN; + buffer[i] = static_cast(buffer[i]) ^ wsFrame.maskingKey[j]; + } + + return true; +} + +bool WebSocketServer::ProtocolUpgrade(const HttpRequest& req) +{ + unsigned char encodedKey[WebSocketKeyEncoder::ENCODED_KEY_LEN + 1]; + if (!WebSocketKeyEncoder::EncodeKey(req.secWebSocketKey, encodedKey)) { + LOGE("ProtocolUpgrade: failed to encode WebSocket-Key"); + return false; + } + + ProtocolUpgradeBuilder request_builder(encodedKey); + if (!Send(connectionFd_, request_builder.GetUpgradeMessage(), request_builder.GetLength(), 0)) { + LOGE("ProtocolUpgrade: Send failed"); + return false; + } + return true; +} + +bool WebSocketServer::ResponseInvalidHandShake() const +{ + std::string response(BAD_REQUEST_RESPONSE); + return Send(connectionFd_, response, 0); +} + +bool WebSocketServer::HttpHandShake() +{ + std::string msgBuf(HTTP_HANDSHAKE_MAX_LEN, 0); + ssize_t msgLen = recv(connectionFd_, msgBuf.data(), HTTP_HANDSHAKE_MAX_LEN, 0); + if (msgLen <= 0) { + LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast(msgLen), errno); + return false; + } + // reduce to received size + msgBuf.resize(msgLen); + + HttpRequest req; + if (!HttpRequest::Decode(msgBuf, req)) { + LOGE("HttpHandShake: Upgrade failed"); + return false; + } + if (validateCb_ && !validateCb_(req)) { + LOGE("HttpHandShake: Validation failed"); + return false; + } + + if (ValidateHandShakeMessage(req)) { + return ProtocolUpgrade(req); + } + + LOGE("HttpHandShake: HTTP upgrade parameters failure"); + if (!ResponseInvalidHandShake()) { + LOGE("HttpHandShake: failed to send 'bad request' response"); + } + return false; +} + +/* static */ +bool WebSocketServer::ValidateHandShakeMessage(const HttpRequest& req) +{ + return req.connection.find("Upgrade") != std::string::npos && + req.upgrade.find("websocket") != std::string::npos && + req.version.compare("HTTP/1.1") == 0; +} + +bool WebSocketServer::AcceptNewConnection() +{ + if (socketState_ == SocketState::UNINITED) { + LOGE("AcceptNewConnection failed, websocket not inited"); + return false; + } + if (socketState_ == SocketState::CONNECTED) { + LOGI("AcceptNewConnection websocket has connected"); + return true; + } + + // TODO: this must be better done in async manner via epoll + if ((connectionFd_ = accept(server_fd_, nullptr, nullptr)) < SOCKET_SUCCESS) { + LOGI("AcceptNewConnection accept has exited"); + return false; + } + + if (!HttpHandShake()) { + LOGE("AcceptNewConnection HttpHandShake failed"); + CloseConnectionSocket(ConnectionCloseReason::FAIL, SocketState::INITED); + return false; + } + OnNewConnection(); + return true; +} + +#if !defined(OHOS_PLATFORM) +bool WebSocketServer::InitTcpWebSocket(int port, uint32_t timeoutLimit) +{ + if (port < 0) { + LOGE("InitTcpWebSocket invalid port"); + return false; + } + if (socketState_ != SocketState::UNINITED) { + LOGI("InitTcpWebSocket websocket has inited"); + return true; + } +#if defined(WINDOWS_PLATFORM) + WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2 + WSADATA wsaData; + if (WSAStartup(sockVersion, &wsaData) != 0) { + LOGE("InitTcpWebSocket WSA init failed"); + return false; + } +#endif + server_fd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (server_fd_ < SOCKET_SUCCESS) { + LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno); + return false; + } + // allow specified port can be used at once and not wait TIME_WAIT status ending + int sockOptVal = 1; + if ((setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) { + LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno); + CloseServerSocket(); + return false; + } + // set send and recv timeout + if (!SetWebSocketTimeOut(server_fd_, timeoutLimit)) { + LOGE("InitTcpWebSocket SetWebSocketTimeOut failed"); + CloseServerSocket(); + return false; + } + sockaddr_in addr_sin = {}; + addr_sin.sin_family = AF_INET; + addr_sin.sin_port = htons(port); + addr_sin.sin_addr.s_addr = INADDR_ANY; + if (bind(server_fd_, reinterpret_cast(&addr_sin), sizeof(addr_sin)) != SOCKET_SUCCESS) { + LOGE("InitTcpWebSocket bind failed, errno = %{public}d", errno); + CloseServerSocket(); + return false; + } + if (listen(server_fd_, 1) != SOCKET_SUCCESS) { + LOGE("InitTcpWebSocket listen failed, errno = %{public}d", errno); + CloseServerSocket(); + return false; + } + socketState_ = SocketState::INITED; + return true; +} +#else +bool WebSocketServer::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit) +{ + if (socketState_ != SocketState::UNINITED) { + LOGI("InitUnixWebSocket websocket has inited"); + return true; + } + server_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: default protocol + if (server_fd_ < SOCKET_SUCCESS) { + LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno); + return false; + } + // set send and recv timeout + if (!SetWebSocketTimeOut(server_fd_, timeoutLimit)) { + LOGE("InitUnixWebSocket SetWebSocketTimeOut failed"); + CloseServerSocket(); + return false; + } + + struct sockaddr_un un; + if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) { + LOGE("InitUnixWebSocket memset_s failed"); + CloseServerSocket(); + return false; + } + un.sun_family = AF_UNIX; + if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) { + LOGE("InitUnixWebSocket strcpy_s failed"); + CloseServerSocket(); + return false; + } + un.sun_path[0] = '\0'; + uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1; + if (bind(server_fd_, reinterpret_cast(&un), static_cast(len)) != SOCKET_SUCCESS) { + LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno); + CloseServerSocket(); + return false; + } + if (listen(server_fd_, 1) != SOCKET_SUCCESS) { // 1: connection num + LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno); + CloseServerSocket(); + return false; + } + socketState_ = SocketState::INITED; + return true; +} +#endif + +void WebSocketServer::CloseServerSocket() +{ + close(server_fd_); + server_fd_ = -1; + socketState_ = SocketState::UNINITED; +} + +void WebSocketServer::OnNewConnection() +{ + LOGI("New client connected"); + socketState_ = SocketState::CONNECTED; + if (openCb_) { + openCb_(); + } +} + +void WebSocketServer::SetValidateConnectionCallback(ValidateConnectionCallback cb) +{ + validateCb_ = std::move(cb); +} + +void WebSocketServer::SetOpenConnectionCallback(OpenConnectionCallback cb) +{ + openCb_ = std::move(cb); +} + +bool WebSocketServer::ValidateIncomingFrame(const WebSocketFrame& wsFrame) +{ + // "The server MUST close the connection upon receiving a frame that is not masked." + // https://www.rfc-editor.org/rfc/rfc6455#section-5.1 + return wsFrame.mask == 1; +} + +std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType) const +{ + ServerFrameBuilder builder(isLast, frameType); + return builder.Build(); +} + +std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const +{ + ServerFrameBuilder builder(isLast, frameType); + return builder.SetPayload(payload).Build(); +} + +std::string WebSocketServer::CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const +{ + ServerFrameBuilder builder(isLast, frameType); + return builder.SetPayload(std::move(payload)).Build(); +} + +void WebSocketServer::Close() +{ + if (socketState_ == SocketState::UNINITED) { + return; + } + if (socketState_ == SocketState::CONNECTED) { + CloseConnection(CloseStatusCode::SERVER_GO_AWAY, SocketState::INITED); + } + // TODO: research why sleep is needed + usleep(10000); // 10000: time for websocket to enter the accept +#if defined(OHOS_PLATFORM) + shutdown(server_fd_, SHUT_RDWR); +#endif + CloseServerSocket(); +} +} // namespace OHOS::ArkCompiler::Toolchain diff --git a/websocket/server/websocket_server.h b/websocket/server/websocket_server.h new file mode 100644 index 00000000..8573cef8 --- /dev/null +++ b/websocket/server/websocket_server.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_SERVER_WEBSOCKET_SERVER_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_SERVER_WEBSOCKET_SERVER_H + +#include "websocket/http.h" +#include "websocket/websocket_base.h" + +#include + +namespace OHOS::ArkCompiler::Toolchain { +class WebSocketServer final : public WebSocketBase { +public: + using ValidateConnectionCallback = std::function; + using OpenConnectionCallback = std::function; + +public: + ~WebSocketServer() noexcept override = default; + + bool AcceptNewConnection(); + +#if !defined(OHOS_PLATFORM) + // Initialize server socket, transition to `INITED` state. + bool InitTcpWebSocket(int port, uint32_t timeoutLimit = 0); +#else + // Initialize server socket, transition to `INITED` state. + bool InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit = 0); +#endif + + void SetValidateConnectionCallback(ValidateConnectionCallback cb); + void SetOpenConnectionCallback(OpenConnectionCallback cb); + + void Close() override; + +private: + static bool ValidateHandShakeMessage(const HttpRequest& req); + + bool ValidateIncomingFrame(const WebSocketFrame& wsFrame) override; + std::string CreateFrame(bool isLast, FrameType frameType) const override; + std::string CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const override; + std::string CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const override; + bool DecodeMessage(WebSocketFrame& wsFrame, bool &isRecvFail) const override; + + bool HttpHandShake(); + bool ProtocolUpgrade(const HttpRequest& req); + bool ResponseInvalidHandShake() const; + // Run `openCb_`, transition to `OnNewConnection` state. + void OnNewConnection(); + // Close server socket, transition to `UNINITED` state. + void CloseServerSocket(); + +private: + int32_t server_fd_ {-1}; + + // Callbacks used during different stages of connection lifecycle. + // E.g. validation callback - it is executed during handshake + // and used to indicate whether the incoming connection should be accepted. + ValidateConnectionCallback validateCb_; + OpenConnectionCallback openCb_; + + static constexpr std::string_view BAD_REQUEST_RESPONSE = "HTTP/1.1 400 Bad Request\r\n\r\n"; +}; +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_SERVER_WEBSOCKET_SERVER_H diff --git a/websocket/string_utils.h b/websocket/string_utils.h new file mode 100644 index 00000000..b9eb9e74 --- /dev/null +++ b/websocket/string_utils.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_STRING_UTILS_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_STRING_UTILS_H + +#include +#include +#include + +namespace OHOS::ArkCompiler::Toolchain { +// all cctype function arguments must be representable as unsigned char +inline void TrimLeft(std::string &str) +{ + str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](unsigned char ch) { return !std::isspace(ch); })); +} + +inline void TrimRight(std::string &str) +{ + str.erase(std::find_if(str.rbegin(), str.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), + str.end()); +} + +inline void Trim(std::string &str) +{ + TrimLeft(str); + TrimRight(str); +} + +inline void ToLowerCase(std::string& str) +{ + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); }); +} +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_STRING_UTILS_H + diff --git a/websocket/test/BUILD.gn b/websocket/test/BUILD.gn index 0b858074..d5c10fc5 100644 --- a/websocket/test/BUILD.gn +++ b/websocket/test/BUILD.gn @@ -24,13 +24,17 @@ host_unittest_action("WebSocketTest") { sources = [ # test file "../../common/log_wrapper.cpp", + "frame_builder_test.cpp", + "http_decoder_test.cpp", + "web_socket_frame_test.cpp", "websocket_test.cpp", ] configs = [ "$toolchain_root:toolchain_test_config" ] deps = [ - "$toolchain_root/websocket:websocket", + "$toolchain_root/websocket:websocket_server", + "$toolchain_root/websocket:websocket_client", sdk_libc_secshared_dep, ] diff --git a/websocket/test/frame_builder_test.cpp b/websocket/test/frame_builder_test.cpp new file mode 100644 index 00000000..ff06ae85 --- /dev/null +++ b/websocket/test/frame_builder_test.cpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2023 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "websocket/frame_builder.h" + +using namespace OHOS::ArkCompiler::Toolchain; + +namespace panda::test { +class FrameBuilderTest : public testing::Test { +public: + // final message, ping-frame opcode + static constexpr char PING_EXPECTED_FIRST_BYTE = 0x89; + + static constexpr uint8_t MASKING_KEY[WebSocketFrame::MASK_LEN] = {0xab}; + + static constexpr size_t SHORT_MSG_SIZE = 10; + static constexpr size_t LONG_MSG_SIZE = 1000; + static constexpr size_t LONG_LONG_MSG_SIZE = 0xfffff; + + static const std::string SHORT_MSG; + static const std::string LONG_MSG; + static const std::string LONG_LONG_MSG; +}; + +const std::string FrameBuilderTest::SHORT_MSG = std::string(SHORT_MSG_SIZE, 'f'); +const std::string FrameBuilderTest::LONG_MSG = std::string(LONG_MSG_SIZE, 'f'); +const std::string FrameBuilderTest::LONG_LONG_MSG = std::string(LONG_LONG_MSG_SIZE, 'f'); + +HWTEST_F(FrameBuilderTest, TestNoPayload, testing::ext::TestSize.Level0) +{ + ServerFrameBuilder frameBuilder(true, FrameType::PING); + auto message = frameBuilder.Build(); + + constexpr size_t EXPECTED_MESSAGE_SIZE = 2; + + ASSERT_EQ(message.size(), EXPECTED_MESSAGE_SIZE); + ASSERT_EQ(message[0], PING_EXPECTED_FIRST_BYTE); + // unmasked, zero length + ASSERT_EQ(message[1], 0); +} + +HWTEST_F(FrameBuilderTest, TestShortPayload, testing::ext::TestSize.Level0) +{ + ServerFrameBuilder frameBuilder(true, FrameType::PING); + auto message = frameBuilder + .SetPayload(SHORT_MSG) + .Build(); + + constexpr size_t HEADER_LENGTH = 2; + constexpr size_t EXPECTED_MESSAGE_SIZE = HEADER_LENGTH + SHORT_MSG_SIZE; + + ASSERT_EQ(message.size(), EXPECTED_MESSAGE_SIZE); + ASSERT_EQ(message[0], PING_EXPECTED_FIRST_BYTE); + // length fits into [0, 126) range + ASSERT_EQ(message[1], static_cast(SHORT_MSG_SIZE)); + for (size_t i = HEADER_LENGTH; i < message.size(); ++i) { + ASSERT_EQ(message[i], SHORT_MSG[i - HEADER_LENGTH]); + } +} + +HWTEST_F(FrameBuilderTest, TestLongPayload, testing::ext::TestSize.Level0) +{ + ServerFrameBuilder frameBuilder(true, FrameType::PING); + auto message = frameBuilder + .SetPayload(LONG_MSG) + .Build(); + + constexpr size_t HEADER_LENGTH = 2 + 2; + constexpr size_t EXPECTED_MESSAGE_SIZE = HEADER_LENGTH + LONG_MSG_SIZE; + + ASSERT_EQ(message.size(), EXPECTED_MESSAGE_SIZE); + ASSERT_EQ(message[0], PING_EXPECTED_FIRST_BYTE); + // length fits into [125, 65536) range - encoded with 126 + ASSERT_EQ(message[1], 126); + // everything is encoded as big-endian + ASSERT_EQ(message[2], static_cast((LONG_MSG_SIZE >> 8) & 0xff)); + ASSERT_EQ(message[3], static_cast(LONG_MSG_SIZE & 0xff)); + for (size_t i = HEADER_LENGTH; i < message.size(); ++i) { + ASSERT_EQ(message[i], LONG_MSG[i - HEADER_LENGTH]); + } +} + +HWTEST_F(FrameBuilderTest, TestLongLongPayload, testing::ext::TestSize.Level0) +{ + ServerFrameBuilder frameBuilder(true, FrameType::PING); + auto message = frameBuilder + .SetPayload(LONG_LONG_MSG) + .Build(); + + constexpr size_t HEADER_LENGTH = 2 + 8; + constexpr size_t EXPECTED_MESSAGE_SIZE = HEADER_LENGTH + LONG_LONG_MSG_SIZE; + + ASSERT_EQ(message.size(), EXPECTED_MESSAGE_SIZE); + ASSERT_EQ(message[0], PING_EXPECTED_FIRST_BYTE); + // length is bigger than 65536 - encoded with 127 + ASSERT_EQ(message[1], 127); + // everything is encoded as big-endian + for (size_t idx = 2, shiftCount = 8 * (sizeof(uint64_t) - 1); idx < HEADER_LENGTH; ++idx, shiftCount -= 8) { + ASSERT_EQ(message[idx], static_cast((LONG_LONG_MSG_SIZE >> shiftCount) & 0xff)); + } + for (size_t i = HEADER_LENGTH; i < message.size(); ++i) { + ASSERT_EQ(message[i], LONG_LONG_MSG[i - HEADER_LENGTH]); + } +} + +HWTEST_F(FrameBuilderTest, TestAppendPayload, testing::ext::TestSize.Level0) +{ + ServerFrameBuilder frameBuilder(true, FrameType::PING); + auto message = frameBuilder + .SetPayload(SHORT_MSG) + .AppendPayload(SHORT_MSG) + .Build(); + + constexpr size_t HEADER_LENGTH = 2; + constexpr size_t PAYLOAD_SIZE = SHORT_MSG_SIZE * 2; + constexpr size_t EXPECTED_MESSAGE_SIZE = HEADER_LENGTH + PAYLOAD_SIZE; + + ASSERT_EQ(message.size(), EXPECTED_MESSAGE_SIZE); + ASSERT_EQ(message[0], PING_EXPECTED_FIRST_BYTE); + // length fits into [0, 126) range + ASSERT_EQ(message[1], static_cast(PAYLOAD_SIZE)); + for (size_t i = HEADER_LENGTH; i < message.size(); ++i) { + ASSERT_EQ(message[i], SHORT_MSG[(i - HEADER_LENGTH) % SHORT_MSG_SIZE]); + } +} + +HWTEST_F(FrameBuilderTest, TestClientNoPayload, testing::ext::TestSize.Level0) +{ + ClientFrameBuilder frameBuilder(true, FrameType::PING, MASKING_KEY); + auto message = frameBuilder.Build(); + + constexpr size_t EXPECTED_MESSAGE_SIZE = 2 + WebSocketFrame::MASK_LEN; + + ASSERT_EQ(message.size(), EXPECTED_MESSAGE_SIZE); + ASSERT_EQ(message[0], PING_EXPECTED_FIRST_BYTE); + // masked, even if no payload provided + ASSERT_EQ(message[1], static_cast(0x80)); +} + +HWTEST_F(FrameBuilderTest, TestClientMasking, testing::ext::TestSize.Level0) +{ + ClientFrameBuilder frameBuilder(true, FrameType::PING, MASKING_KEY); + auto message = frameBuilder + .SetPayload(LONG_MSG) + .Build(); + + constexpr size_t HEADER_LENGTH = 2 + 2 + WebSocketFrame::MASK_LEN; + constexpr size_t EXPECTED_MESSAGE_SIZE = HEADER_LENGTH + LONG_MSG_SIZE; + + ASSERT_EQ(message.size(), EXPECTED_MESSAGE_SIZE); + ASSERT_EQ(message[0], PING_EXPECTED_FIRST_BYTE); + // masked, length fits into [125, 65536) range - encoded with 126 + ASSERT_EQ(message[1], static_cast(0x80 | 126)); + // everything is encoded as big-endian + ASSERT_EQ(message[2], static_cast((LONG_MSG_SIZE >> 8) & 0xff)); + ASSERT_EQ(message[3], static_cast(LONG_MSG_SIZE & 0xff)); + // message must be masked + for (size_t i = HEADER_LENGTH; i < message.size(); ++i) { + ASSERT_EQ(static_cast(message[i] ^ MASKING_KEY[i % WebSocketFrame::MASK_LEN]), + static_cast(LONG_MSG[i - HEADER_LENGTH])); + } +} +} // namespace panda::test diff --git a/websocket/test/http_decoder_test.cpp b/websocket/test/http_decoder_test.cpp new file mode 100644 index 00000000..58c5778c --- /dev/null +++ b/websocket/test/http_decoder_test.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2023 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "gtest/gtest.h" +#include "websocket/http.h" + +using namespace OHOS::ArkCompiler::Toolchain; + +namespace panda::test { +class HttpDecoderTest : public testing::Test { +public: +#define SEC_WEBSOCKET_KEY "AyuTxzyBTJJdViDskomT0Q==" + static constexpr std::string_view REQUEST_HEADERS = "GET / HTTP/1.1\r\n" + "Host: 127.0.0.1:19015\r\n" + "Connection: Upgrade\r\n" + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" + "User-Agent: Mozilla/5.0 (X11; Linux x86_64) Chrome/117.0.0.0 Safari/537.36\r\n" + "Upgrade: websocket\r\n" + "Origin: dvtls://dvtls\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Accept-Encoding: gzip, deflate, br\r\n" + "Accept-Language: en-US,en;q=0.9\r\n" + "Sec-WebSocket-Key: " + SEC_WEBSOCKET_KEY + "\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\r\n"; + std::string requestHeaders = std::string(REQUEST_HEADERS); + + static constexpr std::string_view RESPONSE_HEADERS = "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: " + SEC_WEBSOCKET_KEY + "\r\n\r\n"; + std::string responseHeaders = std::string(RESPONSE_HEADERS); + + static constexpr std::string_view EXPECTED_VERSION = "HTTP/1.1"; + static constexpr std::string_view EXPECTED_STATUS = "101"; + static constexpr std::string_view EXPECTED_CONNECTION = "Upgrade"; + static constexpr std::string_view EXPECTED_UPGRADE = "websocket"; + static constexpr std::string_view EXPECTED_SEC_WEBSOCKET_KEY = SEC_WEBSOCKET_KEY; +#undef SEC_WEBSOCKET_KEY +}; + +HWTEST_F(HttpDecoderTest, TestRequestDecode, testing::ext::TestSize.Level0) +{ + HttpRequest parsed; + auto succeeded = HttpRequest::Decode(requestHeaders, parsed); + + ASSERT_TRUE(succeeded); + ASSERT_EQ(parsed.version, EXPECTED_VERSION); + ASSERT_EQ(parsed.connection, EXPECTED_CONNECTION); + ASSERT_EQ(parsed.upgrade, EXPECTED_UPGRADE); + ASSERT_EQ(parsed.secWebSocketKey, EXPECTED_SEC_WEBSOCKET_KEY); +} + +HWTEST_F(HttpDecoderTest, TestResponseDecode, testing::ext::TestSize.Level0) +{ + HttpResponse parsed; + auto succeeded = HttpResponse::Decode(responseHeaders, parsed); + + ASSERT_TRUE(succeeded); + ASSERT_EQ(parsed.version, EXPECTED_VERSION); + ASSERT_EQ(parsed.status, EXPECTED_STATUS); + ASSERT_EQ(parsed.connection, EXPECTED_CONNECTION); + ASSERT_EQ(parsed.upgrade, EXPECTED_UPGRADE); + ASSERT_EQ(parsed.secWebSocketAccept, EXPECTED_SEC_WEBSOCKET_KEY); +} +} // namespace panda::test diff --git a/websocket/test/web_socket_frame_test.cpp b/websocket/test/web_socket_frame_test.cpp new file mode 100644 index 00000000..dd998329 --- /dev/null +++ b/websocket/test/web_socket_frame_test.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2023 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "websocket/web_socket_frame.h" + +using namespace OHOS::ArkCompiler::Toolchain; + +namespace panda::test { +class WebSocketFrameTest : public testing::Test { +public: + static constexpr uint8_t HEADER_RAW[WebSocketFrame::HEADER_LEN] = {0x01, 0x9a}; + static constexpr uint8_t EXPECTED_FIN = 0; + static constexpr uint8_t EXPECTED_OPCODE = 0x1; + static constexpr uint8_t EXPECTED_MASK_BIT = 1; + static constexpr uint8_t EXPECTED_PAYLOAD_LEN = 0x1a; +}; + +HWTEST_F(WebSocketFrameTest, TestDecode, testing::ext::TestSize.Level0) +{ + WebSocketFrame wsFrame(HEADER_RAW); + + ASSERT_EQ(wsFrame.fin, EXPECTED_FIN); + ASSERT_EQ(wsFrame.opcode, EXPECTED_OPCODE); + ASSERT_EQ(wsFrame.mask, EXPECTED_MASK_BIT); + ASSERT_EQ(wsFrame.payloadLen, EXPECTED_PAYLOAD_LEN); + // these fields must not be filled + ASSERT_TRUE(wsFrame.payload.empty()); + for (size_t i = 0; i < WebSocketFrame::MASK_LEN; ++i) { + ASSERT_EQ(wsFrame.maskingKey[i], 0); + } +} +} // namespace panda::test diff --git a/websocket/test/websocket_test.cpp b/websocket/test/websocket_test.cpp index 137a5e94..f086d702 100644 --- a/websocket/test/websocket_test.cpp +++ b/websocket/test/websocket_test.cpp @@ -14,11 +14,12 @@ */ #include +#include #include #include "gtest/gtest.h" -#include "websocket/websocket.h" -#include "securec.h" +#include "websocket/client/websocket_client.h" +#include "websocket/server/websocket_server.h" using namespace OHOS::ArkCompiler::Toolchain; @@ -43,283 +44,16 @@ public: { } - class ClientWebSocket : public WebSocket { - public: - ClientWebSocket() = default; - ~ClientWebSocket() = default; -#if defined(OHOS_PLATFORM) - bool ClientConnectUnixWebSocket(const std::string &sockName, uint32_t timeoutLimit = 0) - { - if (socketState_ != SocketState::UNINITED) { - std::cout << "ClientConnectUnixWebSocket::client has inited..." << std::endl; - return true; - } - - client_ = socket(AF_UNIX, SOCK_STREAM, 0); - if (client_ < SOCKET_SUCCESS) { - std::cerr << "ClientConnectUnixWebSocket::client socket failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - return false; - } - - // set send and recv timeout limit - if (!SetWebSocketTimeOut(client_, timeoutLimit)) { - std::cerr << "ClientConnectUnixWebSocket::client SetWebSocketTimeOut failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - - struct sockaddr_un serverAddr; - if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) { - std::cerr << "ClientConnectUnixWebSocket::client memset_s serverAddr failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - serverAddr.sun_family = AF_UNIX; - if (strcpy_s(serverAddr.sun_path + 1, sizeof(serverAddr.sun_path) - 1, sockName.c_str()) != EOK) { - std::cerr << "ClientConnectUnixWebSocket::client strcpy_s serverAddr.sun_path failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - serverAddr.sun_path[0] = '\0'; - - uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1; - int ret = connect(client_, reinterpret_cast(&serverAddr), static_cast(len)); - if (ret != SOCKET_SUCCESS) { - std::cerr << "ClientConnectUnixWebSocket::client connect failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - socketState_ = SocketState::INITED; - std::cout << "ClientConnectUnixWebSocket::client connect success..." << std::endl; - return true; - } -#else - bool ClientConnectTcpWebSocket(uint32_t timeoutLimit = 0) - { - if (socketState_ != SocketState::UNINITED) { - std::cout << "ClientConnectTcpWebSocket::client has inited..." << std::endl; - return true; - } - - client_ = socket(AF_INET, SOCK_STREAM, 0); - if (client_ < SOCKET_SUCCESS) { - std::cerr << "ClientConnectTcpWebSocket::client socket failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - return false; - } - - // set send and recv timeout limit - if (!SetWebSocketTimeOut(client_, timeoutLimit)) { - std::cerr << "ClientConnectTcpWebSocket::client SetWebSocketTimeOut failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - - struct sockaddr_in serverAddr; - if (memset_s(&serverAddr, sizeof(serverAddr), 0, sizeof(serverAddr)) != EOK) { - std::cerr << "ClientConnectTcpWebSocket::client memset_s serverAddr failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - serverAddr.sin_family = AF_INET; - serverAddr.sin_port = htons(9230); // 9230: sockName for tcp - if (int ret = inet_pton(AF_INET, "127.0.0.1", &serverAddr.sin_addr) < NET_SUCCESS) { - std::cerr << "ClientConnectTcpWebSocket::client inet_pton failed, ret = " - << ret << ", error = " << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - - int ret = connect(client_, reinterpret_cast(&serverAddr), sizeof(serverAddr)); - if (ret != SOCKET_SUCCESS) { - std::cerr << "ClientConnectTcpWebSocket::client connect failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - close(client_); - client_ = -1; - return false; - } - socketState_ = SocketState::INITED; - std::cout << "ClientConnectTcpWebSocket::client connect success..." << std::endl; - return true; - } -#endif - - bool ClientSendWSUpgradeReq() - { - if (socketState_ == SocketState::UNINITED) { - std::cerr << "ClientSendWSUpgradeReq::client has not inited..." << std::endl; - return false; - } - if (socketState_ == SocketState::CONNECTED) { - std::cout << "ClientSendWSUpgradeReq::client has connected..." << std::endl; - return true; - } - - int msgLen = strlen(CLIENT_WEBSOCKET_UPGRADE_REQ); - int32_t sendLen = send(client_, CLIENT_WEBSOCKET_UPGRADE_REQ, msgLen, 0); - if (sendLen != msgLen) { - std::cerr << "ClientSendWSUpgradeReq::client send wsupgrade req failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - socketState_ = SocketState::UNINITED; -#if defined(OHOS_PLATFORM) - shutdown(client_, SHUT_RDWR); -#endif - close(client_); - client_ = -1; - return false; - } - std::cout << "ClientSendWSUpgradeReq::client send wsupgrade req success..." << std::endl; - return true; - } - - bool ClientRecvWSUpgradeRsp() - { - if (socketState_ == SocketState::UNINITED) { - std::cerr << "ClientRecvWSUpgradeRsp::client has not inited..." << std::endl; - return false; - } - if (socketState_ == SocketState::CONNECTED) { - std::cout << "ClientRecvWSUpgradeRsp::client has connected..." << std::endl; - return true; - } - - char recvBuf[CLIENT_WEBSOCKET_UPGRADE_RSP_LEN + 1] = {0}; - int32_t bufLen = recv(client_, recvBuf, CLIENT_WEBSOCKET_UPGRADE_RSP_LEN, 0); - if (bufLen != CLIENT_WEBSOCKET_UPGRADE_RSP_LEN) { - std::cerr << "ClientRecvWSUpgradeRsp::client recv wsupgrade rsp failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - socketState_ = SocketState::UNINITED; -#if defined(OHOS_PLATFORM) - shutdown(client_, SHUT_RDWR); -#endif - close(client_); - client_ = -1; - return false; - } - socketState_ = SocketState::CONNECTED; - std::cout << "ClientRecvWSUpgradeRsp::client recv wsupgrade rsp success..." << std::endl; - return true; - } - - bool ClientSendReq(const std::string &message, FrameType frameType = FrameType::TEXT, bool isLast = true) - { - if (socketState_ != SocketState::CONNECTED) { - std::cerr << "ClientSendReq::client has not connected..." << std::endl; - return false; - } - - uint32_t msgLen = message.length(); - std::unique_ptr msgBuf = std::make_unique(msgLen + 15); // 15: the maximum expand length - char *sendBuf = msgBuf.get(); - uint32_t sendMsgLen = 0; - if (isLast) { - sendBuf[0] = 0x80; // 0x80: 0x80 means the last frame - } else { - sendBuf[0] = 0x0; // 0x0: 0x0 means not the last frame - } - sendBuf[0] |= GetFrameType(frameType); - uint32_t mask = 1; - // Depending on the length of the messages, client will use shift operation to get the res - // and store them in the buffer. - if (msgLen <= 125) { // 125: situation 1 when message's length <= 125 - sendBuf[1] = msgLen | (mask << 7); // 7: mask need shift left by 7 bits - sendMsgLen = 2; // 2: the length of header frame is 2; - } else if (msgLen < 65536) { // 65536: message's length - sendBuf[1] = 126 | (mask << 7); // 126: payloadLen according to the spec; 7: mask shift left by 7 bits - sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1) - sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0) - sendMsgLen = 4; // 4: the length of header frame is 4 - } else { - sendBuf[1] = 127 | (mask << 7); // 127: payloadLen according to the spec; 7: mask shift left by 7 bits - for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits - sendBuf[i] = 0; - } - sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3) - sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2) - sendBuf[8] = ((msgLen & 0x0000ff00) >> 8); // 8: shift 8 bits => res * (256^1) - sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0) - sendMsgLen = 10; // 10: the length of header frame is 10 - } - - if (memcpy_s(sendBuf + sendMsgLen, SOCKET_MASK_LEN, MASK_KEY, SOCKET_MASK_LEN) != EOK) { - std::cerr << "ClientSendReq::client memcpy_s MASK_KEY failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - return false; - } - sendMsgLen += SOCKET_MASK_LEN; - - std::string maskMessage; - for (uint64_t i = 0; i < msgLen; i++) { - uint64_t j = i % SOCKET_MASK_LEN; - maskMessage.push_back(message[i] ^ MASK_KEY[j]); - } - if (memcpy_s(sendBuf + sendMsgLen, msgLen, maskMessage.c_str(), msgLen) != EOK) { - std::cerr << "ClientSendReq::client memcpy_s maskMessage failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - return false; - } - msgBuf[sendMsgLen + msgLen] = '\0'; - - if (send(client_, sendBuf, sendMsgLen + msgLen, 0) != static_cast(sendMsgLen + msgLen)) { - std::cerr << "ClientSendReq::client send msg req failed, error = " - << errno << ", desc = " << strerror(errno) << std::endl; - return false; - } - std::cout << "ClientRecvWSUpgradeRsp::client send msg req success..." << std::endl; - return true; - } - - void Close() - { - if (socketState_ == SocketState::UNINITED) { - return; - } -#if defined(OHOS_PLATFORM) - shutdown(client_, SHUT_RDWR); -#endif - close(client_); - client_ = -1; - socketState_ = SocketState::UNINITED; - } - - private: - static constexpr char CLIENT_WEBSOCKET_UPGRADE_REQ[] = "GET / HTTP/1.1\r\n" - "Connection: Upgrade\r\n" - "Pragma: no-cache\r\n" - "Cache-Control: no-cache\r\n" - "Upgrade: websocket\r\n" - "Sec-WebSocket-Version: 13\r\n" - "Accept-Encoding: gzip, deflate, br\r\n" - "Sec-WebSocket-Key: 64b4B+s5JDlgkdg7NekJ+g==\r\n" - "Sec-WebSocket-Extensions: permessage-deflate\r\n"; - static constexpr int32_t CLIENT_WEBSOCKET_UPGRADE_RSP_LEN = 129; - static constexpr char MASK_KEY[SOCKET_MASK_LEN + 1] = "abcd"; - static constexpr int NET_SUCCESS = 1; - }; - #if defined(OHOS_PLATFORM) static constexpr char UNIX_DOMAIN_PATH[] = "server.sock"; #endif - static constexpr char HELLO_SERVER[] = "hello server"; - static constexpr char HELLO_CLIENT[] = "hello client"; - static constexpr char SERVER_OK[] = "server ok"; - static constexpr char CLIENT_OK[] = "client ok"; - static constexpr char QUIT[] = "quit"; - static constexpr char PING[] = "ping"; + static constexpr char HELLO_SERVER[] = "hello server"; + static constexpr char HELLO_CLIENT[] = "hello client"; + static constexpr char SERVER_OK[] = "server ok"; + static constexpr char CLIENT_OK[] = "client ok"; + static constexpr char QUIT[] = "quit"; + static constexpr char PING[] = "ping"; + static constexpr int TCP_PORT = 9230; static const std::string LONG_MSG; static const std::string LONG_LONG_MSG; }; @@ -329,13 +63,13 @@ const std::string WebSocketTest::LONG_LONG_MSG = std::string(0xfffff, 'f'); HWTEST_F(WebSocketTest, ConnectWebSocketTest, testing::ext::TestSize.Level0) { - WebSocket serverSocket; + WebSocketServer serverSocket; bool ret = false; #if defined(OHOS_PLATFORM) int appPid = getpid(); ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); #else - ret = serverSocket.InitTcpWebSocket(9230, 5); + ret = serverSocket.InitTcpWebSocket(TCP_PORT, 5); #endif ASSERT_TRUE(ret); pid_t pid = fork(); @@ -343,57 +77,53 @@ HWTEST_F(WebSocketTest, ConnectWebSocketTest, testing::ext::TestSize.Level0) // subprocess, handle client connect and recv/send message // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess, // so testcase still success finally. - ClientWebSocket clientSocket; + WebSocketClient clientSocket; bool retClient = false; #if defined(OHOS_PLATFORM) - retClient = clientSocket.ClientConnectUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); + retClient = clientSocket.InitToolchainWebSocketForSockName(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); #else - retClient = clientSocket.ClientConnectTcpWebSocket(5); + retClient = clientSocket.InitToolchainWebSocketForPort(TCP_PORT, 5); #endif ASSERT_TRUE(retClient); retClient = clientSocket.ClientSendWSUpgradeReq(); ASSERT_TRUE(retClient); retClient = clientSocket.ClientRecvWSUpgradeRsp(); ASSERT_TRUE(retClient); - retClient = clientSocket.ClientSendReq(HELLO_SERVER); + retClient = clientSocket.SendReply(HELLO_SERVER); EXPECT_TRUE(retClient); std::string recv = clientSocket.Decode(); EXPECT_EQ(strcmp(recv.c_str(), HELLO_CLIENT), 0); if (strcmp(recv.c_str(), HELLO_CLIENT) == 0) { - retClient = clientSocket.ClientSendReq(CLIENT_OK); + retClient = clientSocket.SendReply(CLIENT_OK); EXPECT_TRUE(retClient); } - retClient = clientSocket.ClientSendReq(LONG_MSG); + retClient = clientSocket.SendReply(LONG_MSG); EXPECT_TRUE(retClient); recv = clientSocket.Decode(); EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0); if (strcmp(recv.c_str(), SERVER_OK) == 0) { - retClient = clientSocket.ClientSendReq(CLIENT_OK); + retClient = clientSocket.SendReply(CLIENT_OK); EXPECT_TRUE(retClient); } - retClient = clientSocket.ClientSendReq(LONG_LONG_MSG); + retClient = clientSocket.SendReply(LONG_LONG_MSG); EXPECT_TRUE(retClient); recv = clientSocket.Decode(); EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0); if (strcmp(recv.c_str(), SERVER_OK) == 0) { - retClient = clientSocket.ClientSendReq(CLIENT_OK); + retClient = clientSocket.SendReply(CLIENT_OK); EXPECT_TRUE(retClient); } - retClient = clientSocket.ClientSendReq(PING, FrameType::PING); // send a ping frame and wait for pong frame + retClient = clientSocket.SendReply(PING, FrameType::PING); // send a ping frame and wait for pong frame EXPECT_TRUE(retClient); recv = clientSocket.Decode(); // get the pong frame EXPECT_EQ(strcmp(recv.c_str(), ""), 0); // pong frame has no data - retClient = clientSocket.ClientSendReq(QUIT); + retClient = clientSocket.SendReply(QUIT); EXPECT_TRUE(retClient); clientSocket.Close(); exit(0); } else if (pid > 0) { // mainprocess, handle server connect and recv/send message -#if defined(OHOS_PLATFORM) - ret = serverSocket.ConnectUnixWebSocket(); -#else - ret = serverSocket.ConnectTcpWebSocket(); -#endif + ret = serverSocket.AcceptNewConnection(); ASSERT_TRUE(ret); std::string recv = serverSocket.Decode(); bool isSendFail = false; @@ -426,13 +156,13 @@ HWTEST_F(WebSocketTest, ConnectWebSocketTest, testing::ext::TestSize.Level0) HWTEST_F(WebSocketTest, ReConnectWebSocketTest, testing::ext::TestSize.Level0) { - WebSocket serverSocket; + WebSocketServer serverSocket; bool ret = false; #if defined(OHOS_PLATFORM) int appPid = getpid(); ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); #else - ret = serverSocket.InitTcpWebSocket(9230, 5); + ret = serverSocket.InitTcpWebSocket(TCP_PORT, 5); #endif ASSERT_TRUE(ret); for (int i = 0; i < 5; i++) { @@ -441,35 +171,31 @@ HWTEST_F(WebSocketTest, ReConnectWebSocketTest, testing::ext::TestSize.Level0) // subprocess, handle client connect and recv/send message // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess, // so testcase still success finally. - ClientWebSocket clientSocket; + WebSocketClient clientSocket; bool retClient = false; #if defined(OHOS_PLATFORM) - retClient = clientSocket.ClientConnectUnixWebSocket(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); + retClient = clientSocket.InitToolchainWebSocketForSockName(UNIX_DOMAIN_PATH + std::to_string(appPid), 5); #else - retClient = clientSocket.ClientConnectTcpWebSocket(5); + retClient = clientSocket.InitToolchainWebSocketForPort(TCP_PORT, 5); #endif ASSERT_TRUE(retClient); retClient = clientSocket.ClientSendWSUpgradeReq(); ASSERT_TRUE(retClient); retClient = clientSocket.ClientRecvWSUpgradeRsp(); ASSERT_TRUE(retClient); - retClient = clientSocket.ClientSendReq(HELLO_SERVER + std::to_string(i)); + retClient = clientSocket.SendReply(HELLO_SERVER + std::to_string(i)); EXPECT_TRUE(retClient); std::string recv = clientSocket.Decode(); EXPECT_EQ(strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()), 0); if (strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()) == 0) { - retClient = clientSocket.ClientSendReq(CLIENT_OK + std::to_string(i)); + retClient = clientSocket.SendReply(CLIENT_OK + std::to_string(i)); EXPECT_TRUE(retClient); } clientSocket.Close(); exit(0); } else if (pid > 0) { // mainprocess, handle server connect and recv/send message -#if defined(OHOS_PLATFORM) - ret = serverSocket.ConnectUnixWebSocket(); -#else - ret = serverSocket.ConnectTcpWebSocket(); -#endif + ret = serverSocket.AcceptNewConnection(); ASSERT_TRUE(ret); std::string recv = serverSocket.Decode(); bool isSendFail = false; diff --git a/websocket/web_socket_frame.h b/websocket/web_socket_frame.h new file mode 100644 index 00000000..97a5f96a --- /dev/null +++ b/websocket/web_socket_frame.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WS_FRAME_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WS_FRAME_H + +#include + +namespace OHOS::ArkCompiler::Toolchain { +enum class FrameType : uint8_t { + CONTINUATION = 0x0, + TEXT = 0x1, + BINARY = 0x2, + CLOSE = 0x8, + PING = 0x9, + PONG = 0xa, +}; + +constexpr inline bool IsControlFrame(uint8_t opcode) +{ + return opcode >= static_cast(FrameType::CLOSE); +} + +template>> +constexpr inline auto EnumToNumber(T type) +{ + using UnderlyingT = std::underlying_type_t; + return static_cast(type); +} + +struct WebSocketFrame { + static constexpr size_t MASK_LEN = 4; + static constexpr size_t HEADER_LEN = 2; + static constexpr size_t ONE_BYTE_LENTH_ENC_LIMIT = 125; + static constexpr size_t TWO_BYTES_LENTH_ENC = 126; + static constexpr size_t TWO_BYTES_LENTH = 2; + static constexpr size_t EIGHT_BYTES_LENTH_ENC = 127; + static constexpr size_t EIGHT_BYTES_LENTH = 8; + + uint64_t payloadLen = 0; + uint8_t fin = 0; + uint8_t opcode = 0; + uint8_t mask = 0; + uint8_t maskingKey[MASK_LEN] = {0}; + std::string payload; + + WebSocketFrame() = default; + explicit WebSocketFrame(const uint8_t headerRaw[HEADER_LEN]) + : payloadLen(static_cast(headerRaw[1]) & 0x7f), + fin(static_cast((headerRaw[0] >> MSB_SHIFT_COUNT) & 0x1)), + opcode(static_cast(headerRaw[0] & 0xf)), + mask(static_cast((headerRaw[1] >> MSB_SHIFT_COUNT) & 0x1)) + { + } + +private: + static constexpr int MSB_SHIFT_COUNT = 7; +}; +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WS_FRAME_H diff --git a/websocket/websocket.cpp b/websocket/websocket.cpp deleted file mode 100644 index 40ecad7d..00000000 --- a/websocket/websocket.cpp +++ /dev/null @@ -1,612 +0,0 @@ -/* - * Copyright (c) 2022 Huawei Device Co., Ltd. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "websocket/websocket.h" - -#include "define.h" -#include "common/log_wrapper.h" -#include "securec.h" - -namespace OHOS::ArkCompiler::Toolchain { -/** - * SendMessage in WebSocket has 3 situations: - * 1. message's length <= 125 - * 2. message's length >= 126 && messages's length < 65536 - * 3. message's length >= 65536 - */ - -// if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0 -// and the last frame will be marked as 0x1. -// we just add the 'isLast' parameter to indicate whether it is the last frame. -bool WebSocket::SendReply(const std::string& message, bool& isSendFail, FrameType frameType, bool isLast) const -{ - if (socketState_ != SocketState::CONNECTED) { - LOGE("SendReply failed, websocket not connected"); - return false; - } - const size_t msgLen = message.length(); - const uint32_t headerSize = 11; // 11: the maximum expandable length - std::unique_ptr msgBuf = std::make_unique(msgLen + headerSize); - char* sendBuf = msgBuf.get(); - if (!isLast) { - sendBuf[0] = 0x0; // 0x0: 0x0 means a continuation frame - } else { - sendBuf[0] = 0x80; // 0x80: 0x80 means a text frame - } - uint32_t sendMsgLen; - sendBuf[0] |= GetFrameType(frameType); - // Depending on the length of the messages, server will use shift operation to get the res - // and store them in the buffer. - if (msgLen <= 125) { // 125: situation 1 when message's length <= 125 - sendBuf[1] = msgLen; - sendMsgLen = 2; // 2: the length of header frame is 2 - } else if (msgLen < 65536) { // 65536: message's length - sendBuf[1] = 126; // 126: payloadLen according to the spec - sendBuf[2] = ((msgLen >> 8) & 0xff); // 8: shift right by 8 bits => res * (256^1) - sendBuf[3] = (msgLen & 0xff); // 3: store len's data => res * (256^0) - sendMsgLen = 4; // 4: the length of header frame is 4 - } else { - sendBuf[1] = 127; // 127: payloadLen according to the spec - for (int32_t i = 2; i <= 5; i++) { // 2 ~ 5: unused bits - sendBuf[i] = 0; - } - sendBuf[6] = ((msgLen & 0xff000000) >> 24); // 6: shift 24 bits => res * (256^3) - sendBuf[7] = ((msgLen & 0x00ff0000) >> 16); // 7: shift 16 bits => res * (256^2) - sendBuf[8] = ((msgLen & 0x0000ff00) >> 8); // 8: shift 8 bits => res * (256^1) - sendBuf[9] = (msgLen & 0x000000ff); // 9: res * (256^0) - sendMsgLen = 10; // 10: the length of header frame is 10 - } - if (memcpy_s(sendBuf + sendMsgLen, msgLen, message.c_str(), msgLen) != EOK) { - LOGE("SendReply: memcpy_s failed"); - return false; - } - sendBuf[sendMsgLen + msgLen] = '\0'; - if (!Send(client_, sendBuf, sendMsgLen + msgLen, 0)) { - isSendFail = errno == EINTR ? false : true; - LOGE("SendReply: send failed"); - return false; - } - return true; -} - -char WebSocket::GetFrameType(FrameType frameType) const -{ - switch (frameType) { - case FrameType::CONTINUATION: - return 0x0; // 0x0: 0x0 means a continuation frame - case FrameType::TEXT: - return 0x1; // 0x1: 0x1 means a text frame - case FrameType::BINARY: - return 0x2; // 0x2: 0x2 means a binary frame - case FrameType::CLOSE: - return 0x8; // 0x8: 0x8 means a close frame - case FrameType::PING: - return 0x9; // 0x9: 0x9 means a ping frame - case FrameType::PONG: - return 0xa; // 0xa: 0xa means a pong frame - default: - LOGF("GetFrameType failed, invalid frame type"); - return 0x0; - } -} - -bool WebSocket::HttpProtocolDecode(const std::string& request, HttpProtocol& req) -{ - if (request.find("GET") == std::string::npos) { - LOGE("Handshake failed: lack of necessary info"); - return false; - } - std::vector reqStr = ProtocolSplit(request, EOL); - for (size_t i = 0; i < reqStr.size(); i++) { - if (i == 0) { - std::vector headers = ProtocolSplit(reqStr.at(i), " "); - req.version = headers.at(2); // 2: to get the version param - } else if (i < reqStr.size() - 1) { - std::vector headers = ProtocolSplit(reqStr.at(i), ": "); - if (reqStr.at(i).find("Connection") != std::string::npos) { - req.connection = headers.at(1); // 1: to get the connection param - } else if (reqStr.at(i).find("Upgrade") != std::string::npos) { - req.upgrade = headers.at(1); // 1: to get the upgrade param - } else if (reqStr.at(i).find("Sec-WebSocket-Key") != std::string::npos) { - req.secWebSocketKey = headers.at(1); // 1: to get the secWebSocketKey param - } - } - } - return true; -} - -/** - * The wired format of this data transmission section is described in detail through ABNFRFC5234. - * When receive the message, we should decode it according the spec. The structure is as follows: - * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - * +-+-+-+-+-------+-+-------------+-------------------------------+ - * |F|R|R|R| opcode|M| Payload len | Extended payload length | - * |I|S|S|S| (4) |A| (7) | (16/64) | - * |N|V|V|V| |S| | (if payload len==126/127) | - * | |1|2|3| |K| | | - * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - * | Extended payload length continued, if payload len == 127 | - * + - - - - - - - - - - - - - - - +-------------------------------+ - * | |Masking-key, if MASK set to 1 | - * +-------------------------------+-------------------------------+ - * | Masking-key (continued) | Payload Data | - * +-------------------------------- - - - - - - - - - - - - - - - + - * : Payload Data continued ... : - * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - * | Payload Data continued ... | - * +---------------------------------------------------------------+ - */ - -bool WebSocket::HandleFrame(WebSocketFrame& wsFrame, bool &isRecvFail) -{ - if (wsFrame.payloadLen == 126) { // 126: the payloadLen read from frame - char recvbuf[PAYLOAD_LEN + 1] = {0}; - if (!Recv(client_, recvbuf, PAYLOAD_LEN, 0)) { - LOGE("HandleFrame: Recv payloadLen == 126 failed"); - isRecvFail = errno == EINTR ? false : true; - return false; - } - - uint16_t msgLen = 0; - if (memcpy_s(&msgLen, sizeof(recvbuf), recvbuf, sizeof(recvbuf) - 1) != EOK) { - LOGE("HandleFrame: memcpy_s failed"); - return false; - } - wsFrame.payloadLen = ntohs(msgLen); - } else if (wsFrame.payloadLen > 126) { // 126: the payloadLen read from frame - char recvbuf[EXTEND_PAYLOAD_LEN + 1] = {0}; - if (!Recv(client_, recvbuf, EXTEND_PAYLOAD_LEN, 0)) { - LOGE("HandleFrame: Recv payloadLen > 127 failed"); - isRecvFail = errno == EINTR ? false : true; - return false; - } - wsFrame.payloadLen = NetToHostLongLong(recvbuf, EXTEND_PAYLOAD_LEN); - } - return DecodeMessage(wsFrame, isRecvFail); -} - -bool WebSocket::DecodeMessage(WebSocketFrame& wsFrame, bool &isRecvFail) -{ - if (wsFrame.payloadLen > UINT64_MAX) { - LOGE("ReadMsg length error, the length should be less than UINT64_MAX"); - return false; - } - uint64_t msgLen = wsFrame.payloadLen; - wsFrame.payload = std::make_unique(msgLen + 1); - if (wsFrame.mask == 1) { - char buf[msgLen + 1]; - if (!Recv(client_, wsFrame.maskingkey, SOCKET_MASK_LEN, 0)) { - LOGE("DecodeMessage: Recv maskingkey failed"); - isRecvFail = errno == EINTR ? false : true; - return false; - } - - if (!Recv(client_, buf, msgLen, 0)) { - LOGE("DecodeMessage: Recv message with mask failed"); - isRecvFail = errno == EINTR ? false : true; - return false; - } - - for (uint64_t i = 0; i < msgLen; i++) { - uint64_t j = i % SOCKET_MASK_LEN; - wsFrame.payload.get()[i] = buf[i] ^ wsFrame.maskingkey[j]; - } - } else { - char buf[msgLen + 1]; - if (!Recv(client_, buf, msgLen, 0)) { - LOGE("DecodeMessage: Recv message without mask failed"); - isRecvFail = errno == EINTR ? false : true; - return false; - } - - if (memcpy_s(wsFrame.payload.get(), msgLen, buf, msgLen) != EOK) { - LOGE("DecodeMessage: memcpy_s failed"); - return false; - } - } - wsFrame.payload.get()[msgLen] = '\0'; - return true; -} - -bool WebSocket::ProtocolUpgrade(const HttpProtocol& req) -{ - std::string rawKey = req.secWebSocketKey + WEB_SOCKET_GUID; - unsigned const char* webSocketKey = reinterpret_cast(std::move(rawKey).c_str()); - unsigned char hash[SHA_DIGEST_LENGTH + 1]; - SHA1(webSocketKey, strlen(reinterpret_cast(webSocketKey)), hash); - hash[SHA_DIGEST_LENGTH] = '\0'; - unsigned char encodedKey[ENCODED_KEY_LEN]; - EVP_EncodeBlock(encodedKey, reinterpret_cast(hash), SHA_DIGEST_LENGTH); - std::string response; - - std::ostringstream sstream; - sstream << "HTTP/1.1 101 Switching Protocols" << EOL; - sstream << "Connection: upgrade" << EOL; - sstream << "Upgrade: websocket" << EOL; - sstream << "Sec-WebSocket-Accept: " << encodedKey << EOL; - sstream << EOL; - response = sstream.str(); - if (!Send(client_, response.c_str(), response.length(), 0)) { - LOGE("ProtocolUpgrade: Send failed"); - return false; - } - return true; -} - -std::string WebSocket::ResolveHeader(int32_t index, WebSocketFrame& wsFrame, const char* recvbuf) -{ - index++; - wsFrame.mask = static_cast((recvbuf[index] >> 7) & 0x1); // 7: to get the mask - wsFrame.payloadLen = recvbuf[index] & 0x7f; - bool isRecvFail = false; - if (HandleFrame(wsFrame, isRecvFail)) { - return wsFrame.payload.get(); - } - return isRecvFail ? std::string(DECODE_DISCONNECT_MSG) : ""; -} - -std::string WebSocket::Decode() -{ - if (socketState_ != SocketState::CONNECTED) { - LOGE("Decode failed, websocket not connected!"); - return ""; - } - char recvbuf[SOCKET_HEADER_LEN + 1]; - if (!Recv(client_, recvbuf, SOCKET_HEADER_LEN, 0)) { - LOGE("Decode failed, client websocket disconnect"); - socketState_ = SocketState::INITED; -#if defined(OHOS_PLATFORM) - shutdown(client_, SHUT_RDWR); - close(client_); - client_ = -1; -#else - close(client_); - client_ = -1; -#endif - return std::string(DECODE_DISCONNECT_MSG); - } - WebSocketFrame wsFrame; - int32_t index = 0; - wsFrame.fin = static_cast(recvbuf[index] >> 7); // 7: shift right by 7 bits to get the fin - wsFrame.opcode = static_cast(recvbuf[index] & 0xf); - if (wsFrame.opcode == 0x1) { // 0x1: 0x1 means a text frame - return ResolveHeader(index, wsFrame, recvbuf); - } else if (wsFrame.opcode == 0x9) { // 0x9: 0x9 means a ping frame - // send pong frame - char pongFrame[SOCKET_HEADER_LEN] = {0}; - pongFrame[0] = 0x8a; // 0x8a: 0x8a means a pong frame - pongFrame[1] = 0x0; - if (!Send(client_, pongFrame, SOCKET_HEADER_LEN, 0)) { - LOGE("Decode: Send pong frame failed"); - } - return ResolveHeader(index, wsFrame, recvbuf); - } else if (wsFrame.opcode == 0xa) { // 0xa: 0xa means a pong frame - // pong frame does not contain any data - LOGI("Decode: pong frame"); - return ""; - } - return ""; -} - -bool WebSocket::HttpHandShake() -{ - char msgBuf[SOCKET_HANDSHAKE_LEN] = {0}; - ssize_t msgLen = recv(client_, msgBuf, SOCKET_HANDSHAKE_LEN, 0); - if (msgLen <= 0) { - LOGE("ReadMsg failed, msgLen = %{public}ld, errno = %{public}d", static_cast(msgLen), errno); - return false; - } else { - msgBuf[msgLen - 1] = '\0'; - HttpProtocol req; - if (!HttpProtocolDecode(msgBuf, req)) { - LOGE("HttpHandShake: Upgrade failed"); - return false; - } else if (req.connection.find("Upgrade") != std::string::npos && - req.upgrade.find("websocket") != std::string::npos && req.version.compare("HTTP/1.1") == 0) { - return ProtocolUpgrade(req); - } - } - return true; -} - -#if !defined(OHOS_PLATFORM) -bool WebSocket::InitTcpWebSocket(int port, uint32_t timeoutLimit) -{ - if (port < 0) { - LOGE("InitTcpWebSocket invalid port"); - return false; - } - - if (socketState_ != SocketState::UNINITED) { - LOGI("InitTcpWebSocket websocket has inited"); - return true; - } -#if defined(WINDOWS_PLATFORM) - WORD sockVersion = MAKEWORD(2, 2); // 2: version 2.2 - WSADATA wsaData; - if (WSAStartup(sockVersion, &wsaData) != 0) { - LOGE("InitTcpWebSocket WSA init failed"); - return false; - } -#endif - fd_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (fd_ < SOCKET_SUCCESS) { - LOGE("InitTcpWebSocket socket init failed, errno = %{public}d", errno); - return false; - } - // allow specified port can be used at once and not wait TIME_WAIT status ending - int sockOptVal = 1; - if ((setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, - reinterpret_cast(&sockOptVal), sizeof(sockOptVal))) != SOCKET_SUCCESS) { - LOGE("InitTcpWebSocket setsockopt SO_REUSEADDR failed, errno = %{public}d", errno); - close(fd_); - fd_ = -1; - return false; - } - - // set send and recv timeout - if (!SetWebSocketTimeOut(fd_, timeoutLimit)) { - LOGE("InitTcpWebSocket SetWebSocketTimeOut failed"); - close(fd_); - fd_ = -1; - return false; - } - - sockaddr_in addr_sin = {}; - addr_sin.sin_family = AF_INET; - addr_sin.sin_port = htons(port); - addr_sin.sin_addr.s_addr = INADDR_ANY; - if (bind(fd_, reinterpret_cast(&addr_sin), sizeof(addr_sin)) < SOCKET_SUCCESS) { - LOGE("InitTcpWebSocket bind failed, errno = %{public}d", errno); - close(fd_); - fd_ = -1; - return false; - } - if (listen(fd_, 1) < SOCKET_SUCCESS) { - LOGE("InitTcpWebSocket listen failed, errno = %{public}d", errno); - close(fd_); - fd_ = -1; - return false; - } - socketState_ = SocketState::INITED; - return true; -} - -bool WebSocket::ConnectTcpWebSocket() -{ - if (socketState_ == SocketState::UNINITED) { - LOGE("ConnectTcpWebSocket failed, websocket not inited"); - return false; - } - if (socketState_ == SocketState::CONNECTED) { - LOGI("ConnectTcpWebSocket websocket has connected"); - return true; - } - - if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) { - LOGI("ConnectTcpWebSocket accept has exited"); - socketState_ = SocketState::UNINITED; - close(fd_); - fd_ = -1; - return false; - } - - if (!HttpHandShake()) { - LOGE("ConnectTcpWebSocket HttpHandShake failed"); - socketState_ = SocketState::UNINITED; - close(client_); - client_ = -1; - close(fd_); - fd_ = -1; - return false; - } - socketState_ = SocketState::CONNECTED; - return true; -} -#else -bool WebSocket::InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit) -{ - if (socketState_ != SocketState::UNINITED) { - LOGI("InitUnixWebSocket websocket has inited"); - return true; - } - fd_ = socket(AF_UNIX, SOCK_STREAM, 0); // 0: defautlt protocol - if (fd_ < SOCKET_SUCCESS) { - LOGE("InitUnixWebSocket socket init failed, errno = %{public}d", errno); - return false; - } - // set send and recv timeout - if (!SetWebSocketTimeOut(fd_, timeoutLimit)) { - LOGE("InitUnixWebSocket SetWebSocketTimeOut failed"); - close(fd_); - fd_ = -1; - return false; - } - - struct sockaddr_un un; - if (memset_s(&un, sizeof(un), 0, sizeof(un)) != EOK) { - LOGE("InitUnixWebSocket memset_s failed"); - close(fd_); - fd_ = -1; - return false; - } - un.sun_family = AF_UNIX; - if (strcpy_s(un.sun_path + 1, sizeof(un.sun_path) - 1, sockName.c_str()) != EOK) { - LOGE("InitUnixWebSocket strcpy_s failed"); - close(fd_); - fd_ = -1; - return false; - } - un.sun_path[0] = '\0'; - uint32_t len = offsetof(struct sockaddr_un, sun_path) + strlen(sockName.c_str()) + 1; - if (bind(fd_, reinterpret_cast(&un), static_cast(len)) < SOCKET_SUCCESS) { - LOGE("InitUnixWebSocket bind failed, errno = %{public}d", errno); - close(fd_); - fd_ = -1; - return false; - } - if (listen(fd_, 1) < SOCKET_SUCCESS) { // 1: connection num - LOGE("InitUnixWebSocket listen failed, errno = %{public}d", errno); - close(fd_); - fd_ = -1; - return false; - } - socketState_ = SocketState::INITED; - return true; -} - -bool WebSocket::ConnectUnixWebSocket() -{ - if (socketState_ == SocketState::UNINITED) { - LOGE("ConnectUnixWebSocket failed, websocket not inited"); - return false; - } - if (socketState_ == SocketState::CONNECTED) { - LOGI("ConnectUnixWebSocket websocket has connected"); - return true; - } - - if ((client_ = accept(fd_, nullptr, nullptr)) < SOCKET_SUCCESS) { - LOGI("ConnectUnixWebSocket accept has exited"); - socketState_ = SocketState::UNINITED; - close(fd_); - fd_ = -1; - return false; - } - if (!HttpHandShake()) { - LOGE("ConnectUnixWebSocket HttpHandShake failed"); - socketState_ = SocketState::UNINITED; - shutdown(client_, SHUT_RDWR); - close(client_); - client_ = -1; - shutdown(fd_, SHUT_RDWR); - close(fd_); - fd_ = -1; - return false; - } - socketState_ = SocketState::CONNECTED; - return true; -} -#endif - -bool WebSocket::IsConnected() -{ - return socketState_ == SocketState::CONNECTED; -} - -void WebSocket::Close() -{ - if (socketState_ == SocketState::UNINITED) { - return; - } - if (socketState_ == SocketState::CONNECTED) { -#if defined(OHOS_PLATFORM) - shutdown(client_, SHUT_RDWR); -#endif - close(client_); - client_ = -1; - } - socketState_ = SocketState::UNINITED; - usleep(10000); // 10000: time for websocket to enter the accept -#if defined(OHOS_PLATFORM) - shutdown(fd_, SHUT_RDWR); -#endif - close(fd_); - fd_ = -1; -} - -uint64_t WebSocket::NetToHostLongLong(char* buf, uint32_t len) -{ - uint64_t result = 0; - for (uint32_t i = 0; i < len; i++) { - result |= static_cast(buf[i]); - if ((i + 1) < len) { - result <<= 8; // 8: result need shift left 8 bits in order to big endian convert to int - } - } - return result; -} - -bool WebSocket::Recv(int32_t client, char* buf, size_t totalLen, int32_t flags) const -{ - size_t recvLen = 0; - while (recvLen < totalLen) { - ssize_t len = recv(client, buf + recvLen, totalLen - recvLen, flags); - if (len <= 0) { - LOGE("Recv payload in while failed, websocket disconnect, len = %{public}ld, errno = %{public}d", - static_cast(len), errno); - return false; - } - recvLen += static_cast(len); - } - buf[totalLen] = '\0'; - return true; -} - -bool WebSocket::Send(int32_t client, const char* buf, size_t totalLen, int32_t flags) const -{ - size_t sendLen = 0; - while (sendLen < totalLen) { - ssize_t len = send(client, buf + sendLen, totalLen - sendLen, flags); - if (len <= 0) { - LOGE("Send Message in while failed, websocket disconnect, len = %{public}ld, errno = %{public}d", - static_cast(len), errno); - return false; - } - sendLen += static_cast(len); - } - return true; -} - -#if !defined(OHOS_PLATFORM) -bool WebSocket::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit) -{ - if (timeoutLimit > 0) { - struct timeval timeout = {timeoutLimit, 0}; - if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { - LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno); - return false; - } - if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, - reinterpret_cast(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { - LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno); - return false; - } - } - return true; -} -#else -bool WebSocket::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit) -{ - if (timeoutLimit > 0) { - struct timeval timeout = {timeoutLimit, 0}; - if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) { - LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno); - return false; - } - if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) { - LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno); - return false; - } - } - return true; -} -#endif -bool WebSocket::IsDecodeDisconnectMsg(const std::string& message) -{ - return message == std::string(DECODE_DISCONNECT_MSG); -} -} // namespace OHOS::ArkCompiler::Toolchain \ No newline at end of file diff --git a/websocket/websocket.h b/websocket/websocket.h deleted file mode 100644 index 30e7b918..00000000 --- a/websocket/websocket.h +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2022 Huawei Device Co., Ltd. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_H -#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_H - -#include -#include -#include - -namespace panda::test { -class WebSocketTest; -} -namespace OHOS::ArkCompiler::Toolchain { -enum class FrameType : uint8_t { - CONTINUATION = 0x0, - TEXT = 0x1, - BINARY = 0x2, - CLOSE = 0x8, - PING = 0x9, - PONG = 0xa, -}; - -struct WebSocketFrame { - uint8_t fin = 0; - uint8_t opcode = 0; - uint8_t mask = 0; - uint64_t payloadLen = 0; - char maskingkey[5] = {0}; - std::unique_ptr payload = nullptr; -}; - -struct HttpProtocol { - std::string connection; - std::string upgrade; - std::string version; - std::string secWebSocketKey; -}; - -class WebSocket { -public: - enum SocketState : uint8_t { - UNINITED, - INITED, - CONNECTED, - }; - WebSocket() = default; - ~WebSocket() = default; - std::string Decode(); - void Close(); - bool SendReply(const std::string& message, bool& isSendFail, - FrameType frameType = FrameType::TEXT, bool isLast = true) const; -#if !defined(OHOS_PLATFORM) - bool InitTcpWebSocket(int port, uint32_t timeoutLimit = 0); - bool ConnectTcpWebSocket(); -#else - bool InitUnixWebSocket(const std::string& sockName, uint32_t timeoutLimit = 0); - bool ConnectUnixWebSocket(); -#endif - bool IsConnected(); - bool IsDecodeDisconnectMsg(const std::string& message); - -private: - friend class panda::test::WebSocketTest; - - bool DecodeMessage(WebSocketFrame& wsFrame, bool &isRecvFail); - bool HttpHandShake(); - bool HttpProtocolDecode(const std::string& request, HttpProtocol& req); - bool HandleFrame(WebSocketFrame& wsFrame, bool &isRecvFail); - bool ProtocolUpgrade(const HttpProtocol& req); - uint64_t NetToHostLongLong(char* buf, uint32_t len); - bool SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit); - bool Recv(int32_t client, char* buf, size_t totalLen, int32_t flags) const; - bool Send(int32_t client, const char* buf, size_t totalLen, int32_t flags) const; - std::string ResolveHeader(int32_t index, WebSocketFrame& wsFrame, const char* recvbuf); - char GetFrameType(FrameType frameType) const; - - int32_t client_ {-1}; - int32_t fd_ {-1}; - std::atomic socketState_ {SocketState::UNINITED}; - static constexpr int32_t ENCODED_KEY_LEN = 128; - static constexpr char EOL[] = "\r\n"; - static constexpr int32_t SOCKET_HANDSHAKE_LEN = 1024; - static constexpr int32_t SOCKET_HEADER_LEN = 2; - static constexpr int32_t SOCKET_MASK_LEN = 4; - static constexpr int32_t SOCKET_SUCCESS = 0; - static constexpr int32_t PAYLOAD_LEN = 2; - static constexpr int32_t EXTEND_PAYLOAD_LEN = 8; - static constexpr char WEB_SOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - static constexpr char DECODE_DISCONNECT_MSG[] = "disconnect"; -}; -} // namespace OHOS::ArkCompiler::Toolchain - -#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_H \ No newline at end of file diff --git a/websocket/websocket_base.cpp b/websocket/websocket_base.cpp new file mode 100644 index 00000000..cff0593b --- /dev/null +++ b/websocket/websocket_base.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/log_wrapper.h" +#include "websocket/define.h" +#include "websocket/network.h" +#include "websocket/websocket_base.h" + +namespace OHOS::ArkCompiler::Toolchain { +// if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0 +// and the last frame will be marked as 0x1. +// we just add the 'isLast' parameter to indicate whether it is the last frame. +bool WebSocketBase::SendReply(const std::string& message, bool& isSendFail, FrameType frameType, bool isLast) const +{ + if (socketState_ != SocketState::CONNECTED) { + LOGE("SendReply failed, websocket not connected"); + return false; + } + + auto frame = CreateFrame(isLast, frameType, message); + if (!Send(connectionFd_, frame, 0)) { + LOGE("SendReply: send failed"); + SetSocketFail(isSendFail); + return false; + } + return true; +} + +bool WebSocketBase::SendReply(const std::string& message, FrameType frameType, bool isLast) const +{ + bool ignored = false; + return SendReply(message, ignored, frameType, isLast); +} + +/** + * The wired format of this data transmission section is described in detail through ABNFRFC5234. + * When receive the message, we should decode it according the spec. The structure is as follows: + * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 + * +-+-+-+-+-------+-+-------------+-------------------------------+ + * |F|R|R|R| opcode|M| Payload len | Extended payload length | + * |I|S|S|S| (4) |A| (7) | (16/64) | + * |N|V|V|V| |S| | (if payload len==126/127) | + * | |1|2|3| |K| | | + * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + * | Extended payload length continued, if payload len == 127 | + * + - - - - - - - - - - - - - - - +-------------------------------+ + * | |Masking-key, if MASK set to 1 | + * +-------------------------------+-------------------------------+ + * | Masking-key (continued) | Payload Data | + * +-------------------------------- - - - - - - - - - - - - - - - + + * : Payload Data continued ... : + * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + * | Payload Data continued ... | + * +---------------------------------------------------------------+ + */ + +bool WebSocketBase::ReadPayload(WebSocketFrame& wsFrame, bool &isRecvFail) +{ + if (wsFrame.payloadLen == WebSocketFrame::TWO_BYTES_LENTH_ENC) { + uint8_t recvbuf[WebSocketFrame::TWO_BYTES_LENTH] = {0}; + if (!Recv(connectionFd_, recvbuf, WebSocketFrame::TWO_BYTES_LENTH, 0)) { + LOGE("ReadPayload: Recv payloadLen == 126 failed"); + SetSocketFail(isRecvFail); + return false; + } + wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::TWO_BYTES_LENTH); + } else if (wsFrame.payloadLen == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) { + uint8_t recvbuf[WebSocketFrame::EIGHT_BYTES_LENTH] = {0}; + if (!Recv(connectionFd_, recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH, 0)) { + LOGE("ReadPayload: Recv `payloadLen == 127` failed"); + SetSocketFail(isRecvFail); + return false; + } + wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH); + } + return DecodeMessage(wsFrame, isRecvFail); +} + +bool WebSocketBase::HandleDataFrame(WebSocketFrame& wsFrame, bool &isRecvFail) +{ + if (wsFrame.opcode == EnumToNumber(FrameType::TEXT)) { + return ReadPayload(wsFrame, isRecvFail); + } else { + LOGW("Received unsupported data frame, opcode = %{public}d", wsFrame.opcode); + } + return false; +} + +bool WebSocketBase::HandleControlFrame(WebSocketFrame& wsFrame, bool &isRecvFail) +{ + if (wsFrame.opcode == EnumToNumber(FrameType::PING)) { + // A Pong frame sent in response to a Ping frame must have identical + // "Application data" as found in the message body of the Ping frame + // being replied to. + // https://www.rfc-editor.org/rfc/rfc6455#section-5.5.3 + if (!ReadPayload(wsFrame, isRecvFail)) { + LOGE("Failed to read ping frame payload"); + return false; + } + SendPongFrame(wsFrame.payload); + } else if (wsFrame.opcode == EnumToNumber(FrameType::CLOSE)) { + // TODO: might read payload to response by echoing the status code + CloseConnection(CloseStatusCode::NO_STATUS_CODE, SocketState::INITED); + } + return true; +} + +std::string WebSocketBase::Decode() +{ + if (socketState_ != SocketState::CONNECTED) { + LOGE("Decode failed, websocket not connected!"); + return ""; + } + + uint8_t recvbuf[WebSocketFrame::HEADER_LEN] = {0}; + if (!Recv(connectionFd_, recvbuf, WebSocketFrame::HEADER_LEN, 0)) { + LOGE("Decode failed, client websocket disconnect"); + CloseConnection(CloseStatusCode::UNEXPECTED_ERROR, SocketState::INITED); + return std::string(DECODE_DISCONNECT_MSG); + } + WebSocketFrame wsFrame(recvbuf); + if (!ValidateIncomingFrame(wsFrame)) { + LOGE("Received websocket frame is invalid - header is %02x%02x", recvbuf[0], recvbuf[1]); + CloseConnection(CloseStatusCode::PROTOCOL_ERROR, SocketState::INITED); + } + + bool isRecvFail = false; + if (IsControlFrame(wsFrame.opcode)) { + if (HandleControlFrame(wsFrame, isRecvFail)) { + return wsFrame.payload; + } + } else if (HandleDataFrame(wsFrame, isRecvFail)) { + return wsFrame.payload; + } + return isRecvFail ? std::string(DECODE_DISCONNECT_MSG) : ""; +} + +bool WebSocketBase::IsConnected() +{ + return socketState_ == SocketState::CONNECTED; +} + +void WebSocketBase::SetCloseConnectionCallback(CloseConnectionCallback cb) +{ + closeCb_ = std::move(cb); +} + +void WebSocketBase::SetFailConnectionCallback(FailConnectionCallback cb) +{ + failCb_ = std::move(cb); +} + +void WebSocketBase::CloseConnectionSocket(ConnectionCloseReason status, SocketState newSocketState) +{ +#if defined(OHOS_PLATFORM) + shutdown(connectionFd_, SHUT_RDWR); +#endif + close(connectionFd_); + connectionFd_ = -1; + socketState_ = newSocketState; + + if (status == ConnectionCloseReason::FAIL) { + if (failCb_) { + failCb_(); + } + } else if (status == ConnectionCloseReason::CLOSE) { + if (closeCb_) { + closeCb_(); + } + } +} + +void WebSocketBase::SendPongFrame(std::string payload) +{ + auto frame = CreateFrame(true, FrameType::PONG, std::move(payload)); + if (!Send(connectionFd_, frame, 0)) { + LOGE("Decode: Send pong frame failed"); + } +} + +void WebSocketBase::SendCloseFrame(CloseStatusCode status) +{ + auto frame = CreateFrame(true, FrameType::CLOSE, ToString(status)); + if (!Send(connectionFd_, frame, 0)) { + LOGE("SendCloseFrame: Send close frame failed"); + } +} + +void WebSocketBase::CloseConnection(CloseStatusCode status, SocketState newSocketState) +{ + LOGI("Close connection, status = %{public}d", static_cast(status)); + SendCloseFrame(status); + // can close connection right after sending back close frame. + CloseConnectionSocket(ConnectionCloseReason::CLOSE, newSocketState); +} + +/* static */ +bool WebSocketBase::IsDecodeDisconnectMsg(const std::string& message) +{ + return message == DECODE_DISCONNECT_MSG; +} + +/* static */ +void WebSocketBase::SetSocketFail(bool& isSendFail) +{ + isSendFail = (errno != EINTR); +} + +#if !defined(OHOS_PLATFORM) +/* static */ +bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit) +{ + if (timeoutLimit > 0) { + struct timeval timeout = {timeoutLimit, 0}; + if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { + LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno); + return false; + } + if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { + LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno); + return false; + } + } + return true; +} +#else +/* static */ +bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit) +{ + if (timeoutLimit > 0) { + struct timeval timeout = {timeoutLimit, 0}; + if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) { + LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno); + return false; + } + if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) { + LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno); + return false; + } + } + return true; +} +#endif +} // namespace OHOS::ArkCompiler::Toolchain diff --git a/websocket/websocket_base.h b/websocket/websocket_base.h new file mode 100644 index 00000000..4cfbc4a3 --- /dev/null +++ b/websocket/websocket_base.h @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2022 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H +#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H + +#include "websocket/frame_builder.h" +#include "websocket/web_socket_frame.h" + +#include +#include +#include + +namespace panda::test { +class WebSocketTest; +} // namespace panda::test + +namespace OHOS::ArkCompiler::Toolchain { +enum CloseStatusCode : uint16_t { + NO_STATUS_CODE = 0, + NORMAL = 1000, + SERVER_GO_AWAY = 1001, + PROTOCOL_ERROR = 1002, + UNACCEPTABLE_DATA = 1003, + INCONSISTENT_DATA = 1007, + POLICY_VIOLATION = 1008, + MESSAGE_TOO_BIG = 1009, + UNEXPECTED_ERROR = 1011, +}; + +inline std::string ToString(CloseStatusCode status) +{ + if (status == CloseStatusCode::NO_STATUS_CODE) { + return ""; + } + std::string result; + PushNumberPerByte(result, EnumToNumber(status)); + return result; +} + +class WebSocketBase { +public: + using CloseConnectionCallback = std::function; + using FailConnectionCallback = std::function; + +public: + static bool IsDecodeDisconnectMsg(const std::string& message); + + WebSocketBase() = default; + virtual ~WebSocketBase() noexcept = default; + + // Receive and decode a message. + // In case of control frames this method handles it accordingly and returns an empty string, + // otherwise returns the decoded received message. + std::string Decode(); + // Send message on current connection. + // Returns success status. + bool SendReply(const std::string& message, bool& isSendFail, + FrameType frameType = FrameType::TEXT, bool isLast = true) const; + // Send message on current connection. + // Returns success status. + bool SendReply(const std::string& message, FrameType frameType = FrameType::TEXT, bool isLast = true) const; + + bool IsConnected(); + + void SetCloseConnectionCallback(CloseConnectionCallback cb); + void SetFailConnectionCallback(FailConnectionCallback cb); + + // Close current websocket endpoint and connections (if any). + virtual void Close() = 0; + +protected: + enum class SocketState : uint8_t { + UNINITED, + INITED, + CONNECTED, + }; + + enum class ConnectionCloseReason: uint8_t { + FAIL, + CLOSE, + }; + +protected: + static bool SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit); + static void SetSocketFail(bool& isSendFail); + + bool ReadPayload(WebSocketFrame& wsFrame, bool &isRecvFail); + void SendPongFrame(std::string payload); + void SendCloseFrame(CloseStatusCode status); + // Sending close frame and close connection. + void CloseConnection(CloseStatusCode status, SocketState newSocketState); + // Close connection socket. + void CloseConnectionSocket(ConnectionCloseReason status, SocketState newSocketState); + + virtual bool HandleDataFrame(WebSocketFrame& wsFrame, bool &isRecvFail); + virtual bool HandleControlFrame(WebSocketFrame& wsFrame, bool &isRecvFail); + + virtual bool ValidateIncomingFrame(const WebSocketFrame& wsFrame) = 0; + virtual std::string CreateFrame(bool isLast, FrameType frameType) const = 0; + virtual std::string CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const = 0; + virtual std::string CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const = 0; + virtual bool DecodeMessage(WebSocketFrame& wsFrame, bool &isRecvFail) const = 0; + +protected: + std::atomic socketState_ {SocketState::UNINITED}; + + int connectionFd_ {-1}; + + // Callbacks used during different stages of connection lifecycle. + CloseConnectionCallback closeCb_; + FailConnectionCallback failCb_; + + static constexpr char WEB_SOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + static constexpr size_t HTTP_HANDSHAKE_MAX_LEN = 1024; + static constexpr int SOCKET_SUCCESS = 0; + static constexpr std::string_view DECODE_DISCONNECT_MSG = "disconnect"; + +private: + friend class panda::test::WebSocketTest; +}; +} // namespace OHOS::ArkCompiler::Toolchain + +#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H -- Gitee