diff --git a/bundle.json b/bundle.json index ee88a25ec2b19c3886d68576a157f81467fc12dc..23b6cbf59e018b58a4ef2ca7f796af7a64b6926e 100644 --- a/bundle.json +++ b/bundle.json @@ -107,12 +107,13 @@ }, { "type": "so", - "name": "//foundation/communication/netstack/interfaces/innerkits/websocket_client:websocket_client", + "name": "//foundation/communication/netstack/interfaces/innerkits/websocket_native:websocket_native", "header": { "header_files": [ - "websocket_client_innerapi.h" + "websocket_client_innerapi.h", + "websocket_server_innerapi.h" ], - "header_base": "//foundation/communication/netstack/interfaces/innerkits/websocket_client/include" + "header_base": "//foundation/communication/netstack/interfaces/innerkits/websocket_native/include" } }, { diff --git a/frameworks/ets/ani/web_socket/BUILD.gn b/frameworks/ets/ani/web_socket/BUILD.gn index 77daf83c56dfa703b66915017a2d0411264ed11a..dc9656ee2ad826ba75301c5107c708d70b82fadc 100644 --- a/frameworks/ets/ani/web_socket/BUILD.gn +++ b/frameworks/ets/ani/web_socket/BUILD.gn @@ -34,7 +34,7 @@ ohos_static_library("websocket_ani_static") { "${target_gen_dir}/src", "$NETSTACK_DIR/utils/common_utils/include", "$NETSTACK_DIR/utils/log/include", - "$NETSTACK_NATIVE_ROOT/websocket_client/include", + "$NETSTACK_NATIVE_ROOT/websocket_native/include", ] sources = [ "src/cxx/websocket_ani.cpp" ] sources += get_target_outputs(":websocket_ani_cxx") @@ -43,7 +43,7 @@ ohos_static_library("websocket_ani_static") { ":websocket_ani_cxx", "//third_party/rust/crates/cxx:cxx_cppdeps", "$NETSTACK_DIR/utils/napi_utils:napi_utils", - "$NETSTACK_INNERKITS_DIR/websocket_client:websocket_client" + "$NETSTACK_INNERKITS_DIR/websocket_native:websocket_native" ] external_deps = [ "c_utils:utils", @@ -86,7 +86,7 @@ ohos_rust_shared_library("websocket_ani") { generate_static_abc("websocket") { base_url = "./ets" - files = [ "ets/@ohos.net.webSocket.d.ets" ] + files = [ "ets/@ohos.net.webSocket.ets" ] is_boot_abc = "True" device_dst_file = "/system/framework/websocket.abc" } diff --git a/frameworks/ets/ani/web_socket/ets/@ohos.net.webSocket.d.ets b/frameworks/ets/ani/web_socket/ets/@ohos.net.webSocket.ets similarity index 35% rename from frameworks/ets/ani/web_socket/ets/@ohos.net.webSocket.d.ets rename to frameworks/ets/ani/web_socket/ets/@ohos.net.webSocket.ets index 3a74921863ff737c0a4af2f105a606c2b0b7b720..6667d1f828dbcedd2a74e11cb68bcbb0007ac424 100644 --- a/frameworks/ets/ani/web_socket/ets/@ohos.net.webSocket.d.ets +++ b/frameworks/ets/ani/web_socket/ets/@ohos.net.webSocket.ets @@ -24,9 +24,9 @@ export default namespace webSocket { loadLibrary("websocket_ani"); class Cleaner { - private ptr: long = 0 - constructor(ptr:long) { - this.ptr = ptr + private nativePtr: long = 0 + constructor(nativePtr:long) { + this.nativePtr = nativePtr } native clean(): void } @@ -37,7 +37,19 @@ export default namespace webSocket { let unregisterToken = new object() export interface WebSocketRequestOptions { - header?: Record; + header?: Record; + + caPath?: string; + + clientCert?: ClientCert; + + proxy?: ProxyConfiguration; + + protocol?: string; + } + + export class WebSocketRequestOptionsInner implements WebSocketRequestOptions { + header?: Record; caPath?: string; @@ -58,19 +70,45 @@ export default namespace webSocket { keyPassword?: string; } + export class ClientCertInner implements ClientCert { + certPath: string; + + keyPath: string; + + keyPassword?: string; + } + export interface WebSocketCloseOptions { code?: int; reason?: string; } + export class WebSocketCloseOptionsInner implements WebSocketCloseOptions { + code?: int; + reason?: string; + } + export interface CloseResult { code: int; reason: string; } - // export type ResponseHeaders = { - // [k: string]: string | string[] | undefined; - // } + export class CloseResultInner implements CloseResult { + code: int; + reason: string; + } + + export interface OpenResult { + status: int; + message: string; + } + + export class OpenResultInner implements OpenResult { + status: int; + message: string; + } + + export type ResponseHeaders = Record; export interface WebSocket { connect(url: string, callback: AsyncCallback): void; @@ -89,29 +127,11 @@ export default namespace webSocket { close(options?: WebSocketCloseOptions): Promise; - // on(type: 'open', callback: AsyncCallback): void; - - // off(type: 'open', callback?: AsyncCallback): void; - - // on(type: 'message', callback: AsyncCallback): void; - - // off(type: 'message', callback?: AsyncCallback): void; - - // on(type: 'close', callback: AsyncCallback): void; + on(type: 'open' | 'message' | 'close' | 'error' | 'dataEnd' | 'headerReceive', + callback: Object): void; - // off(type: 'close', callback?: AsyncCallback): void; - - // on(type: 'error', callback: ErrorCallback): void; - - // off(type: 'error', callback?: ErrorCallback): void; - - // on(type: 'dataEnd', callback: Callback): void; - - // off(type: 'dataEnd', callback?: Callback): void; - - // on(type: 'headerReceive', callback: Callback): void; - - // off(type: 'headerReceive', callback?: Callback): void; + off(type: 'open' | 'message' | 'close' | 'error' | 'dataEnd' | 'headerReceive', + callback?: Object): void; } export class WebSocketInner implements WebSocket { @@ -125,10 +145,11 @@ export default namespace webSocket { this.registerCleaner(this.nativePtr) } - registerCleaner(ptr: long): void { - this.cleaner = new Cleaner(ptr) + registerCleaner(nativePtr: long): void { + this.cleaner = new Cleaner(nativePtr) destroyRegister.register(this, this.cleaner!, unregisterToken); } + unregisterCleaner(): void { destroyRegister.unregister(unregisterToken); } @@ -229,5 +250,305 @@ export default namespace webSocket { }); }); } + + native onOpen(callback: AsyncCallback): void; + native onMessage(callback: AsyncCallback): void; + native onClose(callback: AsyncCallback): void; + native onError(callback: ErrorCallback): void; + native onDataEnd(callback: Callback): void; + native onHeaderReceive(callback: Callback): void; + + on(type: 'open' | 'message' | 'close' | 'error' | 'dataEnd' | 'headerReceive', + callback: Object): void { + if (type == 'open') { + this.onOpen(callback as AsyncCallback) + } else if (type == 'message') { + this.onMessage(callback as AsyncCallback) + } else if (type == 'close') { + this.onClose(callback as AsyncCallback) + } else if (type == 'error') { + this.onError(callback as ErrorCallback) + } else if (type == 'dataEnd') { + this.onDataEnd(callback as Callback) + } else if (type == 'headerReceive') { + this.onHeaderReceive(callback as Callback) + } + } + + native offOpen(callback?: AsyncCallback): void; + native offMessage(callback?: AsyncCallback): void; + native offClose(callback?: AsyncCallback): void; + native offError(callback?: ErrorCallback): void; + native offDataEnd(callback?: Callback): void; + native offHeaderReceive(callback?: Callback): void; + + off(type: 'open' | 'message' | 'close' | 'error' | 'dataEnd' | 'headerReceive', + callback?: Object): void { + if (type == 'open') { + if (callback === undefined) { + this.offOpen() + } else { + this.offOpen(callback as AsyncCallback) + } + } else if (type == 'message') { + if (callback === undefined) { + this.offMessage() + } else { + this.offMessage(callback as AsyncCallback) + } + } else if (type == 'close') { + if (callback === undefined) { + this.offClose() + } else { + this.offClose(callback as AsyncCallback) + } + } else if (type == 'error') { + if (callback === undefined) { + this.offError() + } else { + this.offError(callback as ErrorCallback) + } + } else if (type == 'dataEnd') { + if (callback === undefined) { + this.offDataEnd() + } else { + this.offDataEnd(callback as Callback) + } + } else if (type == 'headerReceive') { + if (callback === undefined) { + this.offHeaderReceive() + } else { + this.offHeaderReceive(callback as Callback) + } + } + } + } + + class CleanerServer { + private nativePtr: long = 0 + constructor(nativePtr:long) { + this.nativePtr = nativePtr + } + native clean(): void + } + + let destroyRegisterServer = new FinalizationRegistry((cleaner: CleanerServer) => {cleaner.clean()}) + let unregisterTokenServer = new object() + + export native function createWebSocketServer(): WebSocketServer; + + export interface WebSocketServerConfig { + serverIP?: string; + + serverPort: int; + + serverCert?: ServerCert; + + maxConcurrentClientsNumber: int; + + protocol?: string; + + maxConnectionsForOneClient: int; + } + + export class WebSocketServerConfigInner implements WebSocketServerConfig { + serverIP?: string; + + serverPort: int; + + serverCert?: ServerCert; + + maxConcurrentClientsNumber: int; + + protocol?: string; + + maxConnectionsForOneClient: int; + } + + export interface ServerCert { + certPath: string; + + keyPath: string; + } + + export class ServerCertInner implements ServerCert { + certPath: string; + + keyPath: string; + } + + export interface WebSocketConnection { + clientIP: string; + + clientPort: int; + } + + export class WebSocketConnectionInner implements WebSocketConnection { + clientIP: string; + + clientPort: int; + } + + export type ClientConnectionCloseCallback = (clientConnection: WebSocketConnection, closeReason: CloseResult) => void; + + export interface WebSocketMessage { + data: string | ArrayBuffer; + + clientConnection: WebSocketConnection; + } + + export class WebSocketMessageInner implements WebSocketMessage { + data: string | ArrayBuffer; + + clientConnection: WebSocketConnection; + } + + export interface WebSocketServer { + start(config: WebSocketServerConfig): Promise; + + listAllConnections(): WebSocketConnection[]; + + close(connection: WebSocketConnection, options?: WebSocketCloseOptions): Promise; + + send(data: string | ArrayBuffer, connection: WebSocketConnection): Promise; + + stop(): Promise; + + on(type: 'error' | 'connect' | 'close' | 'messageReceive', callback: Object): void; + + off(type: 'error' | 'connect' | 'close' | 'messageReceive', callback?: Object): void; + } + + export class WebSocketServerInner implements WebSocketServer { + private nativePtr: long = 0; + private cleaner: CleanerServer | null = null; + + constructor(context: long) { + if (this.nativePtr == 0) { + this.nativePtr = context; + } + this.registerCleaner(this.nativePtr); + } + + registerCleaner(nativePtr: long): void { + this.cleaner = new CleanerServer(nativePtr); + destroyRegisterServer.register(this, this.cleaner!, unregisterTokenServer); + } + + unregisterCleaner(): void { + destroyRegisterServer.unregister(unregisterTokenServer); + } + + native startSync(config: WebSocketServerConfig): boolean; + + start(config: WebSocketServerConfig): Promise { + return new Promise((resolve, reject) => { + taskpool.execute((): boolean => { + return this.startSync(config); + }).then((content: NullishType) => { + resolve(content as boolean); + }, (err: Error): void => { + reject(err as BusinessError); + }); + }); + } + + native closeSync(connection: WebSocketConnection, options?: webSocket.WebSocketCloseOptions): boolean; + + close(connection: WebSocketConnection, options?: webSocket.WebSocketCloseOptions): Promise { + return new Promise((resolve, reject) => { + taskpool.execute((): boolean => { + return this.closeSync(connection, options); + }).then((content: NullishType) => { + resolve(content as boolean); + }, (err: Error): void => { + reject(err as BusinessError); + }); + }); + } + + native listAllConnectionsSync(): WebSocketConnection[]; + + listAllConnections(): WebSocketConnection[] { + return this.listAllConnectionsSync(); + } + + native sendSync(data: string | ArrayBuffer, connection: WebSocketConnection): boolean; + + send(data: string | ArrayBuffer, connection: WebSocketConnection): Promise { + return new Promise((resolve, reject) => { + taskpool.execute((): boolean => { + return this.sendSync(data, connection); + }).then((content: NullishType) => { + resolve(content as boolean); + }, (err: Error): void => { + reject(err as BusinessError); + }); + }); + } + + native stopSync(): boolean; + + stop(): Promise { + return new Promise((resolve, reject) => { + taskpool.execute((): boolean => { + return this.stopSync(); + }).then((content: NullishType) => { + resolve(content as boolean); + }, (err: Error): void => { + reject(err as BusinessError); + }); + }); + } + + native onError(callback: ErrorCallback): void; + native onConnect(callback: Callback): void; + native onClose(callback: ClientConnectionCloseCallback): void; + native onMessageReceive(callback: Callback): void; + + on(type: 'error' | 'connect' | 'close' | 'messageReceive', callback: Object): void { + if (type == 'error') { + this.onError(callback as ErrorCallback); + } else if (type == 'connect') { + this.onConnect(callback as Callback); + } else if (type == 'close') { + this.onClose(callback as ClientConnectionCloseCallback); + } else if (type == 'messageReceive') { + this.onMessageReceive(callback as Callback); + } + } + + native offError(callback?: ErrorCallback): void; + native offConnect(callback?: Callback): void; + native offClose(callback?: ClientConnectionCloseCallback): void; + native offMessageReceive(callback?: Callback): void; + + off(type: 'error' | 'connect' | 'close' | 'messageReceive', callback?: Object): void { + if (type == 'error') { + if (callback === undefined) { + this.offError(); + } else { + this.offError(callback as ErrorCallback); + } + } else if (type == 'connect') { + if (callback === undefined) { + this.offConnect(); + } else { + this.offConnect(callback as Callback); + } + } else if (type == 'close') { + if (callback === undefined) { + this.offClose(); + } else { + this.offClose(callback as ClientConnectionCloseCallback); + } + } else if (type == 'messageReceive') { + if (callback === undefined) { + this.offMessageReceive(); + } else { + this.offMessageReceive(callback as Callback); + } + } + } } } diff --git a/frameworks/ets/ani/web_socket/include/websocket_ani.h b/frameworks/ets/ani/web_socket/include/websocket_ani.h index d4e4a061eeb30590d3996dba28679936d7876ad0..64804187ca0da6437aeea13d3aea98dbc29c27af 100644 --- a/frameworks/ets/ani/web_socket/include/websocket_ani.h +++ b/frameworks/ets/ani/web_socket/include/websocket_ani.h @@ -22,23 +22,80 @@ #include "cxx.h" #include "websocket_client_innerapi.h" +#include "websocket_server_innerapi.h" namespace OHOS { namespace NetStackAni { -struct ConnectOptions; -struct CloseOption; -std::unique_ptr CreateWebSocket(); +struct AniConnectOptions; +struct AniCloseOption; +struct AniServerConfig; +struct AniServerConfigCert; +struct AniWebSocketConnection; -int32_t Connect(NetStack::WebSocketClient::WebSocketClient &client, const rust::str url, ConnectOptions options); +class WebSocketClientWrapper{ +public: + WebSocketClientWrapper(); + ~WebSocketClientWrapper(); + std::shared_ptr client = nullptr; +}; -void SetCaPath(NetStack::WebSocketClient::WebSocketClient &client, const rust::str caPath); -void SetClientCert(NetStack::WebSocketClient::WebSocketClient &client, const rust::str clientCert, - const rust::str clientKey); -void SetCertPassword(NetStack::WebSocketClient::WebSocketClient &client, const rust::str password); +std::unique_ptr CreateWebSocket(); +int32_t Connect(WebSocketClientWrapper &client, + const rust::str url, AniConnectOptions options); -int32_t Send(NetStack::WebSocketClient::WebSocketClient &client, const rust::str data); -int32_t Close(NetStack::WebSocketClient::WebSocketClient &client, CloseOption options); +void SetCaPath(WebSocketClientWrapper &client, + const rust::str caPath); +void SetClientCert(WebSocketClientWrapper &client, + const rust::str clientCert, const rust::str clientKey); +void SetCertPassword(WebSocketClientWrapper &client, + const rust::str password); + +int32_t Send(WebSocketClientWrapper &client, + const rust::Vec data, int32_t data_type); +int32_t Close(WebSocketClientWrapper &client, + AniCloseOption options); + +int32_t RegisterOpenCallback(WebSocketClientWrapper &client); +int32_t RegisterMessageCallback(WebSocketClientWrapper &client); +int32_t RegisterCloseCallback(WebSocketClientWrapper &client); +int32_t RegisterErrorCallback(WebSocketClientWrapper &client); +int32_t RegisterDataEndCallback(WebSocketClientWrapper &client); +int32_t RegisterHeaderReceiveCallback(WebSocketClientWrapper &client); + +int32_t UnregisterOpenCallback(WebSocketClientWrapper &client); +int32_t UnregisterMessageCallback(WebSocketClientWrapper &client); +int32_t UnregisterCloseCallback(WebSocketClientWrapper &client); +int32_t UnregisterErrorCallback(WebSocketClientWrapper &client); +int32_t UnregisterDataEndCallback(WebSocketClientWrapper &client); +int32_t UnregisterHeaderReceiveCallback(WebSocketClientWrapper &client); + +/** + * @brief server + */ +std::unique_ptr CreateWebSocketServer(); +int32_t StartServer(NetStack::WebSocketServer::WebSocketServer &server, + AniServerConfig options); +int32_t StopServer(NetStack::WebSocketServer::WebSocketServer &server); +int32_t SendServerData(NetStack::WebSocketServer::WebSocketServer &server, + const rust::Vec data, + const AniWebSocketConnection &connection, + int32_t data_type); +int32_t CloseServer(NetStack::WebSocketServer::WebSocketServer &server, + const AniWebSocketConnection &connection, + AniCloseOption options); +int32_t ListAllConnections(NetStack::WebSocketServer::WebSocketServer &server, + rust::Vec &connections); + +int32_t RegisterServerErrorCallback(NetStack::WebSocketServer::WebSocketServer &server); +int32_t RegisterServerConnectCallback(NetStack::WebSocketServer::WebSocketServer &server); +int32_t RegisterServerCloseCallback(NetStack::WebSocketServer::WebSocketServer &server); +int32_t RegisterServerMessageReceiveCallback(NetStack::WebSocketServer::WebSocketServer &server); + +int32_t UnregisterServerErrorCallback(NetStack::WebSocketServer::WebSocketServer &server); +int32_t UnregisterServerConnectCallback(NetStack::WebSocketServer::WebSocketServer &server); +int32_t UnregisterServerCloseCallback(NetStack::WebSocketServer::WebSocketServer &server); +int32_t UnregisterServerMessageReceiveCallback(NetStack::WebSocketServer::WebSocketServer &server); } // namespace NetStackAni } // namespace OHOS diff --git a/frameworks/ets/ani/web_socket/src/bridge.rs b/frameworks/ets/ani/web_socket/src/bridge.rs index 981e9acf18d73b166cd999fdcf40ecbd68fdae8e..7e9581865788be1f282066da7485ce78f26a9a81 100644 --- a/frameworks/ets/ani/web_socket/src/bridge.rs +++ b/frameworks/ets/ani/web_socket/src/bridge.rs @@ -13,21 +13,31 @@ use std::collections::HashMap; -use ani_rs::business_error::BusinessError; -use serde::Deserialize; +use ani_rs::{business_error::BusinessError, typed_array::ArrayBuffer}; +use serde::{Deserialize, Serialize}; -#[ani_rs::ani] -pub struct Cleaner { - pub native_ptr: i64, +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/Cleaner")] +pub struct AniCleaner { + pub nativePtr: i64, } -#[ani_rs::ani] -pub struct WebSocket { - pub native_ptr: i64, +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/CleanerServer")] +pub struct AniCleanerServer { + pub nativePtr: i64, +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/WebSocketInner")] +pub struct AniWebSocket { + pub nativePtr: i64, +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/WebSocketServerInner")] +pub struct AniWebSocketServer { + pub nativePtr: i64, } #[ani_rs::ani(path = "L@ohos/net/connection/connection/HttpProxyInner")] -pub struct HttpProxy { +pub struct AniHttpProxy { pub host: String, pub port: i32, @@ -39,46 +49,190 @@ pub struct HttpProxy { pub exclusion_list: Vec, } -#[derive(Deserialize)] -pub enum ProxyConfiguration { +#[derive(Serialize, Deserialize)] +pub enum AniProxyConfiguration { S(String), #[serde(rename = "L@ohos/net/connection/connection/HttpProxyInner;")] - Proxy(HttpProxy), + Proxy(AniHttpProxy), } -#[ani_rs::ani] -pub struct WebSocketRequestOptions { +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/WebSocketRequestOptionsInner")] +pub struct AniWebSocketRequestOptions { pub header: Option>, - pub ca_path: Option, + pub caPath: Option, - pub client_cert: Option, + pub clientCert: Option, - pub proxy: Option, + pub proxy: Option, pub protocol: Option, } -#[ani_rs::ani] -pub struct ClientCert { - pub cert_path: String, +impl AniWebSocketRequestOptions { + pub fn new() -> Self { + Self { + header: None, + caPath: None, + clientCert: None, + proxy: None, + protocol: None, + } + } +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/ClientCertInner")] +pub struct AniClientCert { + pub certPath: String, - pub key_path: String, + pub keyPath: String, - pub key_password: Option, + pub keyPassword: Option, } -#[ani_rs::ani] -pub struct CloseResult { +impl AniClientCert { + pub fn new() -> Self { + Self { + certPath: "".to_string(), + keyPath: "".to_string(), + keyPassword: None, + } + } +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/CloseResultInner")] +pub struct AniCloseResult { pub code: i32, pub reason: String, } -#[ani_rs::ani] -pub struct WebSocketCloseOptions { +impl AniCloseResult { + pub fn new() -> Self { + Self { + code: 0, + reason: "".to_string(), + } + } +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/OpenResultInner")] +pub struct AniOpenResult { + pub status: i32, + pub message: String, +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/WebSocketCloseOptionsInner")] +pub struct AniWebSocketCloseOptions { pub code: Option, pub reason: Option, } +impl AniWebSocketCloseOptions { + pub fn new() -> Self { + Self { + code: None, + reason: None, + } + } +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/ServerCertInner")] +pub struct AniServerCert { + pub certPath: String, + pub keyPath: String, +} + +impl AniServerCert { + pub fn new() -> Self { + Self { + certPath: "".to_string(), + keyPath: "".to_string(), + } + } +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/WebSocketServerConfigInner")] +pub struct AniWebSocketServerConfig { + pub serverIP: Option, + pub serverPort: i32, + pub serverCert: Option, + pub maxConcurrentClientsNumber: i32, + pub protocol: Option, + pub maxConnectionsForOneClient: i32, +} + +impl AniWebSocketServerConfig { + pub fn new() -> Self { + Self { + serverIP: None, + serverPort: 0, + serverCert: None, + maxConcurrentClientsNumber: 0, + protocol: None, + maxConnectionsForOneClient: 0, + } + } +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/WebSocketConnectionInner")] +pub struct AniWebSocketConnection { + pub clientIP: String, + pub clientPort: i32, +} + +impl AniWebSocketConnection { + pub fn new() -> Self { + Self { + clientIP: "".to_string(), + clientPort: 0, + } + } +} + +pub fn get_web_socket_connection_client_ip(conn: &AniWebSocketConnection) -> String { + conn.clientIP.clone() +} + +pub fn get_web_socket_connection_client_port(conn: &AniWebSocketConnection) -> i32 { + conn.clientPort.clone() +} + +pub fn socket_connection_push_data( + connection_info_value: &mut Vec, + clientIP: String, + clientPort: i32, +) { + let connection_info = AniWebSocketConnection {clientIP, clientPort}; + connection_info_value.push(connection_info); +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/ResponseHeaders")] +pub enum AniResponseHeaders { + MapBuffer(HashMap), + VecBuffer(Vec), + Undefined, +} + +#[derive(Serialize, Deserialize)] +pub enum AniData { + S(String), + ArrayBuffer(ArrayBuffer), +} + +#[ani_rs::ani(path = "L@ohos/net/webSocket/webSocket/WebSocketMessageInner")] +pub struct AniWebSocketMessage { + pub data: AniData, + pub clientConnection: AniWebSocketConnection, +} + +impl AniWebSocketMessage { + pub fn new(data: AniData, clientConnection: AniWebSocketConnection) -> Self { + Self { + data, + clientConnection, + } + } +} + pub const fn convert_to_business_error(code: i32) -> BusinessError { match code { 1004 => BusinessError::new_static(2302001, "Websocket Parse url error."), diff --git a/frameworks/ets/ani/web_socket/src/cxx/websocket_ani.cpp b/frameworks/ets/ani/web_socket/src/cxx/websocket_ani.cpp index 4b3d1945267a08087cd122d392def1a22e54abf1..22bb35d07c17a1560cff403ec2c4e251994134db 100644 --- a/frameworks/ets/ani/web_socket/src/cxx/websocket_ani.cpp +++ b/frameworks/ets/ani/web_socket/src/cxx/websocket_ani.cpp @@ -4,7 +4,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,40 +18,28 @@ #include "secure_char.h" #include "wrapper.rs.h" #include +#include namespace OHOS { namespace NetStackAni { -void OnMessageCallbackAni(NetStack::WebSocketClient::WebSocketClient *ptrInner, - const std::string &data, size_t length) -{ -} +static std::map clientMap; -void OnCloseCallbackAni(NetStack::WebSocketClient::WebSocketClient *ptrInner, - NetStack::WebSocketClient::CloseResult closeResult) -{ +WebSocketClientWrapper::WebSocketClientWrapper() { + client = std::make_shared(); + clientMap[client.get()] = this; } -void OnErrorCallbackAni(NetStack::WebSocketClient::WebSocketClient *ptrInner, - NetStack::WebSocketClient::ErrorResult error) -{ +WebSocketClientWrapper::~WebSocketClientWrapper() { + clientMap.erase(client.get()); } -void OnOpenCallbackAni(NetStack::WebSocketClient::WebSocketClient *ptrInner, - NetStack::WebSocketClient::OpenResult openResult) +std::unique_ptr CreateWebSocket() { + return std::make_unique(); } -std::unique_ptr CreateWebSocket() -{ - auto client = std::make_unique(); - client->Registcallback(OnOpenCallbackAni, OnMessageCallbackAni, - OnErrorCallbackAni, OnCloseCallbackAni); - - return client; -} - -int32_t Connect(NetStack::WebSocketClient::WebSocketClient &client, const rust::str url, ConnectOptions options) +int32_t Connect(WebSocketClientWrapper &client, const rust::str url, AniConnectOptions options) { NetStack::WebSocketClient::OpenOptions openOptions; bool isValue = false; @@ -67,42 +55,346 @@ int32_t Connect(NetStack::WebSocketClient::WebSocketClient &client, const rust:: isValue = true; } } - return client.Connect(std::string(url), openOptions); + return client.client->ConnectEx(std::string(url), openOptions); } -void SetCaPath(NetStack::WebSocketClient::WebSocketClient &client, const rust::str caPath) +void SetCaPath(WebSocketClientWrapper &client, const rust::str caPath) { - auto context = client.GetClientContext(); + auto context = client.client->GetClientContext(); context->SetUserCertPath(std::string(caPath)); } -void SetClientCert(NetStack::WebSocketClient::WebSocketClient &client, const rust::str clientCert, - const rust::str clientKey) +void SetClientCert(WebSocketClientWrapper &client, const rust::str clientCert, + const rust::str clientKey) { - auto context = client.GetClientContext(); + auto context = client.client->GetClientContext(); context->clientCert = std::string(clientCert); context->clientKey = NetStack::Secure::SecureChar(std::string(clientKey)); } -void SetCertPassword(NetStack::WebSocketClient::WebSocketClient &client, const rust::str password) +void SetCertPassword(WebSocketClientWrapper &client, const rust::str password) { - auto context = client.GetClientContext(); + auto context = client.client->GetClientContext(); context->keyPassword = NetStack::Secure::SecureChar(std::string(password)); } -int32_t Send(NetStack::WebSocketClient::WebSocketClient &client, const rust::str data) +int32_t Send(WebSocketClientWrapper &client, const rust::Vec data, int32_t data_type) { - return client.Send(const_cast(data.data()), data.size()); + NETSTACK_LOGI("WebSocketClient data %{public}p %{public}p %{public}d", &client, data.data(), data.size()); + if ((data.size() == 0 || data.data() == nullptr) && data_type == 1) { + return 401; + } + return client.client->SendEx((char*)data.data(), data.size()); } -int32_t Close(NetStack::WebSocketClient::WebSocketClient &client, CloseOption options) +int32_t Close(WebSocketClientWrapper &client, AniCloseOption options) { NetStack::WebSocketClient::CloseOption closeOption{ .code = options.code, .reason = options.reason.data(), }; - return client.Close(closeOption); + int ret = client.client->CloseEx(closeOption); + if (ret == -1) { + ret = 2302999; + } + return ret; +} + +void OnOpenCallbackC(NetStack::WebSocketClient::WebSocketClient *client, + NetStack::WebSocketClient::OpenResult openResult) +{ + auto iter = clientMap.find(client); + if (iter == clientMap.end()) { + NETSTACK_LOGE("OnOpenCallbackC can not find client"); + return; + } + on_open_websocket_client(*(iter->second), std::string(openResult.message), openResult.status); +} + +void OnMessageCallbackC(NetStack::WebSocketClient::WebSocketClient *client, const std::string &data, size_t length) +{ + auto iter = clientMap.find(client); + if (iter == clientMap.end()) { + NETSTACK_LOGE("OnOpenCallbackC can not find client"); + return; + } + on_message_websocket_client(*(iter->second), data, length); +} + +void OnCloseCallbackC(NetStack::WebSocketClient::WebSocketClient *client, + NetStack::WebSocketClient::CloseResult closeResult) +{ + auto iter = clientMap.find(client); + if (iter == clientMap.end()) { + NETSTACK_LOGE("OnOpenCallbackC can not find client"); + return; + } + on_close_websocket_client(*(iter->second), std::string(closeResult.reason), closeResult.code); +} + +void OnErrorCallbackC(NetStack::WebSocketClient::WebSocketClient *client, NetStack::WebSocketClient::ErrorResult error) +{ + auto iter = clientMap.find(client); + if (iter == clientMap.end()) { + NETSTACK_LOGE("OnOpenCallbackC can not find client"); + return; + } + if (error.errorCode == 1003) { + error.errorCode = 200; + } + on_error_websocket_client(*(iter->second), std::string(error.errorMessage), error.errorCode); +} + +void OnDataEndCallbackC(NetStack::WebSocketClient::WebSocketClient *client) +{ + auto iter = clientMap.find(client); + if (iter == clientMap.end()) { + NETSTACK_LOGE("OnOpenCallbackC can not find client"); + return; + } + on_data_end_websocket_client(*(iter->second)); +} + +void OnHeaderReceiveCallbackC(NetStack::WebSocketClient::WebSocketClient *client, + const std::map &headers) +{ + auto iter = clientMap.find(client); + if (iter == clientMap.end()) { + NETSTACK_LOGE("OnOpenCallbackC can not find client"); + return; + } + rust::Vec keys; + rust::Vec values; + for (const auto &pair : headers) { + header_push_data(keys, rust::String(pair.first.c_str())); + header_push_data(values, rust::String(pair.second.c_str())); + } + on_header_receive_websocket_client(*(iter->second), keys, values); +} + +int32_t RegisterOpenCallback(WebSocketClientWrapper &client) +{ + client.client->onOpenCallback_ = &OnOpenCallbackC; + return 0; +} + +int32_t RegisterMessageCallback(WebSocketClientWrapper &client) +{ + client.client->onMessageCallback_ = &OnMessageCallbackC; + return 0; +} + +int32_t RegisterCloseCallback(WebSocketClientWrapper &client) +{ + client.client->onCloseCallback_ = &OnCloseCallbackC; + return 0; +} + +int32_t RegisterErrorCallback(WebSocketClientWrapper &client) +{ + client.client->onErrorCallback_ = &OnErrorCallbackC; + return 0; +} + +int32_t RegisterDataEndCallback(WebSocketClientWrapper &client) +{ + client.client->onDataEndCallback_ = &OnDataEndCallbackC; + return 0; +} + +int32_t RegisterHeaderReceiveCallback(WebSocketClientWrapper &client) +{ + client.client->onHeaderReceiveCallback_ = &OnHeaderReceiveCallbackC; + return 0; +} + +int32_t UnregisterOpenCallback(WebSocketClientWrapper &client) +{ + client.client->onOpenCallback_ = nullptr; + return 0; +} + +int32_t UnregisterMessageCallback(WebSocketClientWrapper &client) +{ + client.client->onMessageCallback_ = nullptr; + return 0; +} + +int32_t UnregisterCloseCallback(WebSocketClientWrapper &client) +{ + client.client->onCloseCallback_ = nullptr; + return 0; +} + +int32_t UnregisterHeaderReceiveCallback(WebSocketClientWrapper &client) +{ + client.client->onHeaderReceiveCallback_ = nullptr; + return 0; +} + +int32_t UnregisterErrorCallback(WebSocketClientWrapper &client) +{ + client.client->onErrorCallback_ = nullptr; + return 0; +} + +int32_t UnregisterDataEndCallback(WebSocketClientWrapper &client) +{ + client.client->onDataEndCallback_ = nullptr; + return 0; +} + +/* * + * @brief server + */ +std::unique_ptr CreateWebSocketServer() +{ + return std::make_unique(); +} + +int32_t StartServer(NetStack::WebSocketServer::WebSocketServer &server, AniServerConfig options) +{ + NetStack::WebSocketServer::ServerCert serverCert{ + .certPath = options.serverCert.certPath.c_str(), + .keyPath = options.serverCert.keyPath.c_str() + }; + NetStack::WebSocketServer::ServerConfig severCfg{ + .serverIP = options.serverIP.c_str(), + .serverPort = options.serverPort, + .serverCert = serverCert, + .maxConcurrentClientsNumber = options.maxConcurrentClientsNumber, + .protocol = options.protocol.c_str(), + .maxConnectionsForOneClient = options.maxConnectionsForOneClient + }; + return server.Start(severCfg); +} + +int32_t StopServer(NetStack::WebSocketServer::WebSocketServer &server) +{ + return server.Stop(); +} + +int32_t CloseServer(NetStack::WebSocketServer::WebSocketServer &server, const AniWebSocketConnection &connection, + AniCloseOption options) +{ + std::string strIP(get_web_socket_connection_client_ip(connection).c_str()); + int32_t iPort = get_web_socket_connection_client_port(connection); + NetStack::WebSocketServer::SocketConnection socketConn{ + .clientIP = strIP, + .clientPort = static_cast(iPort), + }; + NetStack::WebSocketServer::CloseOption closeOpt{ + .code = options.code, + .reason = options.reason.data(), + }; + return server.Close(socketConn, closeOpt); +} + +int32_t SendServerData(NetStack::WebSocketServer::WebSocketServer &server, const rust::Vec data, + const AniWebSocketConnection &connection, int32_t data_type) +{ + if ((data.size() == 0 || data.data() == nullptr) && data_type == 1) { + return 401; + } + std::string strIP(get_web_socket_connection_client_ip(connection).c_str()); + int32_t iPort = get_web_socket_connection_client_port(connection); + NetStack::WebSocketServer::SocketConnection socketConn{ + .clientIP = strIP, + .clientPort = static_cast(iPort), + }; + return server.Send((char*)(data.data()), data.size(), socketConn); } +int32_t ListAllConnections(NetStack::WebSocketServer::WebSocketServer &server, + rust::Vec &connections) +{ + int32_t iRet; + std::vector connectionList; + iRet = server.ListAllConnections(connectionList); + if (iRet != 0) { + return iRet; + } + + for (size_t i = 0; i < connectionList.size(); ++i) { + std::string strIP = connectionList[i].clientIP; + int32_t iPort = static_cast(connectionList[i].clientPort); + socket_connection_push_data(connections, rust::String(strIP.c_str()), rust::i32(iPort)); + } + + return iRet; +} + +void OnErrorCallbackServerC(NetStack::WebSocketServer::WebSocketServer *server, + NetStack::WebSocketServer::ErrorResult error) +{ + on_error_websocket_server(*server, std::string(error.errorMessage), error.errorCode); +} + +void OnConnectCallbackServerC(NetStack::WebSocketServer::WebSocketServer *server, + NetStack::WebSocketServer::SocketConnection connection) +{ + on_connect_websocket_server(*server, std::string(connection.clientIP), connection.clientPort); +} + +void OnCloseCallbackServerC(NetStack::WebSocketServer::WebSocketServer *server, + NetStack::WebSocketServer::CloseResult result, NetStack::WebSocketServer::SocketConnection connection) +{ + on_close_websocket_server(*server, std::string(result.reason), result.code, std::string(connection.clientIP), + connection.clientPort); +} + +void OnMessageReceiveCallbackServerC(NetStack::WebSocketServer::WebSocketServer *server, const std::string &data, + size_t length, NetStack::WebSocketServer::SocketConnection connection) +{ + on_message_receive_websocket_server(*server, data, length, std::string(connection.clientIP), connection.clientPort); +} + +int32_t RegisterServerErrorCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onErrorCallback_ = &OnErrorCallbackServerC; + return 0; +} + +int32_t RegisterServerConnectCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onConnectCallback_ = &OnConnectCallbackServerC; + return 0; +} + +int32_t RegisterServerCloseCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onCloseCallback_ = &OnCloseCallbackServerC; + return 0; +} + +int32_t RegisterServerMessageReceiveCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onMessageReceiveCallback_ = &OnMessageReceiveCallbackServerC; + return 0; +} + +int32_t UnregisterServerErrorCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onErrorCallback_ = nullptr; + return 0; +} + +int32_t UnregisterServerConnectCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onConnectCallback_ = nullptr; + return 0; +} + +int32_t UnregisterServerCloseCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onCloseCallback_ = nullptr; + return 0; +} + +int32_t UnregisterServerMessageReceiveCallback(NetStack::WebSocketServer::WebSocketServer &server) +{ + server.onMessageReceiveCallback_ = nullptr; + return 0; +} } // namespace NetStackAni } // namespace OHOS \ No newline at end of file diff --git a/frameworks/ets/ani/web_socket/src/lib.rs b/frameworks/ets/ani/web_socket/src/lib.rs index 69104ac3cac45486e35b1ce1a921ec7773469c8e..0a9a0664b0f4223fa38fd31cff5c6955bdf6eba2 100644 --- a/frameworks/ets/ani/web_socket/src/lib.rs +++ b/frameworks/ets/ani/web_socket/src/lib.rs @@ -22,23 +22,57 @@ extern crate netstack_common; mod bridge; mod web_socket; +mod web_socket_server; mod wrapper; ani_rs::ani_constructor! { namespace "L@ohos/net/webSocket/webSocket" [ - "createWebSocket" : web_socket::create_web_socket + "createWebSocket" : web_socket::create_web_socket, + "createWebSocketServer" : web_socket_server::create_web_socket_server, ] class "L@ohos/net/webSocket/webSocket/WebSocketInner" [ "connectSync" : web_socket::connect_sync, "sendSync" : web_socket::send_sync, "closeSync" : web_socket::close_sync, + "onOpen" : web_socket::on_open, + "onMessage" : web_socket::on_message, + "onClose" : web_socket::on_close, + "onError" : web_socket::on_error, + "onDataEnd" : web_socket::on_data_end, + "onHeaderReceive" : web_socket::on_header_receive, + "offOpen" : web_socket::off_open, + "offMessage" : web_socket::off_message, + "offClose" : web_socket::off_close, + "offError" : web_socket::off_error, + "offDataEnd" : web_socket::off_data_end, + "offHeaderReceive" : web_socket::off_header_receive, + ] + class "L@ohos/net/webSocket/webSocket/WebSocketServerInner" + [ + "startSync" : web_socket_server::start_sync, + "stopSync" : web_socket_server::stop_sync, + "sendSync" : web_socket_server::send_sync, + "closeSync" : web_socket_server::close_sync, + "listAllConnectionsSync" : web_socket_server::list_all_connections_sync, + "onError" : web_socket_server::on_error, + "onConnect" : web_socket_server::on_connect, + "onClose" : web_socket_server::on_close, + "onMessageReceive" : web_socket_server::on_message_receive, + "offError" : web_socket_server::off_error, + "offConnect" : web_socket_server::off_connect, + "offClose" : web_socket_server::off_close, + "offMessageReceive" : web_socket_server::off_message_receive, ] class "L@ohos/net/webSocket/webSocket/Cleaner" [ "clean" : web_socket::web_socket_clean, ] + class "L@ohos/net/webSocket/webSocket/CleanerServer" + [ + "clean" : web_socket_server::web_socket_server_clean, + ] } #[used] diff --git a/frameworks/ets/ani/web_socket/src/web_socket.rs b/frameworks/ets/ani/web_socket/src/web_socket.rs index e12e6d8eee1c0e408428103c19b4bbfebe83cd7c..59275d0f20bf6662393576aca8396a6ce745352b 100644 --- a/frameworks/ets/ani/web_socket/src/web_socket.rs +++ b/frameworks/ets/ani/web_socket/src/web_socket.rs @@ -13,18 +13,21 @@ use core::str; use std::{collections::HashMap, ffi::CStr}; - -use ani_rs::{business_error::BusinessError, objects::AniRef, AniEnv}; +use ani_rs::{ + business_error::BusinessError, + objects::{AniFnObject, AniAsyncCallback, AniErrorCallback, AniRef}, + AniEnv, +}; use serde::{Deserialize, Serialize}; use crate::{ - bridge::{self, convert_to_business_error, Cleaner}, - wrapper::WebSocket, + bridge::{self, convert_to_business_error, AniCleaner}, + wrapper::AniClient, }; #[ani_rs::native] -pub(crate) fn web_socket_clean(this: Cleaner) -> Result<(), BusinessError> { - let _ = unsafe { Box::from_raw(this.native_ptr as *mut WebSocket) }; +pub(crate) fn web_socket_clean(this: AniCleaner) -> Result<(), BusinessError> { + let _ = unsafe { Box::from_raw(this.nativePtr as *mut AniClient) }; Ok(()) } @@ -35,73 +38,68 @@ pub fn create_web_socket<'local>(env: &AniEnv<'local>) -> Result, CStr::from_bytes_with_nul_unchecked(b"L@ohos/net/webSocket/webSocket/WebSocketInner;\0") }; static CTOR_SIGNATURE: &CStr = unsafe { CStr::from_bytes_with_nul_unchecked(b"J:V\0") }; - let web_socket = Box::new(WebSocket::new()); - let ptr = Box::into_raw(web_socket); + let ptr = AniClient::new(); let class = env.find_class(WEB_SOCKET_CLASS).unwrap(); let obj = env - .new_object_with_signature(&class, CTOR_SIGNATURE, (ptr as i64,)) + .new_object_with_signature(&class, CTOR_SIGNATURE, (ptr,)) .unwrap(); Ok(obj.into()) } #[ani_rs::native] pub(crate) fn connect_sync( - this: bridge::WebSocket, + this: bridge::AniWebSocket, url: String, - options: Option, + options: Option, ) -> Result { info!("Connecting to WebSocket at URL: {}", url); - let web_socket = unsafe { &mut *(this.native_ptr as *mut WebSocket) }; + let web_socket = unsafe { &mut *(this.nativePtr as *mut AniClient) }; let mut headers = HashMap::new(); - let (mut ca_path, mut client_cert, mut protocol) = (None, None, None); + let (mut caPath, mut clientCert, mut protocol) = (None, None, None); if let Some(options) = options { if let Some(header) = options.header { headers = header; } - if let Some(path) = options.ca_path { - ca_path = Some(path); + if let Some(path) = options.caPath { + caPath = Some(path); } - if let Some(cert) = options.client_cert { - client_cert = Some(cert); + if let Some(cert) = options.clientCert { + clientCert = Some(cert); } if let Some(p) = options.protocol { protocol = Some(p); } } web_socket - .connect(&url, headers, ca_path, client_cert, protocol) + .connect(&url, headers, caPath, clientCert, protocol) .map(|_| true) .map_err(|e| convert_to_business_error(e)) } -#[derive(Serialize, Deserialize)] -pub(crate) enum Data<'a> { - S(String), - #[serde(borrow)] - ArrayBuffer(&'a [u8]), -} - #[ani_rs::native] -pub(crate) fn send_sync(this: bridge::WebSocket, data: Data) -> Result { - let web_socket = unsafe { &mut *(this.native_ptr as *mut WebSocket) }; - let s = match data { - Data::S(s) => s, - Data::ArrayBuffer(arr) => String::from_utf8_lossy(arr).to_string(), +pub(crate) fn send_sync( + this: bridge::AniWebSocket, + data: bridge::AniData, +) -> Result { + let web_socket = unsafe { &mut *(this.nativePtr as *mut AniClient) }; + let (s, data_type) = match data { + bridge::AniData::S(s) => (s.into_bytes(), 0), + bridge::AniData::ArrayBuffer(arr) => (arr.to_vec(), 1), }; web_socket - .send(&s) + .send(s, data_type) .map(|_| true) .map_err(|e| convert_to_business_error(e)) } #[ani_rs::native] pub(crate) fn close_sync( - this: bridge::WebSocket, - options: Option, + this: bridge::AniWebSocket, + options: Option, ) -> Result { - let web_socket = unsafe { &mut *(this.native_ptr as *mut WebSocket) }; + let web_socket = unsafe { &mut *(this.nativePtr as *mut AniClient) }; let code = options.as_ref().and_then(|opt| opt.code).unwrap_or(0) as u32; let reason = options @@ -115,3 +113,147 @@ pub(crate) fn close_sync( .map(|_| true) .map_err(|e| convert_to_business_error(e)) } + +#[ani_rs::native] +pub(crate) fn on_open( + env: &AniEnv, + this: bridge::AniWebSocket, + async_callback: AniAsyncCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_open = Some(async_callback.into_global_callback(env).unwrap()); + web_socket.on_open_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_open( + env: &AniEnv, + this: bridge::AniWebSocket, + async_callback: AniAsyncCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_open = None; + web_socket.off_open_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_message( + env: &AniEnv, + this: bridge::AniWebSocket, + async_callback: AniAsyncCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_message = Some(async_callback.into_global_callback(env).unwrap()); + web_socket.on_message_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_message( + env: &AniEnv, + this: bridge::AniWebSocket, + async_callback: AniAsyncCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_message = None; + web_socket.off_message_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_close( + env: &AniEnv, + this: bridge::AniWebSocket, + async_callback: AniAsyncCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_close = Some(async_callback.into_global_callback(env).unwrap()); + web_socket.on_close_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_close( + env: &AniEnv, + this: bridge::AniWebSocket, + async_callback: AniAsyncCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_close = None; + web_socket.off_close_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_error( + env: &AniEnv, + this: bridge::AniWebSocket, + error_callback: AniErrorCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_error = Some(error_callback.into_global_callback(env).unwrap()); + web_socket.on_error_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_error( + env: &AniEnv, + this: bridge::AniWebSocket, + error_callback: AniErrorCallback, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_error = None; + web_socket.off_error_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_data_end( + env: &AniEnv, + this: bridge::AniWebSocket, + callback: AniFnObject, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_data_end = Some(callback.into_global_callback(env).unwrap()); + web_socket.on_data_end_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_data_end( + env: &AniEnv, + this: bridge::AniWebSocket, + callback: AniFnObject, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_data_end = None; + web_socket.off_data_end_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_header_receive( + env: &AniEnv, + this: bridge::AniWebSocket, + callback: AniFnObject, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_header_receive = Some(callback.into_global_callback(env).unwrap()); + web_socket.on_header_receive_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_header_receive( + env: &AniEnv, + this: bridge::AniWebSocket, + callback: AniFnObject, +) -> Result<(), BusinessError> { + let web_socket = unsafe { &mut (*(this.nativePtr as *mut AniClient)) }; + web_socket.callback.on_header_receive = None; + web_socket.off_header_receive_native(); + Ok(()) +} \ No newline at end of file diff --git a/frameworks/ets/ani/web_socket/src/web_socket_server.rs b/frameworks/ets/ani/web_socket/src/web_socket_server.rs new file mode 100644 index 0000000000000000000000000000000000000000..661f7b1b862b8d5bd637841e39e20b3df9be8f83 --- /dev/null +++ b/frameworks/ets/ani/web_socket/src/web_socket_server.rs @@ -0,0 +1,274 @@ +// Copyright (C) 2025 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::str; +use std::ffi::CStr; + +use ani_rs::{ + business_error::BusinessError, + objects::{AniFnObject, AniErrorCallback, AniRef, GlobalRefCallback, GlobalRefAsyncCallback}, + AniEnv, +}; + +use crate::{ + bridge::{self, AniCleanerServer}, + wrapper::AniServer, +}; + +#[ani_rs::native] +pub(crate) fn web_socket_server_clean(this: AniCleanerServer) -> Result<(), BusinessError> { + info!("Cleaning up WebSocket server"); + let _ = unsafe { Box::from_raw(this.nativePtr as *mut AniServer) }; + Ok(()) +} + +#[ani_rs::native] +pub fn create_web_socket_server<'local>( + env: &AniEnv<'local>, +) -> Result, BusinessError> { + info!("Creating WebSocket server instance"); + static WEB_SOCKET_SERVER_CLASS: &CStr = unsafe { + CStr::from_bytes_with_nul_unchecked( + b"L@ohos/net/webSocket/webSocket/WebSocketServerInner;\0", + ) + }; + static CTOR_SIGNATURE: &CStr = unsafe { CStr::from_bytes_with_nul_unchecked(b"J:V\0") }; + let ptr = AniServer::new(); + let class = env.find_class(WEB_SOCKET_SERVER_CLASS).unwrap(); + let obj = env + .new_object_with_signature(&class, CTOR_SIGNATURE, (ptr,)) + .unwrap(); + Ok(obj.into()) +} + +#[ani_rs::native] +pub(crate) fn start_sync( + this: bridge::AniWebSocketServer, + config: bridge::AniWebSocketServerConfig, +) -> Result { + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + + let server_ip_str = config.serverIP + .as_ref() + .map(|s| s.as_str()) + .unwrap_or("0.0.0.0") + .to_string(); + info!("Starting WebSocket server at IP: {}", server_ip_str); + let server_port_num = config.serverPort; + let server_cert_path = config.serverCert + .as_ref() + .map(|s| s.certPath.as_str()) + .unwrap_or("") + .to_string(); + let server_key_path = config.serverCert + .as_ref() + .map(|s| s.keyPath.as_str()) + .unwrap_or("") + .to_string(); + let max_con_current_client_num = config.maxConcurrentClientsNumber; + let protocol_str = config.protocol + .as_ref() + .map(|s| s.as_str()) + .unwrap_or("") + .to_string(); + let max_connections_for_one_client_num = config.maxConnectionsForOneClient; + + web_socket_server + .start( + server_ip_str, + server_port_num, + server_cert_path, + server_key_path, + max_con_current_client_num, + protocol_str, + max_connections_for_one_client_num, + ) + .map(|_| true) + .map_err(|e| BusinessError::new(e, format!("Failed to start"))) +} + +#[ani_rs::native] +pub(crate) fn stop_sync(this: bridge::AniWebSocketServer) -> Result { + info!("Stopping WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server + .stop() + .map(|_| true) + .map_err(|e| BusinessError::new(e, format!("Failed to stop"))) +} + +#[ani_rs::native] +pub(crate) fn send_sync( + this: bridge::AniWebSocketServer, + data: bridge::AniData, + connection: bridge::AniWebSocketConnection, +) -> Result { + info!( + "Sending data to connection ip: {} and port: {}", + connection.clientIP, connection.clientPort + ); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + let (s, data_type) = match data { + bridge::AniData::S(s) => (s.into_bytes(), 0), + bridge::AniData::ArrayBuffer(arr) => (arr.to_vec(), 1), + }; + web_socket_server + .send(s, &connection, data_type) + .map(|_| true) + .map_err(|e| { + BusinessError::new(e, format!("Failed to send data to connection ip: {} and port: {}", + connection.clientIP, connection.clientPort)) + }) +} + +#[ani_rs::native] +pub(crate) fn close_sync( + this: bridge::AniWebSocketServer, + connection: bridge::AniWebSocketConnection, + options: Option, +) -> Result { + info!( + "Closing connection ip: {} and port: {}", + connection.clientIP, connection.clientPort + ); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + + let code = options.as_ref().and_then(|opt| opt.code).unwrap_or(0) as u32; + let reason = options + .as_ref() + .and_then(|opt| opt.reason.as_ref()) + .map(|s| s.as_str()) + .unwrap_or(""); + + web_socket_server + .close(&connection, code, &reason) + .map(|_| true) + .map_err(|e| { + BusinessError::new(e, format!("Failed to close connection ip: {} and port: {}", + connection.clientIP, connection.clientPort))}) +} + +#[ani_rs::native] +pub(crate) fn list_all_connections_sync( + this: bridge::AniWebSocketServer, +) -> Result, BusinessError> { + info!("Listing all WebSocket connections"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + let mut socket_connection = Vec::new(); + web_socket_server.list_all_connections(&mut socket_connection); + Ok(socket_connection) +} + +#[ani_rs::native] +pub(crate) fn on_error( + env: &AniEnv, + this: bridge::AniWebSocketServer, + error_callback: AniErrorCallback, +) -> Result<(), BusinessError> { + info!("Setting up error callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_error = Some(error_callback.into_global_callback(env).unwrap()); + web_socket_server.on_error_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_error( + env: &AniEnv, + this: bridge::AniWebSocketServer, + error_callback: AniErrorCallback, +) -> Result<(), BusinessError> { + info!("Removing error callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_error = None; + web_socket_server.off_error_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_connect( + env: &AniEnv, + this: bridge::AniWebSocketServer, + callback: AniFnObject, +) -> Result<(), BusinessError> { + info!("Setting up connect callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_connect = Some(callback.into_global_callback(env).unwrap()); + web_socket_server.on_connect_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_connect( + env: &AniEnv, + this: bridge::AniWebSocketServer, + callback: AniFnObject, +) -> Result<(), BusinessError> { + info!("Removing connect callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_connect = None; + web_socket_server.off_connect_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_close( + env: &AniEnv, + this: bridge::AniWebSocketServer, + callback: AniFnObject, +) -> Result<(), BusinessError> { + info!("Setting up close callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_close = Some(callback.into_global_callback(env).unwrap()); + web_socket_server.on_close_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_close( + env: &AniEnv, + this: bridge::AniWebSocketServer, + callback: AniFnObject, +) -> Result<(), BusinessError> { + info!("Removing close callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_close = None; + web_socket_server.off_close_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn on_message_receive( + env: &AniEnv, + this: bridge::AniWebSocketServer, + callback: AniFnObject, +) -> Result<(), BusinessError> { + info!("add message receive callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_message_receive = Some(callback.into_global_callback(env).unwrap()); + web_socket_server.on_message_receive_native(); + Ok(()) +} + +#[ani_rs::native] +pub(crate) fn off_message_receive( + env: &AniEnv, + this: bridge::AniWebSocketServer, + callback: AniFnObject, +) -> Result<(), BusinessError> { + info!("Removing message receive callback for WebSocket server"); + let web_socket_server = unsafe { &mut *(this.nativePtr as *mut AniServer) }; + web_socket_server.callback.on_message_receive = None; + web_socket_server.off_message_receive_native(); + Ok(()) +} \ No newline at end of file diff --git a/frameworks/ets/ani/web_socket/src/wrapper.rs b/frameworks/ets/ani/web_socket/src/wrapper.rs index 747dffc34e5580ca302cbde16b1096aa9eab998a..1889b740bf6e3a05b7a94223821e128ea65518be 100644 --- a/frameworks/ets/ani/web_socket/src/wrapper.rs +++ b/frameworks/ets/ani/web_socket/src/wrapper.rs @@ -12,40 +12,169 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; +use std::sync::{Mutex, OnceLock}; -use crate::bridge::ClientCert; +use ani_rs::{ + business_error::BusinessError, + objects::{GlobalRefCallback, GlobalRefAsyncCallback, GlobalRefErrorCallback}, + AniEnv, +}; -pub struct WebSocket { - client: cxx::UniquePtr, +use crate::bridge::{ + get_web_socket_connection_client_ip, get_web_socket_connection_client_port, + socket_connection_push_data, AniClientCert, AniCloseResult, AniOpenResult, + AniWebSocketConnection, AniWebSocketMessage, AniData, AniResponseHeaders, +}; + + +static WS_MAP_CLIENT: OnceLock>> = OnceLock::new(); + +fn get_ws_client_map() -> &'static Mutex> { + WS_MAP_CLIENT.get_or_init(|| Mutex::new(HashMap::new())) +} + +pub fn on_open_websocket_client(client: Pin<&mut ffi::WebSocketClientWrapper>, message: String, status: u32) { + let client_ptr = &*client as *const _ as *mut ffi::WebSocketClientWrapper as usize; + if let Some(&ws_ptr) = get_ws_client_map().lock().unwrap().get(&client_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniClient) }; + if let Some(cb) = &ws.callback.on_open { + let cr = AniOpenResult { + status: status as i32, + message: message, + }; + cb.execute(None, (cr,)); + } + } +} + +pub fn on_message_websocket_client(client: Pin<&mut ffi::WebSocketClientWrapper>, data: String, len: u32) { + let client_ptr = &*client as *const _ as *mut ffi::WebSocketClientWrapper as usize; + if let Some(&ws_ptr) = get_ws_client_map().lock().unwrap().get(&client_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniClient) }; + if let Some(cb) = &ws.callback.on_message { + let message = AniData::S(data); + cb.execute(None, (message,)); + } + } +} + +pub fn on_close_websocket_client(client: Pin<&mut ffi::WebSocketClientWrapper>, reason: String, code: u32) { + let client_ptr = &*client as *const _ as *mut ffi::WebSocketClientWrapper as usize; + if let Some(&ws_ptr) = get_ws_client_map().lock().unwrap().get(&client_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniClient) }; + if let Some(cb) = &ws.callback.on_close { + let cr = AniCloseResult { + code: code as i32, + reason: reason, + }; + cb.execute(None, (cr,)); + } + } +} + +pub fn on_error_websocket_client(client: Pin<&mut ffi::WebSocketClientWrapper>, errMessage: String, errCode: u32) { + let client_ptr = &*client as *const _ as *mut ffi::WebSocketClientWrapper as usize; + if let Some(&ws_ptr) = get_ws_client_map().lock().unwrap().get(&client_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniClient) }; + if let Some(cb) = &ws.callback.on_error { + let err = BusinessError::new(errCode as i32, errMessage); + cb.execute(err); + } + } } -impl WebSocket { +pub fn on_data_end_websocket_client(client: Pin<&mut ffi::WebSocketClientWrapper>) { + let client_ptr = &*client as *const _ as *mut ffi::WebSocketClientWrapper as usize; + if let Some(&ws_ptr) = get_ws_client_map().lock().unwrap().get(&client_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniClient) }; + if let Some(cb) = &ws.callback.on_data_end { + cb.execute(()); + } + } +} + +pub fn on_header_receive_websocket_client(client: Pin<&mut ffi::WebSocketClientWrapper>, keys: &mut Vec, + values: &mut Vec) { + let client_ptr = &*client as *const _ as *mut ffi::WebSocketClientWrapper as usize; + if let Some(&ws_ptr) = get_ws_client_map().lock().unwrap().get(&client_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniClient) }; + if let Some(cb) = &ws.callback.on_header_receive { + let mut data = HashMap::new(); + for (key, value) in keys.iter().zip(values.iter()) { + data.insert(key.clone(), value.clone()); + } + let map_headers = AniResponseHeaders::MapBuffer(data); + cb.execute((map_headers,)); + } + } +} + +pub fn header_push_data(header: &mut Vec, data: String) +{ + header.push(data); +} + +pub struct CallBackWebSocketClient { + pub on_open: Option>, + pub on_message: Option>, + pub on_close: Option>, + pub on_error: Option, + pub on_data_end: Option>, + pub on_header_receive: Option>, +} + +impl CallBackWebSocketClient { pub fn new() -> Self { + Self { + on_open: None, + on_message: None, + on_close: None, + on_error: None, + on_data_end: None, + on_header_receive: None, + } + } +} + +pub struct AniClient { + client: cxx::UniquePtr, + pub callback: CallBackWebSocketClient, +} + +impl AniClient { + pub fn new() -> i64 { let client = ffi::CreateWebSocket(); - WebSocket { client } + let callback = CallBackWebSocketClient::new(); + let ws = AniClient { client, callback }; + let client_ptr = ws.client.as_ref().unwrap() as *const _ as usize; + let web_socket = Box::new(ws); + let ptr = Box::into_raw(web_socket); + get_ws_client_map().lock().unwrap().insert(client_ptr, ptr as usize); + ptr as i64 } pub fn connect( &mut self, url: &str, headers: HashMap, - ca_path: Option, - client_cert: Option, + caPath: Option, + clientCert: Option, protocol: Option, ) -> Result<(), i32> { - let options = ffi::ConnectOptions { + let options = ffi::AniConnectOptions { headers: headers .iter() .map(|(k, v)| [k.as_str(), v.as_str()]) .flatten() .collect(), }; - if let Some(ca_path) = ca_path { - ffi::SetCaPath(self.client.pin_mut(), &ca_path); + if let Some(caPath) = caPath { + ffi::SetCaPath(self.client.pin_mut(), &caPath); } - if let Some(cert) = client_cert { - ffi::SetClientCert(self.client.pin_mut(), &cert.cert_path, &cert.key_path); - if let Some(password) = cert.key_password { + if let Some(cert) = clientCert { + ffi::SetClientCert(self.client.pin_mut(), &cert.certPath, &cert.keyPath); + if let Some(password) = cert.keyPassword { ffi::SetCertPassword(self.client.pin_mut(), &password); } } @@ -57,8 +186,8 @@ impl WebSocket { Ok(()) } - pub fn send(&mut self, data: &str) -> Result<(), i32> { - let ret = ffi::Send(self.client.pin_mut(), data); + pub fn send(&mut self, data: Vec, data_type: i32) -> Result<(), i32> { + let ret = ffi::Send(self.client.pin_mut(), data, data_type); if ret != 0 { return Err(ret); } @@ -66,45 +195,533 @@ impl WebSocket { } pub fn close(&mut self, code: u32, reason: &str) -> Result<(), i32> { - let options = ffi::CloseOption { code, reason }; + let options = ffi::AniCloseOption { code, reason }; let ret = ffi::Close(self.client.pin_mut(), options); if ret != 0 { return Err(ret); } Ok(()) } + + pub fn on_open_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterOpenCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_message_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterMessageCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_close_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterCloseCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_error_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterErrorCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_data_end_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterDataEndCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_header_receive_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterHeaderReceiveCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_open_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterOpenCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_message_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterMessageCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_close_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterCloseCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_error_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterErrorCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_data_end_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterDataEndCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_header_receive_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterHeaderReceiveCallback(self.client.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } +} + +impl Drop for AniClient { + fn drop(&mut self) { + if let Some(client_ptr) = self.client.as_ref().map(|c| c as *const _ as usize) { + get_ws_client_map().lock().unwrap().remove(&client_ptr); + } + } +} + +/** + * @brief server + */ +static WS_MAP_SERVER: OnceLock>> = OnceLock::new(); + +fn get_ws_server_map() -> &'static Mutex> { + WS_MAP_SERVER.get_or_init(|| Mutex::new(HashMap::new())) +} + +pub fn on_error_websocket_server( + server: Pin<&mut ffi::WebSocketServer>, + message: String, + code: u32, +) { + let server_ptr = &*server as *const _ as *mut ffi::WebSocketServer as usize; + if let Some(&ws_ptr) = get_ws_server_map().lock().unwrap().get(&server_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniServer) }; + if let Some(cb) = &ws.callback.on_error { + let error = BusinessError::new(code as i32, message); + cb.execute(error); + } + } +} + +pub fn on_connect_websocket_server(server: Pin<&mut ffi::WebSocketServer>, ip: String, port: u32) { + let server_ptr = &*server as *const _ as *mut ffi::WebSocketServer as usize; + if let Some(&ws_ptr) = get_ws_server_map().lock().unwrap().get(&server_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniServer) }; + if let Some(cb) = &ws.callback.on_connect { + let connection = AniWebSocketConnection { + clientIP: ip, + clientPort: port as i32, + }; + cb.execute((connection,)); + } + } +} + +pub fn on_close_websocket_server( + server: Pin<&mut ffi::WebSocketServer>, + reason: String, + code: u32, + ip: String, + port: u32, +) { + let server_ptr = &*server as *const _ as *mut ffi::WebSocketServer as usize; + if let Some(&ws_ptr) = get_ws_server_map().lock().unwrap().get(&server_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniServer) }; + if let Some(cb) = &ws.callback.on_close { + let connection = AniWebSocketConnection { + clientIP: ip, + clientPort: port as i32, + }; + let result = AniCloseResult { + code: code as i32, + reason: reason, + }; + cb.execute((connection, result,)); + } + } +} + +pub fn on_message_receive_websocket_server( + server: Pin<&mut ffi::WebSocketServer>, + data: String, + length: u32, + ip: String, + port: u32, +) { + let server_ptr = &*server as *const _ as *mut ffi::WebSocketServer as usize; + if let Some(&ws_ptr) = get_ws_server_map().lock().unwrap().get(&server_ptr) { + let ws = unsafe { &mut *(ws_ptr as *mut AniServer) }; + if let Some(cb) = &ws.callback.on_message_receive { + let data = AniData::S(data); + let connection = AniWebSocketConnection { + clientIP: ip, + clientPort: port as i32, + }; + let message = AniWebSocketMessage { + data: data, + clientConnection: connection, + }; + cb.execute((message,)); + } + } +} + +pub struct CallBackWebSocketServer { + pub on_error: Option, + pub on_connect: Option>, + pub on_close: Option>, + pub on_message_receive: Option>, +} + +impl CallBackWebSocketServer { + pub fn new() -> Self { + Self { + on_error: None, + on_connect: None, + on_close: None, + on_message_receive: None, + } + } +} + +pub struct AniServer { + server: cxx::UniquePtr, + pub callback: CallBackWebSocketServer, +} + +impl AniServer { + pub fn new() -> i64 { + let server = ffi::CreateWebSocketServer(); + let callback = CallBackWebSocketServer::new(); + let ws = AniServer { server, callback }; + let server_ptr = ws.server.as_ref().unwrap() as *const _ as usize; + let web_socket_server = Box::new(ws); + let ptr = Box::into_raw(web_socket_server); + get_ws_server_map().lock().unwrap().insert(server_ptr, ptr as usize); + ptr as i64 + } + + pub fn start( + &mut self, + serverIP: String, + serverPort: i32, + server_cert_path: String, + server_key_path: String, + maxConcurrentClientsNumber: i32, + protocol: String, + maxConnectionsForOneClient: i32, + ) -> Result<(), i32> { + let server_config_cert = ffi::AniServerConfigCert { + certPath: server_cert_path, + keyPath: server_key_path, + }; + let server_config = ffi::AniServerConfig { + serverIP: serverIP, + serverPort: serverPort, + serverCert: server_config_cert, + maxConcurrentClientsNumber: maxConcurrentClientsNumber, + protocol: protocol, + maxConnectionsForOneClient: maxConnectionsForOneClient, + }; + let ret = ffi::StartServer(self.server.pin_mut(), server_config); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn stop(&mut self) -> Result<(), i32> { + let ret = ffi::StopServer(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn send(&mut self, data: Vec, connection: &AniWebSocketConnection, data_type: i32) -> Result<(), i32> { + let ret = ffi::SendServerData(self.server.pin_mut(), data, &connection, data_type); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn close( + &mut self, + connection: &AniWebSocketConnection, + code: u32, + reason: &str, + ) -> Result<(), i32> { + let option = ffi::AniCloseOption { code, reason }; + let ret = ffi::CloseServer(self.server.pin_mut(), &connection, option); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn list_all_connections( + &mut self, + connections: &mut Vec, + ) -> Result<(), i32> { + let ret = ffi::ListAllConnections(self.server.pin_mut(), connections); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_error_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterServerErrorCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_error_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterServerErrorCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_connect_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterServerConnectCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_connect_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterServerConnectCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_close_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterServerCloseCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_close_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterServerCloseCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn on_message_receive_native(&mut self) -> Result<(), i32> { + let ret = ffi::RegisterServerMessageReceiveCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } + + pub fn off_message_receive_native(&mut self) -> Result<(), i32> { + let ret = ffi::UnregisterServerMessageReceiveCallback(self.server.pin_mut()); + if ret != 0 { + return Err(ret); + } + Ok(()) + } +} + +impl Drop for AniServer { + fn drop(&mut self) { + if let Some(server_ptr) = self.server.as_ref().map(|c| c as *const _ as usize) { + get_ws_server_map().lock().unwrap().remove(&server_ptr); + } + } } #[cxx::bridge(namespace = "OHOS::NetStackAni")] mod ffi { - - pub struct ConnectOptions<'a> { + pub struct AniConnectOptions<'a> { pub headers: Vec<&'a str>, } - struct CloseOption<'a> { + struct AniCloseOption<'a> { code: u32, reason: &'a str, } + pub struct AniServerConfigCert { + certPath: String, + keyPath: String, + } + + pub struct AniServerConfig { + serverIP: String, + serverPort: i32, + serverCert: AniServerConfigCert, + maxConcurrentClientsNumber: i32, + protocol: String, + maxConnectionsForOneClient: i32, + } + + extern "Rust" { + type AniClient; + fn on_open_websocket_client(client: Pin<&mut WebSocketClientWrapper>, message: String, status: u32); + fn on_message_websocket_client(client: Pin<&mut WebSocketClientWrapper>, data: String, len: u32); + fn on_close_websocket_client(client: Pin<&mut WebSocketClientWrapper>, reason: String, code: u32); + fn on_error_websocket_client(client: Pin<&mut WebSocketClientWrapper>, errMessage: String, errCode: u32); + fn on_data_end_websocket_client(client: Pin<&mut WebSocketClientWrapper>); + fn on_header_receive_websocket_client(client: Pin<&mut WebSocketClientWrapper>, keys: &mut Vec, values: &mut Vec); + fn header_push_data(header: &mut Vec, data: String); + + type AniServer; + type AniWebSocketConnection; + fn get_web_socket_connection_client_ip(conn: &AniWebSocketConnection) -> String; + fn get_web_socket_connection_client_port(conn: &AniWebSocketConnection) -> i32; + fn socket_connection_push_data( + connection_info_value: &mut Vec, + clientIP: String, + clientPort: i32, + ); + + fn on_error_websocket_server(server: Pin<&mut WebSocketServer>, message: String, code: u32); + fn on_connect_websocket_server(server: Pin<&mut WebSocketServer>, ip: String, port: u32); + fn on_close_websocket_server( + server: Pin<&mut WebSocketServer>, + reason: String, + code: u32, + ip: String, + port: u32, + ); + fn on_message_receive_websocket_server( + server: Pin<&mut WebSocketServer>, + data: String, + length: u32, + ip: String, + port: u32, + ); + } + unsafe extern "C++" { include!("websocket_ani.h"); - #[namespace = "OHOS::NetStack::WebSocketClient"] - type WebSocketClient; + //#[namespace = "OHOS::NetStack::WebSocketClient"] + type WebSocketClientWrapper; + + fn CreateWebSocket() -> UniquePtr; + + fn Connect(client: Pin<&mut WebSocketClientWrapper>, url: &str, options: AniConnectOptions) + -> i32; + + fn SetCaPath(client: Pin<&mut WebSocketClientWrapper>, caPath: &str); + + fn SetClientCert(client: Pin<&mut WebSocketClientWrapper>, certPath: &str, key: &str); + + fn SetCertPassword(client: Pin<&mut WebSocketClientWrapper>, password: &str); + + fn Send(client: Pin<&mut WebSocketClientWrapper>, data: Vec, data_type: i32) -> i32; + + fn Close(client: Pin<&mut WebSocketClientWrapper>, options: AniCloseOption) -> i32; + + fn RegisterOpenCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn RegisterMessageCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn RegisterCloseCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn RegisterErrorCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn RegisterDataEndCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn RegisterHeaderReceiveCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn UnregisterOpenCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn UnregisterMessageCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn UnregisterCloseCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn UnregisterErrorCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn UnregisterDataEndCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + fn UnregisterHeaderReceiveCallback(client: Pin<&mut WebSocketClientWrapper>) -> i32; + + #[namespace = "OHOS::NetStack::WebSocketServer"] + type WebSocketServer; + + fn CreateWebSocketServer() -> UniquePtr; + + fn StartServer(server: Pin<&mut WebSocketServer>, options: AniServerConfig) -> i32; + + fn StopServer(server: Pin<&mut WebSocketServer>) -> i32; + + fn SendServerData( + server: Pin<&mut WebSocketServer>, + data: Vec, + connection: &AniWebSocketConnection, + data_type: i32, + ) -> i32; + + fn CloseServer( + server: Pin<&mut WebSocketServer>, + connection: &AniWebSocketConnection, + options: AniCloseOption, + ) -> i32; + + fn ListAllConnections( + server: Pin<&mut WebSocketServer>, + connections: &mut Vec, + ) -> i32; + + fn RegisterServerErrorCallback(server: Pin<&mut WebSocketServer>) -> i32; - fn CreateWebSocket() -> UniquePtr; + fn RegisterServerConnectCallback(server: Pin<&mut WebSocketServer>) -> i32; - fn Connect(client: Pin<&mut WebSocketClient>, url: &str, options: ConnectOptions) -> i32; + fn RegisterServerCloseCallback(server: Pin<&mut WebSocketServer>) -> i32; - fn SetCaPath(client: Pin<&mut WebSocketClient>, ca_path: &str); + fn RegisterServerMessageReceiveCallback(server: Pin<&mut WebSocketServer>) -> i32; - fn SetClientCert(client: Pin<&mut WebSocketClient>, cert_path: &str, key: &str); + fn UnregisterServerErrorCallback(server: Pin<&mut WebSocketServer>) -> i32; - fn SetCertPassword(client: Pin<&mut WebSocketClient>, password: &str); + fn UnregisterServerConnectCallback(server: Pin<&mut WebSocketServer>) -> i32; - fn Send(client: Pin<&mut WebSocketClient>, data: &str) -> i32; + fn UnregisterServerCloseCallback(server: Pin<&mut WebSocketServer>) -> i32; - fn Close(client: Pin<&mut WebSocketClient>, options: CloseOption) -> i32; + fn UnregisterServerMessageReceiveCallback(server: Pin<&mut WebSocketServer>) -> i32; } } diff --git a/frameworks/js/napi/websocket/websocket_module/src/websocket_client.cpp b/frameworks/js/napi/websocket/websocket_module/src/websocket_client.cpp index 0716525459e0a09059844b9b86c78f27df2fe418..ea8bb4f3dea37e33d0e8d29130b44b0e01ba3c94 100644 --- a/frameworks/js/napi/websocket/websocket_module/src/websocket_client.cpp +++ b/frameworks/js/napi/websocket/websocket_module/src/websocket_client.cpp @@ -27,6 +27,19 @@ #include "net_conn_client.h" #endif +enum WebsocketErrorCodeEx { + WEBSOCKET_CONNECT_FAILED = -1, + WEBSOCKET_ERROR_CODE_BASE = 2302000, + WEBSOCKET_ERROR_CODE_URL_ERROR = WEBSOCKET_ERROR_CODE_BASE + 1, + WEBSOCKET_ERROR_CODE_FILE_NOT_EXIST = WEBSOCKET_ERROR_CODE_BASE + 2, + WEBSOCKET_ERROR_CODE_CONNECT_AlREADY_EXIST = WEBSOCKET_ERROR_CODE_BASE + 3, + WEBSOCKET_ERROR_CODE_INVALID_NIC = WEBSOCKET_ERROR_CODE_BASE + 4, + WEBSOCKET_ERROR_CODE_INVALID_PORT = WEBSOCKET_ERROR_CODE_BASE + 5, + WEBSOCKET_ERROR_CODE_CONNECTION_NOT_EXIST = WEBSOCKET_ERROR_CODE_BASE + 6, + WEBSOCKET_NOT_ALLOWED_HOST = 2302998, + WEBSOCKET_UNKNOWN_OTHER_ERROR = 2302999 +}; + static constexpr const char *PATH_START = "/"; static constexpr const char *NAME_END = ":"; static constexpr const char *STATUS_LINE_SEP = " "; @@ -197,7 +210,9 @@ int LwsCallbackClientConnectionError(lws *wsi, lws_callback_reasons reason, void ErrorResult errorResult; errorResult.errorCode = WebSocketErrorCode::WEBSOCKET_CONNECTION_ERROR; errorResult.errorMessage = data; - client->onErrorCallback_(client, errorResult); + if (client->onErrorCallback_) { + client->onErrorCallback_(client, errorResult); + } return HttpDummy(wsi, reason, user, in, len); } @@ -211,7 +226,12 @@ int LwsCallbackClientReceive(lws *wsi, lws_callback_reasons reason, void *user, return HttpDummy(wsi, reason, user, in, len); } std::string data = client->GetData(); - client->onMessageCallback_(client, data.c_str(), data.size()); + if (client->onMessageCallback_) { + client->onMessageCallback_(client, data.c_str(), data.size()); + } + if (client->onDataEndCallback_) { + client->onDataEndCallback_(client); + } client->ClearData(); return HttpDummy(wsi, reason, user, in, len); } @@ -269,6 +289,34 @@ int LwsCallbackClientFilterPreEstablish(lws *wsi, lws_callback_reasons reason, v if (vec.size() >= FUNCTION_PARAM_TWO) { client->GetClientContext()->openMessage = vec[1]; } + char buffer[MAX_HDR_LENGTH] = {}; + std::map responseHeader; + for (int i = 0; i < WSI_TOKEN_COUNT; i++) { + if (lws_hdr_total_length(wsi, static_cast(i)) > 0) { + lws_hdr_copy(wsi, buffer, sizeof(buffer), static_cast(i)); + std::string str; + if (lws_token_to_string(static_cast(i))) { + str = + std::string(reinterpret_cast(lws_token_to_string(static_cast(i)))); + } + if (!str.empty() && str.back() == ':') { + responseHeader.emplace(str.substr(0, str.size() - 1), std::string(buffer)); + } + } + } + lws_hdr_custom_name_foreach( + wsi, + [](const char *name, int nlen, void *opaque) -> void { + auto header = static_cast *>(opaque); + if (header == nullptr) { + return; + } + header->emplace(std::string(name).substr(0, nlen - 1), std::string(name).substr(nlen)); + }, + &responseHeader); + if (client->onHeaderReceiveCallback_) { + client->onHeaderReceiveCallback_(client, responseHeader); + } return HttpDummy(wsi, reason, user, in, len); } @@ -285,8 +333,9 @@ int LwsCallbackClientEstablished(lws *wsi, lws_callback_reasons reason, void *us OpenResult openResult; openResult.status = client->GetClientContext()->openStatus; openResult.message = client->GetClientContext()->openMessage.c_str(); - client->onOpenCallback_(client, openResult); - + if (client->onOpenCallback_) { + client->onOpenCallback_(client, openResult); + } return HttpDummy(wsi, reason, user, in, len); } @@ -302,9 +351,20 @@ int LwsCallbackClientClosed(lws *wsi, lws_callback_reasons reason, void *user, v char *data = static_cast(in); buf.assign(data, len); CloseResult closeResult; - closeResult.code = CLOSE_RESULT_FROM_SERVER_CODE; - closeResult.reason = CLOSE_REASON_FORM_SERVER; - client->onCloseCallback_(client, closeResult); + auto ctx = client->GetClientContext(); + if (ctx != nullptr && ctx->closeStatus != LWS_CLOSE_STATUS_NOSTATUS) { + closeResult.code = static_cast(ctx->closeStatus); + } else { + closeResult.code = CLOSE_RESULT_FROM_SERVER_CODE; + } + if (ctx != nullptr && !ctx->closeReason.empty()) { + closeResult.reason = ctx->closeReason.c_str(); + } else { + closeResult.reason = CLOSE_REASON_FORM_SERVER; + } + if (client->onCloseCallback_) { + client->onCloseCallback_(client, closeResult); + } client->GetClientContext()->SetThreadStop(true); if ((client->GetClientContext()->closeReason).empty()) { client->GetClientContext()->Close(client->GetClientContext()->closeStatus, LINK_DOWN); @@ -353,7 +413,7 @@ int LwsCallback(lws *wsi, lws_callback_reasons reason, void *user, void *in, siz }; auto it = std::find_if(std::begin(dispatchers), std::end(dispatchers), [&reason](const CallbackDispatcher &dispatcher) { return dispatcher.reason == reason; }); - if (it != std::end(dispatchers)) { + if (it != std::end(dispatchers) && user != nullptr) { return it->callback(wsi, reason, user, in, len); } return HttpDummy(wsi, reason, user, in, len); @@ -609,7 +669,6 @@ int WebSocketClient::Send(char *data, size_t length) int WebSocketClient::Close(CloseOption options) { - NETSTACK_LOGI("Close start"); if (this->GetClientContext() == nullptr) { return WebSocketErrorCode::WEBSOCKET_ERROR_NO_CLIENTCONTEX; } @@ -646,4 +705,173 @@ int WebSocketClient::Destroy() return WebSocketErrorCode::WEBSOCKET_NONE_ERR; } +int CreatConnectInfoEx(const std::string url, lws_context *lwsContext, WebSocketClient *client) +{ + lws_client_connect_info connectInfo = {}; + char prefix[MAX_URI_LENGTH] = {0}; + char address[MAX_URI_LENGTH] = {0}; + char pathWithoutStart[MAX_URI_LENGTH] = {0}; + int port = 0; + if (!ParseUrl(url, prefix, address, pathWithoutStart, &port)) { + return WebSocketErrorCode::WEBSOCKET_CONNECTION_PARSEURL_ERROR; + } + std::string path = PATH_START + std::string(pathWithoutStart); + std::string tempHost; + if ((strcmp(prefix, PREFIX_WS) == 0 && port == WS_DEFAULT_PORT) || + (strcmp(prefix, PREFIX_WSS) == 0 && port == WSS_DEFAULT_PORT)) { + tempHost = std::string(address); + } else { + tempHost = std::string(address) + NAME_END + std::to_string(port); + } + connectInfo.context = lwsContext; + connectInfo.address = address; + connectInfo.port = port; + connectInfo.path = path.c_str(); + connectInfo.host = tempHost.c_str(); + connectInfo.origin = address; + + NETSTACK_LOGI("Connect info %{public}s, %{public}d, %{public}s, %{public}s", address, port, path.c_str(), tempHost.c_str()); + + connectInfo.local_protocol_name = "lws-minimal-client1"; + connectInfo.retry_and_idle_policy = &RETRY; + if (strcmp(prefix, PREFIX_HTTPS) == 0 || strcmp(prefix, PREFIX_WSS) == 0) { + connectInfo.ssl_connection = + LCCSCF_USE_SSL | LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK | LCCSCF_ALLOW_INSECURE | LCCSCF_ALLOW_SELFSIGNED; + } + lws *wsi = nullptr; + connectInfo.pwsi = &wsi; + connectInfo.userdata = client; + if (lws_client_connect_via_info(&connectInfo) == nullptr) { + NETSTACK_LOGE("Connect lws_context_destroy"); + return -1; + } + return WebSocketErrorCode::WEBSOCKET_NONE_ERR; +} + +int WebSocketClient::ConnectEx(std::string url, struct OpenOptions options) +{ + NETSTACK_LOGI("ClientId:%{public}d, Connect start, %{public}p, %{public}p", this->GetClientContext()->GetClientId(), this, onOpenCallback_); + if (!CommonUtils::HasInternetPermission()) { + this->GetClientContext()->permissionDenied = true; + return WebSocketErrorCode::WEBSOCKET_ERROR_PERMISSION_DENIED; + } + if (this->GetClientContext()->isAtomicService && !CommonUtils::IsAllowedHostname(this->GetClientContext()-> + bundleName, CommonUtils::DOMAIN_TYPE_WEBSOCKET_REQUEST, this->GetClientContext()->url)) { + this->GetClientContext()->noAllowedHost = true; + return WebSocketErrorCode::WEBSOCKET_ERROR_DISALLOW_HOST; + } + if (!options.headers.empty()) { + if (options.headers.size() > MAX_HEADER_LENGTH) { + return WebSocketErrorCode::WEBSOCKET_ERROR_NO_HEADR_EXCEEDS; + } + for (const auto &item : options.headers) { + const std::string &key = item.first; + const std::string &value = item.second; + this->GetClientContext()->header[key] = value; + } + } + lws_context_creation_info info = {}; + FillContextInfo(this->GetClientContext(), info); + FillCaPath(this->GetClientContext(), info); + if (this->GetClientContext()->GetContext() != nullptr) { + NETSTACK_LOGE("Websocket connect already exist"); + return WebsocketErrorCodeEx::WEBSOCKET_ERROR_CODE_CONNECT_AlREADY_EXIST; + } + lws_context *lwsContext = lws_create_context(&info); + if (lwsContext == nullptr) { + return WebSocketErrorCode::WEBSOCKET_CONNECTION_NO_MEMOERY; + } + this->GetClientContext()->SetContext(lwsContext); + int ret = CreatConnectInfoEx(url, lwsContext, this); + if (ret != WEBSOCKET_NONE_ERR) { + NETSTACK_LOGE("websocket CreatConnectInfo error"); + GetClientContext()->SetContext(nullptr); + lws_context_destroy(lwsContext); + return ret; + } + std::weak_ptr weak = shared_from_this(); + std::thread serviceThread = std::thread([weak]() { + auto client = weak.lock(); + if(client == nullptr) { + NETSTACK_LOGE("WebSocketClient instance has been destroyed"); + return; + } + int id = client->GetClientContext()->GetClientId(); + auto* context = client->GetClientContext()->GetContext(); + if (context == nullptr) { + return; + } + NETSTACK_LOGI("start RunService %{public}d, %{public}p, %{public}p", id, client.get(), context); + int res = 0; + while (res >= 0 && !client->GetClientContext()->IsThreadStop()) { + res = lws_service(context, 0); + } + NETSTACK_LOGI("end RunService %{public}d, %{public}p, %{public}p, %{public}d", id, client.get(), context, res); + lws_context_destroy(context); + NETSTACK_LOGI("end1 RunService %{public}d, %{public}p, %{public}p, %{public}d", id, client.get(), context, res); + client->GetClientContext()->SetContext(nullptr); + NETSTACK_LOGI("end2 RunService %{public}d, %{public}p", id, client.get()); + client = nullptr; + }); +#if defined(MAC_PLATFORM) || defined(IOS_PLATFORM) + pthread_setname_np(WEBSOCKET_CLIENT_THREAD_RUN); +#else + pthread_setname_np(serviceThread.native_handle(), WEBSOCKET_CLIENT_THREAD_RUN); +#endif + serviceThread.detach(); + return WebSocketErrorCode::WEBSOCKET_NONE_ERR; +} + +int WebSocketClient::SendEx(char *data, size_t length) +{ + NETSTACK_LOGI("WebSocketClient::SendEx start %{public}s, %{public}d", data, length); + if (data == nullptr) { + return WebSocketErrorCode::WEBSOCKET_SEND_DATA_NULL; + } + if (length == 0) { + return WebSocketErrorCode::WEBSOCKET_NONE_ERR; + } + if (length > MAX_DATA_LENGTH) { + return WebSocketErrorCode::WEBSOCKET_DATA_LENGTH_EXCEEDS; + } + if (this->GetClientContext() == nullptr) { + return WebSocketErrorCode::WEBSOCKET_ERROR_NO_CLIENTCONTEX; + } + if (this->GetClientContext()->GetContext() == nullptr) { + return -1; + } + + lws_write_protocol protocol = (strlen(data) == length) ? LWS_WRITE_TEXT : LWS_WRITE_BINARY; + auto dataCopy = reinterpret_cast(malloc(length)); + if (dataCopy == nullptr) { + NETSTACK_LOGE("webSocketClient malloc error"); + return WEBSOCKET_SEND_NO_MEMOERY_ERROR; + } else if (memcpy_s(dataCopy, length, data, length) != EOK) { + free(dataCopy); + NETSTACK_LOGE("webSocketClient malloc copy error"); + return WEBSOCKET_SEND_NO_MEMOERY_ERROR; + } + this->GetClientContext()->Push(dataCopy, length, protocol); + this->GetClientContext()->TriggerWritable(); + NETSTACK_LOGI("WebSocketClient::Send end %{public}s, %{public}s, %{public}d", dataCopy, data, length); + return WebSocketErrorCode::WEBSOCKET_NONE_ERR; +} + +int WebSocketClient::CloseEx(CloseOption options) +{ + if (this->GetClientContext() == nullptr) { + return WebSocketErrorCode::WEBSOCKET_ERROR_NO_CLIENTCONTEX; + } + if (this->GetClientContext()->GetContext() == nullptr) { + return -1; + } + if (options.reason == nullptr || options.code == 0) { + options.reason = ""; + options.code = CLOSE_RESULT_FROM_CLIENT_CODE; + } + this->GetClientContext()->Close(static_cast(options.code), options.reason); + this->GetClientContext()->TriggerWritable(); + return WebSocketErrorCode::WEBSOCKET_NONE_ERR; +} + } // namespace OHOS::NetStack::WebSocketClient \ No newline at end of file diff --git a/frameworks/js/napi/websocket/websocket_module/src/websocket_server.cpp b/frameworks/js/napi/websocket/websocket_module/src/websocket_server.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2304c5be85f3bc3e1e327e45a0929b8a843dc923 --- /dev/null +++ b/frameworks/js/napi/websocket/websocket_module/src/websocket_server.cpp @@ -0,0 +1,786 @@ +/* + * Copyright (c) 2025 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netstack_log.h" +#include "netstack_common_utils.h" +#include "websocket_server_innerapi.h" + +#define LWS_PLUGIN_STATIC + +static constexpr const char *WEBSOCKET_SERVER_THREAD_RUN = "OS_NET_WSJsSer"; + +static constexpr const char *LINK_DOWN = "The link is down"; + +static constexpr const uint32_t MAX_CONCURRENT_CLIENTS_NUMBER = 10; + +static constexpr const uint32_t MAX_CONNECTIONS_FOR_ONE_CLIENT = 10; + +static constexpr const int32_t COMMON_ERROR_CODE = 200; + +namespace OHOS::NetStack::WebSocketServer { +enum WebsocketErrorCode { + WEBSOCKET_CONNECT_FAILED = -1, + WEBSOCKET_ERROR_CODE_BASE = 2302000, + WEBSOCKET_ERROR_CODE_URL_ERROR = WEBSOCKET_ERROR_CODE_BASE + 1, + WEBSOCKET_ERROR_CODE_FILE_NOT_EXIST = WEBSOCKET_ERROR_CODE_BASE + 2, + WEBSOCKET_ERROR_CODE_CONNECT_ALREADY_EXIST = WEBSOCKET_ERROR_CODE_BASE + 3, + WEBSOCKET_ERROR_CODE_INVALID_NIC = WEBSOCKET_ERROR_CODE_BASE + 4, + WEBSOCKET_ERROR_CODE_INVALID_PORT = WEBSOCKET_ERROR_CODE_BASE + 5, + WEBSOCKET_ERROR_CODE_CONNECTION_NOT_EXIST = WEBSOCKET_ERROR_CODE_BASE + 6, + WEBSOCKET_NOT_ALLOWED_HOST = 2302998, + WEBSOCKET_UNKNOWN_OTHER_ERROR = 2302999 +}; + +static const std::map WEBSOCKET_ERR_MAP = { { WEBSOCKET_CONNECT_FAILED, + "Websocket connect failed" }, + { WEBSOCKET_ERROR_CODE_URL_ERROR, "Websocket url error" }, + { WEBSOCKET_ERROR_CODE_FILE_NOT_EXIST, "Websocket file not exist" }, + { WEBSOCKET_ERROR_CODE_CONNECT_ALREADY_EXIST, "Websocket connection exist" }, + { WEBSOCKET_ERROR_CODE_INVALID_NIC, "Can't listen to the given NIC" }, + { WEBSOCKET_ERROR_CODE_INVALID_PORT, "Can't listen to the given Port" }, + { WEBSOCKET_ERROR_CODE_CONNECTION_NOT_EXIST, "websocket connection does not exist" }, + { WEBSOCKET_NOT_ALLOWED_HOST, "It is not allowed to access this domain" }, + { WEBSOCKET_UNKNOWN_OTHER_ERROR, "Websocket Unknown Other Error" } }; + +enum { + CLOSE_REASON_NORMAL_CLOSE [[maybe_unused]] = 1000, + CLOSE_REASON_SERVER_CLOSED [[maybe_unused]] = 1001, + CLOSE_REASON_PROTOCOL_ERROR [[maybe_unused]] = 1002, + CLOSE_REASON_UNSUPPORT_DATA_TYPE [[maybe_unused]] = 1003, + CLOSE_REASON_RESERVED1 [[maybe_unused]], + CLOSE_REASON_RESERVED2 [[maybe_unused]], + CLOSE_REASON_RESERVED3 [[maybe_unused]], + CLOSE_REASON_RESERVED4 [[maybe_unused]], + CLOSE_REASON_RESERVED5 [[maybe_unused]], + CLOSE_REASON_RESERVED6 [[maybe_unused]], + CLOSE_REASON_RESERVED7 [[maybe_unused]], + CLOSE_REASON_RESERVED8 [[maybe_unused]], + CLOSE_REASON_RESERVED9 [[maybe_unused]], + CLOSE_REASON_RESERVED10 [[maybe_unused]], + CLOSE_REASON_RESERVED11 [[maybe_unused]], + CLOSE_REASON_RESERVED12 [[maybe_unused]], +}; + +struct CallbackDispatcher { + lws_callback_reasons reason; + int (*callback)(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len); +}; + +static const lws_http_mount mount = { + NULL, "/", "./mount-origin", "index.html", NULL, NULL, NULL, NULL, 0, 0, 0, 0, 0, 0, LWSMPRO_FILE, 1, NULL, +}; + +void OnServerError(WebSocketServer *server, int32_t code) +{ + if (server == nullptr || server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server or context is null"); + return; + } + if (server->onErrorCallback_ == nullptr) { + NETSTACK_LOGE("onErrorCallback_ is null"); + return; + } + ErrorResult errorResult; + errorResult.errorCode = code; + auto it = WEBSOCKET_ERR_MAP.find(code); + if (it != WEBSOCKET_ERR_MAP.end()) { + errorResult.errorMessage = it->second.c_str(); + } + server->onErrorCallback_(server, errorResult); +} + +void RunServerService(WebSocketServer *server) +{ + NETSTACK_LOGI("websocket run service start"); + int res = 0; + lws_context *context = server->GetServerContext()->GetContext(); + if (context == nullptr) { + NETSTACK_LOGE("context is null"); + return; + } + while (res >= 0 && !server->GetServerContext()->IsThreadStop()) { + res = lws_service(context, 0); + } + server->Destroy(); +} + +int RaiseServerError(WebSocketServer *server) +{ + OnServerError(server, COMMON_ERROR_CODE); + return -1; +} + +int HttpDummy(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + int ret = lws_callback_http_dummy(wsi, reason, user, in, len); + if (ret < 0) { + OnServerError(reinterpret_cast(user), COMMON_ERROR_CODE); + } + return 0; +} + +bool GetPeerConnMsg(lws *wsi, std::string &clientId, SocketConnection &conn) +{ + struct sockaddr_storage addr {}; + socklen_t addrLen = sizeof(addr); + int ret = getpeername(lws_get_socket_fd(wsi), reinterpret_cast(&addr), &addrLen); + if (ret != 0) { + NETSTACK_LOGE("getpeername failed"); + return false; + } + char ipStr[INET6_ADDRSTRLEN] = {0}; + if (addr.ss_family == AF_INET) { + NETSTACK_LOGI("family is ipv4"); + auto *addrIn = reinterpret_cast(&addr); + inet_ntop(AF_INET, &addrIn->sin_addr, ipStr, sizeof(ipStr)); + uint16_t port = ntohs(addrIn->sin_port); + conn.clientPort = static_cast(port); + conn.clientIP = ipStr; + clientId = std::string(ipStr) + ":" + std::to_string(port); + } else if (addr.ss_family == AF_INET6) { + NETSTACK_LOGI("family is ipv6"); + auto *addrIn6 = reinterpret_cast(&addr); + inet_ntop(AF_INET6, &addrIn6->sin6_addr, ipStr, sizeof(ipStr)); + uint16_t port = ntohs(addrIn6->sin6_port); + conn.clientPort = static_cast(port); + conn.clientIP = ipStr; + clientId = std::string(ipStr) + ":" + std::to_string(port); + } else { + NETSTACK_LOGE("getpeer Ipv4 or Ipv6 failed"); + return false; + } + return true; +} + +int LwsCallbackEstablished(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGD("lws callback server established"); + lws_context *context = lws_get_context(wsi); + WebSocketServer *server = static_cast(lws_context_user(context)); + if (server == nullptr) { + NETSTACK_LOGE("server is null"); + return RaiseServerError(server); + } + if (server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server context is null"); + return RaiseServerError(server); + } + if (server->GetServerContext()->IsClosed() || server->GetServerContext()->IsThreadStop()) { + NETSTACK_LOGE("server is closed or thread is stopped"); + return RaiseServerError(server); + } + lws_context *lwsContext = lws_get_context(wsi); + auto clientUserData = std::make_shared(lwsContext); + lws_set_wsi_user(wsi, clientUserData.get()); + server->GetServerContext()->AddClientUserData(wsi, clientUserData); + std::string clientId; + SocketConnection connection; + bool ret = GetPeerConnMsg(wsi, clientId, connection); + if (!ret) { + NETSTACK_LOGE("GetPeerConnMsg failed"); + return RaiseServerError(server); + } + server->GetServerContext()->AddConnections(clientId, wsi, connection); + clientUserData->SetLws(wsi); + clientUserData->TriggerWritable(); + if (server->onConnectCallback_ != nullptr) { + server->onConnectCallback_(server, connection); + } + return HttpDummy(wsi, reason, user, in, len); +} + +bool IsOverMaxConcurrentClientsCnt(WebSocketServer *server, const std::vector &connections, + const std::string &ip) +{ + std::unordered_set uniqueIp; + for (const auto &conn : connections) { + uniqueIp.insert(conn.clientIP); + } + if (uniqueIp.find(ip) != uniqueIp.end()) { + return uniqueIp.size() > server->GetServerContext()->startServerConfig_.maxConcurrentClientsNumber; + } else { + return (uniqueIp.size() + 1) > server->GetServerContext()->startServerConfig_.maxConcurrentClientsNumber; + } +} + +bool IsOverMaxCntForOneClient(WebSocketServer *server, const std::vector &connections, + const std::string &ip) +{ + uint32_t cnt = 0; + for (auto it = connections.begin(); it != connections.end(); ++it) { + if (ip == it->clientIP) { + ++cnt; + } + } + if (cnt + 1 > server->GetServerContext()->startServerConfig_.maxConnectionsForOneClient) { + return true; + } + return false; +} + +bool IsOverMaxClientConns(WebSocketServer *server, const std::string &ip) +{ + std::vector connections; + server->ListAllConnections(connections); + if (IsOverMaxConcurrentClientsCnt(server, connections, ip)) { + NETSTACK_LOGI("current client connections is over max concurrent number"); + return true; + } + if (IsOverMaxCntForOneClient(server, connections, ip)) { + NETSTACK_LOGI("current connections for one client is over max number"); + return true; + } + return false; +} + +int LwsCallbackClosed(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGD("lws callback server closed"); + lws_context *context = lws_get_context(wsi); + WebSocketServer *server = static_cast(lws_context_user(context)); + if (server == nullptr) { + NETSTACK_LOGE("server is null"); + return -1; + } + if (server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server context is null"); + return -1; + } + if (server->GetServerContext()->IsClosed() || server->GetServerContext()->IsThreadStop()) { + NETSTACK_LOGE("server is closed or thread is stopped"); + return -1; + } + if (wsi == nullptr) { + NETSTACK_LOGE("wsi is null"); + return -1; + } + auto clientUserData = reinterpret_cast(lws_wsi_user(wsi)); + if (clientUserData == nullptr) { + NETSTACK_LOGE("clientUserData is null"); + return RaiseServerError(server); + } + clientUserData->SetThreadStop(true); + if ((clientUserData->closeReason).empty()) { + clientUserData->Close(clientUserData->closeStatus, LINK_DOWN); + } + if (clientUserData->closeStatus == LWS_CLOSE_STATUS_NOSTATUS) { + NETSTACK_LOGE("The link is down, onError"); + OnServerError(server, COMMON_ERROR_CODE); + } + std::string clientId = server->GetServerContext()->GetClientIdFromConnectionByWsi(wsi); + if (server->onCloseCallback_ != nullptr) { + SocketConnection sc = server->GetServerContext()->GetConnectionFromWsi(wsi); + CloseResult cr; + cr.code = clientUserData->closeStatus; + cr.reason = clientUserData->closeReason.c_str(); + server->onCloseCallback_(server, cr, sc); + } + server->GetServerContext()->RemoveConnections(clientId); + server->GetServerContext()->RemoveClientUserData(wsi); + lws_set_wsi_user(wsi, nullptr); + if (server->GetServerContext()->IsClosed() && !server->GetServerContext()->IsThreadStop()) { + NETSTACK_LOGI("server service is stopped"); + server->GetServerContext()->SetThreadStop(true); + } + return HttpDummy(wsi, reason, user, in, len); +} + +int LwsCallbackWsiDestroyServer(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGD("lws server callback wsi destroy"); + lws_context *context = lws_get_context(wsi); + WebSocketServer *server = static_cast(lws_context_user(context)); + if (wsi == nullptr) { + NETSTACK_LOGE("wsi is null"); + return -1; + } + if (server == nullptr) { + NETSTACK_LOGE("server is null"); + return RaiseServerError(server); + } + if (server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server context is null"); + return RaiseServerError(server); + } + server->GetServerContext()->SetContext(nullptr); + return HttpDummy(wsi, reason, user, in, len); +} + +int LwsCallbackProtocolDestroyServer(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGD("lws server callback protocol destroy"); + return HttpDummy(wsi, reason, user, in, len); +} + +int LwsCallbackServerWriteable(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGD("lws callback Server writable"); + lws_context *context = lws_get_context(wsi); + WebSocketServer *server = static_cast(lws_context_user(context)); + if (server == nullptr) { + NETSTACK_LOGE("server is null"); + return -1; + } + if (server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server context is null"); + return -1; + } + if (server->GetServerContext()->IsThreadStop()) { + NETSTACK_LOGE("server is closed or thread is stopped"); + return -1; + } + // client + auto *clientUserData = reinterpret_cast(lws_wsi_user(wsi)); + if (clientUserData == nullptr) { + NETSTACK_LOGE("clientUserData is null"); + return RaiseServerError(server); + } + if (clientUserData->IsClosed()) { + NETSTACK_LOGI("client is closed, need to close"); + lws_close_reason(wsi, clientUserData->closeStatus, + reinterpret_cast(const_cast(clientUserData->closeReason.c_str())), + strlen(clientUserData->closeReason.c_str())); + return -1; + } + auto sendData = clientUserData->Pop(); + if (sendData.data == nullptr || sendData.length == 0) { + NETSTACK_LOGE("send data is empty"); + return HttpDummy(wsi, reason, user, in, len); + } + int sendLength = lws_write(wsi, reinterpret_cast(sendData.data) + LWS_SEND_BUFFER_PRE_PADDING, + sendData.length, sendData.protocol); + free(sendData.data); + NETSTACK_LOGD("lws send data length is %{public}d", sendLength); + if (!clientUserData->IsEmpty()) { + NETSTACK_LOGE("userData is not empty"); + clientUserData->TriggerWritable(); + } + return HttpDummy(wsi, reason, user, in, len); +} + +int LwsCallbackWsPeerInitiatedCloseServer(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGD("lws server callback ws peer initiated close"); + if (wsi == nullptr) { + NETSTACK_LOGE("wsi is null"); + return -1; + } + lws_context *context = lws_get_context(wsi); + WebSocketServer *server = static_cast(lws_context_user(context)); + if (server == nullptr) { + NETSTACK_LOGE("server is null"); + return -1; + } + if (server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server context is null"); + return -1; + } + if (in == nullptr || len < sizeof(uint16_t)) { + NETSTACK_LOGI("No close reason"); + server->GetServerContext()->Close(LWS_CLOSE_STATUS_NORMAL, ""); + return HttpDummy(wsi, reason, user, in, len); + } + uint16_t closeStatus = ntohs(*reinterpret_cast(in)); + std::string closeReason; + closeReason.append(reinterpret_cast(in) + sizeof(uint16_t), len - sizeof(uint16_t)); + auto *clientUserData = reinterpret_cast(lws_wsi_user(wsi)); + clientUserData->Close(static_cast(closeStatus), closeReason); + return HttpDummy(wsi, reason, user, in, len); +} + +int LwsCallbackFilterProtocolConnection(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + lws_context *context = lws_get_context(wsi); + WebSocketServer *server = static_cast(lws_context_user(context)); + if (server == nullptr) { + NETSTACK_LOGE("server is null"); + return -1; + } + if (server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server context is null"); + return -1; + } + if (server->GetServerContext()->IsClosed() || server->GetServerContext()->IsThreadStop()) { + NETSTACK_LOGE("server is closed or thread is stopped"); + return -1; + } + std::string clientId; + SocketConnection connection; + bool ret = GetPeerConnMsg(wsi, clientId, connection); + if (!ret) { + NETSTACK_LOGE("GetPeerConnMsg failed"); + return RaiseServerError(server); + } + if (IsOverMaxClientConns(server, connection.clientIP)) { + NETSTACK_LOGE("current connections count is more than limit, need to close"); + return RaiseServerError(server); + } + if (!server->GetServerContext()->IsAllowConnection(connection.clientIP)) { + NETSTACK_LOGE("Rejected malicious connection"); + return RaiseServerError(server); + } + return HttpDummy(wsi, reason, user, in, len); +} + +int LwsCallbackReceive(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGD("lws callback server receive"); + lws_context *context = lws_get_context(wsi); + WebSocketServer *server = static_cast(lws_context_user(context)); + if (server == nullptr) { + NETSTACK_LOGE("server is null"); + return -1; + } + if (server->GetServerContext() == nullptr) { + NETSTACK_LOGE("server context is null"); + return -1; + } + if (len > INT32_MAX) { + NETSTACK_LOGE("data length too long"); + return -1; + } + bool isBinary = lws_frame_is_binary(wsi); + if (isBinary) { + server->GetServerContext()->AppendWsServerBinaryData(wsi, in, len); + } else { + server->GetServerContext()->AppendWsServerTextData(wsi, in, len); + } + auto isFinal = lws_is_final_fragment(wsi); + if (!isFinal) { + return HttpDummy(wsi, reason, user, in, len); + } + SocketConnection connection = server->GetServerContext()->GetConnectionFromWsi(wsi); + if (server->onMessageReceiveCallback_ != nullptr) { + if (isBinary) { + auto data = server->GetServerContext()->GetWsServerBinaryData(wsi); + server->onMessageReceiveCallback_(server, data, data.size(), connection); + } else { + auto data = server->GetServerContext()->GetWsServerTextData(wsi); + server->onMessageReceiveCallback_(server, data, data.size(), connection); + } + } + server->GetServerContext()->ClearWsServerBinaryData(wsi); + server->GetServerContext()->ClearWsServerTextData(wsi); + return HttpDummy(wsi, reason, user, in, len); +} + +int lwsServerCallback(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) +{ + NETSTACK_LOGI("lws server callback reason is %{public}d", reason); + CallbackDispatcher dispatchers[] = { + {LWS_CALLBACK_ESTABLISHED, LwsCallbackEstablished}, + {LWS_CALLBACK_FILTER_PROTOCOL_CONNECTION, LwsCallbackFilterProtocolConnection}, + {LWS_CALLBACK_RECEIVE, LwsCallbackReceive}, + {LWS_CALLBACK_SERVER_WRITEABLE, LwsCallbackServerWriteable}, + {LWS_CALLBACK_WS_PEER_INITIATED_CLOSE, LwsCallbackWsPeerInitiatedCloseServer}, + {LWS_CALLBACK_CLOSED, LwsCallbackClosed}, + {LWS_CALLBACK_WSI_DESTROY, LwsCallbackWsiDestroyServer}, + {LWS_CALLBACK_PROTOCOL_DESTROY, LwsCallbackProtocolDestroyServer}, + }; + for (const auto dispatcher : dispatchers) { + if (dispatcher.reason == reason) { + return dispatcher.callback(wsi, reason, user, in, len); + } + } + return HttpDummy(wsi, reason, user, in, len); +} + +static const lws_protocols LWS_SERVER_PROTOCOLS[] = { + {"lws_server1", lwsServerCallback, 0, 0}, + {NULL, NULL, 0, 0}, // this line is needed +}; + +void FillServerContextInfo(WebSocketServer *server, lws_context_creation_info &info) +{ + info.options = LWS_SERVER_OPTION_HTTP_HEADERS_SECURITY_BEST_PRACTICES_ENFORCE; + info.port = static_cast(server->GetServerContext()->startServerConfig_.serverPort); + info.mounts = &mount; + info.protocols = LWS_SERVER_PROTOCOLS; + info.vhost_name = "localhost"; + info.user = server; + // maybe + info.gid = -1; + info.uid = -1; +} + +static bool CheckFilePath(std::string &path) +{ + char tmpPath[PATH_MAX] = {0}; + if (!realpath(static_cast(path.c_str()), tmpPath)) { + NETSTACK_LOGE("path is error"); + return false; + } + path = tmpPath; + return true; +} + +bool FillServerCertPath(ServerContext *context, lws_context_creation_info &info) +{ + ServerCert sc = context->startServerConfig_.serverCert; + if (!sc.certPath.empty()) { + if (!CheckFilePath(sc.certPath) || !CheckFilePath(sc.keyPath)) { + NETSTACK_LOGE("client cert not exist"); + return false; + } + info.ssl_cert_filepath = sc.certPath.c_str(); + info.ssl_private_key_filepath = sc.keyPath.c_str(); + } + return true; +} + +void CloseAllConnection(ServerContext *serverContext) +{ + if (serverContext == nullptr) { + NETSTACK_LOGE("server context is nullptr"); + return; + } + auto connListTmp = serverContext->GetWebSocketConnection(); + if (connListTmp.empty()) { + NETSTACK_LOGE("webSocketConnection is empty"); + if (!serverContext->IsThreadStop()) { + NETSTACK_LOGI("server service is stopped"); + serverContext->SetThreadStop(true); + } + return; + } + const char *closeReason = "server is going away"; + for (auto [id, connPair] : connListTmp) { + if (connPair.first == nullptr) { + NETSTACK_LOGE("clientId not found:%{public}s", id.c_str()); + continue; + } + auto *clientUserData = reinterpret_cast(lws_wsi_user(connPair.first)); + clientUserData->Close(LWS_CLOSE_STATUS_GOINGAWAY, closeReason); + clientUserData->TriggerWritable(); + } + NETSTACK_LOGI("CloseAllConnection OK"); +} + +WebSocketServer::WebSocketServer() +{ + serverContext_ = new ServerContext(); +} + +WebSocketServer::~WebSocketServer() +{ + Destroy(); + delete serverContext_; + serverContext_ = nullptr; +} + +int WebSocketServer::Start(const ServerConfig &config) +{ + NETSTACK_LOGD("websocket server start exec"); + if (!CommonUtils::HasInternetPermission()) { + serverContext_->SetPermissionDenied(true); + return -1; + } + if (!CommonUtils::IsValidIPV4(config.serverIP) && !CommonUtils::IsValidIPV6(config.serverIP)) { + NETSTACK_LOGE("IPV4 and IPV6 are not valid"); + return WEBSOCKET_ERROR_CODE_INVALID_NIC; + } + if (!CommonUtils::IsValidPort(config.serverPort)) { + NETSTACK_LOGE("Port is not valid"); + return WEBSOCKET_ERROR_CODE_INVALID_PORT; + } + if (config.maxConcurrentClientsNumber > MAX_CONCURRENT_CLIENTS_NUMBER) { + NETSTACK_LOGE("max concurrent clients number is set over limit"); + return WEBSOCKET_UNKNOWN_OTHER_ERROR; + } + if (config.maxConnectionsForOneClient > MAX_CONNECTIONS_FOR_ONE_CLIENT) { + NETSTACK_LOGE("max connection number for one client is set over limit"); + return WEBSOCKET_UNKNOWN_OTHER_ERROR; + } + serverContext_->startServerConfig_ = config; + lws_context_creation_info info = {}; + FillServerContextInfo(this, info); + if (!FillServerCertPath(serverContext_, info)) { + NETSTACK_LOGE("FillServerCertPath error"); + return WEBSOCKET_ERROR_CODE_FILE_NOT_EXIST; + } + lws_context *lwsContext = nullptr; + std::shared_ptr userData; + lwsContext = lws_create_context(&info); + serverContext_->SetContext(lwsContext); + std::thread serviceThread(RunServerService, this); +#if defined(MAC_PLATFORM) || defined(IOS_PLATFORM) + pthread_setname_np(WEBSOCKET_SERVER_THREAD_RUN); +#else + pthread_setname_np(serviceThread.native_handle(), WEBSOCKET_SERVER_THREAD_RUN); +#endif + serviceThread.detach(); + return 0; +} + +int WebSocketServer::Stop() +{ + if (serverContext_->GetContext() == nullptr) { + NETSTACK_LOGE("context is nullptr"); + return -1; + } + if (!CommonUtils::HasInternetPermission()) { + serverContext_->SetPermissionDenied(true); + return -1; + } + if (serverContext_->IsClosed() || serverContext_->IsThreadStop()) { + NETSTACK_LOGE("session is closed or stopped"); + return -1; + } + CloseAllConnection(serverContext_); + serverContext_->Close(LWS_CLOSE_STATUS_GOINGAWAY, ""); + NETSTACK_LOGI("CloseServer OK"); + return 0; +} + +int WebSocketServer::Close(const SocketConnection &connection, const CloseOption &option) +{ + if (serverContext_->GetContext() == nullptr) { + NETSTACK_LOGE("context is nullptr"); + return -1; + } + if (!CommonUtils::HasInternetPermission()) { + serverContext_->SetPermissionDenied(true); + return -1; + } + if (connection.clientIP.empty()) { + NETSTACK_LOGE("connection is empty"); + return -1; + } + std::string clientId = connection.clientIP + ":" + std::to_string(connection.clientPort); + NETSTACK_LOGI("Close, clientID:%{public}s", clientId.c_str()); + auto wsi = serverContext_->GetClientWsi(clientId); + if (wsi == nullptr) { + NETSTACK_LOGE("clientId not found:%{public}s", clientId.c_str()); + return WEBSOCKET_ERROR_CODE_CONNECTION_NOT_EXIST; + } + auto *clientUserData = reinterpret_cast(lws_wsi_user(wsi)); + if (clientUserData == nullptr) { + NETSTACK_LOGE("clientUser data is nullptr"); + return -1; + } + if (clientUserData->IsClosed() || clientUserData->IsThreadStop()) { + NETSTACK_LOGE("session is closed or stopped"); + return -1; + } + clientUserData->Close(static_cast(option.code), option.reason); + clientUserData->TriggerWritable(); + NETSTACK_LOGI("Close OK"); + return 0; +} + +int WebSocketServer::Send(const char *data, int length, const SocketConnection &connection) +{ + if (serverContext_->GetContext() == nullptr) { + NETSTACK_LOGE("context is nullptr"); + return -1; + } + if (!CommonUtils::HasInternetPermission()) { + serverContext_->SetPermissionDenied(true); + return -1; + } + if (connection.clientIP.empty()) { + NETSTACK_LOGE("connection is empty"); + return -1; + } + std::string clientId = connection.clientIP + ":" + std::to_string(connection.clientPort); + NETSTACK_LOGI("connection clientid:%{public}s", clientId.c_str()); + auto wsi = serverContext_->GetClientWsi(clientId); + if (wsi == nullptr) { + NETSTACK_LOGE("clientId not found:%{public}s", clientId.c_str()); + return WEBSOCKET_ERROR_CODE_CONNECTION_NOT_EXIST; + } + auto *clientUserData = reinterpret_cast(lws_wsi_user(wsi)); + if (clientUserData == nullptr) { + NETSTACK_LOGE("clientUser data is nullptr"); + return -1; + } + if (clientUserData->IsClosed() || clientUserData->IsThreadStop()) { + NETSTACK_LOGE("session is closed or stopped"); + return -1; + } + lws_write_protocol protocol = (strlen(data) == length) ? LWS_WRITE_TEXT : LWS_WRITE_BINARY; + size_t dataLen = LWS_SEND_BUFFER_PRE_PADDING + length + LWS_SEND_BUFFER_POST_PADDING; + char *tmpData = (char *)malloc(dataLen); + if (tmpData == nullptr) { + NETSTACK_LOGE("malloc failed"); + return -1; + } + if (memcpy_s(reinterpret_cast(reinterpret_cast(tmpData) + LWS_SEND_BUFFER_PRE_PADDING), length, + data, length) < 0) { + NETSTACK_LOGE("copy failed"); + free(tmpData); + return -1; + } + clientUserData->Push((void *)tmpData, length, protocol); + clientUserData->TriggerWritable(); + NETSTACK_LOGD("lws ts send success"); + return 0; +} + +int WebSocketServer::ListAllConnections(std::vector &connections) const +{ + NETSTACK_LOGD("websocket server list all connections exec"); + if (serverContext_->GetContext() == nullptr) { + NETSTACK_LOGE("websocket server context is null"); + return -1; + } + serverContext_->ListAllConnections(connections); + return 0; +} + +int WebSocketServer::Registcallback(OnErrorCallback onError, OnConnectCallback onConnect, OnCloseCallback onClose, + OnMessageReceiveCallback onMessageReceive) +{ + onErrorCallback_ = onError; + onConnectCallback_ = onConnect; + onCloseCallback_ = onClose; + onMessageReceiveCallback_ = onMessageReceive; + return 0; +} + +ServerContext *WebSocketServer::GetServerContext() const +{ + return serverContext_; +} + +int WebSocketServer::Destroy() +{ + NETSTACK_LOGI("websocket server destroy exec"); + if (this->GetServerContext()->GetContext() == nullptr) { + return -1; + } + lws_context_destroy(this->GetServerContext()->GetContext()); + this->GetServerContext()->SetContext(nullptr); + return 0; +} +} \ No newline at end of file diff --git a/frameworks/native/websocket_client/include/websocket_client_error.h b/frameworks/native/websocket_native/include/websocket_client_error.h old mode 100755 new mode 100644 similarity index 100% rename from frameworks/native/websocket_client/include/websocket_client_error.h rename to frameworks/native/websocket_native/include/websocket_client_error.h diff --git a/interfaces/innerkits/websocket_client/libwebsocket_client.map b/interfaces/innerkits/websocket_client/libwebsocket_client.map deleted file mode 100755 index 29c7d3d4a5b4f99947b9c946eddeac0b94716b88..0000000000000000000000000000000000000000 --- a/interfaces/innerkits/websocket_client/libwebsocket_client.map +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2023 Huawei Device Co., Ltd. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{ - global: - *WebSocketClient*; - local: - *; -}; \ No newline at end of file diff --git a/interfaces/innerkits/websocket_client/BUILD.gn b/interfaces/innerkits/websocket_native/BUILD.gn similarity index 83% rename from interfaces/innerkits/websocket_client/BUILD.gn rename to interfaces/innerkits/websocket_native/BUILD.gn index 50b443054cbf8b91b9fb9adb2c637dc2d7cc3616..fc5b987fd6ad71e4dc5635c0884561ab4b2f7ed9 100644 --- a/interfaces/innerkits/websocket_client/BUILD.gn +++ b/interfaces/innerkits/websocket_native/BUILD.gn @@ -14,10 +14,10 @@ import("//build/ohos.gni") import("//foundation/communication/netstack/netstack_config.gni") -config("websocket_client_config") { +config("websocket_native_config") { # header file path include_dirs = - [ "$NETSTACK_DIR/interfaces/innerkits/websocket_client/include" ] + [ "$NETSTACK_DIR/interfaces/innerkits/websocket_native/include" ] cflags = [] if (is_double_framework) { @@ -40,7 +40,7 @@ config("websocket_client_config") { } } -ohos_shared_library("websocket_client") { +ohos_shared_library("websocket_native") { sanitize = { cfi = true cfi_cross_dso = true @@ -49,12 +49,13 @@ ohos_shared_library("websocket_client") { branch_protector_ret = "pac_ret" - sources = [ "$SUBSYSTEM_DIR/netstack/frameworks/js/napi/websocket/websocket_module/src/websocket_client.cpp" ] + sources = [ "$SUBSYSTEM_DIR/netstack/frameworks/js/napi/websocket/websocket_module/src/websocket_client.cpp", + "$SUBSYSTEM_DIR/netstack/frameworks/js/napi/websocket/websocket_module/src/websocket_server.cpp" ] include_dirs = [ "$NETSTACK_DIR/utils/common_utils/include", "$NETSTACK_DIR/utils/log/include", - "$NETSTACK_NATIVE_ROOT/websocket_client/include", + "$NETSTACK_NATIVE_ROOT/websocket_native/include", ] cflags = [ @@ -68,9 +69,9 @@ ohos_shared_library("websocket_client") { "-O2", ] - version_script = "libwebsocket_client.map" + version_script = "libwebsocket_native.map" - public_configs = [ ":websocket_client_config" ] + public_configs = [ ":websocket_native_config" ] defines = [ "OHOS_LIBWEBSOCKETS=1" ] diff --git a/interfaces/innerkits/websocket_client/include/client_context.h b/interfaces/innerkits/websocket_native/include/client_context.h old mode 100755 new mode 100644 similarity index 100% rename from interfaces/innerkits/websocket_client/include/client_context.h rename to interfaces/innerkits/websocket_native/include/client_context.h diff --git a/interfaces/innerkits/websocket_native/include/server_context.h b/interfaces/innerkits/websocket_native/include/server_context.h new file mode 100644 index 0000000000000000000000000000000000000000..e99315a3027a5a7e6d3336d7508eee23aedcb78b --- /dev/null +++ b/interfaces/innerkits/websocket_native/include/server_context.h @@ -0,0 +1,442 @@ +/* + * Copyright (c) 2025 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SERVER_CONTEXT_H +#define SERVER_CONTEXT_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "netstack_log.h" +#include "secure_char.h" + +namespace OHOS { +namespace NetStack { +namespace WebSocketServer { +struct ClientInfo { + int32_t cnt; + uint64_t lastConnectionTime; +}; + +struct ServerCert { + std::string certPath; + std::string keyPath; +}; + +struct ServerConfig { + std::string serverIP; + int serverPort; + ServerCert serverCert; + int maxConcurrentClientsNumber; + std::string protocol; + int maxConnectionsForOneClient; +}; + +struct SocketConnection { + std::string clientIP; + uint32_t clientPort; +}; + +struct CloseOption { + unsigned int code; + const char *reason; +}; + +struct ErrorResult { + unsigned int errorCode; + const char *errorMessage; +}; + +struct CloseResult { + unsigned int code; + const char *reason; +}; + +class UserData { +public: + struct SendData { + SendData(void *paraData, size_t paraLength, lws_write_protocol paraProtocol) + : data(paraData), length(paraLength), protocol(paraProtocol) + {} + + SendData() = delete; + + ~SendData() = default; + + void *data; + size_t length; + lws_write_protocol protocol; + }; + + explicit UserData(lws_context *context) + : closeStatus(LWS_CLOSE_STATUS_NOSTATUS), openStatus(0), closed_(false), threadStop_(false), context_(context) + {} + + bool IsClosed() + { + std::lock_guard lock(mutex_); + return closed_; + } + + bool IsThreadStop() + { + return threadStop_.load(); + } + + void SetThreadStop(bool threadStop) + { + threadStop_.store(threadStop); + } + + void Close(lws_close_status status, const std::string &reason) + { + std::lock_guard lock(mutex_); + closeStatus = status; + closeReason = reason; + closed_ = true; + } + + void Push(void *data, size_t length, lws_write_protocol protocol) + { + std::lock_guard lock(mutex_); + dataQueue_.emplace(data, length, protocol); + } + + SendData Pop() + { + std::lock_guard lock(mutex_); + if (dataQueue_.empty()) { + return { nullptr, 0, LWS_WRITE_TEXT }; + } + SendData data = dataQueue_.front(); + dataQueue_.pop(); + return data; + } + + void SetContext(lws_context *context) + { + context_ = context; + } + + lws_context *GetContext() + { + return context_; + } + + bool IsEmpty() + { + std::lock_guard lock(mutex_); + if (dataQueue_.empty()) { + return true; + } + return false; + } + + void SetLws(lws *wsi) + { + std::lock_guard lock(mutexForLws_); + if (wsi == nullptr) { + NETSTACK_LOGD("set wsi nullptr"); + } + wsi_ = wsi; + } + + void TriggerWritable() + { + std::lock_guard lock(mutexForLws_); + if (wsi_ == nullptr) { + NETSTACK_LOGE("wsi is nullptr, can not trigger"); + return; + } + lws_callback_on_writable(wsi_); + } + + std::map header; + + lws_close_status closeStatus; + + std::string closeReason; + + uint32_t openStatus; + + std::string openMessage; + +private: + volatile bool closed_; + + std::atomic_bool threadStop_; + + std::mutex mutex_; + + std::mutex mutexForLws_; + + lws_context *context_; + + std::queue dataQueue_; + + lws *wsi_ = nullptr; +}; + +class ServerContext { +public: + ServerContext() {} + ~ServerContext() = default; + uint64_t GetCurrentSecond() + { + return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()) + .count(); + } + bool IsClosed() + { + std::lock_guard lock(closeMutex_); + return closed_; + } + void Close(lws_close_status status, const std::string &reason) + { + std::lock_guard lock(closeMutex_); + closeStatus_ = status; + closeReason_ = reason; + closed_ = true; + } + bool IsThreadStop() + { + return threadStop_.load(); + } + void SetThreadStop(bool threadStop) + { + threadStop_.store(threadStop); + } + void SetContext(lws_context *context) + { + context_ = context; + } + lws_context *GetContext() + { + return context_; + } + void AddClientUserData(void *wsi, std::shared_ptr &data) + { + std::lock_guard lock(mapUserDataMutex_); + userDataMap_[wsi] = data; + } + void RemoveClientUserData(void *wsi) + { + std::lock_guard lock(mapUserDataMutex_); + auto it = userDataMap_.find(wsi); + if (it != userDataMap_.end()) { + userDataMap_.erase(it); + } + } + void AddConnections(const std::string &id, lws *wsi, SocketConnection &conn) + { + if (IsClosed() || IsThreadStop()) { + NETSTACK_LOGE("AddConnections failed: session %s", IsClosed() ? "closed" : "thread stopped"); + return; + } + std::unique_lock lock(wsMutex_); + webSocketConnection_[id].first = wsi; + webSocketConnection_[id].second = conn; + } + std::string GetClientIdFromConnectionByWsi(lws *wsi) + { + std::shared_lock lock(wsMutex_); + for (const auto &pair : webSocketConnection_) { + if (pair.second.first == wsi) { + return pair.first; + } + } + return ""; + } + SocketConnection GetConnectionFromWsi(lws *wsi) + { + std::shared_lock lock(wsMutex_); + for (const auto &pair : webSocketConnection_) { + if (pair.second.first == wsi) { + return pair.second.second; + } + } + return {}; + } + void RemoveConnections(const std::string &id) + { + if (webSocketConnection_.empty()) { + return; + } + { + std::unique_lock lock(wsMutex_); + if (webSocketConnection_.find(id) == webSocketConnection_.end()) { + return; + } + webSocketConnection_.erase(id); + } + } + void ListAllConnections(std::vector &connections) + { + std::shared_lock lock(wsMutex_); + connections.clear(); + for (const auto &pair : webSocketConnection_) { + connections.push_back(pair.second.second); + } + } + lws *GetClientWsi(const std::string &clientId) + { + std::shared_lock lock(wsMutex_); + if (webSocketConnection_.empty()) { + return nullptr; + } + auto it = webSocketConnection_.find(clientId); + if (it != webSocketConnection_.end()) { + return it->second.first; + } + return nullptr; + } + const std::unordered_map> &GetWebSocketConnection() + { + std::shared_lock lock(wsMutex_); + return webSocketConnection_; + } + void AddBanList(const std::string &ip) + { + std::shared_lock lock(banListMutex_); + banList_[ip] = GetCurrentSecond() + ONE_MINUTE_IN_SEC; + } + + bool IsIpInBanList(const std::string &ip) + { + std::shared_lock lock(banListMutex_); + auto it = banList_.find(ip); + if (it != banList_.end()) { + auto now = GetCurrentSecond(); + if (now < it->second) { + return true; + } else { + banList_.erase(it); + } + } + return false; + } + void UpdateClientList(const std::string &ip) + { + std::shared_lock lock(connListMutex_); + auto it = clientList_.find(ip); + if (it == clientList_.end()) { + NETSTACK_LOGI("add clientid to clientlist"); + clientList_[ip] = {1, GetCurrentSecond()}; + } else { + auto now = GetCurrentSecond() - it->second.lastConnectionTime; + if (now > ONE_MINUTE_IN_SEC) { + NETSTACK_LOGI("reset clientid connections cnt"); + it->second = { 1, GetCurrentSecond() }; + } else { + it->second.cnt++; + } + } + } + bool IsHighFreqConnection(const std::string &ip) + { + std::shared_lock lock(connListMutex_); + auto it = clientList_.find(ip); + if (it != clientList_.end()) { + auto duration = GetCurrentSecond() - it->second.lastConnectionTime; + if (duration <= ONE_MINUTE_IN_SEC) { + return it->second.cnt > MAX_CONNECTIONS_PER_MINUTE; + } + } + return false; + } + bool IsAllowConnection(const std::string &ip) + { + if (IsIpInBanList(ip)) { + NETSTACK_LOGE("client is in banlist"); + return false; + } + if (IsHighFreqConnection(ip)) { + NETSTACK_LOGE("client reach high frequency connection"); + AddBanList(ip); + return false; + } + UpdateClientList(ip); + return true; + } + const std::string &GetWsServerBinaryData(void *wsi) + { + return wsServerBinaryData_[wsi]; + } + + const std::string &GetWsServerTextData(void *wsi) + { + return wsServerTextData_[wsi]; + } + + void AppendWsServerBinaryData(void *wsi, void *data, size_t length) + { + wsServerBinaryData_[wsi].append(reinterpret_cast(data), length); + } + + void AppendWsServerTextData(void *wsi, void *data, size_t length) + { + wsServerTextData_[wsi].append(reinterpret_cast(data), length); + } + + void ClearWsServerBinaryData(void *wsi) + { + wsServerBinaryData_[wsi].clear(); + } + + void ClearWsServerTextData(void *wsi) + { + wsServerTextData_[wsi].clear(); + } + void SetPermissionDenied(bool denied) + { + permissionDenied = denied; + } + +public: + lws_close_status closeStatus_ = LWS_CLOSE_STATUS_NOSTATUS; + std::string closeReason_; + ServerConfig startServerConfig_; + +private: + bool permissionDenied = false; + lws_context *context_ = nullptr; + std::atomic_bool threadStop_ = false; + std::mutex closeMutex_; + volatile bool closed_ = false; + std::shared_mutex wsMutex_; + std::unordered_map> webSocketConnection_; + std::shared_mutex connListMutex_; + std::unordered_map clientList_; + std::shared_mutex banListMutex_; + std::unordered_map banList_; + std::mutex mapUserDataMutex_; + std::unordered_map> userDataMap_; + std::unordered_map wsServerBinaryData_; + std::unordered_map wsServerTextData_; + static constexpr const uint64_t ONE_MINUTE_IN_SEC = 60; + static constexpr const int32_t MAX_CONNECTIONS_PER_MINUTE = 50; +}; +}; // namespace WebSocketServer +} // namespace NetStack +} // namespace OHOS +#endif diff --git a/interfaces/innerkits/websocket_client/include/websocket_client_innerapi.h b/interfaces/innerkits/websocket_native/include/websocket_client_innerapi.h old mode 100755 new mode 100644 similarity index 82% rename from interfaces/innerkits/websocket_client/include/websocket_client_innerapi.h rename to interfaces/innerkits/websocket_native/include/websocket_client_innerapi.h index 3125c970f8ccbed2b05ac7367cc0fe2ee8e03431..fb3e061a054531b815f4d744a3d19dade41bb869 --- a/interfaces/innerkits/websocket_client/include/websocket_client_innerapi.h +++ b/interfaces/innerkits/websocket_native/include/websocket_client_innerapi.h @@ -58,7 +58,7 @@ struct OpenOptions { std::map headers; }; -class WebSocketClient { +class WebSocketClient : public std::enable_shared_from_this { public: WebSocketClient(); ~WebSocketClient(); @@ -66,12 +66,17 @@ public: typedef void (*OnCloseCallback)(WebSocketClient *client, CloseResult closeResult); typedef void (*OnErrorCallback)(WebSocketClient *client, ErrorResult error); typedef void (*OnOpenCallback)(WebSocketClient *client, OpenResult openResult); + typedef void (*OnHeaderReceiveCallback)(WebSocketClient *client, const std::map &headers); + typedef void (*OnDataEndCallback)(WebSocketClient *client); int Connect(std::string URL, OpenOptions Options); + int ConnectEx(std::string URL, OpenOptions Options); int Send(char *data, size_t length); + int SendEx(char *data, size_t length); int Close(CloseOption options); + int CloseEx(CloseOption options); int Registcallback(OnOpenCallback OnOpen, OnMessageCallback onMessage, OnErrorCallback OnError, - OnCloseCallback onclose); + OnCloseCallback onclose); int Destroy(); void AppendData(void *data, size_t length); const std::string &GetData(); @@ -81,6 +86,8 @@ public: OnCloseCallback onCloseCallback_ = nullptr; OnErrorCallback onErrorCallback_ = nullptr; OnOpenCallback onOpenCallback_ = nullptr; + OnHeaderReceiveCallback onHeaderReceiveCallback_ = nullptr; + OnDataEndCallback onDataEndCallback_ = nullptr; ClientContext *GetClientContext() const; private: diff --git a/interfaces/innerkits/websocket_native/include/websocket_server_innerapi.h b/interfaces/innerkits/websocket_native/include/websocket_server_innerapi.h new file mode 100644 index 0000000000000000000000000000000000000000..99dd2bbaf175e5a70b4c809d771d81981bf139a8 --- /dev/null +++ b/interfaces/innerkits/websocket_native/include/websocket_server_innerapi.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2025 Huawei Device Co., Ltd. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMUNICATIONNETSTACK_WEBSOCKET_SERVER_H +#define COMMUNICATIONNETSTACK_WEBSOCKET_SERVER_H + +#include +#include "server_context.h" + +namespace OHOS { +namespace NetStack { +namespace WebSocketServer { + +class WebSocketServer { +public: + WebSocketServer(); + ~WebSocketServer(); + typedef void (*OnErrorCallback)(WebSocketServer *server, ErrorResult error); + typedef void (*OnConnectCallback)(WebSocketServer *server, SocketConnection connection); + typedef void (*OnCloseCallback)(WebSocketServer *server, CloseResult result, SocketConnection connection); + typedef void (*OnMessageReceiveCallback)(WebSocketServer *server, const std::string &data, + size_t length, SocketConnection connection); + + int Start(const ServerConfig &config); + int Stop(); + int Close(const SocketConnection &connection, const CloseOption &option); + int Send(const char *data, int length, const SocketConnection &connection); + int ListAllConnections(std::vector &connections) const; + int Destroy(); + + ServerContext *GetServerContext() const; + + int Registcallback(OnErrorCallback onError, OnConnectCallback onConnect, + OnCloseCallback onClose, OnMessageReceiveCallback onMessageReceive); + + OnErrorCallback onErrorCallback_ = nullptr; + OnConnectCallback onConnectCallback_ = nullptr; + OnCloseCallback onCloseCallback_ = nullptr; + OnMessageReceiveCallback onMessageReceiveCallback_ = nullptr; + +private: + ServerContext *serverContext_ = nullptr; +}; +} // namespace WebSocketServer +} // namespace NetStack +} // namespace OHOS + +#endif // COMMUNICATIONNETSTACK_WEBSOCKET_SERVER_H \ No newline at end of file diff --git a/interfaces/innerkits/websocket_native/libwebsocket_native.map b/interfaces/innerkits/websocket_native/libwebsocket_native.map new file mode 100644 index 0000000000000000000000000000000000000000..b2604e089da8b440cc6097ab3452fffdddccb3fa --- /dev/null +++ b/interfaces/innerkits/websocket_native/libwebsocket_native.map @@ -0,0 +1,38 @@ +# Copyright (c) 2023 Huawei Device Co., Ltd. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{ + global: + extern "C++" { + "OHOS::NetStack::WebSocketClient::WebSocketClient::WebSocketClient()"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::~WebSocketClient()"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::Connect(std::__h::basic_string, std::__h::allocator>, OHOS::NetStack::WebSocketClient::OpenOptions)"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::ConnectEx(std::__h::basic_string, std::__h::allocator>, OHOS::NetStack::WebSocketClient::OpenOptions)"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::GetClientContext() const"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::Send(char*, unsigned int)"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::SendEx(char*, unsigned int)"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::Close(OHOS::NetStack::WebSocketClient::CloseOption)"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::CloseEx(OHOS::NetStack::WebSocketClient::CloseOption)"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::Registcallback(void (*)(OHOS::NetStack::WebSocketClient::WebSocketClient*, OHOS::NetStack::WebSocketClient::OpenResult), void (*)(OHOS::NetStack::WebSocketClient::WebSocketClient*, std::__h::basic_string, std::__h::allocator> const&, unsigned int), void (*)(OHOS::NetStack::WebSocketClient::WebSocketClient*, OHOS::NetStack::WebSocketClient::ErrorResult), void (*)(OHOS::NetStack::WebSocketClient::WebSocketClient*, OHOS::NetStack::WebSocketClient::CloseResult))"; + "OHOS::NetStack::WebSocketClient::WebSocketClient::Destroy()"; + "OHOS::NetStack::WebSocketServer::WebSocketServer::WebSocketServer()"; + "OHOS::NetStack::WebSocketServer::WebSocketServer::~WebSocketServer()"; + "OHOS::NetStack::WebSocketServer::WebSocketServer::Start(OHOS::NetStack::WebSocketServer::ServerConfig const&)"; + "OHOS::NetStack::WebSocketServer::WebSocketServer::Stop()"; + "OHOS::NetStack::WebSocketServer::WebSocketServer::Close(OHOS::NetStack::WebSocketServer::SocketConnection const&, OHOS::NetStack::WebSocketServer::CloseOption const&)"; + "OHOS::NetStack::WebSocketServer::WebSocketServer::Send(char const*, int, OHOS::NetStack::WebSocketServer::SocketConnection const&)"; + "OHOS::NetStack::WebSocketServer::WebSocketServer::ListAllConnections(std::__h::vector>&) const"; + }; + local: + *; +}; \ No newline at end of file diff --git a/interfaces/kits/c/net_websocket/BUILD.gn b/interfaces/kits/c/net_websocket/BUILD.gn old mode 100755 new mode 100644 index 1b880a83d10137527b9e495b07d27e1f7bcdeeba..017529a02ad9d613e6ca1b6fe95906e02bb0091a --- a/interfaces/kits/c/net_websocket/BUILD.gn +++ b/interfaces/kits/c/net_websocket/BUILD.gn @@ -26,8 +26,8 @@ ohos_shared_library("net_websocket") { output_extension = "so" include_dirs = [ "$NETSTACK_DIR/interfaces/kits/c/net_websocket/include", - "$NETSTACK_DIR/interfaces/innerkits/websocket_client/include", - "$NETSTACK_DIR/frameworks/native/websocket_client/include", + "$NETSTACK_DIR/interfaces/innerkits/websocket_native/include", + "$NETSTACK_DIR/frameworks/native/websocket_native/include", "$NETSTACK_DIR/utils/log/include", ] @@ -37,7 +37,7 @@ ohos_shared_library("net_websocket") { ] deps = [ - "$NETSTACK_DIR/interfaces/innerkits/websocket_client:websocket_client", + "$NETSTACK_DIR/interfaces/innerkits/websocket_native:websocket_native", "$NETSTACK_DIR/utils/napi_utils:napi_utils", ] diff --git a/test/fuzztest/websocketcapi_fuzzer/BUILD.gn b/test/fuzztest/websocketcapi_fuzzer/BUILD.gn index 545b43b193dd13251150fcf77f4825b5873b0338..02fb915a53f552926f2c56cc9d3a166d682cefd3 100644 --- a/test/fuzztest/websocketcapi_fuzzer/BUILD.gn +++ b/test/fuzztest/websocketcapi_fuzzer/BUILD.gn @@ -18,7 +18,7 @@ import("//build/test.gni") import("//foundation/communication/netstack/netstack_config.gni") ##############################fuzztest########################################## -WEBSOCKET_INNERAPI = "$NETSTACK_DIR/frameworks/native/websocket_client" +WEBSOCKET_INNERAPI = "$NETSTACK_DIR/frameworks/native/websocket_native" utils_include = [ "$SUBSYSTEM_DIR/netstack/utils/common_utils/include", diff --git a/test/unittest/websocket_inner_unittest/BUILD.gn b/test/unittest/websocket_inner_unittest/BUILD.gn old mode 100755 new mode 100644 index 7e7b7b40d7139df65d54c2808445993cb0d65ddb..f7a05cf9d5d00c7f002ad7f7a525ac21dab05802 --- a/test/unittest/websocket_inner_unittest/BUILD.gn +++ b/test/unittest/websocket_inner_unittest/BUILD.gn @@ -16,7 +16,7 @@ import("//build/test.gni") import("//foundation/communication/netstack/netstack_config.gni") #SOCKET_NAPI = "$NETSTACK_DIR/frameworks/js/napi/socket" -WEBSOCKET_INNERAPI = "$NETSTACK_DIR/frameworks/native/websocket_client" +WEBSOCKET_INNERAPI = "$NETSTACK_DIR/frameworks/native/websocket_native" utils_include = [ "$SUBSYSTEM_DIR/netstack/utils/common_utils/include", @@ -43,7 +43,7 @@ ohos_unittest("websocket_inner_unittest") { deps = [ "$NETSTACK_DIR/utils/napi_utils:napi_utils", - "$NETSTACK_INNERKITS_DIR/websocket_client:websocket_client", + "$NETSTACK_INNERKITS_DIR/websocket_native:websocket_native", ] external_deps = common_external_deps