diff --git a/src/common/auth.cpp b/src/common/auth.cpp index cc54aa8f1955af03cfa5dda1e5a18ff12272cd42..be23e884a63cfaabee016d76384b60de798eea1b 100644 --- a/src/common/auth.cpp +++ b/src/common/auth.cpp @@ -1035,11 +1035,10 @@ int RsaPrikeyDecryptPsk(const unsigned char* in, int inLen, unsigned char* out, } #else // DAEMON -int RsaPubkeyEncryptPsk(const uint32_t sessionId, - const unsigned char* in, int inLen, unsigned char* out, const string& pubkey) +int RsaPubkeyEncryptPsk(const unsigned char* in, int inLen, unsigned char* out, int outBufSize, const string& pubkey) { if (out == nullptr) { - WRITE_LOG(LOG_FATAL, "out buf is alloc failed"); + WRITE_LOG(LOG_FATAL, "RsaPubkeyEncryptPsk out buf is alloc failed"); return -1; } BIO *bio = nullptr; @@ -1048,23 +1047,23 @@ int RsaPubkeyEncryptPsk(const uint32_t sessionId, do { bio = BIO_new(BIO_s_mem()); if (bio == nullptr) { - WRITE_LOG(LOG_FATAL, "bio failed for session %u", sessionId); + WRITE_LOG(LOG_FATAL, "RsaPubkeyEncryptPsk create bio failed"); break; } int wbytes = BIO_write(bio, reinterpret_cast(pubkey.c_str()), pubkey.length()); if (wbytes <= 0) { - WRITE_LOG(LOG_FATAL, "bio write failed %d for session %u", wbytes, sessionId); + WRITE_LOG(LOG_FATAL, "RsaPubkeyEncryptPsk bio write failed %d ", wbytes); break; } rsa = PEM_read_bio_RSA_PUBKEY(bio, nullptr, nullptr, nullptr); if (rsa == nullptr) { - WRITE_LOG(LOG_FATAL, "rsa failed for session %u", sessionId); + WRITE_LOG(LOG_FATAL, "RsaPubkeyEncryptPsk rsa failed"); break; } unsigned char encryptedBuf[BUF_SIZE_DEFAULT2] = { 0 }; int encryptedBufSize = RSA_public_encrypt(inLen, in, encryptedBuf, rsa, RSA_PKCS1_OAEP_PADDING); - if (encryptedBufSize <= 0) { - WRITE_LOG(LOG_FATAL, "encrypt PreShared Key failed"); + if (encryptedBufSize <= 0 || (outBufSize < ((encryptedBufSize + 2) / 3 * 4))) { // (x+2)/3*4 base64 encode size + WRITE_LOG(LOG_FATAL, "encrypt PreShared Key failed, encryptedBufSize: %d", encryptedBufSize); break; } outLen = EVP_EncodeBlock(out, encryptedBuf, encryptedBufSize); diff --git a/src/common/auth.h b/src/common/auth.h index 1d6160e692eb502655ed28e47119e6303f438440..3b2536a36830560f7f2d81d0b10a81f70621127b 100644 --- a/src/common/auth.h +++ b/src/common/auth.h @@ -29,8 +29,7 @@ bool RsaSignAndBase64(string &buf, Hdc::AuthVerifyType type); bool GetPublicKeyinfo(string &pubkey_info); int RsaPrikeyDecryptPsk(const unsigned char* in, int inLen, unsigned char* out, int outBufSize); #else -int RsaPubkeyEncryptPsk(const uint32_t sessionId, - const unsigned char* in, int inLen, unsigned char* out, const string& pubkey); +int RsaPubkeyEncryptPsk(const unsigned char* in, int inLen, unsigned char* out, int outBufSize, const string& pubkey); #endif // host diff --git a/src/common/define.h b/src/common/define.h index 7b7107bb58dc02b92262d63f871e6a631ecc5246..fd414e9a2d93f247b3a6da84c342f6f1cce8fb29 100644 --- a/src/common/define.h +++ b/src/common/define.h @@ -102,6 +102,7 @@ constexpr uint16_t CMD_FILE_PENULT_PARAM = 2; constexpr uint16_t BUNDLE_MIN_SIZE = 7; constexpr uint16_t BUNDLE_MAX_SIZE = 128; constexpr uint16_t HEARTBEAT_INTERVAL = 5000; // 5 seconds +constexpr uint16_t SSL_HANDSHAKE_FINISHED_WAIT_TIME = 300; // 300 ms #ifdef HDC_HOST constexpr uint16_t MAX_DELETED_SESSION_ID_RECORD_COUNT = 32; #else diff --git a/src/common/define_plus.h b/src/common/define_plus.h index 51f09211bf9478be0871c7be3e3c832be52f2941..e1c9b12a1d9b1a0b1773f9ec2778cae5f9b0c88e 100644 --- a/src/common/define_plus.h +++ b/src/common/define_plus.h @@ -458,7 +458,7 @@ struct HdcSSLInfo { uint32_t sessionId; std::string cipher; }; -using HSSLInfo = struct HdcSSLInfo *; +using SSLInfoPtr = struct HdcSSLInfo *; #endif } #endif diff --git a/src/common/hdc_ssl.cpp b/src/common/hdc_ssl.cpp index 1183ab273b40b7ae99e132afe431150f46ae4aae..4bf219e1401a3b45cb7030ee16825b99ba5e995d 100644 --- a/src/common/hdc_ssl.cpp +++ b/src/common/hdc_ssl.cpp @@ -16,7 +16,7 @@ #include "hdc_ssl.h" namespace Hdc { -HdcSSLBase::HdcSSLBase(const HSSLInfo &hSSLInfo) +HdcSSLBase::HdcSSLBase(SSLInfoPtr hSSLInfo) { #if OPENSSL_VERSION_NUMBER >= 0x10100003L if (OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, NULL) == 0) { @@ -38,8 +38,9 @@ HdcSSLBase::~HdcSSLBase() if (!isInited) { return; } - if (SSL_shutdown(ssl)!= 1) { - SSL_get_error(ssl, SSL_shutdown(ssl)); + int ret = SSL_shutdown(ssl); + if (ret != 1) { + SSL_get_error(ssl, ret); uint8_t buf[BUF_SIZE_DEFAULT]; BIO_read(outBIO, buf, BUF_SIZE_DEFAULT); } @@ -55,7 +56,7 @@ HdcSSLBase::~HdcSSLBase() isInited = false; } -void HdcSSLBase::SetSSLInfo(HSSLInfo hSSLInfo, HSession hSession) +void HdcSSLBase::SetSSLInfo(SSLInfoPtr hSSLInfo, HSession hSession) { hSSLInfo->cipher = TLS_AES_128_GCM_SHA256; hSSLInfo->isDaemon = !hSession->serverOrDaemon; @@ -114,11 +115,7 @@ int HdcSSLBase::Encrypt(const int bufLen, uint8_t *bufPtr) int HdcSSLBase::DoSSLRead(const int bufLen, int &index, uint8_t *bufPtr) { - if (static_cast(index + BUF_SIZE_DEFAULT16) > bufLen) { - WRITE_LOG(LOG_FATAL, "DoSSLRead failed, buffer overwrite index: %d", index); - return ERR_GENERIC; - } - int nSSLRead = SSL_read(ssl, bufPtr + index, BUF_SIZE_DEFAULT16); + int nSSLRead = SSL_read(ssl, bufPtr + index, std::min(static_cast(BUF_SIZE_DEFAULT16), bufLen - index)); if (nSSLRead < 0) { int err = SSL_get_error(ssl, nSSLRead); if (err == SSL_ERROR_WANT_READ) { @@ -223,8 +220,8 @@ int HdcSSLBase::GetPskEncrypt(unsigned char *bufPtr, const int bufLen, const str return ERR_GENERIC; } unsigned char* buf = preSharedKey; - int payloadSize = RsaPubkeyEncrypt(sessionId, buf, BUF_SIZE_PSK, bufPtr, pubkey); - WRITE_LOG(LOG_INFO, "RsaPubkeyEncrypt payloadSize = %d", payloadSize); + int payloadSize = RsaPubkeyEncrypt(buf, BUF_SIZE_PSK, bufPtr, bufLen, pubkey); + WRITE_LOG(LOG_INFO, "RsaPubkeyEncrypt payloadSize = %d, sid: %u", payloadSize, sessionId); return payloadSize; // return the size of encrypted psk } @@ -255,7 +252,12 @@ int HdcSSLBase::Decrypt(const int nread, const int bufLen, uint8_t *bufPtr, int unsigned int HdcSSLBase::PskServerCallback(SSL *ssl, const char *identity, unsigned char *psk, unsigned int maxPskLen) { SSL_CTX *sslctx = SSL_get_SSL_CTX(ssl); - unsigned char *pskInput = reinterpret_cast(SSL_CTX_get_ex_data(sslctx, 0)); + void *exData = SSL_CTX_get_ex_data(sslctx, 0); + if (exData == nullptr) { + WRITE_LOG(LOG_FATAL, "exData is null"); + return 0; + } + unsigned char *pskInput = reinterpret_cast(exData); if (strcmp(identity, STR_PSK_IDENTITY.c_str()) != 0) { WRITE_LOG(LOG_FATAL, "identity not same"); return 0; @@ -276,8 +278,13 @@ unsigned int HdcSSLBase::PskClientCallback(SSL *ssl, const char *hint, char *ide unsigned char *psk, unsigned int maxPskLen) { SSL_CTX *sslctx = SSL_get_SSL_CTX(ssl); - unsigned char *pskInput = reinterpret_cast(SSL_CTX_get_ex_data(sslctx, 0)); - if (STR_PSK_IDENTITY.size() >= maxIdentityLen) { + void *exData = SSL_CTX_get_ex_data(sslctx, 0); + if (exData == nullptr) { + WRITE_LOG(LOG_FATAL, "exData is null"); + return 0; + } + unsigned char *pskInput = reinterpret_cast(exData); + if (STR_PSK_IDENTITY.size() + 1 > maxIdentityLen) { WRITE_LOG(LOG_FATAL, "Client identity buffer too small, maxIdentityLen = %d", maxIdentityLen); return 0; } @@ -307,12 +314,12 @@ int HdcSSLBase::RsaPrikeyDecrypt(const unsigned char *inBuf, int inLen, unsigned return outLen; } -int HdcSSLBase::RsaPubkeyEncrypt(const uint32_t sessionId, - const unsigned char *inBuf, int inLen, unsigned char *outBuf, const string &pubkey) +int HdcSSLBase::RsaPubkeyEncrypt(const unsigned char *inBuf, int inLen, + unsigned char *outBuf, int outBufSize, const string &pubkey) { int outLen = -1; #ifndef HDC_HOST - outLen = HdcAuth::RsaPubkeyEncryptPsk(sessionId, inBuf, inLen, outBuf, pubkey); + outLen = HdcAuth::RsaPubkeyEncryptPsk(inBuf, inLen, outBuf, outBufSize, pubkey); #endif return outLen; } diff --git a/src/common/hdc_ssl.h b/src/common/hdc_ssl.h index 7f84cdac651a98e69fd47fbfb904df1ba1dbee28..e2aefa61e7aa3ec15e5d069bf10bc5f9e50b8f9c 100644 --- a/src/common/hdc_ssl.h +++ b/src/common/hdc_ssl.h @@ -19,7 +19,7 @@ namespace Hdc { class HdcSSLBase { public: - explicit HdcSSLBase(const HSSLInfo &hSSLInfo); + explicit HdcSSLBase(SSLInfoPtr hSSLInfo); HdcSSLBase(const HdcSSLBase&) = delete; virtual ~HdcSSLBase(); int Encrypt(const int bufLen, uint8_t *bufPtr); @@ -38,7 +38,7 @@ public: bool GenPsk(); bool InputPsk(unsigned char *psk, int pskLen); int GetPskEncrypt(unsigned char *bufPtr, const int bufLen, const string &pubkey); - static void SetSSLInfo(HSSLInfo hSSLInfo, HSession hsession); + static void SetSSLInfo(SSLInfoPtr hSSLInfo, HSession hsession); int PerformHandshake(vector &outBuf); bool SetHandshakeLabel(HSession hSession); inline static int GetSSLBufLen(const int bufLen) @@ -47,8 +47,8 @@ public: } private: - static int RsaPubkeyEncrypt(const uint32_t sessionId, - const unsigned char *inBuf, int inLen, unsigned char *outBuf, const string &pubkey); + static int RsaPubkeyEncrypt(const unsigned char *inBuf, int inLen, + unsigned char *outBuf, int outBufSize, const string &pubkey); int DoSSLWrite(const int bufLen, uint8_t *bufPtr); int DoSSLRead(const int bufLen, int &index, uint8_t *bufPtr); bool isDaemon; diff --git a/src/common/password.cpp b/src/common/password.cpp index 0c6910fa14c207d1f9198eaeaeecdeb22caad411..76a621ff0ade5017e68a06d548d032f75509b8c1 100644 --- a/src/common/password.cpp +++ b/src/common/password.cpp @@ -132,7 +132,10 @@ bool HdcPassword::DecryptPwd(std::vector& encryptData) success = true; } while (0); - memset_s(result.first, result.second, 0, result.second); + if (memset_s(result.first, result.second, 0, PASSWORD_LENGTH) != EOK) { + WRITE_LOG(LOG_FATAL, "memset_s failed"); + success = false; + } delete[] result.first; return success; } diff --git a/src/daemon/daemon.cpp b/src/daemon/daemon.cpp index 69e20baed1b4e5bbe9ca2f371bb353f21d3ac9c8..31f70519b15feeb2f3c983d0113593a8c621723c 100755 --- a/src/daemon/daemon.cpp +++ b/src/daemon/daemon.cpp @@ -798,8 +798,6 @@ void HdcDaemon::DaemonSessionHandshakeInit(HSession &hSession, SessionHandShake // host( ) <--(TLS handshake server hello )--- hdcd( ) step 2 // host(ok) ---(TLS handshake change cipher)--> hdcd( ) step 3 // host(ok) <--(TLS handshake change cipher)--- hdcd(ok) step 4 -// host(ok) ---(encrypted: CHANNEL_CLOSE )--> hdcd(ok) step 5 -// host(ok) <--(encrypted: CHANNEL_CLOSE )--- hdcd(ok) step 6 #ifdef HDC_SUPPORT_ENCRYPT_TCP bool HdcDaemon::DaemonSSLHandshake(HSession hSession, const uint32_t channelId, SessionHandShake &handshake) { @@ -834,6 +832,7 @@ bool HdcDaemon::DaemonSSLHandshake(HSession hSession, const uint32_t channelId, reinterpret_cast(const_cast(bufString.c_str())), bufString.size()); } if (hssl->SetHandshakeLabel(hSession)) { + std::this_thread::sleep_for(std::chrono::milliseconds(SSL_HANDSHAKE_FINISHED_WAIT_TIME)); UpdateSessionAuthOk(hSession->sessionId); SendAuthOkMsg(handshake, channelId, hSession->sessionId); if (!hssl->ClearPsk()) { @@ -1249,9 +1248,18 @@ void HdcDaemon::SendAuthEncryptPsk(SessionHandShake &handshake, const uint32_t c UpdateSessionAuthPubkey(hSession->sessionId, pubkey); handshake.authType = AUTH_SSL_TLS_PSK; if (!hSession->classSSL) { - HSSLInfo hSSLInfo = new (std::nothrow) HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new (std::nothrow) HdcSSLInfo(); + if (!hSSLInfo) { + WRITE_LOG(LOG_FATAL, "SendAuthEncryptPsk new HdcSSLInfo failed"); + return; + } HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); hSession->classSSL = new (std::nothrow) HdcDaemonSSL(hSSLInfo); // long lifetime with session. + delete hSSLInfo; + if (!hSession->classSSL) { + WRITE_LOG(LOG_FATAL, "SendAuthEncryptPsk new HdcDaemonSSL failed"); + return; + } } HdcSSLBase *hssl = static_cast(hSession->classSSL); if (!hssl) { diff --git a/src/daemon/daemon_ssl.cpp b/src/daemon/daemon_ssl.cpp index 19a81400854672a609cdf11e46837c33d06d9632..2a147c1c6a57435523f353511bac267e7d890ae6 100644 --- a/src/daemon/daemon_ssl.cpp +++ b/src/daemon/daemon_ssl.cpp @@ -15,7 +15,7 @@ #ifdef HDC_SUPPORT_ENCRYPT_TCP #include "daemon_ssl.h" namespace Hdc { -HdcDaemonSSL::HdcDaemonSSL(const HSSLInfo &hSSLInfo) : HdcSSLBase(hSSLInfo) +HdcDaemonSSL::HdcDaemonSSL(SSLInfoPtr hSSLInfo) : HdcSSLBase(hSSLInfo) { } diff --git a/src/daemon/daemon_ssl.h b/src/daemon/daemon_ssl.h index b97a2b0943f96d3a9e5a07964ce51dcd53e92299..0f2a9135c9eddc8d32e5eb0f1d31fb8acddc1e62 100644 --- a/src/daemon/daemon_ssl.h +++ b/src/daemon/daemon_ssl.h @@ -19,7 +19,7 @@ namespace Hdc { class HdcDaemonSSL : public HdcSSLBase { public: - explicit HdcDaemonSSL(const HSSLInfo &hSSLInfo); + explicit HdcDaemonSSL(SSLInfoPtr hSSLInfo); ~HdcDaemonSSL(); const SSL_METHOD *SetSSLMethod() override; bool SetPskCallback() override; diff --git a/src/host/host_ssl.cpp b/src/host/host_ssl.cpp index db5dfa94833ebbe60a7cd2da3ca53e5936f71185..58740d8a5942d4d10800f12c5bba9a3096d6f588 100644 --- a/src/host/host_ssl.cpp +++ b/src/host/host_ssl.cpp @@ -15,7 +15,7 @@ #ifdef HDC_SUPPORT_ENCRYPT_TCP #include "host_ssl.h" namespace Hdc { -HdcHostSSL::HdcHostSSL(const HSSLInfo &hSSLInfo) : HdcSSLBase(hSSLInfo) +HdcHostSSL::HdcHostSSL(SSLInfoPtr hSSLInfo) : HdcSSLBase(hSSLInfo) { } diff --git a/src/host/host_ssl.h b/src/host/host_ssl.h index 0f14125ac2b6a151128225cadbf154739e1555e8..b942533252b57a9efbafbde4fe893ebf3f7f4538 100644 --- a/src/host/host_ssl.h +++ b/src/host/host_ssl.h @@ -19,7 +19,7 @@ namespace Hdc { class HdcHostSSL : public HdcSSLBase { public: - explicit HdcHostSSL(const HSSLInfo &hSSLInfo); + explicit HdcHostSSL(SSLInfoPtr hSSLInfo); ~HdcHostSSL(); const SSL_METHOD *SetSSLMethod() override; bool SetPskCallback() override; diff --git a/src/host/server.cpp b/src/host/server.cpp index 8d3e45586842bc4bfeb739042871fb9ecdffccb3..e85406d4a4d14dd1bb6299c80d8ae2f21aac29b7 100644 --- a/src/host/server.cpp +++ b/src/host/server.cpp @@ -562,8 +562,6 @@ void HdcServer::UpdateHdiInfo(Hdc::HdcSessionBase::SessionHandShake &handshake, // host( ) <--(TLS handshake server hello )--- hdcd( ) step 2 // host(ok) ---(TLS handshake change cipher)--> hdcd( ) step 3 // host(ok) <--(TLS handshake change cipher)--- hdcd(ok) step 4 -// host(ok) ---(encrypted: CHANNEL_CLOSE )--> hdcd(ok) step 5 -// host(ok) <--(encrypted: CHANNEL_CLOSE )--- hdcd(ok) step 6 bool HdcServer::ServerSSLHandshake(HSession hSession, SessionHandShake &handshake) { if (hSession->classSSL == nullptr) { @@ -596,11 +594,9 @@ bool HdcServer::ServerSSLHandshake(HSession hSession, SessionHandShake &handshak Send(hSession->sessionId, 0, CMD_KERNEL_HANDSHAKE, reinterpret_cast(const_cast(bufString.c_str())), bufString.size()); } - if (ret == RET_SSL_HANDSHAKE_FINISHED) { // SSL handshake step 5 + if (ret == RET_SSL_HANDSHAKE_FINISHED) { hssl->SetHandshakeLabel(hSession); - uint8_t count = 1; WRITE_LOG(LOG_DEBUG, "ssl handshake finished, SetHandshakeLabel"); - Send(hssl->sessionId, 0, CMD_KERNEL_CHANNEL_CLOSE, &count, 1); if (!hssl->ClearPsk()) { WRITE_LOG(LOG_WARN, "clear Pre Shared Key failed"); ret = ERR_GENERIC; @@ -628,13 +624,14 @@ bool HdcServer::ServerSessionSSLInit(HSession hSession, SessionHandShake &handsh WRITE_LOG(LOG_WARN, "ServerSessionSSLInit memset_s failed"); return false; } - HSSLInfo hSSLInfo = new (std::nothrow) HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new (std::nothrow) HdcSSLInfo(); if (!hSSLInfo) { - WRITE_LOG(LOG_WARN, "new HSSLInfo failed"); + WRITE_LOG(LOG_WARN, "new SSLInfoPtr failed"); return false; } HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); hSession->classSSL = new (std::nothrow) HdcHostSSL(hSSLInfo); + delete hSSLInfo; HdcSSLBase *hssl = static_cast(hSession->classSSL); if (!hssl) { WRITE_LOG(LOG_WARN, "new HdcHostSSL failed"); diff --git a/test/unittest/common/hdc_ssl_ut.cpp b/test/unittest/common/hdc_ssl_ut.cpp index 96419874c68c3f2798423b08710ac4443b691c8e..03cecac6ed70e478a785e066f83068e8b7940e6f 100644 --- a/test/unittest/common/hdc_ssl_ut.cpp +++ b/test/unittest/common/hdc_ssl_ut.cpp @@ -20,12 +20,12 @@ namespace Hdc { typedef size_t rsize_t; class MockHdcSSLBase : public HdcSSLBase { public: - MOCK_METHOD5(RsaPubkeyEncrypt, int(const uint32_t sessionId, const unsigned char* in, int inLen, - unsigned char* out, const std::string& pubkey)); + MOCK_METHOD5(RsaPubkeyEncrypt, int(const unsigned char* in, int inLen, + unsigned char* out, int outBufSize, const std::string& pubkey)); MOCK_METHOD0(IsHandshakeFinish, bool()); MOCK_METHOD0(ShowSSLInfo, void()); public: - explicit MockHdcSSLBase(const HSSLInfo &hSSLInfo) : HdcSSLBase(hSSLInfo) + explicit MockHdcSSLBase(SSLInfoPtr hSSLInfo) : HdcSSLBase(hSSLInfo) { } @@ -153,7 +153,7 @@ void SSLHandShakeEmulate(HdcSSLBase *sslClient, HdcSSLBase *sslServer) */ HWTEST_F(HdcSSLTest, SetSSLInfoTest001, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); hSession->serverOrDaemon = false; hSession->sessionId = 123; @@ -172,7 +172,7 @@ HWTEST_F(HdcSSLTest, SetSSLInfoTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, InitSSLTest001, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); HdcSSLBase *sslBase = new (std::nothrow) HdcDaemonSSL(hSSLInfo); @@ -194,7 +194,7 @@ HWTEST_F(HdcSSLTest, InitSSLTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, InitSSLTest002, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); HdcSSLBase *sslBase = new (std::nothrow) HdcHostSSL(hSSLInfo); @@ -216,7 +216,7 @@ HWTEST_F(HdcSSLTest, InitSSLTest002, TestSize.Level0) */ HWTEST_F(HdcSSLTest, ClearSSLTest001, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); HdcSSLBase *sslBase = new (std::nothrow) HdcHostSSL(hSSLInfo); @@ -240,7 +240,7 @@ HWTEST_F(HdcSSLTest, ClearSSLTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, ClearSSLTest002, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); HdcSSLBase *sslBase = new (std::nothrow) HdcDaemonSSL(hSSLInfo); @@ -266,12 +266,10 @@ HWTEST_F(HdcSSLTest, ClearSSLTest002, TestSize.Level0) // host( ) <--(TLS handshake server hello )--- hdcd( ) step 2 // host(ok) ---(TLS handshake change cipher)--> hdcd( ) step 3 // host(ok) <--(TLS handshake change cipher)--- hdcd(ok) step 4 -// host(ok) ---(encrypted: CHANNEL_CLOSE )--> hdcd(ok) step 5 -// host(ok) <--(encrypted: CHANNEL_CLOSE )--- hdcd(ok) step 6 HWTEST_F(HdcSSLTest, DoSSLHandshakeTest001, TestSize.Level0) { - HSSLInfo hSSLInfoDaemon = new HdcSSLInfo(); - HSSLInfo hSSLInfoHost = new HdcSSLInfo(); + SSLInfoPtr hSSLInfoDaemon = new HdcSSLInfo(); + SSLInfoPtr hSSLInfoHost = new HdcSSLInfo(); HSession hSessionDaemon = new HdcSession(); HSession hSessionHost = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfoDaemon, hSessionDaemon); @@ -320,7 +318,7 @@ HWTEST_F(HdcSSLTest, DoSSLHandshakeTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, InputPskTest001, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); HdcSSLBase *sslClient = new (std::nothrow) HdcHostSSL(hSSLInfo); @@ -381,6 +379,31 @@ HWTEST_F(HdcSSLTest, PskServerCallbackTest001, TestSize.Level0) SSL_CTX_free(sslCtx); } +/** + * @tc.name: PskServerCallbackTest002 + * @tc.desc: test PskServerCallback function with no pskInput + * @tc.type: FUNC + */ +HWTEST_F(HdcSSLTest, PskServerCallbackTest002, TestSize.Level0) +{ + SSL_library_init(); + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + SSL *ssl; + SSL_CTX *sslCtx; + const SSL_METHOD *method; + method = TLS_server_method(); + sslCtx = SSL_CTX_new(method); + ssl = SSL_new(sslCtx); + SSL_set_accept_state(ssl); + unsigned char psk[BUF_SIZE_PSK]; + unsigned int maxPskLen = BUF_SIZE_PSK; + ASSERT_EQ(HdcSSLBase::PskServerCallback(ssl, STR_PSK_IDENTITY.c_str(), psk, maxPskLen), 0); + SSL_shutdown(ssl); + SSL_free(ssl); + SSL_CTX_free(sslCtx); +} + /** * @tc.name: PskClientCallbackTest001 * @tc.desc: test PskClientCallback function with normal and error input. @@ -419,6 +442,39 @@ HWTEST_F(HdcSSLTest, PskClientCallbackTest001, TestSize.Level0) ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, maxIdentityLen, psk, validLen), 0); ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, validLen, psk, maxPskLen), 0); ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, validLen, pskValid, maxPskLen), 0); + ASSERT_EQ(HdcSSLBase::PskClientCallback(ssl, hint, identity, STR_PSK_IDENTITY.size(), pskValid, maxPskLen), 0); + SSL_shutdown(ssl); + SSL_free(ssl); + SSL_CTX_free(sslCtx); +} + +/** + * @tc.name: PskClientCallbackTest001 + * @tc.desc: test PskClientCallback function with no pskInput. + * @tc.type: FUNC + */ +HWTEST_F(HdcSSLTest, PskClientCallbackTest002, TestSize.Level0) +{ + SSL_library_init(); + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + SSL *ssl; + SSL_CTX *sslCtx; + const SSL_METHOD *method; + method = TLS_client_method(); + sslCtx = SSL_CTX_new(method); + ssl = SSL_new(sslCtx); + SSL_set_connect_state(ssl); + const char* hint = STR_PSK_IDENTITY.c_str(); + char identity[BUF_SIZE_PSK]; + unsigned int maxIdentityLen = BUF_SIZE_PSK; + unsigned char psk[BUF_SIZE_PSK]; + unsigned int maxPskLen = BUF_SIZE_PSK; + unsigned int ret = HdcSSLBase::PskClientCallback(ssl, hint, identity, maxIdentityLen, psk, maxPskLen); + ASSERT_EQ(ret, 0); + SSL_shutdown(ssl); + SSL_free(ssl); + SSL_CTX_free(sslCtx); } /** @@ -428,7 +484,7 @@ HWTEST_F(HdcSSLTest, PskClientCallbackTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, RsaPrikeyDecryptTest001, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); MockHdcSSLBase *sslBase = new (std::nothrow) MockHdcSSLBase(hSSLInfo); @@ -446,11 +502,10 @@ HWTEST_F(HdcSSLTest, RsaPrikeyDecryptTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, RsaPubkeyEncryptTest001, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); MockHdcSSLBase *sslBase = new (std::nothrow) MockHdcSSLBase(hSSLInfo); - uint32_t sessionId = 12345; unsigned char in[BUF_SIZE_DEFAULT2] = "test data"; int inLen = strlen((char*)in); unsigned char out[BUF_SIZE_DEFAULT2]; @@ -459,7 +514,7 @@ HWTEST_F(HdcSSLTest, RsaPubkeyEncryptTest001, TestSize.Level0) EXPECT_CALL(*sslBase, RsaPubkeyEncrypt(testing::_, testing::_, testing::_, testing::_, testing::_)) .WillOnce(testing::Return(inLen)); - int ret = sslBase->RsaPubkeyEncrypt(sessionId, in, inLen, out, pubkey); + int ret = sslBase->RsaPubkeyEncrypt(in, inLen, out, BUF_SIZE_DEFAULT2, pubkey); ASSERT_EQ(ret, inLen); delete sslBase; delete hSSLInfo; @@ -473,8 +528,8 @@ HWTEST_F(HdcSSLTest, RsaPubkeyEncryptTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, SetHandshakeLabelTest001, TestSize.Level0) { - HSSLInfo hSSLInfoDaemon = new HdcSSLInfo(); - HSSLInfo hSSLInfoHost = new HdcSSLInfo(); + SSLInfoPtr hSSLInfoDaemon = new HdcSSLInfo(); + SSLInfoPtr hSSLInfoHost = new HdcSSLInfo(); HSession hSessionDaemon = new HdcSession(); HSession hSessionHost = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfoDaemon, hSessionDaemon); @@ -509,8 +564,8 @@ HWTEST_F(HdcSSLTest, SetHandshakeLabelTest001, TestSize.Level0) */ HWTEST_F(HdcSSLTest, SetHandshakeLabelTest002, TestSize.Level0) { - HSSLInfo hSSLInfoDaemon = new HdcSSLInfo(); - HSSLInfo hSSLInfoHost = new HdcSSLInfo(); + SSLInfoPtr hSSLInfoDaemon = new HdcSSLInfo(); + SSLInfoPtr hSSLInfoHost = new HdcSSLInfo(); HSession hSessionDaemon = new HdcSession(); HSession hSessionHost = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfoDaemon, hSessionDaemon); @@ -542,7 +597,7 @@ HWTEST_F(HdcSSLTest, SetHandshakeLabelTest002, TestSize.Level0) */ HWTEST_F(HdcSSLTest, GetPskEncryptTest001, TestSize.Level0) { - HSSLInfo hSSLInfo = new HdcSSLInfo(); + SSLInfoPtr hSSLInfo = new HdcSSLInfo(); HSession hSession = new HdcSession(); HdcSSLBase::SetSSLInfo(hSSLInfo, hSession); hSSLInfo->isDaemon = false;