diff --git a/.gitignore b/.gitignore index 46e947f7f60e5f095569673939896a29c2a39d35..5a4aa7088ea45ead93c11cfa8a60defd74e6e902 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,7 @@ target_wrapper.* # QtCreator CMake CMakeLists.txt.user* +.idea +__pycache__ +bitXiaoSha/data/ +.vscode diff --git a/Client/Client.pro b/Client/Client.pro index 853be7688dab6ffc11252033badc81bdbe12410d..ebba76d86a2334a12168496674a9ee8d68b376ac 100644 --- a/Client/Client.pro +++ b/Client/Client.pro @@ -39,6 +39,7 @@ HEADERS += \ clientmain.h \ databaseoperation.h \ kuang.h \ + ltest.h \ mainwindow.h \ message.h \ messagemodel.h \ @@ -49,6 +50,7 @@ HEADERS += \ FORMS += \ kuang.ui \ mainwindow.ui \ + message.ui \ userlogin.ui \ userregister.ui @@ -58,4 +60,5 @@ else: unix:!android: target.path = /opt/$${TARGET}/bin !isEmpty(target.path): INSTALLS += target RESOURCES += \ + rsc.qrc \ rsc.qrc diff --git a/Client/client.cpp b/Client/client.cpp deleted file mode 100644 index 429be7b25a0610bf52d28e16061e773ffaaacd97..0000000000000000000000000000000000000000 --- a/Client/client.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "client.h" -#include "ui_client.h" -#include -#include -#include - -Client::Client(QWidget *parent) - : QMainWindow(parent) - , ui(new Ui::Client) -{ - ui->setupUi(this); - this->hide(); - if(connectserver()){ - userlogin = new UserLogin(this); - } -} - -Client::~Client() -{ - delete ui; -} - - -void Client::send(QJsonObject data) -{ - QString str = QString(QJsonDocument(data).toJson()); - socket->write(str.toUtf8()); -} - -bool Client::connectserver() -{ - socket = new QTcpSocket(this); - socket->connectToHost (QHostAddress("127.0.0.7"),8888); - bool bo = false; - - //连接信息提示 - connect(socket, &QTcpSocket::connected,this, [=,&bo](){ - QMessageBox::information (this, "连接信息", "连接成功!"); - bo = true; - }); - connect(socket, &QTcpSocket::disconnected,this, [=,&bo](){ - QMessageBox::information (this, "连接信息", "断开连接!"); - bo = false; - }); - - //接受消息 - connect(socket, &QTcpSocket::readyRead, this, &Client::receiveMessage); - return bo; -} - -QJsonObject Client::receiveMessage() -{ - QByteArray arr = socket->readAll (); - QJsonDocument doc = QJsonDocument::fromJson(arr); - return doc.object(); -} diff --git a/Client/client.h b/Client/client.h deleted file mode 100644 index 99fd60b8eb4521883598ece17e36d570b5844264..0000000000000000000000000000000000000000 --- a/Client/client.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef CLIENT_H -#define CLIENT_H - -#include -#include -#include - -QT_BEGIN_NAMESPACE -namespace Ui { class Client; } -QT_END_NAMESPACE - -class Client : public QMainWindow -{ - Q_OBJECT - -public: - Client(QWidget *parent = nullptr); - ~Client(); - UserLogin *userlogin = nullptr; - QTcpSocket *socket; - -private slots: - - //发送功能实现 - void send(QJsonObject data); - //连接功能实现 - bool connectserver(); - //接收并打印的槽函数 - QJsonObject receiveMessage(); - -private: - Ui::Client *ui; -}; -#endif // CLIENT_H diff --git a/Client/clientdatacenter.cpp b/Client/clientdatacenter.cpp index 703de75f734da272eed35e169f08b49255d5e275..b906e54dca5fb729d07d3a66ae3ad25b1a2d901b 100644 --- a/Client/clientdatacenter.cpp +++ b/Client/clientdatacenter.cpp @@ -2,16 +2,17 @@ #include -void ClientDataCenter::registerUser(OnlineUserModel * newuser) { +void ServerDataCenter::registerUser(OnlineUserModel * newuser) { if (users.contains(newuser->getUsername())) { return; } users[newuser->getUsername()] = newuser; newuser->setParent(this); registeredObjects.append(newuser); + } -void ClientDataCenter::registerSession(OnlineSession *session) { +void ServerDataCenter::registerSession(OnlineSession *session) { if (sessions.contains(session->getSessionID())) { return; } @@ -20,17 +21,29 @@ void ClientDataCenter::registerSession(OnlineSession *session) { registeredObjects.append(session); } -void ClientDataCenter::registerMessage(OnlineMessage *msg) { +void ServerDataCenter::registerMessage(OnlineMessage *msg) { if (messages.contains({msg->getSessionID(), msg->getMessageID()})) { return; } messages[{msg->getSessionID(), msg->getMessageID()}] = msg; msg->setParent(this); registeredObjects.append(msg); - qDebug() << "### ClientDataCenter Down"; + qDebug() << "### ServerDataCenter Down"; +} + + +bool ServerDataCenter::hasUser(QString username) { + return _getUser(username) != nullptr; +} + +bool ServerDataCenter::hasSession(int sessionId) { + return _getSession(sessionId) != nullptr; +} +bool ServerDataCenter::hasMessage(int sessionId, int messageId) { + return _getMessage(sessionId, messageId) != nullptr; } -void ClientDataCenter::clean() { +void ServerDataCenter::clean() { users.clear(); sessions.clear(); messages.clear(); diff --git a/Client/clientdatacenter.h b/Client/clientdatacenter.h index 21159c25f94d1de48374b03010ab965d12555a2d..b3606d239013e81e6e3faa04f04a30727665ce23 100644 --- a/Client/clientdatacenter.h +++ b/Client/clientdatacenter.h @@ -1,5 +1,5 @@ -#ifndef CLIENTDATACENTER_H -#define CLIENTDATACENTER_H +#ifndef SERVERDATACENTER_H +#define SERVERDATACENTER_H #include #include @@ -9,15 +9,16 @@ #include "Session/abstractsession.h" #include "Session/onlinesession.h" #include "Session/offlinesession.h" + #include "usermodel.h" #include "messagemodel.h" -class ClientDataCenter : public QObject +class ServerDataCenter : public QObject { Q_OBJECT public: - static ClientDataCenter& Singleton(QObject * parent = nullptr) { - static ClientDataCenter * singleton = new ClientDataCenter(parent); + static ServerDataCenter& Singleton(QObject * parent = nullptr) { + static ServerDataCenter * singleton = new ServerDataCenter(parent); return * singleton; } @@ -37,7 +38,7 @@ public slots: void clean(); private: - explicit ClientDataCenter(QObject *parent = nullptr); + explicit ServerDataCenter(QObject *parent = nullptr); QMap users; QList registeredObjects; QMap sessions; @@ -45,7 +46,8 @@ private: OnlineSession* _getSession(int SessionId); OnlineUserModel* _getUser(QString username); OnlineMessage* _getMessage(int SessionId, int MessageId); + }; -#endif // CLIENTDATACENTER_H +#endif // SERVERDATACENTER_H diff --git a/Client/clientmain.cpp b/Client/clientmain.cpp index 7b793afeb78a6d3f220f2652f36cd1fed21ce151..6628b824d0e0e12155657a40f70796a784860c19 100644 --- a/Client/clientmain.cpp +++ b/Client/clientmain.cpp @@ -1,11 +1,20 @@ #include "clientmain.h" #include +#include #include ClientMain::ClientMain(QString IPAddress, int portOpen, QObject *parent) : QObject(parent), ipAdd(IPAddress), port(portOpen) { connectToServer(); + connect(socket, &QTcpSocket::connected, this, [=](){ + emit serverConnected(); + }); + connect(socket, &QTcpSocket::disconnected,this, [=](){ + emit serverDisconnected(); + }); + //接受消息 + connect(socket, &QTcpSocket::readyRead, this, &ClientMain::receiveMessage); } @@ -33,20 +42,62 @@ void ClientMain::connectToServer() { socket->connectToHost (QHostAddress(ipAdd), port); } -QJsonObject ClientMain::receiveMessage() +void ClientMain::processMethod(QJsonObject data) { + if(data["MsgType"].toString()=="UserData"){ + emit UserDataReceived(data); + } + if(data["MsgType"].toString()=="LogInConfirm"){ + emit LogInConfirmReceived(data); + } + if(data["MsgType"].toString()=="RegistConfirm"){ + emit RegistConfirmReceived(data); + } + if(data["MsgType"].toString()=="SessionMessage"){ + emit SessionMessageReceived(data); + } + if(data["MsgType"].toString()=="SessionData"){ + if(data["SessionType"].toString()=="FRIEND"){ + emit FriendSessionDataReceived(data); + } + if(data["SessionType"].toString()=="GROUP"){ + emit GroupSessionDataReceived(data); + } + } +} + +void ClientMain::receiveMessage() { QByteArray arr = socket->readAll (); QJsonDocument doc = QJsonDocument::fromJson(arr); - return doc.object(); + QJsonObject data = doc.object(); + if(!data.contains("MsgType")) return; + if (data["MsgType"].toString() == "JsonArray") { + auto array = data["MsgList"].toArray(); + for (int i = 0; i < array.size(); i++) { + processMethod(array[i].toObject()); + } + } + else { + processMethod(data); + } } void ClientMain::operator()() { login = new UserLogin(); login->show(); + regist = new UserRegister(); + // register form connect(login, &UserLogin::registerButtonClicked, this, [=]() { - regist = new UserRegister(); regist->show(); }); + connect(regist,&UserRegister::registfinished,login,&UserLogin::show); + connect(login,&UserLogin::sendlogindata,this,&ClientMain::send); + connect(regist,&UserRegister::sendregistdata,this,&ClientMain::send); + connect(this,&ClientMain::LogInConfirmReceived,login,&UserLogin::loginconfirm); + connect(this,&ClientMain::RegistConfirmReceived,regist,&UserRegister::registconfirm); + connect(this,&ClientMain::FriendSessionDataReceived,this->login->main,&MainWindow::FriendSessionAdd); + connect(this,&ClientMain::GroupSessionDataReceived,this->login->main,&MainWindow::GroupSessionAdd); + connect(this,&ClientMain::SessionMessageReceived,this->login->main,&MainWindow::AddMessage); } diff --git a/Client/clientmain.h b/Client/clientmain.h index df0131652988a779b2b6a0ce18fa72d3589e60c3..d9fcb277df96cc787cda738269f8478a908f1877 100644 --- a/Client/clientmain.h +++ b/Client/clientmain.h @@ -23,19 +23,30 @@ public: //发送功能实现 void send(QJsonObject data); //接收并打印的槽函数 - QJsonObject receiveMessage(); + + void receiveMessage(); + void processMethod(QJsonObject data); bool isConnected() { return is_connected; } + void createmainwindow(QJsonObject data); + void MessageFromMainwindow(const QString & sendername, const QString &text); + + void receiveJsonObject(QJsonObject); signals: void serverConnected(); void serverDisconnected(); + void UserDataReceived(QJsonObject data); + void LogInConfirmReceived(QJsonObject data); + void SessionMessageReceived(QJsonObject data); + void FriendSessionDataReceived(QJsonObject data); + void GroupSessionDataReceived(QJsonObject data); + void RegistConfirmReceived(QJsonObject data); private: ClientMain(QString IPAddress, int portOpen, QObject *parent = nullptr); QTcpSocket *socket; bool is_connected = false; void connectToServer(); - QString ipAdd; int port; UserLogin * login; diff --git a/Client/clientmainwindow.cpp b/Client/clientmainwindow.cpp deleted file mode 100644 index 1669bedda0a8176ff57ce9e53c95663dd7e3a6a6..0000000000000000000000000000000000000000 --- a/Client/clientmainwindow.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "clientmainwindow.h" -#include "ui_clientmainwindow.h" - -ClientMainWindow::ClientMainWindow(QWidget *parent) - : QMainWindow(parent) - , ui(new Ui::ClientMainWindow) -{ - ui->setupUi(this); -} - -ClientMainWindow::~ClientMainWindow() -{ - delete ui; -} - diff --git a/Client/clientmainwindow.h b/Client/clientmainwindow.h deleted file mode 100644 index 2818a2bedd7ba7f73c2a206761408941f36745d8..0000000000000000000000000000000000000000 --- a/Client/clientmainwindow.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef CLIENTMAINWINDOW_H -#define CLIENTMAINWINDOW_H - -#include - -QT_BEGIN_NAMESPACE -namespace Ui { class ClientMainWindow; } -QT_END_NAMESPACE - -class ClientMainWindow : public QMainWindow -{ - Q_OBJECT - -public: - ClientMainWindow(QWidget *parent = nullptr); - ~ClientMainWindow(); - -private: - Ui::ClientMainWindow *ui; -}; -#endif // CLIENTMAINWINDOW_H diff --git a/Client/clientmainwindow.ui b/Client/clientmainwindow.ui deleted file mode 100644 index b785ed0a882765d8b65bad8d98d678f560af877f..0000000000000000000000000000000000000000 --- a/Client/clientmainwindow.ui +++ /dev/null @@ -1,22 +0,0 @@ - - - ClientMainWindow - - - - 0 - 0 - 800 - 600 - - - - ClientMainWindow - - - - - - - - diff --git a/Client/databaseoperation.cpp b/Client/databaseoperation.cpp index c3938c0480bb502814738b934d36fd8c165be20e..18b00dc30332a8b9260c3439d10f3aff2d3d7346 100644 --- a/Client/databaseoperation.cpp +++ b/Client/databaseoperation.cpp @@ -36,6 +36,9 @@ void DatabaseOperation::startDatabaseConnection(QString dbfilename) { } status = Status::Running; createTables(); + findAllUsers(); + findAllSessions(); + findAllMessages(); } void DatabaseOperation::executeSqlStatement(QString str) { @@ -46,7 +49,10 @@ void DatabaseOperation::executeSqlStatement(QString str) { } void DatabaseOperation::createTables() { - executeSqlStatement("CREATE TABLE Message(SessionID INT NOT NULL, MessageID INT NOT NULL, SenderUsername TEXT NOT NULL, MessageText TEXT NOT NULL, Profile TEXT, PRIMARY KEY(SessionID, MessageID))"); + executeSqlStatement("CREATE TABLE User(Username TEXT PRIMARY KEY, Nickname TEXT NOT NULL, Password TEXT NOT NULL, Profile TEXT)"); + executeSqlStatement("CREATE TABLE Session(SessionID INT PRIMARY KEY, SessionType TEXT NOT NULL, Profile TEXT)"); + executeSqlStatement("CREATE TABLE Message(SessionID INT NOT NULL, MessageID INT NOT NULL, SenderUsername TEXT NOT NULL, MessageText TEXT NOT NULL, Profile TEXT, PRIMARY KEY(SessionID, MessageID), FOREIGN KEY(SessionID) REFERENCES Session(SessionID), FOREIGN KEY(SenderUserName) REFERENCES User(Username))"); + executeSqlStatement("CREATE TABLE IsMember(SessionID INT NOT NULL, Username TEXT NOT NULL, PRIMARY KEY (Username, SessionID), FOREIGN KEY (Username) REFERENCES User(username), FOREIGN KEY (SessionID) REFERENCES Session(SessionID))"); } bool DatabaseOperation::isDBExist(QString dbfilename) const { @@ -62,6 +68,50 @@ void DatabaseOperation::closeDB() { emit signal_DBstop(); } +QList DatabaseOperation::findAllUsers() { + QList ret; + QSqlQuery sqlQuery; + sqlQuery.exec("SELECT Username, Nickname, Password, Profile FROM User"); + if(!sqlQuery.exec()) { + qDebug() << "Error: Fail to query table. " << sqlQuery.lastError(); + throw "DB read error"; + } + else { + auto & dcenter = ServerDataCenter::Singleton(); + while(sqlQuery.next()) { + QString username = sqlQuery.value(0).toString(); + QString nickname = sqlQuery.value(1).toString(); + QString profile = sqlQuery.value(3).toString(); + auto newuser = new OnlineUserModel(username, nickname, str2json(profile)); + dcenter.registerUser(newuser); + ret.append(newuser); + } + } + return ret; +} + +QList DatabaseOperation::findAllSessions() { + QList ret; + QSqlQuery query; + if (!query.exec("SELECT SessionID, Profile FROM Session")) { + qDebug() << "findAllSessions: " << query.lastError(); + throw "DB read error"; + } + auto & dcenter = ServerDataCenter::Singleton(); + while (query.next()) { + int sessionId = query.value(0).toInt(); + QJsonObject json = query.value(1).toJsonObject(); + QString sessionName = json.contains("SessionName") ? json["SessionName"].toString() : "None"; + + QList members = queryMembersBySession(sessionId); + + OnlineSession * session = new OnlineSession(sessionId, sessionName, json, members); + ret.append(session); + dcenter.registerSession(session); + } + return ret; +} + QList DatabaseOperation::findAllMessages() { QList ret; QSqlQuery query; @@ -75,12 +125,43 @@ QList DatabaseOperation::findAllMessages() { auto * msg = new OnlineMessage(msgId, sessionID, sender, text, profile); ret.append(msg); - ClientDataCenter::Singleton().registerMessage(msg); + ServerDataCenter::Singleton().registerMessage(msg); } } return ret; } + +//往User表中插入数据 +bool DatabaseOperation::insertUser(const char* username, const char* nickname, const char* password, const char* profile){ + QSqlQuery query; +// QString insert_sql = "insert into User(Username, Nickname, Password, Profile) values (?, ?, ?, ?)"; + QString insert_sql = QString("INSERT INTO User(Username, Nickname, Password, Profile) VALUES ('%1', '%2', '%3', '%4')") + .arg(username).arg(nickname).arg(password).arg(profile); + qDebug() << insert_sql; +// query.prepare(insert_sql); +// query.addBindValue(username); +// query.addBindValue(nickname); +// query.addBindValue(password); +// query.addBindValue(profile); + if (! query.exec(insert_sql) ) { + qDebug() << query.lastError(); + return false; + } + ServerDataCenter::Singleton().registerUser(new OnlineUserModel( + QString(username), QString(nickname), str2json(QString(profile)) )); + return true; +} + +bool DatabaseOperation::insertUser(const OnlineUserModel &user, const QString &password) { + return insertUser(user.getUsername().toUtf8().data(), + user.getNickname().toUtf8().data(), + password.toUtf8().data(), + json2str(user.getProfile()).toUtf8().data()); +} + + + int DatabaseOperation::getTableCount(const char * tableName) const { QSqlQuery query; QString sql = QString("SELECT COUNT(*) FROM %1").arg(tableName); @@ -93,6 +174,91 @@ int DatabaseOperation::getTableCount(const char * tableName) const { return query.value(0).toInt(); } +//往 Session 表中插入数据 +int DatabaseOperation::insertSessionBasicInfo(const char* SessionType, const char* profile) { + QSqlQuery query; + QString insert_sql = "insert into Session(SessionID, SessionType, Profile) values(?, ?, ?)"; + query.prepare(insert_sql); + int sessionId = getTableCount("Session") + 1; + query.addBindValue(sessionId); + query.addBindValue(SessionType); + query.addBindValue(profile); + if (!query.exec()) { + qDebug() << query.lastError(); + return -1; + } + return sessionId; +} + +bool DatabaseOperation::insertMember(int sessionID, const char * user){ + QSqlQuery query; + QString insert_sql = "insert into IsMember values(?, ?)"; + query.prepare(insert_sql); + query.addBindValue(sessionID); + query.addBindValue(user); + if(!query.exec()){ + qDebug()<<"query error: "< DatabaseOperation::queryMembersBySession(int sessionID){ + QList member_List; + QSqlQuery query; + QString query_sql = "SELECT SessionID, Username FROM IsMember WHERE sessionID = (?)"; + query.prepare(query_sql); + query.addBindValue(sessionID); + if (! query.exec()) { + qDebug() << "error occurred in queryMembersBySession, " << query.lastError(); + return member_List; + } + while(query.next()){ + member_List.append(query.value(1).toString()); + } + return member_List; +} + +//查询用户所拥有的session +QList DatabaseOperation::querySessionsByMember(const char * username){ + QList member_List; + QSqlQuery query; + QString query_sql = "SELECT SessionId, Username FROM IsMember WHERE username = (?)"; + query.prepare(query_sql); + query.addBindValue(username); + query.exec(); + while(query.next()){ + member_List.append(query.value(0).toInt()); + } + return member_List; +} + +bool DatabaseOperation::attemptLogIn(QString username, QString password) const { + //用户名检测 + QSqlQuery query(database); + query.prepare("select username, password from User where username=?"); + query.addBindValue(username); + bool ok = query.exec(); + if(!ok){ + qDebug() << "Fail query register username" << query.lastError(); + return false; + } + if(query.next()){ + //密码检测 + if (query.value(1).toString() == password) + return true; + qDebug() << "password incorrect"; + return false; + } + qDebug() << "Username not found"; + return false; +} + int DatabaseOperation::insertNewMessage(int SessionId, const char *senderUsername, const char *MessageText, const char *profile) { QSqlQuery query; QString sql = "select count (*) from Message WHERE SessionId = ?"; @@ -119,6 +285,93 @@ int DatabaseOperation::insertNewMessage(int SessionId, const char *senderUsernam return msgId; } +QList DatabaseOperation::getMessageListBySessionID(int SessionId) const { + QList ret; + QSqlQuery query; + QString sql = "SELECT MessageID, SenderUsername, MessageText, Profile FROM Message WHERE SessionID = ?"; + query.prepare(sql); + query.addBindValue(SessionId); + if (!query.exec()) { + qDebug() << "getMessageListBySessionID: " << query.lastError(); + throw query.lastError(); + } + while(query.next()) { + int msgId = query.value(0).toInt(); + QString senderUsername = query.value(1).toString(); + QString messageText = query.value(2).toString(); + QJsonObject profile = query.value(3).toJsonObject(); + auto * msg = new OnlineMessage(msgId, SessionId, senderUsername, messageText, profile); + ret.append(msgId); + ServerDataCenter::Singleton().registerMessage(msg); + } + return ret; +} + +OnlineUserModel & ServerDataCenter::getUser(QString username) { + if (_getUser(username) == nullptr) throw "Not exist"; + return *users[username]; +} + +OnlineMessage & ServerDataCenter::getMessage(int SessionId, int MessageId) { + if (_getMessage(SessionId, MessageId)) throw "Not exist"; + return *messages[{SessionId, MessageId}]; +} + +OnlineSession & ServerDataCenter::getSession(int SessionId) { + if (_getSession(SessionId) == nullptr) throw "Not exist"; + return *sessions[SessionId]; +} + +OnlineUserModel* ServerDataCenter::_getUser(QString username) { + if (users.contains(username)) + return users[username]; + return nullptr; +} + +OnlineSession* ServerDataCenter::_getSession(int SessionId) { + if (sessions.contains(SessionId)) + return sessions[SessionId]; + return nullptr; +} + +OnlineMessage* ServerDataCenter::_getMessage(int SessionId, int MessageId) { + if (messages.contains({SessionId, MessageId})) + return messages[{SessionId, MessageId}]; + return nullptr; +} + +OnlineUserModel * DatabaseOperation::findUser(QString username) { + QSqlQuery query; + QString sql = "SELECT Username, Nickname, Profile FROM User WHERE Username = " + username; + if (!query.exec(sql) || !query.first()) { + qDebug() << "DBOps::findUser: " << query.lastError(); + return nullptr; + } + OnlineUserModel * ret = new OnlineUserModel(query.value(0).toString(), + query.value(1).toString(), + query.value(2).toJsonObject()); + ServerDataCenter::Singleton().registerUser(ret); + return ret; +} + +OnlineSession * DatabaseOperation::findSession(int sessionID) { + QSqlQuery query; + QString sql = "SELECT Profile FROM Session WHERE SessionID = " + QString::number(sessionID); + if (!query.exec(sql) || !query.first()) { + qDebug() << "DBOps::findSession: " << query.lastError(); + return nullptr; + } + + auto json = query.value(0).toJsonObject(); + QString SessionName = json.contains("SessionName") ? + json["SessionName"].toString() : "None"; + + OnlineSession * ret = new OnlineSession(sessionID, SessionName, json, + queryMembersBySession(sessionID)); + ServerDataCenter::Singleton().registerSession(ret); + return ret; +} + OnlineMessage * DatabaseOperation::findMessage(int sessionId, int MessageId) { QSqlQuery query; QString sql = "SELECT SenderUsername, MessageText, Profile FROM Message WHERE SessionID = " + @@ -132,11 +385,11 @@ OnlineMessage * DatabaseOperation::findMessage(int sessionId, int MessageId) { QString text = query.value(1).toString(); QJsonObject json = query.value(2).toJsonObject(); OnlineMessage * ret = new OnlineMessage(MessageId, sessionId, sender, text, json); - ClientDataCenter::Singleton().registerMessage(ret); + ServerDataCenter::Singleton().registerMessage(ret); return ret; } -ClientDataCenter::ClientDataCenter(QObject *parent) : QObject(parent) +ServerDataCenter::ServerDataCenter(QObject *parent) : QObject(parent) { - connect(&DatabaseOperation::Singleton(), &DatabaseOperation::signal_DBstop, this, &ClientDataCenter::clean); + connect(&DatabaseOperation::Singleton(), &DatabaseOperation::signal_DBstop, this, &ServerDataCenter::clean); } diff --git a/Client/databaseoperation.h b/Client/databaseoperation.h index 8b655d090186430785a909c1740636516712b3c0..58dd26847e197e86027cc463d3693049713e0f8a 100644 --- a/Client/databaseoperation.h +++ b/Client/databaseoperation.h @@ -27,12 +27,28 @@ public: bool isDBExist(QString str) const; void closeDB(); int getTableCount(const char * tableName) const; + QList findAllUsers(); + QList findAllSessions(); QList findAllMessages(); + bool insertUser(const char * username, const char * nickname, const char * password, const char * profile); + bool insertUser(const OnlineUserModel &user, const QString &password); + // 返回SessionID + int insertSessionBasicInfo(const char * sessionType, const char * profile); bool isRunning() const { return status == Status::Running; } + bool insertMember(int sessionID, const OnlineUserModel& user); + bool insertMember(int sessionID, const char * user); int insertNewMessage(int SessionId, const char *senderUsername, const char *MessageText, const char *profile); + QList queryMembersBySession(int sessionID); + QList querySessionsByMember(const char * username); + QList getMessageListBySessionID(int SessionId) const; + OnlineMessage * findMessage(int sessionId, int MessageId); + OnlineSession * findSession(int sessionID); + OnlineUserModel * findUser(QString username); + + bool attemptLogIn(QString username, QString password) const; signals: void signal_DBstop(); diff --git a/Client/kuang.ui b/Client/kuang.ui index ae630d6a6cfe9c01c77c298d32598df8e46ba043..113ff0c5d708ecb554c98dbf63464f40af349e23 100644 --- a/Client/kuang.ui +++ b/Client/kuang.ui @@ -1,51 +1,51 @@ - - - Kuang - - - - 0 - 0 - 200 - 75 - - - - Form - - - - - 10 - 15 - 50 - 50 - - - - QFrame::Box - - - profile - - - - - - 70 - 20 - 121 - 41 - - - - QFrame::Box - - - name - - - - - - + + + Kuang + + + + 0 + 0 + 200 + 75 + + + + Form + + + + + 10 + 15 + 50 + 50 + + + + QFrame::Box + + + profile + + + + + + 70 + 20 + 121 + 41 + + + + QFrame::Box + + + name + + + + + + diff --git a/Client/main.cpp b/Client/main.cpp index 743a7fb870f775e0c7d56721a011b9a982c5b25a..e83dba1b0c2a8971dff89cc30796d001be3db3cf 100644 --- a/Client/main.cpp +++ b/Client/main.cpp @@ -3,7 +3,6 @@ #include "mainwindow.h" #include "kuang.h" #include "clientmain.h" - #include #include #include diff --git a/Client/mainwindow.cpp b/Client/mainwindow.cpp index 86d54840f1295ee9d282bffdd81cc3b2a39f35c6..671fdf142d0eb23dc86de951b4908e8fb5f8b53f 100644 --- a/Client/mainwindow.cpp +++ b/Client/mainwindow.cpp @@ -2,13 +2,6 @@ #include "ui_mainwindow.h" #include #include -MainWindow::MainWindow(QWidget *parent) : - QMainWindow(parent), - ui(new Ui::MainWindow) -{ - ui->setupUi(this); -} - MainWindow::~MainWindow() { delete ui; @@ -94,13 +87,16 @@ void MainWindow::resizeEvent(QResizeEvent *event) } } -MainWindow::MainWindow(QWidget *parent,QJsonObject data): +MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) { ui->setupUi(this); friendlayout = new QVBoxLayout(ui->frd); grouplayout = new QVBoxLayout(ui->group); +} + +void MainWindow::setup(QJsonObject data) { ui->userNameShow->setText(data["Username"].toString()); username = data["Username"].toString(); ui->nickNameShow->setText(data["Nickname"].toString()); diff --git a/Client/mainwindow.h b/Client/mainwindow.h index 46dbc53d8ed4b416ff81e06108442ee7687f678c..e031ce5b0329f11b5bf6a37df050f4379ef5a724 100644 --- a/Client/mainwindow.h +++ b/Client/mainwindow.h @@ -22,8 +22,9 @@ class MainWindow : public QMainWindow public: explicit MainWindow(QWidget *parent = nullptr); - MainWindow(QWidget *parent,QJsonObject data); + // MainWindow(QWidget *parent,QJsonObject data); ~MainWindow(); + void setup(QJsonObject data); void FriendSessionAdd(QJsonObject data); void GroupSessionAdd(QJsonObject data); QVBoxLayout *friendlayout; diff --git a/Client/mainwindow.ui b/Client/mainwindow.ui index de275638539310ce15685d0700442786e6c50a27..3f81eaaeafb9bf2ce6047e211689d1433acffd4b 100644 --- a/Client/mainwindow.ui +++ b/Client/mainwindow.ui @@ -1,581 +1,581 @@ - - - MainWindow - - - - 0 - 0 - 810 - 642 - - - - BICQ - - - - - 0 - - - 0 - - - 0 - - - 0 - - - - - QTabWidget::West - - - QTabWidget::Rounded - - - 0 - - - false - - - false - - - - Personal - - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - - 400 - 400 - - - - - - 100 - 200 - 200 - 40 - - - - QFrame::Box - - - - - - - - - 100 - 270 - 200 - 40 - - - - QFrame::Box - - - nick name - - - - - - 100 - 130 - 200 - 40 - - - - QFrame::Box - - - user name - - - - - - 150 - 0 - 120 - 120 - - - - profile photo - - - - - - 100 - 330 - 200 - 40 - - - - - - - - - Qt::Vertical - - - - 20 - 40 - - - - - - - - Qt::Vertical - - - - 20 - 40 - - - - - - - - - Chatting - - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - 570 - 0 - - - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - 370 - 30 - - - - - 16777215 - 30 - - - - QFrame::NoFrame - - - QFrame::Plain - - - - - - - - 0 - 0 - - - - - 370 - 350 - - - - QListWidget{background-color: rgb(247, 247, 247); color:rgb(51,51,51); border: 1px solid rgb(247, 247, 247);outline:0px;} -QListWidget::Item{background-color: rgb(247, 247, 247);} -QListWidget::Item:hover{background-color: rgb(247, 247, 247); } -QListWidget::item:selected{ - background-color: rgb(247, 247, 247); - color:black; - border: 1px solid rgb(247, 247, 247); -} -QListWidget::item:selected:!active{border: 1px solid rgb(247, 247, 247); background-color: rgb(247, 247, 247); color:rgb(51,51,51); } - - - QFrame::NoFrame - - - Qt::ScrollBarAsNeeded - - - Qt::ScrollBarAlwaysOff - - - - - - - - 0 - 40 - - - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - - 0 - 30 - - - - send - - - - - - - Qt::Horizontal - - - QSizePolicy::Fixed - - - - 40 - 20 - - - - - - - - - - - - 370 - 150 - - - - - 16777215 - 150 - - - - QFrame::NoFrame - - - QFrame::Plain - - - - - - - - - - - 200 - 0 - - - - - 200 - 16777215 - - - - 0 - - - - - 0 - 0 - 200 - 551 - - - - friend - - - - - - - 0 - 0 - 200 - 551 - - - - group - - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - - 0 - 0 - 200 - 551 - - - - top session - - - - 0 - - - 0 - - - 0 - - - 0 - - - - - - - - - - Search - - - - 0 - - - 0 - - - 0 - - - 0 - - - 0 - - - - - Qt::Vertical - - - - 20 - 40 - - - - - - - - - 500 - 40 - - - - - - - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - Qt::Vertical - - - - 20 - 40 - - - - - - - - - - - - - - + + + MainWindow + + + + 0 + 0 + 810 + 642 + + + + BICQ + + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + QTabWidget::West + + + QTabWidget::Rounded + + + 0 + + + false + + + false + + + + Personal + + + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + 400 + 400 + + + + + + 100 + 200 + 200 + 40 + + + + QFrame::Box + + + + + + + + + 100 + 270 + 200 + 40 + + + + QFrame::Box + + + nick name + + + + + + 100 + 130 + 200 + 40 + + + + QFrame::Box + + + user name + + + + + + 150 + 0 + 120 + 120 + + + + profile photo + + + + + + 100 + 330 + 200 + 40 + + + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + Chatting + + + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + 570 + 0 + + + + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + 370 + 30 + + + + + 16777215 + 30 + + + + QFrame::NoFrame + + + QFrame::Plain + + + + + + + + 0 + 0 + + + + + 370 + 350 + + + + QListWidget{background-color: rgb(247, 247, 247); color:rgb(51,51,51); border: 1px solid rgb(247, 247, 247);outline:0px;} +QListWidget::Item{background-color: rgb(247, 247, 247);} +QListWidget::Item:hover{background-color: rgb(247, 247, 247); } +QListWidget::item:selected{ + background-color: rgb(247, 247, 247); + color:black; + border: 1px solid rgb(247, 247, 247); +} +QListWidget::item:selected:!active{border: 1px solid rgb(247, 247, 247); background-color: rgb(247, 247, 247); color:rgb(51,51,51); } + + + QFrame::NoFrame + + + Qt::ScrollBarAsNeeded + + + Qt::ScrollBarAlwaysOff + + + + + + + + 0 + 40 + + + + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + 0 + 30 + + + + send + + + + + + + Qt::Horizontal + + + QSizePolicy::Fixed + + + + 40 + 20 + + + + + + + + + + + + 370 + 150 + + + + + 16777215 + 150 + + + + QFrame::NoFrame + + + QFrame::Plain + + + + + + + + + + + 200 + 0 + + + + + 200 + 16777215 + + + + 0 + + + + + 0 + 0 + 200 + 551 + + + + friend + + + + + + + 0 + 0 + 200 + 551 + + + + group + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + + 0 + 0 + 200 + 551 + + + + top session + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + + + + + Search + + + + 0 + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + 500 + 40 + + + + + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + + + + + + diff --git a/Client/message.cpp b/Client/message.cpp index 852cc6f191240088af26e278da818254f9998514..c5e2d0b0087101b19ff9cedaa95a19c7e8136831 100644 --- a/Client/message.cpp +++ b/Client/message.cpp @@ -1,212 +1,212 @@ -#include "message.h" -#include -#include -#include -Message::Message(QWidget *parent) : QWidget(parent) -{ - //设置字体 - QFont textFont("MicrosoftYaHei", 12); - this->setFont(textFont); - //缓冲按钮暂定 -} -void Message::setTextContent(QString text, QString time, QSize size, UserType type) -{ - m_msg = text; - m_userType = type; - m_time = time; - m_currentTime = QDateTime::fromTime_t(time.toInt()).toString("hh:mm");//时间戳 自1970年过去的秒数 - m_wholeSize = size; - - //自定义头像 next - m_meRightIcon = QPixmap(":/img/Image/Luffy.png"); - m_otherLeftIcon = QPixmap(":/img/Image/OnePiece.png"); - //需增加名称显示 next -} -QSize Message::setSize(QString str) -{ - int minHeight = 30; //聊天气泡最小高度 - int iconWidth = 40; //头像宽度 同头像长度 - int iconSpaceWidth = 20; //头像与聊天界面框 长度距离 - int iconSpaceHeight= 10; //头像与聊天界面框 高度距离 - int iconRectWidth = 5; //头像与小三角的距离 - int triWidth = 6; //小三角长度 - int kuangTMP = 20; //?????? - int textSpaceWidth = 12; //聊天气泡中 文本距单侧气泡框空白宽度 - - m_msg = str; - m_kuangWidth = this->width() - kuangTMP - 2 * (iconWidth + iconSpaceWidth + iconRectWidth); - m_textWidth = m_kuangWidth - 2 * textSpaceWidth; - m_spaceWidth = this->width() - m_textWidth; - m_iconLeftRect = QRect(iconSpaceWidth, iconSpaceHeight, iconWidth, iconWidth); - m_iconRightRect = QRect(this->width() - iconSpaceWidth - iconWidth, iconSpaceHeight, iconWidth, iconWidth); - - QSize size = getStringSize(m_msg); - int height = size.height() < minHeight ? minHeight : size.height(); - - m_triLeftRect = QRect(iconSpaceWidth + iconWidth + iconRectWidth, m_lineHeight / 2, triWidth, height - m_lineHeight);//??? - m_triRightRect = QRect(this->width() - iconSpaceWidth - iconWidth - iconRectWidth - triWidth, m_lineHeight / 2, triWidth, height - m_lineHeight); - - if(size.width() < (m_textWidth + m_spaceWidth)) - { - m_kuangLeftRect.setRect(m_triLeftRect.x() + m_triLeftRect.width(), m_lineHeight / 4 * 3, size.width() - m_spaceWidth + 2 * textSpaceWidth, height - m_lineHeight); - m_kuangRightRect.setRect(this->width() - size.width() + m_spaceWidth - 2 * textSpaceWidth - iconWidth - iconSpaceWidth - iconRectWidth - triWidth, - m_lineHeight / 4 * 3, size.width() - m_spaceWidth + 2 * textSpaceWidth, height - m_lineHeight); - } - else - { - m_kuangLeftRect.setRect(m_triLeftRect.x() + m_triLeftRect.width(), m_lineHeight / 4 * 3, m_kuangWidth, height - m_lineHeight); - m_kuangRightRect.setRect(iconWidth + kuangTMP + iconSpaceWidth + iconRectWidth - triWidth, m_lineHeight/4*3, m_kuangWidth, height-m_lineHeight); - } - - m_textLeftRect.setRect(m_kuangLeftRect.x() + textSpaceWidth, m_kuangLeftRect.y() + iconSpaceHeight, - m_kuangLeftRect.width() - 2 * textSpaceWidth, m_kuangLeftRect.height() - 2 * iconSpaceHeight); - m_textRightRect.setRect(m_kuangRightRect.x() + textSpaceWidth, m_kuangRightRect.y() + iconSpaceHeight, - m_kuangRightRect.width() - 2 * textSpaceWidth, m_kuangRightRect.height() -2 * iconSpaceHeight); - - return QSize(size.width(), height); -} -QSize Message::getStringSize(QString str) -{ - QFontMetricsF fm(this->font()); - m_lineHeight = fm.lineSpacing(); //行间距 - int nCount = str.count("\n"); //\n数量 - int textMaxWidth; - - if (nCount == 0) - {//实际文本无换行 - textMaxWidth = fm.width(str);//返回给定文本中字符的宽度 - if (textMaxWidth > m_textWidth) - {//实际文本宽度大于当前可实现文本宽度,根据当前文本宽度重新编辑文本 \n - textMaxWidth = m_textWidth; - int size = m_textWidth / fm.width(" ");//每行字符数 - int num = fm.width(str) / m_textWidth;//需换行数 - nCount += num; - QString strAfter = ""; - for (int i = 0; i < num ; i++) - { - strAfter += str.mid(i * size, (i + 1) * size) + "\n"; - } - str = strAfter; - } - } - else - {//实际文本有换行 - for (int i = 0; i <= nCount; i++) - { - QString strSplit = str.split("\n").at(i);//依据实际文本的换行符进行分割 - textMaxWidth = fm.width(strSplit) > textMaxWidth ? fm.width(strSplit) : textMaxWidth; - if (fm.width(strSplit) > m_textWidth) - { - textMaxWidth = m_textWidth; - int size = m_textWidth / fm.width(" ");//每行字符数 - int num = fm.width(strSplit) / m_textWidth;//需换行数 - num = ((i + num) * fm.width(" ") + fm.width(strSplit)) / m_textWidth; //??? - nCount += num; - QString strAfter = ""; - for (int i = 0; i < num; i++) - { - strAfter += strSplit.mid(i * size, (i + 1) * size) + "\n"; - } - str.replace(strSplit, strAfter);//需改进 - } - } - } - - //换行效果需增强 - return QSize(textMaxWidth + m_spaceWidth, (nCount + 1) * m_lineHeight + 2 * m_lineHeight); -} -void Message::paintEvent(QPaintEvent *event) -{ - Q_UNUSED(event); - - QPainter painter(this); - painter.setRenderHints(QPainter::Antialiasing | QPainter::SmoothPixmapTransform);//消锯齿 - painter.setPen(Qt::NoPen);//无线 - painter.setBrush(QBrush(Qt::gray));//形状填充为 灰色 纯色图案 - - if(m_userType == UserType::userOther) { // 聊天对象 - //放置头像 - painter.drawPixmap(m_iconLeftRect, m_otherLeftIcon); - - //框加边 - QColor col_KuangB(234, 234, 234); - painter.setBrush(QBrush(col_KuangB)); - painter.drawRoundedRect(m_kuangLeftRect.x() - 1,m_kuangLeftRect.y() - 1,m_kuangLeftRect.width()+2,m_kuangLeftRect.height()+2,4,4);//圆角框 - - //框 气泡 - QColor col_Kuang(255,255,255); - painter.setBrush(QBrush(col_Kuang)); - painter.drawRoundedRect(m_kuangLeftRect,4,4); - - //三角 - QPointF points[3] = { - QPointF(m_triLeftRect.x(), 30),// - QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 25), - QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 35), - }; - QPen pen; - pen.setColor(col_Kuang); - painter.setPen(pen); - painter.drawPolygon(points, 3);//画多边形 - - //三角加边 - QPen penSanJiaoBian; - penSanJiaoBian.setColor(col_KuangB); - painter.setPen(penSanJiaoBian); - painter.drawLine(QPointF(m_triLeftRect.x() - 1, 30), QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 24)); - painter.drawLine(QPointF(m_triLeftRect.x() - 1, 30), QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 36)); - - //内容 - QPen penText; - penText.setColor(QColor(51,51,51)); - painter.setPen(penText); - QTextOption option(Qt::AlignLeft | Qt::AlignVCenter);//左对齐 中心水平对齐 - option.setWrapMode(QTextOption::WrapAtWordBoundaryOrAnywhere);//包围字体 - painter.setFont(this->font()); - painter.drawText(m_textLeftRect, m_msg,option);//写文本 - } else if(m_userType == UserType::userMe) { // 自己 - //头像 - - - painter.drawPixmap(m_iconRightRect, m_meRightIcon); - qDebug() << "this->width()" << this->width(); - qDebug() << "x" << m_iconRightRect.x(); - qDebug() << "this->width()" << this->width(); - - //框 - QColor col_Kuang(75,164,242); - painter.setBrush(QBrush(col_Kuang)); - painter.drawRoundedRect(m_kuangRightRect,4,4); - - //三角 - QPointF points[3] = { - QPointF(m_triRightRect.x()+m_triRightRect.width(), 30), - QPointF(m_triRightRect.x(), 25), - QPointF(m_triRightRect.x(), 35), - }; - QPen pen; - pen.setColor(col_Kuang); - painter.setPen(pen); - painter.drawPolygon(points, 3); - - //内容 - QPen penText; - penText.setColor(Qt::white); - painter.setPen(penText); - QTextOption option(Qt::AlignLeft | Qt::AlignVCenter); - option.setWrapMode(QTextOption::WrapAtWordBoundaryOrAnywhere); - painter.setFont(this->font()); - painter.drawText(m_textRightRect,m_msg,option); - } else if(m_userType == UserType::userTime) { // 时间 - QPen penText; - penText.setColor(QColor(153,153,153)); - painter.setPen(penText); - QTextOption option(Qt::AlignCenter); - option.setWrapMode(QTextOption::WrapAtWordBoundaryOrAnywhere); - QFont te_font = this->font(); - te_font.setFamily("MicrosoftYaHei"); - te_font.setPointSize(10); - painter.setFont(te_font); - painter.drawText(this->rect(),m_currentTime,option); - } -} +#include "message.h" +#include +#include +#include +Message::Message(QWidget *parent) : QWidget(parent) +{ + //设置字体 + QFont textFont("MicrosoftYaHei", 12); + this->setFont(textFont); + //缓冲按钮暂定 +} +void Message::setTextContent(QString text, QString time, QSize size, UserType type) +{ + m_msg = text; + m_userType = type; + m_time = time; + m_currentTime = QDateTime::fromTime_t(time.toInt()).toString("hh:mm");//时间戳 自1970年过去的秒数 + m_wholeSize = size; + + //自定义头像 next + m_meRightIcon = QPixmap(":/img/Image/Luffy.png"); + m_otherLeftIcon = QPixmap(":/img/Image/OnePiece.png"); + //需增加名称显示 next +} +QSize Message::setSize(QString str) +{ + int minHeight = 30; //聊天气泡最小高度 + int iconWidth = 40; //头像宽度 同头像长度 + int iconSpaceWidth = 20; //头像与聊天界面框 长度距离 + int iconSpaceHeight= 10; //头像与聊天界面框 高度距离 + int iconRectWidth = 5; //头像与小三角的距离 + int triWidth = 6; //小三角长度 + int kuangTMP = 20; //?????? + int textSpaceWidth = 12; //聊天气泡中 文本距单侧气泡框空白宽度 + + m_msg = str; + m_kuangWidth = this->width() - kuangTMP - 2 * (iconWidth + iconSpaceWidth + iconRectWidth); + m_textWidth = m_kuangWidth - 2 * textSpaceWidth; + m_spaceWidth = this->width() - m_textWidth; + m_iconLeftRect = QRect(iconSpaceWidth, iconSpaceHeight, iconWidth, iconWidth); + m_iconRightRect = QRect(this->width() - iconSpaceWidth - iconWidth, iconSpaceHeight, iconWidth, iconWidth); + + QSize size = getStringSize(m_msg); + int height = size.height() < minHeight ? minHeight : size.height(); + + m_triLeftRect = QRect(iconSpaceWidth + iconWidth + iconRectWidth, m_lineHeight / 2, triWidth, height - m_lineHeight);//??? + m_triRightRect = QRect(this->width() - iconSpaceWidth - iconWidth - iconRectWidth - triWidth, m_lineHeight / 2, triWidth, height - m_lineHeight); + + if(size.width() < (m_textWidth + m_spaceWidth)) + { + m_kuangLeftRect.setRect(m_triLeftRect.x() + m_triLeftRect.width(), m_lineHeight / 4 * 3, size.width() - m_spaceWidth + 2 * textSpaceWidth, height - m_lineHeight); + m_kuangRightRect.setRect(this->width() - size.width() + m_spaceWidth - 2 * textSpaceWidth - iconWidth - iconSpaceWidth - iconRectWidth - triWidth, + m_lineHeight / 4 * 3, size.width() - m_spaceWidth + 2 * textSpaceWidth, height - m_lineHeight); + } + else + { + m_kuangLeftRect.setRect(m_triLeftRect.x() + m_triLeftRect.width(), m_lineHeight / 4 * 3, m_kuangWidth, height - m_lineHeight); + m_kuangRightRect.setRect(iconWidth + kuangTMP + iconSpaceWidth + iconRectWidth - triWidth, m_lineHeight/4*3, m_kuangWidth, height-m_lineHeight); + } + + m_textLeftRect.setRect(m_kuangLeftRect.x() + textSpaceWidth, m_kuangLeftRect.y() + iconSpaceHeight, + m_kuangLeftRect.width() - 2 * textSpaceWidth, m_kuangLeftRect.height() - 2 * iconSpaceHeight); + m_textRightRect.setRect(m_kuangRightRect.x() + textSpaceWidth, m_kuangRightRect.y() + iconSpaceHeight, + m_kuangRightRect.width() - 2 * textSpaceWidth, m_kuangRightRect.height() -2 * iconSpaceHeight); + + return QSize(size.width(), height); +} +QSize Message::getStringSize(QString str) +{ + QFontMetricsF fm(this->font()); + m_lineHeight = fm.lineSpacing(); //行间距 + int nCount = str.count("\n"); //\n数量 + int textMaxWidth; + + if (nCount == 0) + {//实际文本无换行 + textMaxWidth = fm.width(str);//返回给定文本中字符的宽度 + if (textMaxWidth > m_textWidth) + {//实际文本宽度大于当前可实现文本宽度,根据当前文本宽度重新编辑文本 \n + textMaxWidth = m_textWidth; + int size = m_textWidth / fm.width(" ");//每行字符数 + int num = fm.width(str) / m_textWidth;//需换行数 + nCount += num; + QString strAfter = ""; + for (int i = 0; i < num ; i++) + { + strAfter += str.mid(i * size, (i + 1) * size) + "\n"; + } + str = strAfter; + } + } + else + {//实际文本有换行 + for (int i = 0; i <= nCount; i++) + { + QString strSplit = str.split("\n").at(i);//依据实际文本的换行符进行分割 + textMaxWidth = fm.width(strSplit) > textMaxWidth ? fm.width(strSplit) : textMaxWidth; + if (fm.width(strSplit) > m_textWidth) + { + textMaxWidth = m_textWidth; + int size = m_textWidth / fm.width(" ");//每行字符数 + int num = fm.width(strSplit) / m_textWidth;//需换行数 + num = ((i + num) * fm.width(" ") + fm.width(strSplit)) / m_textWidth; //??? + nCount += num; + QString strAfter = ""; + for (int i = 0; i < num; i++) + { + strAfter += strSplit.mid(i * size, (i + 1) * size) + "\n"; + } + str.replace(strSplit, strAfter);//需改进 + } + } + } + + //换行效果需增强 + return QSize(textMaxWidth + m_spaceWidth, (nCount + 1) * m_lineHeight + 2 * m_lineHeight); +} +void Message::paintEvent(QPaintEvent *event) +{ + Q_UNUSED(event); + + QPainter painter(this); + painter.setRenderHints(QPainter::Antialiasing | QPainter::SmoothPixmapTransform);//消锯齿 + painter.setPen(Qt::NoPen);//无线 + painter.setBrush(QBrush(Qt::gray));//形状填充为 灰色 纯色图案 + + if(m_userType == UserType::userOther) { // 聊天对象 + //放置头像 + painter.drawPixmap(m_iconLeftRect, m_otherLeftIcon); + + //框加边 + QColor col_KuangB(234, 234, 234); + painter.setBrush(QBrush(col_KuangB)); + painter.drawRoundedRect(m_kuangLeftRect.x() - 1,m_kuangLeftRect.y() - 1,m_kuangLeftRect.width()+2,m_kuangLeftRect.height()+2,4,4);//圆角框 + + //框 气泡 + QColor col_Kuang(255,255,255); + painter.setBrush(QBrush(col_Kuang)); + painter.drawRoundedRect(m_kuangLeftRect,4,4); + + //三角 + QPointF points[3] = { + QPointF(m_triLeftRect.x(), 30),// + QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 25), + QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 35), + }; + QPen pen; + pen.setColor(col_Kuang); + painter.setPen(pen); + painter.drawPolygon(points, 3);//画多边形 + + //三角加边 + QPen penSanJiaoBian; + penSanJiaoBian.setColor(col_KuangB); + painter.setPen(penSanJiaoBian); + painter.drawLine(QPointF(m_triLeftRect.x() - 1, 30), QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 24)); + painter.drawLine(QPointF(m_triLeftRect.x() - 1, 30), QPointF(m_triLeftRect.x()+m_triLeftRect.width(), 36)); + + //内容 + QPen penText; + penText.setColor(QColor(51,51,51)); + painter.setPen(penText); + QTextOption option(Qt::AlignLeft | Qt::AlignVCenter);//左对齐 中心水平对齐 + option.setWrapMode(QTextOption::WrapAtWordBoundaryOrAnywhere);//包围字体 + painter.setFont(this->font()); + painter.drawText(m_textLeftRect, m_msg,option);//写文本 + } else if(m_userType == UserType::userMe) { // 自己 + //头像 + + + painter.drawPixmap(m_iconRightRect, m_meRightIcon); + qDebug() << "this->width()" << this->width(); + qDebug() << "x" << m_iconRightRect.x(); + qDebug() << "this->width()" << this->width(); + + //框 + QColor col_Kuang(75,164,242); + painter.setBrush(QBrush(col_Kuang)); + painter.drawRoundedRect(m_kuangRightRect,4,4); + + //三角 + QPointF points[3] = { + QPointF(m_triRightRect.x()+m_triRightRect.width(), 30), + QPointF(m_triRightRect.x(), 25), + QPointF(m_triRightRect.x(), 35), + }; + QPen pen; + pen.setColor(col_Kuang); + painter.setPen(pen); + painter.drawPolygon(points, 3); + + //内容 + QPen penText; + penText.setColor(Qt::white); + painter.setPen(penText); + QTextOption option(Qt::AlignLeft | Qt::AlignVCenter); + option.setWrapMode(QTextOption::WrapAtWordBoundaryOrAnywhere); + painter.setFont(this->font()); + painter.drawText(m_textRightRect,m_msg,option); + } else if(m_userType == UserType::userTime) { // 时间 + QPen penText; + penText.setColor(QColor(153,153,153)); + painter.setPen(penText); + QTextOption option(Qt::AlignCenter); + option.setWrapMode(QTextOption::WrapAtWordBoundaryOrAnywhere); + QFont te_font = this->font(); + te_font.setFamily("MicrosoftYaHei"); + te_font.setPointSize(10); + painter.setFont(te_font); + painter.drawText(this->rect(),m_currentTime,option); + } +} diff --git a/Client/message.h b/Client/message.h index 752b77173c95ffa44b92c44e8a552c297322b4dd..90a6b8625eff1671cc5d62d87429e06dfe1d9b05 100644 --- a/Client/message.h +++ b/Client/message.h @@ -1,70 +1,70 @@ -#ifndef MESSAGE_H -#define MESSAGE_H - -#include - -class Message : public QWidget -{ - Q_OBJECT -public: - explicit Message(QWidget *parent = nullptr); - - //发送者类别 - enum UserType - { - userMe, //自己 - userOther, //他人 - userTime, //时间 - }; - - //设置基本属性 - void setTextContent(QString txt, QString t, QSize allSize, UserType type); - //计算聊天气泡的size - QSize setSize(QString str); - //设置文本动态换行 并获取发送文字的size - QSize getStringSize(QString str); - //画图事件 - void paintEvent(QPaintEvent *event); - - inline UserType userType() - { - return m_userType; - } - inline QString text() - { - return m_msg; - } - inline QString time() - { - return m_time; - } - -private: - UserType m_userType; - - QSize m_wholeSize; //总规格 ??? - QString m_msg; - QString m_time; - QString m_currentTime; - - QPixmap m_meRightIcon;//自己 右头像 - QPixmap m_otherLeftIcon;//他人 左头像 - - QRect m_iconRightRect;//左头像 所在矩形 - QRect m_iconLeftRect;//右头像 所在矩形 - QRect m_triRightRect;//右三角 - QRect m_triLeftRect;//左三角 - QRect m_kuangLeftRect;// ??? - QRect m_kuangRightRect;// ??? - QRect m_textLeftRect;//??? - QRect m_textRightRect;//??? - - int m_kuangWidth;//聊天框宽度 ??? - int m_textWidth;//聊天气泡内文本宽度 - int m_spaceWidth;//??? - int m_lineHeight;//基线之间的距离 ??? -signals: - -}; - -#endif // MESSAGE_H +#ifndef MESSAGE_H +#define MESSAGE_H + +#include + +class Message : public QWidget +{ + Q_OBJECT +public: + explicit Message(QWidget *parent = nullptr); + + //发送者类别 + enum UserType + { + userMe, //自己 + userOther, //他人 + userTime, //时间 + }; + + //设置基本属性 + void setTextContent(QString txt, QString t, QSize allSize, UserType type); + //计算聊天气泡的size + QSize setSize(QString str); + //设置文本动态换行 并获取发送文字的size + QSize getStringSize(QString str); + //画图事件 + void paintEvent(QPaintEvent *event); + + inline UserType userType() + { + return m_userType; + } + inline QString text() + { + return m_msg; + } + inline QString time() + { + return m_time; + } + +private: + UserType m_userType; + + QSize m_wholeSize; //总规格 ??? + QString m_msg; + QString m_time; + QString m_currentTime; + + QPixmap m_meRightIcon;//自己 右头像 + QPixmap m_otherLeftIcon;//他人 左头像 + + QRect m_iconRightRect;//左头像 所在矩形 + QRect m_iconLeftRect;//右头像 所在矩形 + QRect m_triRightRect;//右三角 + QRect m_triLeftRect;//左三角 + QRect m_kuangLeftRect;// ??? + QRect m_kuangRightRect;// ??? + QRect m_textLeftRect;//??? + QRect m_textRightRect;//??? + + int m_kuangWidth;//聊天框宽度 ??? + int m_textWidth;//聊天气泡内文本宽度 + int m_spaceWidth;//??? + int m_lineHeight;//基线之间的距离 ??? +signals: + +}; + +#endif // MESSAGE_H diff --git a/Client/message.ui b/Client/message.ui index dd2ffdd1c3e98c2474796fe39e894be0764a9877..852d53907862053818b254fb7b6e33ffd7f0f26b 100644 --- a/Client/message.ui +++ b/Client/message.ui @@ -1,45 +1,45 @@ - - - Message - - - - 0 - 0 - 390 - 152 - - - - Form - - - - - 50 - 30 - 101 - 101 - - - - TextLabel - - - - - - 190 - 70 - 131 - 16 - - - - TextLabel - - - - - - + + + Message + + + + 0 + 0 + 390 + 152 + + + + Form + + + + + 50 + 30 + 101 + 101 + + + + TextLabel + + + + + + 190 + 70 + 131 + 16 + + + + TextLabel + + + + + + diff --git a/Client/messagemodel.h b/Client/messagemodel.h index a160369d7bfdc2684c7aa4a6ad7ebf85686094bf..9b68c6276d3263c83067731efce6f9289e66a955 100644 --- a/Client/messagemodel.h +++ b/Client/messagemodel.h @@ -34,6 +34,7 @@ protected: int sessionID; QString text; QJsonObject profile; + }; class OnlineMessage : public MessageModel diff --git a/Client/rsc.qrc b/Client/rsc.qrc index f793e769ea534498754c98b44a0a2547307069aa..bb0c06fb64c84e4f426de300bf0536769987f94e 100644 --- a/Client/rsc.qrc +++ b/Client/rsc.qrc @@ -1,15 +1,15 @@ - - - Image/butterfly.png - Image/butterfly1.png - Image/down.png - Image/Frame.jpg - Image/Luffy.png - Image/LuffyQ.png - Image/mario.gif - Image/OnePiece.png - Image/Sunny.jpg - Image/sunny.png - Image/up.png - - + + + Image/butterfly.png + Image/butterfly1.png + Image/down.png + Image/Frame.jpg + Image/Luffy.png + Image/LuffyQ.png + Image/mario.gif + Image/OnePiece.png + Image/Sunny.jpg + Image/sunny.png + Image/up.png + + diff --git a/Client/userlogin.cpp b/Client/userlogin.cpp index b1d64ea4e83f94a8789bfd32eb9f710de7680103..dbfadfcb3694cf5c75c338243828b27af029d9ba 100644 --- a/Client/userlogin.cpp +++ b/Client/userlogin.cpp @@ -1,12 +1,16 @@ #include "userlogin.h" #include "ui_userlogin.h" -#include "clientmain.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + -#include -#include -#include -#include -#include UserLogin::UserLogin(QWidget *parent) : QWidget(parent) @@ -14,69 +18,28 @@ UserLogin::UserLogin(QWidget *parent) { ui->setupUi(this); - ClientMain & client = ClientMain::Singleton(); - //点击注册跳转到注册页面 connect(ui->btnRegister,&QPushButton::clicked,[=](){ this->hide(); emit registerButtonClicked(); }); - //注册完成后显示登录界面 -// connect(regist,&UserRegister::registfinished,this,&UserLogin::show); - - connect(&client, &ClientMain::serverConnected, this, [=]() { - this->ui->btnLogIn->setEnabled(true); - this->ui->btnRegister->setEnabled(true); - }); - - connect(&client, &ClientMain::serverDisconnected, this, [=]() { - this->ui->btnLogIn->setEnabled(false); - this->ui->btnRegister->setEnabled(false); - }); //点击登录发送登录信息 connect(ui->btnLogIn,&QPushButton::clicked,[=](){ - qDebug() << "Login Button Clicked"; - QJsonObject login = { {"Username",ui->lineEditUserName->text()},{"Password",ui->lineEditPassword->text()}}; - //send发送登录信息 -// send() + QJsonObject login = { {"MsgType","LogIn"},{"Username",ui->lineEditUserName->text()},{"Password",ui->lineEditPassword->text()}}; + emit sendlogindata(login); }); + main = new MainWindow(this); } -void UserLogin::login(QJsonObject data){ +void UserLogin::loginconfirm(QJsonObject data){ if(data["IsLegal"].toBool()==false){ QMessageBox::critical(this,"Error!","用户名或密码有误"); } else{ this->hide(); -// main = new MainWindow(this,data); -// main->show(); -// connect(this,&UserLogin::FriendSessionDataReceived,this->main,&MainWindow::FriendSessionAdd); -// connect(this,&UserLogin::GroupSessionDataReceived,this->main,&MainWindow::GroupSessionAdd); -// connect(this,&UserLogin::SessionMessageReceived,this->main,&MainWindow::AddMessage); - } -} - -void UserLogin::receivemessage(QJsonObject data){ - if(data["MsgType"].toString()=="UserData"){ - emit UserDataReceived(data); - } - if(data["MsgType"].toString()=="LogInConfirm"){ - this->login(data); - } -// if(data["MsgType"].toString()=="RegistConfirm"){ -// regist->registconfirm(data); -// } - if(data["MsgType"].toString()=="SessionMessage"){ - emit SessionMessageReceived(data); - } - if(data["MsgType"].toString()=="SessionData"){ - if(data["SessionType"].toString()=="FRIEND"){ - emit FriendSessionDataReceived(data); - } - if(data["SessionType"].toString()=="GROUP"){ - emit GroupSessionDataReceived(data); - } + main->setup(data); + main->show(); } } @@ -84,3 +47,4 @@ UserLogin::~UserLogin() { delete ui; } + diff --git a/Client/userlogin.h b/Client/userlogin.h index d3f8229986bbce7b889326e3bee3e3c061114e27..21fc7ab4176d27e6bed0042ece4067c8cd4b85a0 100644 --- a/Client/userlogin.h +++ b/Client/userlogin.h @@ -2,8 +2,10 @@ #define USERLOGIN_H #include -#include -#include +#include +#include +#include +#include QT_BEGIN_NAMESPACE namespace Ui { class UserLogin; } @@ -14,29 +16,13 @@ class UserLogin : public QWidget Q_OBJECT public: + MainWindow *main = nullptr; UserLogin(QWidget *parent = nullptr); ~UserLogin(); - void receivemessage(QJsonObject data); - void login(QJsonObject data); - //打包login请求 - QJsonObject wrapLoginRequest(); - //连接功能实现 - void connectserver(); - + void loginconfirm(QJsonObject data); signals: void registerButtonClicked(); - - - void UserDataReceived(QJsonObject data); - void LogInConfirmReceived(QJsonObject data); - void SessionMessageReceived(QJsonObject data); - void FriendSessionDataReceived(QJsonObject data); - void GroupSessionDataReceived(QJsonObject data); - void RegistConfirmReceived(QJsonObject data); - -public slots: - - + void sendlogindata(QJsonObject data); private: Ui::UserLogin *ui; }; diff --git a/Client/userlogin.ui b/Client/userlogin.ui index fea7c7b63e40f150df1a7c4807d6f8bcce5590b7..7131e2e7f85e765bf0f682e5d4193e801b028675 100644 --- a/Client/userlogin.ui +++ b/Client/userlogin.ui @@ -155,9 +155,6 @@ - - false - 登录 @@ -178,9 +175,6 @@ - - false - 注册 diff --git a/Client/userregister.cpp b/Client/userregister.cpp index b564f5d3a7c245d31707447e61c030d03bff8958..00314df5fd56ea898c7aa4e2da2edf6ad0004ad8 100644 --- a/Client/userregister.cpp +++ b/Client/userregister.cpp @@ -32,10 +32,11 @@ UserRegister::UserRegister(QWidget *parent) : } if(islegal){ //发送注册信息 - registration_info.insert("username",ui->lERgUserName->text()); - registration_info.insert("nickname",ui->lERgNickName->text()); - registration_info.insert("password",ui->lERgPassword1->text()); - //send + registration_info.insert("Username",ui->lERgUserName->text()); + registration_info.insert("Nickname",ui->lERgNickName->text()); + registration_info.insert("Password",ui->lERgPassword1->text()); + registration_info.insert("MsgType","Regist"); + sendregistdata(registration_info); } }); } diff --git a/Client/userregister.h b/Client/userregister.h index e1231bf90c3ca0b159a4ec89438d83c305361dee..7159159c357800a0ba688bdaba37f5017bd72346 100644 --- a/Client/userregister.h +++ b/Client/userregister.h @@ -3,6 +3,7 @@ #include #include +#include namespace Ui { class UserRegister; @@ -19,7 +20,7 @@ public: signals: void registfinished(); - + void sendregistdata(QJsonObject data); private: Ui::UserRegister *ui; }; diff --git a/Client/userregister.ui b/Client/userregister.ui index 7ffabdacf4f3b5595fc1d66d46579a93937296f5..4e1ef100d14c5702c2cafe0599f2a6ac1ad1b392 100644 --- a/Client/userregister.ui +++ b/Client/userregister.ui @@ -1,100 +1,100 @@ - - - UserRegister - - - - 0 - 0 - 400 - 300 - - - - - 400 - 300 - - - - - 400 - 300 - - - - Form - - - - - 90 - 70 - 226 - 128 - - - - - - - 用户名 - - - - - - - - - - 昵称 - - - - - - - QLineEdit::Password - - - - - - - 密码 - - - - - - - - - - 确认密码 - - - - - - - - - - - - 170 - 220 - 84 - 24 - - - - 注册 - - - - - - + + + UserRegister + + + + 0 + 0 + 400 + 300 + + + + + 400 + 300 + + + + + 400 + 300 + + + + Form + + + + + 90 + 70 + 226 + 128 + + + + + + + 用户名 + + + + + + + + + + 昵称 + + + + + + + QLineEdit::Password + + + + + + + 密码 + + + + + + + + + + 确认密码 + + + + + + + + + + + + 170 + 220 + 84 + 24 + + + + 注册 + + + + + + diff --git a/Database_Lyh/Database_Lyh.pro b/Database_Lyh/Database_Lyh.pro new file mode 100644 index 0000000000000000000000000000000000000000..eb7a5ba37e7786a47cf9faa00f449545f77b74f4 --- /dev/null +++ b/Database_Lyh/Database_Lyh.pro @@ -0,0 +1,31 @@ +QT += core gui sql + +greaterThan(QT_MAJOR_VERSION, 4): QT += widgets + +CONFIG += c++11 + +# The following define makes your compiler emit warnings if you use +# any Qt feature that has been marked deprecated (the exact warnings +# depend on your compiler). Please consult the documentation of the +# deprecated API in order to know how to port your code away from it. +DEFINES += QT_DEPRECATED_WARNINGS + +# You can also make your code fail to compile if it uses deprecated APIs. +# In order to do so, uncomment the following line. +# You can also select to disable deprecated APIs only up to a certain version of Qt. +#DEFINES += QT_DISABLE_DEPRECATED_BEFORE=0x060000 # disables all the APIs deprecated before Qt 6.0.0 + +SOURCES += \ + main.cpp \ + widget.cpp + +HEADERS += \ + widget.h + +FORMS += \ + widget.ui + +# Default rules for deployment. +qnx: target.path = /tmp/$${TARGET}/bin +else: unix:!android: target.path = /opt/$${TARGET}/bin +!isEmpty(target.path): INSTALLS += target diff --git a/Database_Lyh/main.cpp b/Database_Lyh/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e57e09ca21147eeccd4fff1bea4273e8d8a46b28 --- /dev/null +++ b/Database_Lyh/main.cpp @@ -0,0 +1,15 @@ +#include "widget.h" + +#include +#include +#include +#include +#include + +int main(int argc, char *argv[]) +{ + QApplication a(argc, argv); + Widget w; + w.show(); + return a.exec(); +} diff --git a/Database_Lyh/widget.cpp b/Database_Lyh/widget.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e32fbbb5f2e18cea962668e11882bfda951f548d --- /dev/null +++ b/Database_Lyh/widget.cpp @@ -0,0 +1,166 @@ +#include "widget.h" +#include "ui_widget.h" +#include +#include +#include +#include +#include +#include + +Widget::Widget(QWidget *parent) + : QWidget(parent) + , ui(new Ui::Widget) +{ + ui->setupUi(this); + createDB(); + createTable (); + insertData_User("124","Li","123456","..."); + insertData_User("123","Wang","123456","..."); + insertData_Dialog("friend","..."); + insertData_Message(1,"123","123to124","..."); + insertData_Message(1,"124","124to123","..."); + insertData_Member(1,"123"); + insertData_Member(2,"124"); +// insertData_Alldialog("123",1); +// insertData_Alldialog("124",1); + queryTable(); + database.close(); +} + +Widget::~Widget() +{ + delete ui; +} + + +//创建数据库 +void Widget::createDB(){ + + if (QSqlDatabase::contains("qt_sql_default_connection")) + { + database = QSqlDatabase::database("qt_sql_default_connection"); + } + else + { + database = QSqlDatabase::addDatabase("QSQLITE"); + database.setDatabaseName("MyDataBase.db"); + } + //建立连接 + if (!database.open()) + { + qDebug() << "Error: Failed to connect database." << database.lastError(); + } + else + { + qDebug() <<"数据库链接成功"; + } +} + + +//创建表格 +void Widget::createTable (void) { + // 构建创建数据表sql语句的字符串 + //先有成员、然后会话,然后才有消息 + //用户表 + QString userstr ("CREATE TABLE User(username TEXT PRIMARY KEY NOT NULL, nickname TEXT NOT NULL, password TEXT NOT NULL, profile TEXT NOT NULL)"); + //消息表 + QString messagestr ("CREATE TABLE Message(messageID INT PRIMARY KEY NOT NULL, sessionID INT NOT NULL, senderUsername TEXT NOT NULL, messageText TEXT NOT NULL, profile TEXT NOT NULL, foreign KEY (senderUsername) references User(username), foreign KEY (sessionID) references Dialog(sessionID))"); + //会话表 + QString dialoguestr ("CREATE TABLE Dialog(sessionID INT PRIMARY KEY NOT NULL, SessionType TEXT NOT NULL, profile TEXT NOT NULL)"); + //会话中成员表 + QString memberstr ("CREATE TABLE Member(sessionID INT NOT NULL, username TEXT NOT NULL)");//, primary key(sessionID,username), foreign KEY (sessionID) references Dialog(sessionID), foreign KEY (username) references User(username) + //用户所拥有的会话表 + //QString alldialogstr ("CREATE TABLE Alldialog(username TEXT NOT NULL, sessionID INT NOT NULL)");//, primary key(username,sessionID), foreign KEY (sessionID) references Dialog(sessionID), foreign KEY (username) references User(username) + // 执行sql语句 + QSqlQuery *query; + query = new QSqlQuery(); + query->exec (userstr); + query->exec (messagestr); + query->exec (dialoguestr); + query->exec (memberstr); + //query->exec (alldialogstr); +} + + +//查询所有User表中的数据 +void Widget::queryTable (void) { + QSqlQuery sqlQuery; + sqlQuery.exec("SELECT * FROM User"); + if(!sqlQuery.exec()) + { + qDebug() << "Error: Fail to query table. " << sqlQuery.lastError(); + } + else + { + while(sqlQuery.next()) + { + QString username = sqlQuery.value(0).toString(); + QString nickname = sqlQuery.value(1).toString(); + QString password = sqlQuery.value(2).toString(); + QString profile = sqlQuery.value(3).toString(); + qDebug()<prepare(insert_sql); + query->addBindValue(username); + query->addBindValue(nickname); + query->addBindValue(password); + query->addBindValue(profile); + query->exec(); +} +//往message表中插入数据 +void Widget::insertData_Message(int sessionID, const char* senderUsername, const char* messageText,const char* profile){ + QSqlQuery *query; + query = new QSqlQuery(); + QString insert_sql = "insert into Message values (?, ?, ?, ?, ?)"; + query->prepare(insert_sql); + query->addBindValue(++maxMessage); + query->addBindValue(sessionID); + query->addBindValue(senderUsername); + query->addBindValue(messageText); + query->addBindValue(profile); + query->exec(); +} +//往Dialog表中插入数据 +void Widget::insertData_Dialog(const char* SessionType, const char* profile){ + QSqlQuery *query; + query = new QSqlQuery(); + QString insert_sql = "insert into Dialog values(?, ?, ?)"; + query->prepare(insert_sql); + query->addBindValue(++maxDioalog); + query->addBindValue(SessionType); + query->addBindValue(profile); + query->exec(); +} +//记录一个会议的参加人员 +void Widget::insertData_Member(int sessionID, const char* username){ + QSqlQuery *query; + query = new QSqlQuery(); + QString insert_sql = "insert into Member values(?, ?)"; + query->prepare(insert_sql); + query->addBindValue(sessionID); + query->addBindValue(username); + if(!query->exec()){ + qDebug()<<"query error: "<lastError(); + } +} +//记录一人参加的会议 +//void Widget::insertData_Alldialog (const char* username, int sessionID){ +// QSqlQuery *query; +// query = new QSqlQuery(); +// QString insert_sql = "insert into Alldialog values(?, ?)"; +// query->addBindValue(username); +// query->addBindValue(sessionID); +// if(!query->exec()){ +// qDebug()<<"query error: "<lastError(); +// } +//} + diff --git a/Database_Lyh/widget.h b/Database_Lyh/widget.h new file mode 100644 index 0000000000000000000000000000000000000000..99a4ee789c1ea2462015031e946f9bfa32754413 --- /dev/null +++ b/Database_Lyh/widget.h @@ -0,0 +1,40 @@ +#ifndef WIDGET_H +#define WIDGET_H + +#include +#include +#include +#include +#include +#include + +QT_BEGIN_NAMESPACE +namespace Ui { class Widget; } +QT_END_NAMESPACE + +class Widget : public QWidget +{ + Q_OBJECT + +public: + Widget(QWidget *parent = nullptr); + ~Widget(); +private: + QSqlDatabase database;// 建立QT程序和数据的连接 + QSqlQueryModel model; // 保存和遍历查询结果 + int maxMessage=0; + int maxDioalog=0; +private: + void createDB (void);//创建数据库 + void createTable (void);//创建数据表 +public: + void queryTable (void);//查询数据 + void insertData_User (const char* username,const char* nickname,const char* password,const char* profile);//往User中插入数据 + void insertData_Message (int sessionID, const char* senderUsername, const char* messageText,const char* profile); + void insertData_Dialog (const char* SessionType, const char* profile); + void insertData_Member (int sessionID, const char* username); + //void insertData_Alldialog (const char* username, int sessionID); +private: + Ui::Widget *ui; +}; +#endif // WIDGET_H diff --git a/Client/client.ui b/Database_Lyh/widget.ui similarity index 60% rename from Client/client.ui rename to Database_Lyh/widget.ui index ced9b5e716efaab88d95caa6ea31fff4f24c8e79..c3fa28a3218bd25cc6bdae9e09c8e49ae36da0a2 100644 --- a/Client/client.ui +++ b/Database_Lyh/widget.ui @@ -1,17 +1,17 @@ - client - + Widget + 0 0 - 400 - 300 + 800 + 600 - Form + Widget diff --git a/README.md b/README.md index 2ac2f2467c338905dc020669c9c2e6273ab4a65c..09ffd554fe00e75cac13d7ca24117389f13e5237 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,19 @@ BIT ICQ, a Realtime Communicating Solution using C++ Qt framework 1. 类名、文件名使用大驼峰命名法,函数名使用小驼峰命名法 2. 仓库在 master 和 Develop 分支上设置了保护,一般工作在feature分支上,如果是可以独立开发的模块,强烈建议单独拉一条新分支出来。每个 working state 会合并到 Develop 分支,release 版本合并到 master 分支 +```json +{ + "MsgType": "JsonArray", + "MsgList": [ + { ... }, + { ... }, + { ... }, + ] +} + +``` + + ### 注册信息 ``` json @@ -139,7 +152,6 @@ BIT ICQ, a Realtime Communicating Solution using C++ Qt framework "Text": "...", "Profile": { "hasMentionInfo": false, - "ReadMark": true / false, "...": "...", } } diff --git a/Server/databaseoperation.cpp b/Server/databaseoperation.cpp index 18b00dc30332a8b9260c3439d10f3aff2d3d7346..be4ec323df4d69e2490b24cb0aa431ccbd9230ab 100644 --- a/Server/databaseoperation.cpp +++ b/Server/databaseoperation.cpp @@ -1,395 +1,399 @@ -#include -#include -#include -#include -#include - -#include "databaseoperation.h" - -QString json2str(QJsonObject json) { - return QString(QJsonDocument(json).toJson()); -} - -QJsonObject str2json(QString str) { - QJsonDocument jsonDoc = QJsonDocument::fromJson(str.toUtf8().data()); - if (jsonDoc.isNull()) { - qDebug() << "read json obj from str failed: str = " << str.toLocal8Bit().data(); - } - return jsonDoc.object(); -} - -DatabaseOperation::DatabaseOperation(QObject *parent) : QObject(parent) -{ - status = Status::Stop; -} - -void DatabaseOperation::startDatabaseConnection(QString dbfilename) { - if (status != Status::Stop) { - qDebug() << "db server already running..."; - throw "Already running error"; - } - QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE"); - db.setDatabaseName(dbfilename); //如果本目录下没有该文件,则会在本目录下生成,否则连接该文件 - if (!db.open()) { - qDebug() << db.lastError().text(); - throw "Database Error"; - } - status = Status::Running; - createTables(); - findAllUsers(); - findAllSessions(); - findAllMessages(); -} - -void DatabaseOperation::executeSqlStatement(QString str) { - QSqlQuery query (str); - if (!query.isActive()) { - qDebug() << query.lastError(); - } -} - -void DatabaseOperation::createTables() { - executeSqlStatement("CREATE TABLE User(Username TEXT PRIMARY KEY, Nickname TEXT NOT NULL, Password TEXT NOT NULL, Profile TEXT)"); - executeSqlStatement("CREATE TABLE Session(SessionID INT PRIMARY KEY, SessionType TEXT NOT NULL, Profile TEXT)"); - executeSqlStatement("CREATE TABLE Message(SessionID INT NOT NULL, MessageID INT NOT NULL, SenderUsername TEXT NOT NULL, MessageText TEXT NOT NULL, Profile TEXT, PRIMARY KEY(SessionID, MessageID), FOREIGN KEY(SessionID) REFERENCES Session(SessionID), FOREIGN KEY(SenderUserName) REFERENCES User(Username))"); - executeSqlStatement("CREATE TABLE IsMember(SessionID INT NOT NULL, Username TEXT NOT NULL, PRIMARY KEY (Username, SessionID), FOREIGN KEY (Username) REFERENCES User(username), FOREIGN KEY (SessionID) REFERENCES Session(SessionID))"); -} - -bool DatabaseOperation::isDBExist(QString dbfilename) const { - QFileInfo file(dbfilename); - return file.exists(); -} - -void DatabaseOperation::closeDB() { - if (status != Status::Running) throw "already closed"; - database.close(); - - status = Status::Stop; - emit signal_DBstop(); -} - -QList DatabaseOperation::findAllUsers() { - QList ret; - QSqlQuery sqlQuery; - sqlQuery.exec("SELECT Username, Nickname, Password, Profile FROM User"); - if(!sqlQuery.exec()) { - qDebug() << "Error: Fail to query table. " << sqlQuery.lastError(); - throw "DB read error"; - } - else { - auto & dcenter = ServerDataCenter::Singleton(); - while(sqlQuery.next()) { - QString username = sqlQuery.value(0).toString(); - QString nickname = sqlQuery.value(1).toString(); - QString profile = sqlQuery.value(3).toString(); - auto newuser = new OnlineUserModel(username, nickname, str2json(profile)); - dcenter.registerUser(newuser); - ret.append(newuser); - } - } - return ret; -} - -QList DatabaseOperation::findAllSessions() { - QList ret; - QSqlQuery query; - if (!query.exec("SELECT SessionID, Profile FROM Session")) { - qDebug() << "findAllSessions: " << query.lastError(); - throw "DB read error"; - } - auto & dcenter = ServerDataCenter::Singleton(); - while (query.next()) { - int sessionId = query.value(0).toInt(); - QJsonObject json = query.value(1).toJsonObject(); - QString sessionName = json.contains("SessionName") ? json["SessionName"].toString() : "None"; - - QList members = queryMembersBySession(sessionId); - - OnlineSession * session = new OnlineSession(sessionId, sessionName, json, members); - ret.append(session); - dcenter.registerSession(session); - } - return ret; -} - -QList DatabaseOperation::findAllMessages() { - QList ret; - QSqlQuery query; - if (!query.exec("SELECT SessionID, MessageID, senderUsername, MessageText, Profile FROM messages")) { - while (query.next()) { - int sessionID = query.value(0).toInt(); - int msgId = query.value(1).toInt(); - QString sender = query.value(2).toString(); - QString text = query.value(3).toString(); - QJsonObject profile = query.value(4).toJsonObject(); - - auto * msg = new OnlineMessage(msgId, sessionID, sender, text, profile); - ret.append(msg); - ServerDataCenter::Singleton().registerMessage(msg); - } - } - return ret; -} - - -//往User表中插入数据 -bool DatabaseOperation::insertUser(const char* username, const char* nickname, const char* password, const char* profile){ - QSqlQuery query; -// QString insert_sql = "insert into User(Username, Nickname, Password, Profile) values (?, ?, ?, ?)"; - QString insert_sql = QString("INSERT INTO User(Username, Nickname, Password, Profile) VALUES ('%1', '%2', '%3', '%4')") - .arg(username).arg(nickname).arg(password).arg(profile); - qDebug() << insert_sql; -// query.prepare(insert_sql); -// query.addBindValue(username); -// query.addBindValue(nickname); -// query.addBindValue(password); -// query.addBindValue(profile); - if (! query.exec(insert_sql) ) { - qDebug() << query.lastError(); - return false; - } - ServerDataCenter::Singleton().registerUser(new OnlineUserModel( - QString(username), QString(nickname), str2json(QString(profile)) )); - return true; -} - -bool DatabaseOperation::insertUser(const OnlineUserModel &user, const QString &password) { - return insertUser(user.getUsername().toUtf8().data(), - user.getNickname().toUtf8().data(), - password.toUtf8().data(), - json2str(user.getProfile()).toUtf8().data()); -} - - - -int DatabaseOperation::getTableCount(const char * tableName) const { - QSqlQuery query; - QString sql = QString("SELECT COUNT(*) FROM %1").arg(tableName); - query.addBindValue(tableName); - if (!query.exec(sql)) { - qDebug() << query.lastError(); - return -1; - } - if (!query.next()) return -1; - return query.value(0).toInt(); -} - -//往 Session 表中插入数据 -int DatabaseOperation::insertSessionBasicInfo(const char* SessionType, const char* profile) { - QSqlQuery query; - QString insert_sql = "insert into Session(SessionID, SessionType, Profile) values(?, ?, ?)"; - query.prepare(insert_sql); - int sessionId = getTableCount("Session") + 1; - query.addBindValue(sessionId); - query.addBindValue(SessionType); - query.addBindValue(profile); - if (!query.exec()) { - qDebug() << query.lastError(); - return -1; - } - return sessionId; -} - -bool DatabaseOperation::insertMember(int sessionID, const char * user){ - QSqlQuery query; - QString insert_sql = "insert into IsMember values(?, ?)"; - query.prepare(insert_sql); - query.addBindValue(sessionID); - query.addBindValue(user); - if(!query.exec()){ - qDebug()<<"query error: "< DatabaseOperation::queryMembersBySession(int sessionID){ - QList member_List; - QSqlQuery query; - QString query_sql = "SELECT SessionID, Username FROM IsMember WHERE sessionID = (?)"; - query.prepare(query_sql); - query.addBindValue(sessionID); - if (! query.exec()) { - qDebug() << "error occurred in queryMembersBySession, " << query.lastError(); - return member_List; - } - while(query.next()){ - member_List.append(query.value(1).toString()); - } - return member_List; -} - -//查询用户所拥有的session -QList DatabaseOperation::querySessionsByMember(const char * username){ - QList member_List; - QSqlQuery query; - QString query_sql = "SELECT SessionId, Username FROM IsMember WHERE username = (?)"; - query.prepare(query_sql); - query.addBindValue(username); - query.exec(); - while(query.next()){ - member_List.append(query.value(0).toInt()); - } - return member_List; -} - -bool DatabaseOperation::attemptLogIn(QString username, QString password) const { - //用户名检测 - QSqlQuery query(database); - query.prepare("select username, password from User where username=?"); - query.addBindValue(username); - bool ok = query.exec(); - if(!ok){ - qDebug() << "Fail query register username" << query.lastError(); - return false; - } - if(query.next()){ - //密码检测 - if (query.value(1).toString() == password) - return true; - qDebug() << "password incorrect"; - return false; - } - qDebug() << "Username not found"; - return false; -} - -int DatabaseOperation::insertNewMessage(int SessionId, const char *senderUsername, const char *MessageText, const char *profile) { - QSqlQuery query; - QString sql = "select count (*) from Message WHERE SessionId = ?"; - query.prepare(sql); - query.addBindValue(SessionId); - if (!query.exec() || !query.next()) { - qDebug() << "Error Occurred when querying Message Number" << query.lastError(); - return -1; - } - int msgId = query.value(0).toInt() + 1; - qDebug() << "Current MsgId for sessionId = " << msgId; - - sql = "insert into Message(SessionID, MessageID, SenderUsername, MessageText, Profile) VALUES (?, ?, ?, ?, ?)"; - query.prepare(sql); - query.addBindValue(SessionId); - query.addBindValue(msgId); - query.addBindValue(senderUsername); - query.addBindValue(MessageText); - query.addBindValue(profile); - if (!query.exec()) { - qDebug() << "insertNewMessage : " << query.lastError(); - return -1; - } - return msgId; -} - -QList DatabaseOperation::getMessageListBySessionID(int SessionId) const { - QList ret; - QSqlQuery query; - QString sql = "SELECT MessageID, SenderUsername, MessageText, Profile FROM Message WHERE SessionID = ?"; - query.prepare(sql); - query.addBindValue(SessionId); - if (!query.exec()) { - qDebug() << "getMessageListBySessionID: " << query.lastError(); - throw query.lastError(); - } - while(query.next()) { - int msgId = query.value(0).toInt(); - QString senderUsername = query.value(1).toString(); - QString messageText = query.value(2).toString(); - QJsonObject profile = query.value(3).toJsonObject(); - auto * msg = new OnlineMessage(msgId, SessionId, senderUsername, messageText, profile); - ret.append(msgId); - ServerDataCenter::Singleton().registerMessage(msg); - } - return ret; -} - -OnlineUserModel & ServerDataCenter::getUser(QString username) { - if (_getUser(username) == nullptr) throw "Not exist"; - return *users[username]; -} - -OnlineMessage & ServerDataCenter::getMessage(int SessionId, int MessageId) { - if (_getMessage(SessionId, MessageId)) throw "Not exist"; - return *messages[{SessionId, MessageId}]; -} - -OnlineSession & ServerDataCenter::getSession(int SessionId) { - if (_getSession(SessionId) == nullptr) throw "Not exist"; - return *sessions[SessionId]; -} - -OnlineUserModel* ServerDataCenter::_getUser(QString username) { - if (users.contains(username)) - return users[username]; - return nullptr; -} - -OnlineSession* ServerDataCenter::_getSession(int SessionId) { - if (sessions.contains(SessionId)) - return sessions[SessionId]; - return nullptr; -} - -OnlineMessage* ServerDataCenter::_getMessage(int SessionId, int MessageId) { - if (messages.contains({SessionId, MessageId})) - return messages[{SessionId, MessageId}]; - return nullptr; -} - -OnlineUserModel * DatabaseOperation::findUser(QString username) { - QSqlQuery query; - QString sql = "SELECT Username, Nickname, Profile FROM User WHERE Username = " + username; - if (!query.exec(sql) || !query.first()) { - qDebug() << "DBOps::findUser: " << query.lastError(); - return nullptr; - } - OnlineUserModel * ret = new OnlineUserModel(query.value(0).toString(), - query.value(1).toString(), - query.value(2).toJsonObject()); - ServerDataCenter::Singleton().registerUser(ret); - return ret; -} - -OnlineSession * DatabaseOperation::findSession(int sessionID) { - QSqlQuery query; - QString sql = "SELECT Profile FROM Session WHERE SessionID = " + QString::number(sessionID); - if (!query.exec(sql) || !query.first()) { - qDebug() << "DBOps::findSession: " << query.lastError(); - return nullptr; - } - - auto json = query.value(0).toJsonObject(); - QString SessionName = json.contains("SessionName") ? - json["SessionName"].toString() : "None"; - - OnlineSession * ret = new OnlineSession(sessionID, SessionName, json, - queryMembersBySession(sessionID)); - ServerDataCenter::Singleton().registerSession(ret); - return ret; -} - -OnlineMessage * DatabaseOperation::findMessage(int sessionId, int MessageId) { - QSqlQuery query; - QString sql = "SELECT SenderUsername, MessageText, Profile FROM Message WHERE SessionID = " + - QString::number(sessionId) + " and MessageID = " + QString::number(MessageId); - if (!query.exec(sql) || !query.first()) { - qDebug() << "DBOps::findMessage: " << query.lastError(); - return nullptr; - } - - QString sender = query.value(0).toString(); - QString text = query.value(1).toString(); - QJsonObject json = query.value(2).toJsonObject(); - OnlineMessage * ret = new OnlineMessage(MessageId, sessionId, sender, text, json); - ServerDataCenter::Singleton().registerMessage(ret); - return ret; -} - -ServerDataCenter::ServerDataCenter(QObject *parent) : QObject(parent) -{ - connect(&DatabaseOperation::Singleton(), &DatabaseOperation::signal_DBstop, this, &ServerDataCenter::clean); -} +#include +#include +#include +#include +#include + +#include "databaseoperation.h" + +QString json2str(QJsonObject json) { + return QString(QJsonDocument(json).toJson()); +} + +QJsonObject str2json(QString str) { + QJsonDocument jsonDoc = QJsonDocument::fromJson(str.toUtf8().data()); + if (jsonDoc.isNull()) { + qDebug() << "read json obj from str failed: str = " << str.toLocal8Bit().data(); + } + return jsonDoc.object(); +} + +DatabaseOperation::DatabaseOperation(QObject *parent) : QObject(parent) +{ + status = Status::Stop; +} + +void DatabaseOperation::startDatabaseConnection(QString dbfilename) { + if (status != Status::Stop) { + qDebug() << "db server already running..."; + throw "Already running error"; + } + QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE"); + db.setDatabaseName(dbfilename); //如果本目录下没有该文件,则会在本目录下生成,否则连接该文件 + if (!db.open()) { + qDebug() << db.lastError().text(); + throw "Database Error"; + } + status = Status::Running; + createTables(); + findAllUsers(); + findAllSessions(); + findAllMessages(); + emit signal_DB_ready(); +} + +void DatabaseOperation::executeSqlStatement(QString str) { + QSqlQuery query (str); + if (!query.isActive()) { + qDebug() << query.lastError(); + } +} + +void DatabaseOperation::createTables() { + executeSqlStatement("CREATE TABLE User(Username TEXT PRIMARY KEY, Nickname TEXT NOT NULL, Password TEXT NOT NULL, Profile TEXT)"); + executeSqlStatement("CREATE TABLE Session(SessionID INT PRIMARY KEY, SessionType TEXT NOT NULL, Profile TEXT)"); + executeSqlStatement("CREATE TABLE Message(SessionID INT NOT NULL, MessageID INT NOT NULL, SenderUsername TEXT NOT NULL, MessageText TEXT NOT NULL, Profile TEXT, PRIMARY KEY(SessionID, MessageID), FOREIGN KEY(SessionID) REFERENCES Session(SessionID), FOREIGN KEY(SenderUserName) REFERENCES User(Username))"); + executeSqlStatement("CREATE TABLE IsMember(SessionID INT NOT NULL, Username TEXT NOT NULL, PRIMARY KEY (Username, SessionID), FOREIGN KEY (Username) REFERENCES User(username), FOREIGN KEY (SessionID) REFERENCES Session(SessionID))"); +} + +bool DatabaseOperation::isDBExist(QString dbfilename) const { + QFileInfo file(dbfilename); + return file.exists(); +} + +void DatabaseOperation::closeDB() { + if (status != Status::Running) throw "already closed"; + database.close(); + + status = Status::Stop; + emit signal_DBstop(); +} + +QList DatabaseOperation::findAllUsers() { + QList ret; + QSqlQuery sqlQuery; + sqlQuery.exec("SELECT Username, Nickname, Password, Profile FROM User"); + if(!sqlQuery.exec()) { + qDebug() << "Error: Fail to query table. " << sqlQuery.lastError(); + throw "DB read error"; + } + else { + auto & dcenter = ServerDataCenter::Singleton(); + while(sqlQuery.next()) { + QString username = sqlQuery.value(0).toString(); + QString nickname = sqlQuery.value(1).toString(); + QString profile = sqlQuery.value(3).toString(); + auto newuser = new OnlineUserModel(username, nickname, str2json(profile)); + dcenter.registerUser(newuser); + ret.append(newuser); + } + } + return ret; +} + +QList DatabaseOperation::findAllSessions() { + QList ret; + QSqlQuery query; + if (!query.exec("SELECT SessionID, Profile FROM Session")) { + qDebug() << "findAllSessions: " << query.lastError(); + throw "DB read error"; + } + auto & dcenter = ServerDataCenter::Singleton(); + while (query.next()) { + int sessionId = query.value(0).toInt(); + QJsonObject json = query.value(1).toJsonObject(); + QString sessionName = json.contains("SessionName") ? json["SessionName"].toString() : "None"; + + QList members = queryMembersBySession(sessionId); + + OnlineSession * session = new OnlineSession(sessionId, sessionName, json, members); + ret.append(session); + dcenter.registerSession(session); + } + return ret; +} + +QList DatabaseOperation::findAllMessages() { + QList ret; + QSqlQuery query; + if (!query.exec("SELECT SessionID, MessageID, senderUsername, MessageText, Profile FROM messages")) { + while (query.next()) { + int sessionID = query.value(0).toInt(); + int msgId = query.value(1).toInt(); + QString sender = query.value(2).toString(); + QString text = query.value(3).toString(); + QJsonObject profile = query.value(4).toJsonObject(); + + auto * msg = new OnlineMessage(msgId, sessionID, sender, text, profile); + ret.append(msg); + ServerDataCenter::Singleton().registerMessage(msg); + } + } + return ret; +} + + +//往User表中插入数据 +bool DatabaseOperation::insertUser(const char* username, const char* nickname, const char* password, const char* profile){ + QSqlQuery query; +// QString insert_sql = "insert into User(Username, Nickname, Password, Profile) values (?, ?, ?, ?)"; + QString insert_sql = QString("INSERT INTO User(Username, Nickname, Password, Profile) VALUES ('%1', '%2', '%3', '%4')") + .arg(username).arg(nickname).arg(password).arg(profile); + qDebug() << insert_sql; +// query.prepare(insert_sql); +// query.addBindValue(username); +// query.addBindValue(nickname); +// query.addBindValue(password); +// query.addBindValue(profile); + if (! query.exec(insert_sql) ) { + qDebug() << query.lastError(); + return false; + } + ServerDataCenter::Singleton().registerUser(new OnlineUserModel( + QString(username), QString(nickname), str2json(QString(profile)) )); + return true; +} + +bool DatabaseOperation::insertUser(const OnlineUserModel &user, const QString &password) { + return insertUser(user.getUsername().toUtf8().data(), + user.getNickname().toUtf8().data(), + password.toUtf8().data(), + json2str(user.getProfile()).toUtf8().data()); +} + + + +int DatabaseOperation::getTableCount(const char * tableName) const { + QSqlQuery query; + QString sql = QString("SELECT COUNT(*) FROM %1").arg(tableName); + query.addBindValue(tableName); + if (!query.exec(sql)) { + qDebug() << query.lastError(); + return -1; + } + if (!query.next()) return -1; + return query.value(0).toInt(); +} + +//往 Session 表中插入数据 +int DatabaseOperation::insertSessionBasicInfo(const char* SessionType, const char* profile) { + QSqlQuery query; + QString insert_sql = "insert into Session(SessionID, SessionType, Profile) values(?, ?, ?)"; + query.prepare(insert_sql); + int sessionId = getTableCount("Session") + 1; + query.addBindValue(sessionId); + query.addBindValue(SessionType); + query.addBindValue(profile); + if (!query.exec()) { + qDebug() << query.lastError(); + return -1; + } + return sessionId; +} + +bool DatabaseOperation::insertMember(int sessionID, const char * user){ + QSqlQuery query; + QString insert_sql = "insert into IsMember values(?, ?)"; + query.prepare(insert_sql); + query.addBindValue(sessionID); + query.addBindValue(user); + if(!query.exec()){ + qDebug()<<"query error: "< DatabaseOperation::queryMembersBySession(int sessionID){ + QList member_List; + QSqlQuery query; + QString query_sql = "SELECT SessionID, Username FROM IsMember WHERE sessionID = (?)"; + query.prepare(query_sql); + query.addBindValue(sessionID); + if (! query.exec()) { + qDebug() << "error occurred in queryMembersBySession, " << query.lastError(); + return member_List; + } + while(query.next()){ + member_List.append(query.value(1).toString()); + } + return member_List; +} + +//查询用户所拥有的session +QList DatabaseOperation::querySessionsByMember(const char * username){ + QList member_List; + QSqlQuery query; + QString query_sql = "SELECT SessionId, Username FROM IsMember WHERE username = (?)"; + query.prepare(query_sql); + query.addBindValue(username); + query.exec(); + while(query.next()){ + member_List.append(query.value(0).toInt()); + } + return member_List; +} + +bool DatabaseOperation::attemptLogIn(QString username, QString password) const { + //用户名检测 + QSqlQuery query(database); + query.prepare("select username, password from User where username=?"); + query.addBindValue(username); + bool ok = query.exec(); + if(!ok){ + qDebug() << "Fail query register username" << query.lastError(); + return false; + } + if(query.next()){ + //密码检测 + if (query.value(1).toString() == password) + return true; + qDebug() << "password incorrect"; + return false; + } + qDebug() << "Username not found"; + return false; +} + +int DatabaseOperation::insertNewMessage(int SessionId, const char *senderUsername, const char *MessageText, const char *profile) { + QSqlQuery query; + QString sql = "select count (*) from Message WHERE SessionId = ?"; + query.prepare(sql); + query.addBindValue(SessionId); + if (!query.exec() || !query.next()) { + qDebug() << "Error Occurred when querying Message Number" << query.lastError(); + return -1; + } + int msgId = query.value(0).toInt() + 1; + qDebug() << "Current MsgId for sessionId = " << msgId; + + sql = "insert into Message(SessionID, MessageID, SenderUsername, MessageText, Profile) VALUES (?, ?, ?, ?, ?)"; + query.prepare(sql); + query.addBindValue(SessionId); + query.addBindValue(msgId); + query.addBindValue(senderUsername); + query.addBindValue(MessageText); + query.addBindValue(profile); + if (!query.exec()) { + qDebug() << "insertNewMessage : " << query.lastError(); + return -1; + } + ServerDataCenter::Singleton().registerMessage( + new OnlineMessage(msgId, SessionId, QString(senderUsername), QString(MessageText), str2json(QString(profile))) + ); + return msgId; +} + +QList DatabaseOperation::getMessageListBySessionID(int SessionId) const { + QList ret; + QSqlQuery query; + QString sql = "SELECT MessageID, SenderUsername, MessageText, Profile FROM Message WHERE SessionID = ?"; + query.prepare(sql); + query.addBindValue(SessionId); + if (!query.exec()) { + qDebug() << "getMessageListBySessionID: " << query.lastError(); + throw query.lastError(); + } + while(query.next()) { + int msgId = query.value(0).toInt(); + QString senderUsername = query.value(1).toString(); + QString messageText = query.value(2).toString(); + QJsonObject profile = query.value(3).toJsonObject(); + auto * msg = new OnlineMessage(msgId, SessionId, senderUsername, messageText, profile); + ret.append(msgId); + ServerDataCenter::Singleton().registerMessage(msg); + } + return ret; +} + +OnlineUserModel & ServerDataCenter::getUser(QString username) { + if (_getUser(username) == nullptr) throw "Not exist"; + return *users[username]; +} + +OnlineMessage & ServerDataCenter::getMessage(int SessionId, int MessageId) { + if (_getMessage(SessionId, MessageId)) throw "Not exist"; + return *messages[{SessionId, MessageId}]; +} + +OnlineSession & ServerDataCenter::getSession(int SessionId) { + if (_getSession(SessionId) == nullptr) throw "Not exist"; + return *sessions[SessionId]; +} + +OnlineUserModel* ServerDataCenter::_getUser(QString username) { + if (users.contains(username)) + return users[username]; + return nullptr; +} + +OnlineSession* ServerDataCenter::_getSession(int SessionId) { + if (sessions.contains(SessionId)) + return sessions[SessionId]; + return nullptr; +} + +OnlineMessage* ServerDataCenter::_getMessage(int SessionId, int MessageId) { + if (messages.contains({SessionId, MessageId})) + return messages[{SessionId, MessageId}]; + return nullptr; +} + +OnlineUserModel * DatabaseOperation::findUser(QString username) { + QSqlQuery query; + QString sql = "SELECT Username, Nickname, Profile FROM User WHERE Username = " + username; + if (!query.exec(sql) || !query.first()) { + qDebug() << "DBOps::findUser: " << query.lastError(); + return nullptr; + } + OnlineUserModel * ret = new OnlineUserModel(query.value(0).toString(), + query.value(1).toString(), + query.value(2).toJsonObject()); + ServerDataCenter::Singleton().registerUser(ret); + return ret; +} + +OnlineSession * DatabaseOperation::findSession(int sessionID) { + QSqlQuery query; + QString sql = "SELECT Profile FROM Session WHERE SessionID = " + QString::number(sessionID); + if (!query.exec(sql) || !query.first()) { + qDebug() << "DBOps::findSession: " << query.lastError(); + return nullptr; + } + + auto json = query.value(0).toJsonObject(); + QString SessionName = json.contains("SessionName") ? + json["SessionName"].toString() : "None"; + + OnlineSession * ret = new OnlineSession(sessionID, SessionName, json, + queryMembersBySession(sessionID)); + ServerDataCenter::Singleton().registerSession(ret); + return ret; +} + +OnlineMessage * DatabaseOperation::findMessage(int sessionId, int MessageId) { + QSqlQuery query; + QString sql = "SELECT SenderUsername, MessageText, Profile FROM Message WHERE SessionID = " + + QString::number(sessionId) + " and MessageID = " + QString::number(MessageId); + if (!query.exec(sql) || !query.first()) { + qDebug() << "DBOps::findMessage: " << query.lastError(); + return nullptr; + } + + QString sender = query.value(0).toString(); + QString text = query.value(1).toString(); + QJsonObject json = query.value(2).toJsonObject(); + OnlineMessage * ret = new OnlineMessage(MessageId, sessionId, sender, text, json); + ServerDataCenter::Singleton().registerMessage(ret); + return ret; +} + +ServerDataCenter::ServerDataCenter(QObject *parent) : QObject(parent) +{ + connect(&DatabaseOperation::Singleton(), &DatabaseOperation::signal_DBstop, this, &ServerDataCenter::clean); +} diff --git a/Server/databaseoperation.h b/Server/databaseoperation.h index f990432a9215b960d7d8fb29c6f1466dfa7c20e6..9b90e3e7ed79780bca3e995b3a505dc29b7015cd 100644 --- a/Server/databaseoperation.h +++ b/Server/databaseoperation.h @@ -52,6 +52,7 @@ public: signals: void signal_DBstop(); + void signal_DB_ready(); private: explicit DatabaseOperation(QObject *parent = nullptr); diff --git a/Server/handlesignal.h b/Server/handlesignal.h index 111256e98177cf04ec4aa995b84c0f4a5b51b476..6e1db5aa480ab139327486baa54ad6c37836b2a4 100644 --- a/Server/handlesignal.h +++ b/Server/handlesignal.h @@ -13,6 +13,7 @@ public: signals: void sendSignal(int); + }; #endif // HANDLESIGNAL_H diff --git a/Server/main.cpp b/Server/main.cpp index 5f32ea3ebab6be64e0e192870bc3adc06d6dc23d..e42d9cd220cef7e8516dc27ace312f82ddec218c 100644 --- a/Server/main.cpp +++ b/Server/main.cpp @@ -8,7 +8,6 @@ //#define TESTMODE - int main(int argc, char *argv[]) { #ifdef TESTMODE diff --git a/Server/operations.cpp b/Server/operations.cpp index 1b878eceb3a4f376742532ecd74b0e74a3fac7bb..4768df74511d3cbed6abf277a306b56b41f428fe 100644 --- a/Server/operations.cpp +++ b/Server/operations.cpp @@ -28,6 +28,7 @@ resp Operations::loginResponse(QJsonObject json) { } response["IsLegal"] = true; response["Username"] = username; + auto & user = dcenter.getUser(username); response["Nickname"] = user.getNickname(); response["Profile"] = user.getProfile(); @@ -65,3 +66,20 @@ resp Operations::registerResponse(QJsonObject json) { head["IsLegal"] = true; return {head}; } + +resp Operations::newMessageResponse(QJsonObject json) { + DatabaseOperation & db = DatabaseOperation::Singleton(); + int sessionID = json["SessionID"].toInt(); + QString senderUsername = json["SenderName"].toString(); + auto json1 = json["Body"].toObject(); + QString text = json1["Text"].toString(); + bool mention = json1["Profile"].toObject()["hasMentionInfo"].toBool(); + if (mention) { + throw "Not Implemented"; + } + int msgID = db.insertNewMessage(sessionID, senderUsername.toUtf8().data(), text.toUtf8().data(), json2str(json["Profile"].toObject()).toUtf8().data()); + json["MessageID"] = msgID; + emit newMessage(sessionID, json); + + return resp(); +} diff --git a/Server/operations.h b/Server/operations.h index fc1e3723f1e345d281131f9275eb8ba6d4c4f6d7..86c206d15b5b2f7d25fffc3e8f7901d1385642ec 100644 --- a/Server/operations.h +++ b/Server/operations.h @@ -21,6 +21,8 @@ public: QList request(QJsonObject json); signals: + void newMessage(int sessionId, QJsonObject msg); + public: explicit Operations(QObject *parent = nullptr); diff --git a/Server/sever.cpp b/Server/sever.cpp index 943d40a6feb626c35acc6cb2cba4c44f9ac769f9..4e2739db7ee00556c4ca54cb25c3bccfecb3612b 100644 --- a/Server/sever.cpp +++ b/Server/sever.cpp @@ -6,6 +6,7 @@ #include #include"databaseoperation.h" + Sever::Sever(QObject *parent) : QTcpServer (parent) { @@ -30,9 +31,13 @@ void Sever::incomingConnection(qintptr handle) newHandle->aaa(handle); }); + connect(socket,&QTcpSocket::disconnected,[=](){ + emit offLine (handle); + }); + emit sendChannel (handle); - emit linkMsg (socket->peerAddress ().toString ()+"上线了", 2); //2表示在登录框中显示 + emit linkMsg (socket->peerAddress ().toString ()+"连接成功!"); //2表示在登录框中显示 } @@ -67,17 +72,50 @@ void Sever::receiveMessage(int handle) auto returnList = QList(); auto &op = Operations::Singleton (); + + QString messageSender = ""; //标记消息来源用户名称 if(method == "login") { - returnList = op.loginResponse (recejson); + //首次收到登陆信息 + messageSender = recejson["Username"].toString(); + QMap >::iterator it1 = userToHandle.begin(); + int flag=0; + while(it1!=userToHandle.end ()) + { + if(it1.key ()==messageSender)//并非首次登陆 + { + it1.value ().append (handle); + emit linkMsg (messageSender+" "+QString::number (it1.value ().size ())+" "+"登陆了!"); + qDebug()<< it1.value ().size (); + flag=1; + break; + } + it1++; + } + if(flag==0)//首次登陆 + { + QList ll; + ll.append (handle); + userToHandle.insert (messageSender,ll); + emit linkMsg (messageSender+" 1 "+"登陆了!"); + } + returnList = op.loginResponse (recejson); + emit sendMsg (returnList, messageSender); } else if (method == "regist") { returnList = op.registerResponse(recejson); + emit sendMsg (returnList, messageSender); } else if (method == "info") { - qDebug() << recejson["Message"].toString(); - emit linkMsg(recejson["Message"].toString(), 1); + + emit linkMsg(recejson["Message"].toString()); + emit sendMsg (returnList, messageSender); } - qDebug() << returnList; - emit sendMsg (returnList, handle); //1表示在文本框中显示 + else if(method == "sessionmessage"){ + op.newMessageResponse (recejson); + //同时触发信号,到widget + } + + + } diff --git a/Server/sever.h b/Server/sever.h index 2c6d4689b9f77840f47c96eb1378ec2ab7432b18..575429c54b62e56009a13a68a688c816d560c831 100644 --- a/Server/sever.h +++ b/Server/sever.h @@ -12,8 +12,12 @@ class Sever : public QTcpServer public: explicit Sever(QObject *parent = nullptr); + //handle索引的socket QMap clientMap; + //username索引的handle,实现了一对多 + QMap> userToHandle; + private: QTcpSocket *sock; void incomingConnection(qintptr handle); @@ -22,11 +26,14 @@ private: public slots: void setIP(QString); + signals: - void linkMsg(QString, int); - void sendMsg(QList, int);//将tcp_server收到的信息作为信号发送给mianwindow + void linkMsg(QString); + void sendMsg(QList, QString);//将tcp_server收到的信息作为信号发送给mianwindow void ready_Read(qintptr); void sendChannel(int); + void offLine(int); + }; #endif // SEVER_H diff --git a/Server/testcases.cpp b/Server/testcases.cpp index 0527ced8a64383f897b62bffa8e4a6f83c8f243c..45cd8758d2dddcc945ef7637ce6e76d2fd670de3 100644 --- a/Server/testcases.cpp +++ b/Server/testcases.cpp @@ -88,7 +88,7 @@ ENDSUITE(SessionTest) TESTSUITE(UserTest) CASE(OfflineUserGenerateJsonObject) { - OfflineUserModel user("userA", &obj); + OfflineUserModel user("userA"); user.setNickname("nicknameA"); user.setSigniture("None"); auto json = user.generateUserModelJson(); @@ -100,13 +100,13 @@ CASE(OfflineUserGenerateJsonObject) { } CASE(NewOnlineUser_ReadJson_LoadProfile) { - OfflineUserModel user("userA", &obj); + OfflineUserModel user("userA"); user.setNickname("nicknameA"); user.setSigniture("None"); auto json = user.generateUserModelJson(); - OnlineUserModel newuser(json, &obj); + OnlineUserModel newuser(json); assertEqual(newuser.getType(), UserModel::Type::Online); assertEqual(newuser.getUsername(), "userA"); @@ -130,17 +130,17 @@ CASE(CanGetSingleton) { } CASE(getOnlineUserFromUsername) { - OfflineUserModel user("userA", &obj); + OfflineUserModel user("userA"); user.setNickname("nicknameA"); user.setSigniture("None"); auto json = user.generateUserModelJson(); - OnlineUserModel newuser(json, &obj); + OnlineUserModel newuser(json); - OfflineUserModel userB("userB", &obj); + OfflineUserModel userB("userB"); userB.setNickname("nicknameA"); userB.setSigniture("None"); json = userB.generateUserModelJson(); - OnlineUserModel newuserB(json, &obj); + OnlineUserModel newuserB(json); auto& dcenter = ServerDataCenter::Singleton(); dcenter.registerUser(& newuser); @@ -164,17 +164,17 @@ TESTSUITE(MessageTest) CASE(NewMessage_GeneratedBy_UserAndSessionObject) { - OfflineUserModel userA("userA", &obj); + OfflineUserModel userA("userA"); userA.setNickname("nicknameA"); userA.setSigniture("None"); auto json = userA.generateUserModelJson(); - OnlineUserModel userA_online(json, &obj); + OnlineUserModel userA_online(json); - OfflineUserModel userB("userB", &obj); + OfflineUserModel userB("userB"); userB.setNickname("nicknameA"); userB.setSigniture("None"); json = userB.generateUserModelJson(); - OnlineUserModel userB_online(json, &obj); + OnlineUserModel userB_online(json); ServerDataCenter & dcenter = ServerDataCenter::Singleton(); dcenter.registerUser(&userA_online); @@ -197,7 +197,7 @@ CASE(NewMessage_GeneratedBy_UserAndSessionObject) testlog("Constructed mock users and session."); - MessageModel msg(userA_online, session, "a->b text", QJsonObject(), &obj); + MessageModel msg(userA_online, session, "a->b text", QJsonObject()); testlog("Generated offline message"); assertEqual(msg.getType(), MessageModel::Type::Offline); diff --git a/Server/usermodel.cpp b/Server/usermodel.cpp index bf3f1543dd1aed49e2a84877e8eadf21a0f14fe7..9c4b200f749781bafca5c968333e64fccc83feea 100644 --- a/Server/usermodel.cpp +++ b/Server/usermodel.cpp @@ -1,17 +1,17 @@ #include "usermodel.h" -UserModel::UserModel(QObject *parent) : QObject(parent) +UserModel::UserModel() { } -OfflineUserModel::OfflineUserModel(QString Username, QObject *parent) :QObject (parent) +OfflineUserModel::OfflineUserModel(QString Username) { username = Username; } -OnlineUserModel::OnlineUserModel(QJsonObject &json, QObject *parent) : QObject (parent) +OnlineUserModel::OnlineUserModel(QJsonObject &json) { loadBasicInfoFromJson(json); } @@ -23,8 +23,7 @@ void OnlineUserModel::loadBasicInfoFromJson(QJsonObject &json) { } -OnlineUserModel::OnlineUserModel(QString usrname, QString nick, QJsonObject json, QObject * parent) : - UserModel(parent) +OnlineUserModel::OnlineUserModel(QString usrname, QString nick, QJsonObject json) { username = usrname; nickname = nick; diff --git a/Server/usermodel.h b/Server/usermodel.h index ccb7ab9ce1a89949e47833aae52e30afa3012b77..85983b2c24d4449ed9635a2e4f55677404a65756 100644 --- a/Server/usermodel.h +++ b/Server/usermodel.h @@ -10,7 +10,7 @@ class UserModel : virtual public QObject { Q_OBJECT public: - UserModel(QObject *parent = nullptr); + UserModel(); enum class Type { Offline, Online, None}; virtual Type getType() const { return Type::None; } const QString& getUsername() const { return username; } @@ -24,11 +24,11 @@ protected: QJsonObject profile; }; -class OfflineUserModel : virtual public UserModel, virtual public QObject +class OfflineUserModel : public UserModel { Q_OBJECT public: - OfflineUserModel(QString Username, QObject *parent = nullptr); + OfflineUserModel(QString Username); Type getType() const { return Type::Offline; } void setNickname(QString nname) { nickname = nname; } void setSigniture(QString sig) { profile["Signiture"] = sig; } @@ -41,12 +41,12 @@ signals: }; -class OnlineUserModel : virtual public UserModel, virtual public QObject +class OnlineUserModel : public UserModel { Q_OBJECT public: - OnlineUserModel(QJsonObject &json, QObject *parent = nullptr); - OnlineUserModel(QString usrname, QString nick, QJsonObject json, QObject * parent = nullptr); + OnlineUserModel(QJsonObject &json); + OnlineUserModel(QString usrname, QString nick, QJsonObject json); Type getType() const { return Type::Online; } const QString& getNickname() const { return nickname; } const QString getSigniture() const { return profile["Signiture"].toString(); } diff --git a/Server/widget.cpp b/Server/widget.cpp index 29a6129c7df51a6cd7b0a9e1d9749a599e249b82..683728bc94005fa8b6663131657bd18e8762fa69 100644 --- a/Server/widget.cpp +++ b/Server/widget.cpp @@ -3,6 +3,8 @@ #include "operations.h" #include #include +#include "databaseoperation.h" +#include "messagemodel.h" Widget::Widget(QWidget *parent) : QWidget(parent) @@ -13,11 +15,18 @@ Widget::Widget(QWidget *parent) //widget接受来自server的信息,并进行打印 connect(&sever, &Sever::linkMsg,this, &Widget::printLink); connect(&sever, &Sever::sendMsg,this, &Widget::printMsg); + connect(&sever, &Sever::offLine,this,&Widget::setOffline); + //socket首次连接之后设置handle connect(&sever, &Sever::sendChannel,this,&Widget::setChannel); //widget发出setip信号,给server设置ip,开启listen connect(this,&Widget::pushIP,&sever, &Sever::setIP); + + connect(&(DatabaseOperation::Singleton()), &DatabaseOperation::signal_DB_ready, this, &Widget::load_Users); + + connect(&(Operations::Singleton ()),&Operations::newMessage,this, &Widget::getMessageTargetandContent); + } Widget::~Widget() @@ -27,31 +36,33 @@ Widget::~Widget() } -void Widget::printLink(QString str, int type) +void Widget::printLink(QString str) { - if(type==1) - { - ui->listWidget->addItem (str); - ui->listWidget->setCurrentRow (ui->listWidget->count ()-1); - } - else - { - ui->listWidget_2->addItem (str); - ui->listWidget_2->setCurrentRow (ui->listWidget_2->count ()-1); - } - + ui->listWidget_2->addItem (str); + ui->listWidget_2->setCurrentRow (ui->listWidget_2->count ()-1); } -void Widget::printMsg(QList list, int handle) +void Widget::printMsg(QList list, QString messageSender) { + //回发客户端功能,实现同用户多设备回发 qDebug() << "in printMsg!"; - for(int i=0;iwrite(send2.data ()); } + } void Widget::on_btnSend_clicked() @@ -90,3 +101,68 @@ void Widget::on_btnSetServer_clicked() emit pushIP (ui->textEdit->toPlainText ()); } + +void Widget::setOffline(int handle) +{ + //下线功能,非常重要!!! + int index = ui->comboBox->findText(QString::number (handle)); + if(index != -1 ) + ui->comboBox->setItemText (index, "disconnected."); + QString str = sever.clientMap[handle]->peerAddress ().toString ()+"断开连接!"; + ui->listWidget_2->addItem (str); + ui->listWidget_2->setCurrentRow (ui->listWidget_2->count ()-1); + + sever.clientMap.remove (handle); + //没有删除usertohandle的表格 + + + +} +void Widget::load_Users() { + auto& db = DatabaseOperation::Singleton(); + auto list = db.findAllUsers(); + for (int i = 0; i < list.size(); i++) { + this->ui->userWidgets->addItem(list[i]->getNickname() + " @" + list[i]->getUsername()); + } +} + +void Widget::getMessageTargetandContent(int sessionID, QJsonObject msg) +{ + + auto& db = DatabaseOperation::Singleton(); + QList sessionList = db.queryMembersBySession (sessionID); + + + qDebug() <>::iterator it1 = sever.userToHandle.begin();it1!=sever.userToHandle.end ();it1++) + { + qDebug() <write (send2.data ()); + + } + +// break; + } + } + } + + +} diff --git a/Server/widget.h b/Server/widget.h index d79acad0588dd9968a58515f5dacde958f3f9b8e..7d368bc38cf18fedf3a35a98e4950073616620e0 100644 --- a/Server/widget.h +++ b/Server/widget.h @@ -19,8 +19,8 @@ public: private slots: // void on_pushButton_3_clicked() {} - void printLink(QString, int); //收到由server发来窗口的信号,并且打印到文本框 - void printMsg(QList,int); + void printLink(QString); //收到由server发来窗口的信号,并且打印到文本框 + void printMsg(QList, QString); void on_btnSend_clicked(); @@ -30,6 +30,10 @@ private slots: void on_btnSetServer_clicked(); + void setOffline(int); + void load_Users(); + + void getMessageTargetandContent(int, QJsonObject); signals: diff --git a/Server/widget.ui b/Server/widget.ui index c6d03914c01637543305447c934e039acf8b850d..c9d3bec88f8616507c3536b69aa11a7340e47559 100644 --- a/Server/widget.ui +++ b/Server/widget.ui @@ -100,7 +100,7 @@ p, li { white-space: pre-wrap; } 770 170 401 - 691 + 401 @@ -163,6 +163,29 @@ p, li { white-space: pre-wrap; } 上线信息 + + + + 770 + 640 + 401 + 251 + + + + + + + 770 + 600 + 151 + 16 + + + + Registered Users + + diff --git a/bitXiaoSha/main.py b/bitXiaoSha/main.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbfafb44cc71656008b99f291a2b2ec65b09193 --- /dev/null +++ b/bitXiaoSha/main.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'untitled.ui' +# +# Created by: PyQt5 UI code generator 5.13.0 +# +# WARNING! All changes made in this file will be lost! +import time +import sys +import socket +import threading +from PyQt5 import QtCore, QtWidgets +from PyQt5.QtWidgets import QMainWindow, QApplication +import itertools +import pickle +import random +import torch.nn as nn +import torch.nn.functional as F +import torch +import unicodedata +import re +import os +from torch import optim +import json + +# 预定义的token + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +PAD_token = 0 # 表示padding +SOS_token = 1 # 句子的开始 +EOS_token = 2 # 句子的结束 + + +class Voc: + def __init__(self, name): + self.name = name + self.trimmed = False + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # 目前有SOS, EOS, PAD这3个token。 + + def addSentence(self, sentence): + for word in sentence.split(' '): + self.addWord(word) + + def addWord(self, word): + if word not in self.word2index: + self.word2index[word] = self.num_words + self.word2count[word] = 1 + self.index2word[self.num_words] = word + self.num_words += 1 + else: + self.word2count[word] += 1 + + # 删除频次小于min_count的token + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print('keep_words {} / {} = {:.4f}'.format( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # 重新构造词典 + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # Count default tokens + + # 重新构造后词频就没有意义了(都是1) + for word in keep_words: + self.addWord(word) + + +MAX_LENGTH = 10 # 句子最大长度是10个词(包括EOS等特殊词) + + +def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) + + +def normalizeString(s): + s = unicodeToAscii(s.lower().strip()) + s = re.sub(r"([.!?])", r" \1", s) + s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + s = re.sub(r"\s+", r" ", s).strip() + return s + + +# Load/Assemble voc and pairs +save_dir = os.path.join("data", "save") +corpus_name = "cornell movie-dialogs corpus" +corpus = os.path.join("data", corpus_name) +datafile = os.path.join(corpus, "formatted_movie_lines.txt") + + +# 把句子的词变成ID +def indexesFromSentence(voc, sentence): + return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token] + + +# l是多个长度不同句子(list),使用zip_longest padding成定长,长度为最长句子的长度。 +def zeroPadding(l, fillvalue=PAD_token): + return list(itertools.zip_longest(*l, fillvalue=fillvalue)) + + + +#读取voc和pairs +voc=Voc(corpus_name) +with open("voc.pkl", 'rb') as file: + voc = pickle.loads(file.read()) +with open("pairs.pkl", "rb") as file: + pairs = pickle.load(file) + + + +class EncoderRNN(nn.Module): + def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): + super(EncoderRNN, self).__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + self.embedding = embedding + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, + dropout=(0 if n_layers == 1 else dropout), bidirectional=True) + + def forward(self, input_seq, input_lengths, hidden=None): + embedded = self.embedding(input_seq) + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu()) + outputs, hidden = self.gru(packed, hidden) + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] + return outputs, hidden + + +# Luong 注意力layer +class Attn(torch.nn.Module): + def __init__(self, method, hidden_size): + super(Attn, self).__init__() + self.method = method + if self.method not in ['dot', 'general', 'concat']: + raise ValueError(self.method, "is not an appropriate attention method.") + self.hidden_size = hidden_size + if self.method == 'general': + self.attn = torch.nn.Linear(self.hidden_size, hidden_size) + elif self.method == 'concat': + self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size) + self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)) + + def dot_score(self, hidden, encoder_output): + return torch.sum(hidden * encoder_output, dim=2) + + def general_score(self, hidden, encoder_output): + energy = self.attn(encoder_output) + return torch.sum(hidden * energy, dim=2) + + def concat_score(self, hidden, encoder_output): + energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), + encoder_output), 2)).tanh() + return torch.sum(self.v * energy, dim=2) + + def forward(self, hidden, encoder_outputs): + if self.method == 'general': + attn_energies = self.general_score(hidden, encoder_outputs) + elif self.method == 'concat': + attn_energies = self.concat_score(hidden, encoder_outputs) + elif self.method == 'dot': + # 计算内积,参考dot_score函数 + attn_energies = self.dot_score(hidden, encoder_outputs) + return F.softmax(attn_energies, dim=1).unsqueeze(1) + + +class LuongAttnDecoderRNN(nn.Module): + def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): + super(LuongAttnDecoderRNN, self).__init__() + + # 保存到self里,attn_model就是前面定义的Attn类的对象。 + self.attn_model = attn_model + self.hidden_size = hidden_size + self.output_size = output_size + self.n_layers = n_layers + self.dropout = dropout + + # 定义Decoder的layers + self.embedding = embedding + self.embedding_dropout = nn.Dropout(dropout) + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout)) + self.concat = nn.Linear(hidden_size * 2, hidden_size) + self.out = nn.Linear(hidden_size, output_size) + + self.attn = Attn(attn_model, hidden_size) + + def forward(self, input_step, last_hidden, encoder_outputs): + embedded = self.embedding(input_step) + embedded = self.embedding_dropout(embedded) + rnn_output, hidden = self.gru(embedded, last_hidden) + attn_weights = self.attn(rnn_output, encoder_outputs) + context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) + rnn_output = rnn_output.squeeze(0) + context = context.squeeze(1) + concat_input = torch.cat((rnn_output, context), 1) + concat_output = torch.tanh(self.concat(concat_input)) + output = self.out(concat_output) + output = F.softmax(output, dim=1) + return output, hidden + + + +class GreedySearchDecoder(nn.Module): + def __init__(self, encoder, decoder): + super(GreedySearchDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, input_seq, input_length, max_length): + encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length) + decoder_hidden = encoder_hidden[:decoder.n_layers] + decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token + all_tokens = torch.zeros([0], device=device, dtype=torch.long) + all_scores = torch.zeros([0], device=device) + for _ in range(max_length): + decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, + encoder_outputs) + decoder_scores, decoder_input = torch.max(decoder_output, dim=1) + all_tokens = torch.cat((all_tokens, decoder_input), dim=0) + all_scores = torch.cat((all_scores, decoder_scores), dim=0) + decoder_input = torch.unsqueeze(decoder_input, 0) + # 返回所有的词和得分。 + return all_tokens, all_scores + + +def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH): + ### 把输入的一个batch句子变成id + indexes_batch = [indexesFromSentence(voc, sentence)] + # 创建lengths tensor + lengths = torch.tensor([len(indexes) for indexes in indexes_batch]) + # 转置 + input_batch = torch.LongTensor(indexes_batch).transpose(0, 1) + # 放到合适的设备上(比如GPU) + input_batch = input_batch.to(device) + lengths = lengths.to(device) + # 用searcher解码 + tokens, scores = searcher(input_batch, lengths, max_length) + # ID变成词。 + decoded_words = [voc.index2word[token.item()] for token in tokens] + return decoded_words + + +# 配置模型 +model_name = 'cb_model' +attn_model = 'dot' +# attn_model = 'general' +# attn_model = 'concat' +hidden_size = 500 +encoder_n_layers = 2 +decoder_n_layers = 2 +dropout = 0.1 +batch_size = 64 +loadFilename = '.\\data\\save\\cb_model\\cornell movie-dialogs corpus\\2-2_500\\10000_checkpoint.tar' +checkpoint_iter = 4000 + +# 如果loadFilename不空,则从中加载模型 +if loadFilename: + checkpoint = torch.load(loadFilename) + encoder_sd = checkpoint['en'] + decoder_sd = checkpoint['de'] + encoder_optimizer_sd = checkpoint['en_opt'] + decoder_optimizer_sd = checkpoint['de_opt'] + embedding_sd = checkpoint['embedding'] + voc.__dict__ = checkpoint['voc_dict'] + +# 初始化word embedding +embedding = nn.Embedding(voc.num_words, hidden_size) +if loadFilename: + embedding.load_state_dict(embedding_sd) +# 初始化encoder和decoder模型 +encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout) +decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, + decoder_n_layers, dropout) +if loadFilename: + encoder.load_state_dict(encoder_sd) + decoder.load_state_dict(decoder_sd) +# 使用合适的设备 +encoder = encoder.to(device) +decoder = decoder.to(device) + +# 配置训练的超参数和优化器 +learning_rate = 0.0001 +decoder_learning_ratio = 5.0 + +# 设置进入训练模式,从而开启dropout +encoder.train() +decoder.train() + +# 初始化优化器 +encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) +decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio) +if loadFilename: + encoder_optimizer.load_state_dict(encoder_optimizer_sd) + decoder_optimizer.load_state_dict(decoder_optimizer_sd) + + +# 进入eval模式,从而去掉dropout。 +encoder.eval() +decoder.eval() + +# 构造searcher对象 +searcher = GreedySearchDecoder(encoder, decoder) + + +def silly_AI(str, encoder=encoder, decoder=decoder, searcher=searcher, voc=voc): + try: + # 得到用户终端的输入 + input_sentence = str + # 是否退出 + if input_sentence == 'q' or input_sentence == 'quit': return + # 句子归一化 + input_sentence = normalizeString(input_sentence) + # 生成响应Evaluate sentence + output_words = evaluate(encoder, decoder, searcher, voc, input_sentence) + # 去掉EOS后面的内容 + words = [] + for word in output_words: + if word == 'EOS': + break + elif word != 'PAD': + words.append(word) + return ' '.join(words) + + except KeyError: + return "Error: Encountered unknown word." + + + + + +if __name__ == '__main__': + app = QApplication(sys.argv) + c_socket = socket.socket() # 创建套接字 + addr_1 = "10.194.41.126" # 获取服务器地址 + addr_2 = "8888" # 获取端口号 + addr = (addr_1, int(addr_2)) + c_socket.connect(addr) # 连接套接字 + c_socket.send('{"MsgType": "info", "Message": "I am silly_AI"}'.encode('GB2312')) # 发送消息 + print('silly_AI:','{"MsgType": "info", "Message": "I am silly_AI"}') + while 1: + rcv_msg = c_socket.recv(1024) # 接受消息 + print('server:',rcv_msg.decode('utf-8')) + answer = silly_AI(rcv_msg.decode('utf-8')) + resp = json.dumps({"MsgType": "info", "Message": answer}, sort_keys=False, indent=4, separators=(',', ':')) + if rcv_msg.decode('GB2312'): + if rcv_msg.decode('GB2312') == 'bye': + c_socket.send('{"MsgType": "info", "Message": "bye"}'.encode('utf-8')) + break + c_socket.send(resp.encode('utf-8')) + print('silly_AI:', resp) + sys.exit(app.exec_()) \ No newline at end of file diff --git a/bitXiaoSha/make_dic.py b/bitXiaoSha/make_dic.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb37e2ff277f2af6d73e13d5315c81e06ef46cd --- /dev/null +++ b/bitXiaoSha/make_dic.py @@ -0,0 +1,183 @@ +import itertools +import pickle +import random +import torch.nn as nn +import torch.nn.functional as F +import torch +import unicodedata +import re +import os +from torch import optim + +# 预定义的token +PAD_token = 0 # 表示padding +SOS_token = 1 # 句子的开始 +EOS_token = 2 # 句子的结束 +save_dir = os.path.join("data", "save") +corpus_name = "cornell movie-dialogs corpus" +corpus = os.path.join("data", corpus_name) +datafile = os.path.join(corpus, "formatted_movie_lines.txt") + +class Voc: + def __init__(self, name): + self.name = name + self.trimmed = False + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # 目前有SOS, EOS, PAD这3个token。 + + def addSentence(self, sentence): + for word in sentence.split(' '): + self.addWord(word) + + def addWord(self, word): + if word not in self.word2index: + self.word2index[word] = self.num_words + self.word2count[word] = 1 + self.index2word[self.num_words] = word + self.num_words += 1 + else: + self.word2count[word] += 1 + + # 删除频次小于min_count的token + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print('keep_words {} / {} = {:.4f}'.format( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # 重新构造词典 + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # Count default tokens + + # 重新构造后词频就没有意义了(都是1) + for word in keep_words: + self.addWord(word) + + +MAX_LENGTH = 10 # 句子最大长度是10个词(包括EOS等特殊词) + + +# 把Unicode字符串变成ASCII +# 参考https://stackoverflow.com/a/518232/2809427 +def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) + + +def normalizeString(s): + # 变成小写、去掉前后空格,然后unicode变成ascii + s = unicodeToAscii(s.lower().strip()) + # 在标点前增加空格,这样把标点当成一个词 + s = re.sub(r"([.!?])", r" \1", s) + # 字母和标点之外的字符都变成空格 + s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + # 因为把不用的字符都变成空格,所以可能存在多个连续空格 + # 下面的正则替换把多个空格变成一个空格,最后去掉前后空格 + s = re.sub(r"\s+", r" ", s).strip() + return s + + +# 读取问答句对并且返回Voc词典对象 +def readVocs(datafile, corpus_name): + print("Reading lines...") + # 文件每行读取到list lines中。 + lines = open(datafile, encoding='utf-8'). \ + read().strip().split('\n') + # 每行用tab切分成问答两个句子,然后调用normalizeString函数进行处理。 + pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines] + voc = Voc(corpus_name) + return voc, pairs + + +def filterPair(p): + return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH + + +# 过滤太长的句对 +def filterPairs(pairs): + return [pair for pair in pairs if filterPair(pair)] + + +# 使用上面的函数进行处理,返回Voc对象和句对的list +def loadPrepareData(corpus, corpus_name, datafile): + print("Start preparing training data ...") + voc, pairs = readVocs(datafile, corpus_name) + print("Read {!s} sentence pairs".format(len(pairs))) + pairs = filterPairs(pairs) + print("Trimmed to {!s} sentence pairs".format(len(pairs))) + print("Counting words...") + for pair in pairs: + voc.addSentence(pair[0]) + voc.addSentence(pair[1]) + print("Counted words:", voc.num_words) + return voc, pairs + + +# Load/Assemble voc and pairs +# save_dir = os.path.join("data", "save") +voc, pairs = loadPrepareData(corpus, corpus_name, datafile) +# 输出一些句对 +print("\npairs:") +for pair in pairs[:10]: + print(pair) + +MIN_COUNT = 3 # 阈值为3 + + +def trimRareWords(voc, pairs, MIN_COUNT): + # 去掉voc中频次小于3的词 + voc.trim(MIN_COUNT) + # 保留的句对 + keep_pairs = [] + for pair in pairs: + input_sentence = pair[0] + output_sentence = pair[1] + keep_input = True + keep_output = True + # 检查问题 + for word in input_sentence.split(' '): + if word not in voc.word2index: + keep_input = False + break + # 检查答案 + for word in output_sentence.split(' '): + if word not in voc.word2index: + keep_output = False + break + # 如果问题和答案都只包含高频词,我们才保留这个句对 + if keep_input and keep_output: + keep_pairs.append(pair) + + print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), + len(keep_pairs), len(keep_pairs) / len(pairs))) + return keep_pairs + + +# 实际进行处理 +pairs = trimRareWords(voc, pairs, MIN_COUNT) + +#保存数据 +with open("pairs.pkl", "wb") as file: + list1 = pairs + pickle.dump(list1, file, True) + + +output_hal = open("voc.pkl", 'wb') +str = pickle.dumps(voc) +output_hal.write(str) +output_hal.close() \ No newline at end of file diff --git a/bitXiaoSha/pairs.pkl b/bitXiaoSha/pairs.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f4fe1ccca5b34b6df687e3b03ad32271e54350ae Binary files /dev/null and b/bitXiaoSha/pairs.pkl differ diff --git a/bitXiaoSha/processing_Data.py b/bitXiaoSha/processing_Data.py new file mode 100644 index 0000000000000000000000000000000000000000..b33c5818cab4c2e1fb086c445c46a012be16ae73 --- /dev/null +++ b/bitXiaoSha/processing_Data.py @@ -0,0 +1,107 @@ +import os +import codecs +import csv + +# 把每一行都parse成一个dict,key是lineID、characterID、movieID、character和text +# 分别代表这一行的ID、人物ID、电影ID,人物名称和文本。 +# 最终输出一个dict,key是lineID,value是一个dict。 +# value这个dict的key是lineID、characterID、movieID、character和text +def loadLines(fileName, fields): + lines = {} + with open(fileName, 'r', encoding='iso-8859-1') as f: + for line in f: + values = line.split(" +++$+++ ") + # 抽取fields + lineObj = {} + for i, field in enumerate(fields): + lineObj[field] = values[i] + lines[lineObj['lineID']] = lineObj + return lines + + +# 根据movie_conversations.txt文件和上输出的lines,把utterance组成对话。 +# 最终输出一个list,这个list的每一个元素都是一个dict, +# key分别是character1ID、character2ID、movieID和utteranceIDs。 +# 分别表示这对话的第一个人物的ID,第二个的ID,电影的ID以及它包含的utteranceIDs +# 最后根据lines,还给每一行的dict增加一个key为lines,其value是个list, +# 包含所有utterance(上面得到的lines的value) +def loadConversations(fileName, lines, fields): + conversations = [] + with open(fileName, 'r', encoding='iso-8859-1') as f: + for line in f: + values = line.split(" +++$+++ ") + # 抽取fields + convObj = {} + for i, field in enumerate(fields): + convObj[field] = values[i] + # convObj["utteranceIDs"]是一个字符串,形如['L198', 'L199'] + # 我们用eval把这个字符串变成一个字符串的list。 + lineIds = eval(convObj["utteranceIDs"]) + # 根据lineIds构造一个数组,根据lineId去lines里检索出存储utterance对象。 + convObj["lines"] = [] + for lineId in lineIds: + convObj["lines"].append(lines[lineId]) + conversations.append(convObj) + return conversations + + +# 从对话中抽取句对 +# 假设一段对话包含s1,s2,s3,s4这4个utterance +# 那么会返回3个句对:s1-s2,s2-s3和s3-s4。 +def extractSentencePairs(conversations): + qa_pairs = [] + for conversation in conversations: + # 遍历对话中的每一个句子,忽略最后一个句子,因为没有答案。 + for i in range(len(conversation["lines"]) - 1): + inputLine = conversation["lines"][i]["text"].strip() + targetLine = conversation["lines"][i + 1]["text"].strip() + # 如果有空的句子就去掉 + if inputLine and targetLine: + qa_pairs.append([inputLine, targetLine]) + return qa_pairs + + +corpus_name = "cornell movie-dialogs corpus" +corpus = os.path.join("data", corpus_name) + + +def printLines(file, n=10): + with open(file, 'rb') as datafile: + lines = datafile.readlines() + for line in lines[:n]: + print(line) + + +printLines(os.path.join(corpus, "movie_lines.txt")) +# 定义新的文件 +datafile = os.path.join(corpus, "formatted_movie_lines.txt") + +delimiter = '\t' +# 对分隔符delimiter进行decode,这里对tab进行decode结果并没有变 +delimiter = str(codecs.decode(delimiter, "unicode_escape")) + +# 初始化dict lines,list conversations以及前面我们介绍过的field的id数组。 +lines = {} +conversations = [] +MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"] +MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"] + +# 首先使用loadLines函数处理movie_lines.txt +print("\nProcessing corpus...") +lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS) +# 接着使用loadConversations处理上一步的结果,得到conversations +print("\nLoading conversations...") +conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"), + lines, MOVIE_CONVERSATIONS_FIELDS) + +# 输出到一个新的csv文件 +print("\nWriting newly formatted file...") +with open(datafile, 'w', encoding='utf-8') as outputfile: + writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n') + # 使用extractSentencePairs从conversations里抽取句对。 + for pair in extractSentencePairs(conversations): + writer.writerow(pair) + +# 输出一些行用于检查 +print("\nSample lines from file:") +printLines(datafile) \ No newline at end of file diff --git a/bitXiaoSha/send.py b/bitXiaoSha/send.py new file mode 100644 index 0000000000000000000000000000000000000000..ca049474cc82c6da6fbaaaad6b995baae177cdc2 --- /dev/null +++ b/bitXiaoSha/send.py @@ -0,0 +1,17 @@ +import time +import socket + +s = socket.socket() # 创建套接字 +s.bind(('127.0.0.1', 21567)) # 绑定套接字 +s.listen(5) # 监听套接字 + +while 1: + print('waiting for connecting') # 未连接时打印等待连接 + c_socket, addr = s.accept() # 接受连接 + print('connect from: {}'.format(addr)) # 打印出连接的ip地址 + rcv_MSG = c_socket.recv(1024).decode('GB2312') # 接收信息 + send_MSG = "[{}]已收到消息:{}".format(time.ctime(), rcv_MSG) # 准备发送的信息 + print(rcv_MSG) # 打印收到的信息 + c_socket.send(send_MSG.encode('GB2312')) # 发送信息 + c_socket.close() # 关闭套接字 +s.close() \ No newline at end of file diff --git a/bitXiaoSha/train.py b/bitXiaoSha/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5d984103116502fa7e8f9793f756fa4b1a8734 --- /dev/null +++ b/bitXiaoSha/train.py @@ -0,0 +1,633 @@ +import itertools +import pickle +import random +import torch.nn as nn +import torch.nn.functional as F +import torch +import unicodedata +import re +import os +from torch import optim + + +# 预定义的token + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +PAD_token = 0 # 表示padding +SOS_token = 1 # 句子的开始 +EOS_token = 2 # 句子的结束 + + +class Voc: + def __init__(self, name): + self.name = name + self.trimmed = False + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # 目前有SOS, EOS, PAD这3个token。 + + def addSentence(self, sentence): + for word in sentence.split(' '): + self.addWord(word) + + def addWord(self, word): + if word not in self.word2index: + self.word2index[word] = self.num_words + self.word2count[word] = 1 + self.index2word[self.num_words] = word + self.num_words += 1 + else: + self.word2count[word] += 1 + + # 删除频次小于min_count的token + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print('keep_words {} / {} = {:.4f}'.format( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # 重新构造词典 + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # Count default tokens + + # 重新构造后词频就没有意义了(都是1) + for word in keep_words: + self.addWord(word) + + +MAX_LENGTH = 10 # 句子最大长度是10个词(包括EOS等特殊词) + + +# 把Unicode字符串变成ASCII +#参考https://stackoverflow.com/a/518232/2809427 +def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) + + +def normalizeString(s): + # 变成小写、去掉前后空格,然后unicode变成ascii + s = unicodeToAscii(s.lower().strip()) + # 在标点前增加空格,这样把标点当成一个词 + s = re.sub(r"([.!?])", r" \1", s) + # 字母和标点之外的字符都变成空格 + s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + # 因为把不用的字符都变成空格,所以可能存在多个连续空格 + # 下面的正则替换把多个空格变成一个空格,最后去掉前后空格 + s = re.sub(r"\s+", r" ", s).strip() + return s + + +# Load/Assemble voc and pairs +save_dir = os.path.join("data", "save") +corpus_name = "cornell movie-dialogs corpus" +corpus = os.path.join("data", corpus_name) +datafile = os.path.join(corpus, "formatted_movie_lines.txt") + + +# 把句子的词变成ID +def indexesFromSentence(voc, sentence): + return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token] + + +# l是多个长度不同句子(list),使用zip_longest padding成定长,长度为最长句子的长度。 +def zeroPadding(l, fillvalue=PAD_token): + return list(itertools.zip_longest(*l, fillvalue=fillvalue)) + + + +#读取voc和pairs +voc=Voc(corpus_name) +with open("voc.pkl", 'rb') as file: + voc = pickle.loads(file.read()) +with open("pairs.pkl", "rb") as file: + pairs = pickle.load(file) + + +# 把句子的词变成ID +def indexesFromSentence(voc, sentence): + return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token] + + +# l是多个长度不同句子(list),使用zip_longest padding成定长,长度为最长句子的长度。 +def zeroPadding(l, fillvalue=PAD_token): + return list(itertools.zip_longest(*l, fillvalue=fillvalue)) + + +# l是二维的padding后的list +# 返回m和l的大小一样,如果某个位置是padding,那么值为0,否则为1 +def binaryMatrix(l, value=PAD_token): + m = [] + for i, seq in enumerate(l): + m.append([]) + for token in seq: + if token == PAD_token: + m[i].append(0) + else: + m[i].append(1) + return m + + +# 把输入句子变成ID,然后再padding,同时返回lengths这个list,标识实际长度。 +# 返回的padVar是一个LongTensor,shape是(batch, max_length), +# lengths是一个list,长度为(batch,),表示每个句子的实际长度。 +def inputVar(l, voc): + indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l] + lengths = torch.tensor([len(indexes) for indexes in indexes_batch]) + padList = zeroPadding(indexes_batch) + padVar = torch.LongTensor(padList) + return padVar, lengths + + +# 对输出句子进行padding,然后用binaryMatrix得到每个位置是padding(0)还是非padding, +# 同时返回最大最长句子的长度(也就是padding后的长度) +# 返回值padVar是LongTensor,shape是(batch, max_target_length) +# mask是ByteTensor,shape也是(batch, max_target_length) +def outputVar(l, voc): + indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l] + max_target_len = max([len(indexes) for indexes in indexes_batch]) + padList = zeroPadding(indexes_batch) + mask = binaryMatrix(padList) + mask = torch.ByteTensor(mask) + padVar = torch.LongTensor(padList) + return padVar, mask, max_target_len + + +# 处理一个batch的pair句对 +def batch2TrainData(voc, pair_batch): + # 按照句子的长度(词数)排序 + pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True) + input_batch, output_batch = [], [] + for pair in pair_batch: + input_batch.append(pair[0]) + output_batch.append(pair[1]) + inp, lengths = inputVar(input_batch, voc) + output, mask, max_target_len = outputVar(output_batch, voc) + return inp, lengths, output, mask.bool(), max_target_len + + +class EncoderRNN(nn.Module): + def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): + super(EncoderRNN, self).__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + self.embedding = embedding + + # 初始化GRU,这里输入和hidden大小都是hidden_size,这里假设embedding层的输出大小是hidden_size + # 如果只有一层,那么不进行Dropout,否则使用传入的参数dropout进行GRU的Dropout。 + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, + dropout=(0 if n_layers == 1 else dropout), bidirectional=True) + + def forward(self, input_seq, input_lengths, hidden=None): + # 输入是(max_length, batch),Embedding之后变成(max_length, batch, hidden_size) + embedded = self.embedding(input_seq) + # Pack padded batch of sequences for RNN module + # 因为RNN(GRU)要知道实际长度,所以PyTorch提供了函数pack_padded_sequence把输入向量和长度 + # pack到一个对象PackedSequence里,这样便于使用。 + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu()) + # 通过GRU进行forward计算,需要传入输入和隐变量 + # 如果传入的输入是一个Tensor (max_length, batch, hidden_size) + # 那么输出outputs是(max_length, batch, hidden_size*num_directions)。 + # 第三维是hidden_size和num_directions的混合,它们实际排列顺序是num_directions在前面, + # 因此我们可以使用outputs.view(seq_len, batch, num_directions, hidden_size)得到4维的向量。 + # 其中第三维是方向,第四位是隐状态。 + + # 而如果输入是PackedSequence对象,那么输出outputs也是一个PackedSequence对象,我们需要用 + # 函数pad_packed_sequence把它变成shape为(max_length, batch, hidden*num_directions)的向量以及 + # 一个list,表示输出的长度,当然这个list和输入的input_lengths完全一样,因此通常我们不需要它。 + outputs, hidden = self.gru(packed, hidden) + # 参考前面的注释,我们得到outputs为(max_length, batch, hidden*num_directions) + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + # 我们需要把输出的num_directions双向的向量加起来 + # 因为outputs的第三维是先放前向的hidden_size个结果,然后再放后向的hidden_size个结果 + # 所以outputs[:, :, :self.hidden_size]得到前向的结果 + # outputs[:, :, self.hidden_size:]是后向的结果 + # 注意,如果bidirectional是False,则outputs第三维的大小就是hidden_size, + # 这时outputs[:, : ,self.hidden_size:]是不存在的,因此也不会加上去。 + # 对Python slicing不熟的读者可以看看下面的例子: + + # >>> a=[1,2,3] + # >>> a[:3] + # [1, 2, 3] + # >>> a[3:] + # [] + # >>> a[:3]+a[3:] + # [1, 2, 3] + + # 这样就不用写下面的代码了: + # if bidirectional: + # outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] + # 返回最终的输出和最后时刻的隐状态。 + return outputs, hidden + + +# Luong 注意力layer +class Attn(torch.nn.Module): + def __init__(self, method, hidden_size): + super(Attn, self).__init__() + self.method = method + if self.method not in ['dot', 'general', 'concat']: + raise ValueError(self.method, "is not an appropriate attention method.") + self.hidden_size = hidden_size + if self.method == 'general': + self.attn = torch.nn.Linear(self.hidden_size, hidden_size) + elif self.method == 'concat': + self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size) + self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)) + + def dot_score(self, hidden, encoder_output): + # 输入hidden的shape是(1, batch=64, hidden_size=500) + # encoder_outputs的shape是(input_lengths=10, batch=64, hidden_size=500) + # hidden * encoder_output得到的shape是(10, 64, 500),然后对第3维求和就可以计算出score。 + return torch.sum(hidden * encoder_output, dim=2) + + def general_score(self, hidden, encoder_output): + energy = self.attn(encoder_output) + return torch.sum(hidden * energy, dim=2) + + def concat_score(self, hidden, encoder_output): + energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), + encoder_output), 2)).tanh() + return torch.sum(self.v * energy, dim=2) + + # 输入是上一个时刻的隐状态hidden和所有时刻的Encoder的输出encoder_outputs + # 输出是注意力的概率,也就是长度为input_lengths的向量,它的和加起来是1。 + def forward(self, hidden, encoder_outputs): + # 计算注意力的score,输入hidden的shape是(1, batch=64, hidden_size=500), + # 表示t时刻batch数据的隐状态 + # encoder_outputs的shape是(input_lengths=10, batch=64, hidden_size=500) + if self.method == 'general': + attn_energies = self.general_score(hidden, encoder_outputs) + elif self.method == 'concat': + attn_energies = self.concat_score(hidden, encoder_outputs) + elif self.method == 'dot': + # 计算内积,参考dot_score函数 + attn_energies = self.dot_score(hidden, encoder_outputs) + + # Transpose max_length and batch_size dimensions + # 把attn_energies从(max_length=10, batch=64)转置成(64, 10) + attn_energies = attn_energies.t() + + # 使用softmax函数把score变成概率,shape仍然是(64, 10),然后用unsqueeze(1)变成 + # (64, 1, 10) + return F.softmax(attn_energies, dim=1).unsqueeze(1) + + +class LuongAttnDecoderRNN(nn.Module): + def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): + super(LuongAttnDecoderRNN, self).__init__() + + # 保存到self里,attn_model就是前面定义的Attn类的对象。 + self.attn_model = attn_model + self.hidden_size = hidden_size + self.output_size = output_size + self.n_layers = n_layers + self.dropout = dropout + + # 定义Decoder的layers + self.embedding = embedding + self.embedding_dropout = nn.Dropout(dropout) + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout)) + self.concat = nn.Linear(hidden_size * 2, hidden_size) + self.out = nn.Linear(hidden_size, output_size) + + self.attn = Attn(attn_model, hidden_size) + + def forward(self, input_step, last_hidden, encoder_outputs): + # 注意:decoder每一步只能处理一个时刻的数据,因为t时刻计算完了才能计算t+1时刻。 + # input_step的shape是(1, 64),64是batch,1是当前输入的词ID(来自上一个时刻的输出) + # 通过embedding层变成(1, 64, 500),然后进行dropout,shape不变。 + embedded = self.embedding(input_step) + embedded = self.embedding_dropout(embedded) + # 把embedded传入GRU进行forward计算 + # 得到rnn_output的shape是(1, 64, 500) + # hidden是(2, 64, 500),因为是两层的GRU,所以第一维是2。 + rnn_output, hidden = self.gru(embedded, last_hidden) + # 计算注意力权重, 根据前面的分析,attn_weights的shape是(64, 1, 10) + attn_weights = self.attn(rnn_output, encoder_outputs) + + # encoder_outputs是(10, 64, 500) + # encoder_outputs.transpose(0, 1)后的shape是(64, 10, 500) + # attn_weights.bmm后是(64, 1, 500) + + # bmm是批量的矩阵乘法,第一维是batch,我们可以把attn_weights看成64个(1,10)的矩阵 + # 把encoder_outputs.transpose(0, 1)看成64个(10, 500)的矩阵 + # 那么bmm就是64个(1, 10)矩阵 x (10, 500)矩阵,最终得到(64, 1, 500) + context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) + # 把context向量和GRU的输出拼接起来 + # rnn_output从(1, 64, 500)变成(64, 500) + rnn_output = rnn_output.squeeze(0) + # context从(64, 1, 500)变成(64, 500) + context = context.squeeze(1) + # 拼接得到(64, 1000) + concat_input = torch.cat((rnn_output, context), 1) + # self.concat是一个矩阵(1000, 500), + # self.concat(concat_input)的输出是(64, 500) + # 然后用tanh把输出返回变成(-1,1),concat_output的shape是(64, 500) + concat_output = torch.tanh(self.concat(concat_input)) + + # out是(500, 词典大小=7826) + output = self.out(concat_output) + # 用softmax变成概率,表示当前时刻输出每个词的概率。 + output = F.softmax(output, dim=1) + # 返回 output和新的隐状态 + return output, hidden + + +def maskNLLLoss(inp, target, mask): + # 计算实际的词的个数,因为padding是0,非padding是1,因此sum就可以得到词的个数 + nTotal = mask.sum() + + crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1)) + loss = crossEntropy.masked_select(mask).mean() + loss = loss.to(device) + return loss, nTotal.item() + + +def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, + encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH): + # 梯度清空 + encoder_optimizer.zero_grad() + decoder_optimizer.zero_grad() + # 设置device,从而支持GPU,当然如果没有GPU也能工作。 + input_variable = input_variable.to(device) + lengths = lengths.to(device) + target_variable = target_variable.to(device) + mask = mask.to(device) + # 初始化变量 + loss = 0 + print_losses = [] + n_totals = 0 + # encoder的Forward计算 + encoder_outputs, encoder_hidden = encoder(input_variable, lengths.cpu()) + # Decoder的初始输入是SOS,我们需要构造(1, batch)的输入,表示第一个时刻batch个输入。 + decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]]) + decoder_input = decoder_input.to(device) + # 注意:Encoder是双向的,而Decoder是单向的,因此从下往上取n_layers个 + decoder_hidden = encoder_hidden[:decoder.n_layers] + # 确定是否teacher forcing + use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False + # 一次处理一个时刻 + if use_teacher_forcing: + for t in range(max_target_len): + decoder_output, decoder_hidden = decoder( + decoder_input, decoder_hidden, encoder_outputs + ) + # Teacher forcing: 下一个时刻的输入是当前正确答案 + decoder_input = target_variable[t].view(1, -1) + # 计算累计的loss + mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t]) + loss += mask_loss + print_losses.append(mask_loss.item() * nTotal) + n_totals += nTotal + else: + for t in range(max_target_len): + decoder_output, decoder_hidden = decoder( + decoder_input, decoder_hidden, encoder_outputs + ) + # 不是teacher forcing: 下一个时刻的输入是当前模型预测概率最高的值 + _, topi = decoder_output.topk(1) + decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]]) + decoder_input = decoder_input.to(device) + # 计算累计的loss + mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t]) + loss += mask_loss + print_losses.append(mask_loss.item() * nTotal) + n_totals += nTotal + # 反向计算 + loss.backward() + # 对encoder和decoder进行梯度裁剪 + _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip) + _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip) + # 更新参数 + encoder_optimizer.step() + decoder_optimizer.step() + + return sum(print_losses) / n_totals + + +def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, + embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, + print_every, save_every, clip, corpus_name, loadFilename): + # 随机选择n_iteration个batch的数据(pair) + training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)]) + for _ in range(n_iteration)] + + # 初始化 + print('Initializing ...') + start_iteration = 1 + print_loss = 0 + if loadFilename: + start_iteration = checkpoint['iteration'] + 1 + + # 训练 + print("Training...") + for iteration in range(start_iteration, n_iteration + 1): + training_batch = training_batches[iteration - 1] + + input_variable, lengths, target_variable, mask, max_target_len = training_batch + + # 训练一个batch的数据 + loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder, + decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip) + print_loss += loss + + # 进度 + if iteration % print_every == 0: + print_loss_avg = print_loss / print_every + print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}" + .format(iteration, iteration / n_iteration * 100, print_loss_avg)) + print_loss = 0 + + # 保存checkpoint + if (iteration % save_every == 0): + directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}' + .format(encoder_n_layers, decoder_n_layers, hidden_size)) + if not os.path.exists(directory): + os.makedirs(directory) + torch.save({ + 'iteration': iteration, + 'en': encoder.state_dict(), + 'de': decoder.state_dict(), + 'en_opt': encoder_optimizer.state_dict(), + 'de_opt': decoder_optimizer.state_dict(), + 'loss': loss, + 'voc_dict': voc.__dict__, + 'embedding': embedding.state_dict() + }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint'))) + + +class GreedySearchDecoder(nn.Module): + def __init__(self, encoder, decoder): + super(GreedySearchDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, input_seq, input_length, max_length): + # Encoder的Forward计算 + encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length) + # 把Encoder最后时刻的隐状态作为Decoder的初始值 + decoder_hidden = encoder_hidden[:decoder.n_layers] + # 因为我们的函数都是要求(time,batch),因此即使只有一个数据,也要做出二维的。 + # Decoder的初始输入是SOS + decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token + # 用于保存解码结果的tensor + all_tokens = torch.zeros([0], device=device, dtype=torch.long) + all_scores = torch.zeros([0], device=device) + # 循环,这里只使用长度限制,后面处理的时候把EOS去掉了。 + for _ in range(max_length): + # Decoder forward一步 + decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, + encoder_outputs) + # decoder_outputs是(batch=1, vob_size) + # 使用max返回概率最大的词和得分 + decoder_scores, decoder_input = torch.max(decoder_output, dim=1) + # 把解码结果保存到all_tokens和all_scores里 + all_tokens = torch.cat((all_tokens, decoder_input), dim=0) + all_scores = torch.cat((all_scores, decoder_scores), dim=0) + # decoder_input是当前时刻输出的词的ID,这是个一维的向量,因为max会减少一维。 + # 但是decoder要求有一个batch维度,因此用unsqueeze增加batch维度。 + decoder_input = torch.unsqueeze(decoder_input, 0) + # 返回所有的词和得分。 + return all_tokens, all_scores + + +def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH): + ### 把输入的一个batch句子变成id + indexes_batch = [indexesFromSentence(voc, sentence)] + # 创建lengths tensor + lengths = torch.tensor([len(indexes) for indexes in indexes_batch]) + # 转置 + input_batch = torch.LongTensor(indexes_batch).transpose(0, 1) + # 放到合适的设备上(比如GPU) + input_batch = input_batch.to(device) + lengths = lengths.to(device) + # 用searcher解码 + tokens, scores = searcher(input_batch, lengths, max_length) + # ID变成词。 + decoded_words = [voc.index2word[token.item()] for token in tokens] + return decoded_words + + +def evaluateInput(encoder, decoder, searcher, voc): + input_sentence = '' + while (1): + try: + # 得到用户终端的输入 + input_sentence = input('> ') + # 是否退出 + if input_sentence == 'q' or input_sentence == 'quit': break + # 句子归一化 + input_sentence = normalizeString(input_sentence) + # 生成响应Evaluate sentence + output_words = evaluate(encoder, decoder, searcher, voc, input_sentence) + # 去掉EOS后面的内容 + words = [] + for word in output_words: + if word == 'EOS': + break + elif word != 'PAD': + words.append(word) + print('Bot:', ' '.join(words)) + + except KeyError: + print("Error: Encountered unknown word.") + + +# 配置模型 +model_name = 'cb_model' +attn_model = 'dot' +# attn_model = 'general' +# attn_model = 'concat' +hidden_size = 500 +encoder_n_layers = 2 +decoder_n_layers = 2 +dropout = 0.1 +batch_size = 64 +# 从哪个checkpoint恢复,如果是None,那么从头开始训练。 +loadFilename = '.\\data\\save\\cb_model\\cornell movie-dialogs corpus\\2-2_500\\10000_checkpoint.tar' +checkpoint_iter = 4000 + +# 如果loadFilename不空,则从中加载模型 +if loadFilename: + # 如果训练和加载是一条机器,那么直接加载 + checkpoint = torch.load(loadFilename) + # 否则比如checkpoint是在GPU上得到的,但是我们现在又用CPU来训练或者测试,那么注释掉下面的代码 + # checkpoint = torch.load(loadFilename, map_location=torch.device('cpu')) + encoder_sd = checkpoint['en'] + decoder_sd = checkpoint['de'] + encoder_optimizer_sd = checkpoint['en_opt'] + decoder_optimizer_sd = checkpoint['de_opt'] + embedding_sd = checkpoint['embedding'] + voc.__dict__ = checkpoint['voc_dict'] + +print('Building encoder and decoder ...') +# 初始化word embedding +embedding = nn.Embedding(voc.num_words, hidden_size) +if loadFilename: + embedding.load_state_dict(embedding_sd) +# 初始化encoder和decoder模型 +encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout) +decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, + decoder_n_layers, dropout) +if loadFilename: + encoder.load_state_dict(encoder_sd) + decoder.load_state_dict(decoder_sd) +# 使用合适的设备 +encoder = encoder.to(device) +decoder = decoder.to(device) +print('Models built and ready to go!') + +# 配置训练的超参数和优化器 +clip = 50.0 +teacher_forcing_ratio = 1.0 +learning_rate = 0.0001 +decoder_learning_ratio = 5.0 +n_iteration = 10000 +print_every = 1 +save_every = 500 + +# 设置进入训练模式,从而开启dropout +encoder.train() +decoder.train() + +# 初始化优化器 +print('Building optimizers ...') +encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) +decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio) +if loadFilename: + encoder_optimizer.load_state_dict(encoder_optimizer_sd) + decoder_optimizer.load_state_dict(decoder_optimizer_sd) + +# 开始训练 +print("Starting Training!") +trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, + embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, + print_every, save_every, clip, corpus_name, loadFilename) + +# 进入eval模式,从而去掉dropout。 +encoder.eval() +decoder.eval() + +# 构造searcher对象 +searcher = GreedySearchDecoder(encoder, decoder) + +# 测试 +evaluateInput(encoder, decoder, searcher, voc) + diff --git a/bitXiaoSha/voc.pkl b/bitXiaoSha/voc.pkl new file mode 100644 index 0000000000000000000000000000000000000000..0f838c6ddd4caa9a52fd86f9885e0834be0cc58b Binary files /dev/null and b/bitXiaoSha/voc.pkl differ diff --git a/bitXiaoSha/xiaosha.py b/bitXiaoSha/xiaosha.py new file mode 100644 index 0000000000000000000000000000000000000000..54fdf993d7db36d03927e666243913357ab4a4db --- /dev/null +++ b/bitXiaoSha/xiaosha.py @@ -0,0 +1,435 @@ +import itertools +import pickle +import random +import torch.nn as nn +import torch.nn.functional as F +import torch +import unicodedata +import re +import os +from torch import optim + + +# 预定义的token + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +PAD_token = 0 # 表示padding +SOS_token = 1 # 句子的开始 +EOS_token = 2 # 句子的结束 + + +class Voc: + def __init__(self, name): + self.name = name + self.trimmed = False + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # 目前有SOS, EOS, PAD这3个token。 + + def addSentence(self, sentence): + for word in sentence.split(' '): + self.addWord(word) + + def addWord(self, word): + if word not in self.word2index: + self.word2index[word] = self.num_words + self.word2count[word] = 1 + self.index2word[self.num_words] = word + self.num_words += 1 + else: + self.word2count[word] += 1 + + # 删除频次小于min_count的token + def trim(self, min_count): + if self.trimmed: + return + self.trimmed = True + + keep_words = [] + + for k, v in self.word2count.items(): + if v >= min_count: + keep_words.append(k) + + print('keep_words {} / {} = {:.4f}'.format( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) + + # 重新构造词典 + self.word2index = {} + self.word2count = {} + self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} + self.num_words = 3 # Count default tokens + + # 重新构造后词频就没有意义了(都是1) + for word in keep_words: + self.addWord(word) + + +MAX_LENGTH = 10 # 句子最大长度是10个词(包括EOS等特殊词) + + +# 把Unicode字符串变成ASCII +#参考https://stackoverflow.com/a/518232/2809427 +def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) + + +def normalizeString(s): + # 变成小写、去掉前后空格,然后unicode变成ascii + s = unicodeToAscii(s.lower().strip()) + # 在标点前增加空格,这样把标点当成一个词 + s = re.sub(r"([.!?])", r" \1", s) + # 字母和标点之外的字符都变成空格 + s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + # 因为把不用的字符都变成空格,所以可能存在多个连续空格 + # 下面的正则替换把多个空格变成一个空格,最后去掉前后空格 + s = re.sub(r"\s+", r" ", s).strip() + return s + + +# Load/Assemble voc and pairs +save_dir = os.path.join("data", "save") +corpus_name = "cornell movie-dialogs corpus" +corpus = os.path.join("data", corpus_name) +datafile = os.path.join(corpus, "formatted_movie_lines.txt") + + +# 把句子的词变成ID +def indexesFromSentence(voc, sentence): + return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token] + + +# l是多个长度不同句子(list),使用zip_longest padding成定长,长度为最长句子的长度。 +def zeroPadding(l, fillvalue=PAD_token): + return list(itertools.zip_longest(*l, fillvalue=fillvalue)) + + + +#读取voc和pairs +voc=Voc(corpus_name) +with open("voc.pkl", 'rb') as file: + voc = pickle.loads(file.read()) +with open("pairs.pkl", "rb") as file: + pairs = pickle.load(file) + + + +class EncoderRNN(nn.Module): + def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): + super(EncoderRNN, self).__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + self.embedding = embedding + + # 初始化GRU,这里输入和hidden大小都是hidden_size,这里假设embedding层的输出大小是hidden_size + # 如果只有一层,那么不进行Dropout,否则使用传入的参数dropout进行GRU的Dropout。 + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, + dropout=(0 if n_layers == 1 else dropout), bidirectional=True) + + def forward(self, input_seq, input_lengths, hidden=None): + # 输入是(max_length, batch),Embedding之后变成(max_length, batch, hidden_size) + embedded = self.embedding(input_seq) + # Pack padded batch of sequences for RNN module + # 因为RNN(GRU)要知道实际长度,所以PyTorch提供了函数pack_padded_sequence把输入向量和长度 + # pack到一个对象PackedSequence里,这样便于使用。 + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu()) + # 通过GRU进行forward计算,需要传入输入和隐变量 + # 如果传入的输入是一个Tensor (max_length, batch, hidden_size) + # 那么输出outputs是(max_length, batch, hidden_size*num_directions)。 + # 第三维是hidden_size和num_directions的混合,它们实际排列顺序是num_directions在前面, + # 因此我们可以使用outputs.view(seq_len, batch, num_directions, hidden_size)得到4维的向量。 + # 其中第三维是方向,第四位是隐状态。 + + # 而如果输入是PackedSequence对象,那么输出outputs也是一个PackedSequence对象,我们需要用 + # 函数pad_packed_sequence把它变成shape为(max_length, batch, hidden*num_directions)的向量以及 + # 一个list,表示输出的长度,当然这个list和输入的input_lengths完全一样,因此通常我们不需要它。 + outputs, hidden = self.gru(packed, hidden) + # 参考前面的注释,我们得到outputs为(max_length, batch, hidden*num_directions) + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + # 我们需要把输出的num_directions双向的向量加起来 + # 因为outputs的第三维是先放前向的hidden_size个结果,然后再放后向的hidden_size个结果 + # 所以outputs[:, :, :self.hidden_size]得到前向的结果 + # outputs[:, :, self.hidden_size:]是后向的结果 + # 注意,如果bidirectional是False,则outputs第三维的大小就是hidden_size, + # 这时outputs[:, : ,self.hidden_size:]是不存在的,因此也不会加上去。 + # 对Python slicing不熟的读者可以看看下面的例子: + + # >>> a=[1,2,3] + # >>> a[:3] + # [1, 2, 3] + # >>> a[3:] + # [] + # >>> a[:3]+a[3:] + # [1, 2, 3] + + # 这样就不用写下面的代码了: + # if bidirectional: + # outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] + outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] + # 返回最终的输出和最后时刻的隐状态。 + return outputs, hidden + + +# Luong 注意力layer +class Attn(torch.nn.Module): + def __init__(self, method, hidden_size): + super(Attn, self).__init__() + self.method = method + if self.method not in ['dot', 'general', 'concat']: + raise ValueError(self.method, "is not an appropriate attention method.") + self.hidden_size = hidden_size + if self.method == 'general': + self.attn = torch.nn.Linear(self.hidden_size, hidden_size) + elif self.method == 'concat': + self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size) + self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)) + + def dot_score(self, hidden, encoder_output): + # 输入hidden的shape是(1, batch=64, hidden_size=500) + # encoder_outputs的shape是(input_lengths=10, batch=64, hidden_size=500) + # hidden * encoder_output得到的shape是(10, 64, 500),然后对第3维求和就可以计算出score。 + return torch.sum(hidden * encoder_output, dim=2) + + def general_score(self, hidden, encoder_output): + energy = self.attn(encoder_output) + return torch.sum(hidden * energy, dim=2) + + def concat_score(self, hidden, encoder_output): + energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), + encoder_output), 2)).tanh() + return torch.sum(self.v * energy, dim=2) + + # 输入是上一个时刻的隐状态hidden和所有时刻的Encoder的输出encoder_outputs + # 输出是注意力的概率,也就是长度为input_lengths的向量,它的和加起来是1。 + def forward(self, hidden, encoder_outputs): + # 计算注意力的score,输入hidden的shape是(1, batch=64, hidden_size=500), + # 表示t时刻batch数据的隐状态 + # encoder_outputs的shape是(input_lengths=10, batch=64, hidden_size=500) + if self.method == 'general': + attn_energies = self.general_score(hidden, encoder_outputs) + elif self.method == 'concat': + attn_energies = self.concat_score(hidden, encoder_outputs) + elif self.method == 'dot': + # 计算内积,参考dot_score函数 + attn_energies = self.dot_score(hidden, encoder_outputs) + + # Transpose max_length and batch_size dimensions + # 把attn_energies从(max_length=10, batch=64)转置成(64, 10) + attn_energies = attn_energies.t() + + # 使用softmax函数把score变成概率,shape仍然是(64, 10),然后用unsqueeze(1)变成 + # (64, 1, 10) + return F.softmax(attn_energies, dim=1).unsqueeze(1) + + +class LuongAttnDecoderRNN(nn.Module): + def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): + super(LuongAttnDecoderRNN, self).__init__() + + # 保存到self里,attn_model就是前面定义的Attn类的对象。 + self.attn_model = attn_model + self.hidden_size = hidden_size + self.output_size = output_size + self.n_layers = n_layers + self.dropout = dropout + + # 定义Decoder的layers + self.embedding = embedding + self.embedding_dropout = nn.Dropout(dropout) + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout)) + self.concat = nn.Linear(hidden_size * 2, hidden_size) + self.out = nn.Linear(hidden_size, output_size) + + self.attn = Attn(attn_model, hidden_size) + + def forward(self, input_step, last_hidden, encoder_outputs): + # 注意:decoder每一步只能处理一个时刻的数据,因为t时刻计算完了才能计算t+1时刻。 + # input_step的shape是(1, 64),64是batch,1是当前输入的词ID(来自上一个时刻的输出) + # 通过embedding层变成(1, 64, 500),然后进行dropout,shape不变。 + embedded = self.embedding(input_step) + embedded = self.embedding_dropout(embedded) + # 把embedded传入GRU进行forward计算 + # 得到rnn_output的shape是(1, 64, 500) + # hidden是(2, 64, 500),因为是两层的GRU,所以第一维是2。 + rnn_output, hidden = self.gru(embedded, last_hidden) + # 计算注意力权重, 根据前面的分析,attn_weights的shape是(64, 1, 10) + attn_weights = self.attn(rnn_output, encoder_outputs) + + # encoder_outputs是(10, 64, 500) + # encoder_outputs.transpose(0, 1)后的shape是(64, 10, 500) + # attn_weights.bmm后是(64, 1, 500) + + # bmm是批量的矩阵乘法,第一维是batch,我们可以把attn_weights看成64个(1,10)的矩阵 + # 把encoder_outputs.transpose(0, 1)看成64个(10, 500)的矩阵 + # 那么bmm就是64个(1, 10)矩阵 x (10, 500)矩阵,最终得到(64, 1, 500) + context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) + # 把context向量和GRU的输出拼接起来 + # rnn_output从(1, 64, 500)变成(64, 500) + rnn_output = rnn_output.squeeze(0) + # context从(64, 1, 500)变成(64, 500) + context = context.squeeze(1) + # 拼接得到(64, 1000) + concat_input = torch.cat((rnn_output, context), 1) + # self.concat是一个矩阵(1000, 500), + # self.concat(concat_input)的输出是(64, 500) + # 然后用tanh把输出返回变成(-1,1),concat_output的shape是(64, 500) + concat_output = torch.tanh(self.concat(concat_input)) + + # out是(500, 词典大小=7826) + output = self.out(concat_output) + # 用softmax变成概率,表示当前时刻输出每个词的概率。 + output = F.softmax(output, dim=1) + # 返回 output和新的隐状态 + return output, hidden + + + +class GreedySearchDecoder(nn.Module): + def __init__(self, encoder, decoder): + super(GreedySearchDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, input_seq, input_length, max_length): + # Encoder的Forward计算 + encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length) + # 把Encoder最后时刻的隐状态作为Decoder的初始值 + decoder_hidden = encoder_hidden[:decoder.n_layers] + # 因为我们的函数都是要求(time,batch),因此即使只有一个数据,也要做出二维的。 + # Decoder的初始输入是SOS + decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token + # 用于保存解码结果的tensor + all_tokens = torch.zeros([0], device=device, dtype=torch.long) + all_scores = torch.zeros([0], device=device) + # 循环,这里只使用长度限制,后面处理的时候把EOS去掉了。 + for _ in range(max_length): + # Decoder forward一步 + decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, + encoder_outputs) + # decoder_outputs是(batch=1, vob_size) + # 使用max返回概率最大的词和得分 + decoder_scores, decoder_input = torch.max(decoder_output, dim=1) + # 把解码结果保存到all_tokens和all_scores里 + all_tokens = torch.cat((all_tokens, decoder_input), dim=0) + all_scores = torch.cat((all_scores, decoder_scores), dim=0) + # decoder_input是当前时刻输出的词的ID,这是个一维的向量,因为max会减少一维。 + # 但是decoder要求有一个batch维度,因此用unsqueeze增加batch维度。 + decoder_input = torch.unsqueeze(decoder_input, 0) + # 返回所有的词和得分。 + return all_tokens, all_scores + + +def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH): + ### 把输入的一个batch句子变成id + indexes_batch = [indexesFromSentence(voc, sentence)] + # 创建lengths tensor + lengths = torch.tensor([len(indexes) for indexes in indexes_batch]) + # 转置 + input_batch = torch.LongTensor(indexes_batch).transpose(0, 1) + # 放到合适的设备上(比如GPU) + input_batch = input_batch.to(device) + lengths = lengths.to(device) + # 用searcher解码 + tokens, scores = searcher(input_batch, lengths, max_length) + # ID变成词。 + decoded_words = [voc.index2word[token.item()] for token in tokens] + return decoded_words + + +# 配置模型 +model_name = 'cb_model' +attn_model = 'dot' +# attn_model = 'general' +# attn_model = 'concat' +hidden_size = 500 +encoder_n_layers = 2 +decoder_n_layers = 2 +dropout = 0.1 +batch_size = 64 +# 从哪个checkpoint恢复,如果是None,那么从头开始训练。 +loadFilename = '.\\data\\save\\cb_model\\cornell movie-dialogs corpus\\2-2_500\\10000_checkpoint.tar' +checkpoint_iter = 4000 + +# 如果loadFilename不空,则从中加载模型 +if loadFilename: + # 如果训练和加载是一条机器,那么直接加载 + checkpoint = torch.load(loadFilename) + # 否则比如checkpoint是在GPU上得到的,但是我们现在又用CPU来训练或者测试,那么注释掉下面的代码 + # checkpoint = torch.load(loadFilename, map_location=torch.device('cpu')) + encoder_sd = checkpoint['en'] + decoder_sd = checkpoint['de'] + encoder_optimizer_sd = checkpoint['en_opt'] + decoder_optimizer_sd = checkpoint['de_opt'] + embedding_sd = checkpoint['embedding'] + voc.__dict__ = checkpoint['voc_dict'] + +# 初始化word embedding +embedding = nn.Embedding(voc.num_words, hidden_size) +if loadFilename: + embedding.load_state_dict(embedding_sd) +# 初始化encoder和decoder模型 +encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout) +decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, + decoder_n_layers, dropout) +if loadFilename: + encoder.load_state_dict(encoder_sd) + decoder.load_state_dict(decoder_sd) +# 使用合适的设备 +encoder = encoder.to(device) +decoder = decoder.to(device) + +# 配置训练的超参数和优化器 +learning_rate = 0.0001 +decoder_learning_ratio = 5.0 + +# 设置进入训练模式,从而开启dropout +encoder.train() +decoder.train() + +# 初始化优化器 +encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) +decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio) +if loadFilename: + encoder_optimizer.load_state_dict(encoder_optimizer_sd) + decoder_optimizer.load_state_dict(decoder_optimizer_sd) + + +# 进入eval模式,从而去掉dropout。 +encoder.eval() +decoder.eval() + +# 构造searcher对象 +searcher = GreedySearchDecoder(encoder, decoder) + + +def silly_AI(str, encoder=encoder, decoder=decoder, searcher=searcher, voc=voc): + try: + # 得到用户终端的输入 + input_sentence = str + # 是否退出 + if input_sentence == 'q' or input_sentence == 'quit': return + # 句子归一化 + input_sentence = normalizeString(input_sentence) + # 生成响应Evaluate sentence + output_words = evaluate(encoder, decoder, searcher, voc, input_sentence) + # 去掉EOS后面的内容 + words = [] + for word in output_words: + if word == 'EOS': + break + elif word != 'PAD': + words.append(word) + return ' '.join(words) + + except KeyError: + return "Error: Encountered unknown word." + +question= 'Why?'.encode('utf-8').decode('utf-8') +result = silly_AI(question) +print(result) \ No newline at end of file diff --git a/gitTreeFunction/gitTreeFunction.pro b/gitTreeFunction/gitTreeFunction.pro new file mode 100644 index 0000000000000000000000000000000000000000..b695efb6826519086829e7609a364b4a071f96ac --- /dev/null +++ b/gitTreeFunction/gitTreeFunction.pro @@ -0,0 +1,31 @@ +QT += core gui network + +greaterThan(QT_MAJOR_VERSION, 4): QT += widgets + +CONFIG += c++11 + +# The following define makes your compiler emit warnings if you use +# any Qt feature that has been marked deprecated (the exact warnings +# depend on your compiler). Please consult the documentation of the +# deprecated API in order to know how to port your code away from it. +DEFINES += QT_DEPRECATED_WARNINGS + +# You can also make your code fail to compile if it uses deprecated APIs. +# In order to do so, uncomment the following line. +# You can also select to disable deprecated APIs only up to a certain version of Qt. +#DEFINES += QT_DISABLE_DEPRECATED_BEFORE=0x060000 # disables all the APIs deprecated before Qt 6.0.0 + +SOURCES += \ + main.cpp \ + widget.cpp + +HEADERS += \ + widget.h + +FORMS += \ + widget.ui + +# Default rules for deployment. +qnx: target.path = /tmp/$${TARGET}/bin +else: unix:!android: target.path = /opt/$${TARGET}/bin +!isEmpty(target.path): INSTALLS += target diff --git a/gitTreeFunction/main.cpp b/gitTreeFunction/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0a4ec26478f6b9aba3e1747ec464ea0c26dd5b9 --- /dev/null +++ b/gitTreeFunction/main.cpp @@ -0,0 +1,11 @@ +#include "widget.h" + +#include + +int main(int argc, char *argv[]) +{ + QApplication a(argc, argv); + Widget w; + w.show(); + return a.exec(); +} diff --git a/gitTreeFunction/widget.cpp b/gitTreeFunction/widget.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92ea075d467e3dc1a07f8901cefd0f86e31831d4 --- /dev/null +++ b/gitTreeFunction/widget.cpp @@ -0,0 +1,75 @@ +#include "widget.h" +#include "ui_widget.h" +#include +#include +#include +#include +#include +#include + +Widget::Widget(QWidget *parent) + : QWidget(parent) + , ui(new Ui::Widget) +{ + ui->setupUi(this); + + connect (ui->listWidget, &QListWidget::itemDoubleClicked, this, &Widget::showContents); +} + +Widget::~Widget() +{ + delete ui; + +} + + +void Widget::on_pushButton_clicked() +{ + QByteArray file = get(ui->lineEdit->text ()); + + QJsonDocument newjson = QJsonDocument::fromJson(file); + QJsonObject jsonObject = newjson.object (); + + arrayValue = jsonObject.value(QStringLiteral("tree")); + int length = 0; + if(arrayValue.isArray ()) + { + QJsonArray array = arrayValue.toArray(); + length = array.size (); + for(int i=0;ilistWidget->addItem (icon["path"].toString ()); + } + } + +} + +void Widget::showContents(QListWidgetItem * nowItem) +{ + + + if(arrayValue.isArray ()) + { + QJsonArray array = arrayValue.toArray(); + for(int i=0;itext ()) + { + QByteArray content = get(icon["url"].toString ()); + QJsonDocument newjson = QJsonDocument::fromJson(content); + QJsonObject jsonObject = newjson.object (); + ui->textEdit->setText ( QByteArray::fromBase64(jsonObject["content"].toString ().toUtf8 ())); + + } + + } + } +} + +// git commit try at: https://gitee.com/api/v5/repos/jasonliu233/gitstudy/commits/master +// git tree at: https://gitee.com/api/v5/repos/jasonliu233/zhishigongcheng/git/trees/master + diff --git a/gitTreeFunction/widget.h b/gitTreeFunction/widget.h new file mode 100644 index 0000000000000000000000000000000000000000..bc327382bfef6e1cf07702ce6f4975b17634c97f --- /dev/null +++ b/gitTreeFunction/widget.h @@ -0,0 +1,46 @@ +#ifndef WIDGET_H +#define WIDGET_H + +#include +#include +#include +#include +#include +#include + +QT_BEGIN_NAMESPACE +namespace Ui { class Widget; } +QT_END_NAMESPACE + +class Widget : public QWidget +{ + Q_OBJECT + +public: + Widget(QWidget *parent = nullptr); + ~Widget(); + QByteArray get(const QString &str_url){ + + const QUrl url = QUrl::fromUserInput(str_url); + QNetworkRequest qnr(url); + QNetworkAccessManager qnam; + QNetworkReply *reply = qnam.get(qnr); + QEventLoop eventloop; + QObject::connect(reply, &QNetworkReply::finished, &eventloop, &QEventLoop::quit); + eventloop.exec(QEventLoop::ExcludeUserInputEvents); + QByteArray reply_data = reply->readAll(); + reply->deleteLater(); + reply = nullptr; + return reply_data; + } + + +private slots: + void on_pushButton_clicked(); + void showContents(QListWidgetItem*); + +private: + Ui::Widget *ui; + QJsonValue arrayValue; +}; +#endif // WIDGET_H diff --git a/gitTreeFunction/widget.ui b/gitTreeFunction/widget.ui new file mode 100644 index 0000000000000000000000000000000000000000..fe2c8385505f5ba8d77bb902c663ae5a4a981a3b --- /dev/null +++ b/gitTreeFunction/widget.ui @@ -0,0 +1,101 @@ + + + Widget + + + + 0 + 0 + 1200 + 900 + + + + Widget + + + + + 120 + 20 + 921 + 61 + + + + + + + 1060 + 20 + 111 + 61 + + + + get! + + + + + + 20 + 20 + 91 + 51 + + + + URL + + + + + + 30 + 140 + 361 + 741 + + + + + + + 30 + 90 + 131 + 51 + + + + filename + + + + + + 440 + 140 + 741 + 741 + + + + + + + 440 + 80 + 181 + 61 + + + + content + + + + + + diff --git a/git_try_0827/git_try_0827.pro b/git_try_0827/git_try_0827.pro index 7b47cc2e7a1b366e46267db11c08bd2f5f91106b..b695efb6826519086829e7609a364b4a071f96ac 100644 --- a/git_try_0827/git_try_0827.pro +++ b/git_try_0827/git_try_0827.pro @@ -1,31 +1,31 @@ -QT += core gui network - -greaterThan(QT_MAJOR_VERSION, 4): QT += widgets - -CONFIG += c++11 - -# The following define makes your compiler emit warnings if you use -# any Qt feature that has been marked deprecated (the exact warnings -# depend on your compiler). Please consult the documentation of the -# deprecated API in order to know how to port your code away from it. -DEFINES += QT_DEPRECATED_WARNINGS - -# You can also make your code fail to compile if it uses deprecated APIs. -# In order to do so, uncomment the following line. -# You can also select to disable deprecated APIs only up to a certain version of Qt. -#DEFINES += QT_DISABLE_DEPRECATED_BEFORE=0x060000 # disables all the APIs deprecated before Qt 6.0.0 - -SOURCES += \ - main.cpp \ - widget.cpp - -HEADERS += \ - widget.h - -FORMS += \ - widget.ui - -# Default rules for deployment. -qnx: target.path = /tmp/$${TARGET}/bin -else: unix:!android: target.path = /opt/$${TARGET}/bin -!isEmpty(target.path): INSTALLS += target +QT += core gui network + +greaterThan(QT_MAJOR_VERSION, 4): QT += widgets + +CONFIG += c++11 + +# The following define makes your compiler emit warnings if you use +# any Qt feature that has been marked deprecated (the exact warnings +# depend on your compiler). Please consult the documentation of the +# deprecated API in order to know how to port your code away from it. +DEFINES += QT_DEPRECATED_WARNINGS + +# You can also make your code fail to compile if it uses deprecated APIs. +# In order to do so, uncomment the following line. +# You can also select to disable deprecated APIs only up to a certain version of Qt. +#DEFINES += QT_DISABLE_DEPRECATED_BEFORE=0x060000 # disables all the APIs deprecated before Qt 6.0.0 + +SOURCES += \ + main.cpp \ + widget.cpp + +HEADERS += \ + widget.h + +FORMS += \ + widget.ui + +# Default rules for deployment. +qnx: target.path = /tmp/$${TARGET}/bin +else: unix:!android: target.path = /opt/$${TARGET}/bin +!isEmpty(target.path): INSTALLS += target diff --git a/git_try_0827/main.cpp b/git_try_0827/main.cpp index c3efeb46c6278c0805a7e7764d78bd6bf01eea76..b0a4ec26478f6b9aba3e1747ec464ea0c26dd5b9 100644 --- a/git_try_0827/main.cpp +++ b/git_try_0827/main.cpp @@ -1,11 +1,11 @@ -#include "widget.h" - -#include - -int main(int argc, char *argv[]) -{ - QApplication a(argc, argv); - Widget w; - w.show(); - return a.exec(); -} +#include "widget.h" + +#include + +int main(int argc, char *argv[]) +{ + QApplication a(argc, argv); + Widget w; + w.show(); + return a.exec(); +} diff --git a/git_try_0827/widget.cpp b/git_try_0827/widget.cpp index 4f66be1f864c81c70313ea8256cdbc4dcd22b577..63b8f3513e019244c9fff21b6aac64a4f5575a27 100644 --- a/git_try_0827/widget.cpp +++ b/git_try_0827/widget.cpp @@ -1,28 +1,28 @@ -#include "widget.h" -#include "ui_widget.h" -#include -#include - -Widget::Widget(QWidget *parent) - : QWidget(parent) - , ui(new Ui::Widget) -{ - ui->setupUi(this); -} - -Widget::~Widget() -{ - delete ui; -} - - -void Widget::on_btnGet_clicked() -{ - QByteArray file = get(ui->lineEdit->text ()); - ui->textEdit->setText (file); - - QJsonDocument newjson = QJsonDocument::fromJson(file); - QJsonObject jsonObject = newjson.object (); - qDebug() << jsonObject["commit"].toObject ()["message"]; - -} +#include "widget.h" +#include "ui_widget.h" +#include +#include + +Widget::Widget(QWidget *parent) + : QWidget(parent) + , ui(new Ui::Widget) +{ + ui->setupUi(this); +} + +Widget::~Widget() +{ + delete ui; +} + + +void Widget::on_btnGet_clicked() +{ + QByteArray file = get(ui->lineEdit->text ()); + ui->textEdit->setText (file); + + QJsonDocument newjson = QJsonDocument::fromJson(file); + QJsonObject jsonObject = newjson.object (); + qDebug() << jsonObject["commit"].toObject ()["message"]; + +} diff --git a/git_try_0827/widget.h b/git_try_0827/widget.h index a4f2a7dd23873934777dcd02237f93b582d14982..31fa66c1b02b4e5bb9634ca26efb986725619158 100644 --- a/git_try_0827/widget.h +++ b/git_try_0827/widget.h @@ -1,41 +1,41 @@ -#ifndef WIDGET_H -#define WIDGET_H - -#include -#include -#include -#include - -QT_BEGIN_NAMESPACE -namespace Ui { class Widget; } -QT_END_NAMESPACE - -class Widget : public QWidget -{ - Q_OBJECT - -public: - Widget(QWidget *parent = nullptr); - ~Widget(); - QByteArray get(const QString &str_url){ - - const QUrl url = QUrl::fromUserInput(str_url); - QNetworkRequest qnr(url); - QNetworkAccessManager qnam; - QNetworkReply *reply = qnam.get(qnr); - QEventLoop eventloop; - QObject::connect(reply, &QNetworkReply::finished, &eventloop, &QEventLoop::quit); - eventloop.exec(QEventLoop::ExcludeUserInputEvents); - QByteArray reply_data = reply->readAll(); - reply->deleteLater(); - reply = nullptr; - return reply_data; - } - -private slots: - void on_btnGet_clicked(); - -private: - Ui::Widget *ui; -}; -#endif // WIDGET_H +#ifndef WIDGET_H +#define WIDGET_H + +#include +#include +#include +#include + +QT_BEGIN_NAMESPACE +namespace Ui { class Widget; } +QT_END_NAMESPACE + +class Widget : public QWidget +{ + Q_OBJECT + +public: + Widget(QWidget *parent = nullptr); + ~Widget(); + QByteArray get(const QString &str_url){ + + const QUrl url = QUrl::fromUserInput(str_url); + QNetworkRequest qnr(url); + QNetworkAccessManager qnam; + QNetworkReply *reply = qnam.get(qnr); + QEventLoop eventloop; + QObject::connect(reply, &QNetworkReply::finished, &eventloop, &QEventLoop::quit); + eventloop.exec(QEventLoop::ExcludeUserInputEvents); + QByteArray reply_data = reply->readAll(); + reply->deleteLater(); + reply = nullptr; + return reply_data; + } + +private slots: + void on_btnGet_clicked(); + +private: + Ui::Widget *ui; +}; +#endif // WIDGET_H diff --git a/git_try_0827/widget.ui b/git_try_0827/widget.ui index 4efb67538faa7a2b8ffb4a1486dc7d0705e3c955..43a5c186e46a3de24824f746fa72ca24a5d38622 100644 --- a/git_try_0827/widget.ui +++ b/git_try_0827/widget.ui @@ -1,78 +1,78 @@ - - - Widget - - - - 0 - 0 - 1200 - 900 - - - - Widget - - - - - 110 - 30 - 951 - 71 - - - - - - - 40 - 20 - 151 - 81 - - - - URL - - - - - - 1080 - 30 - 111 - 71 - - - - get! - - - - - - 30 - 170 - 1151 - 701 - - - - - - - 40 - 110 - 171 - 61 - - - - JSON file - - - - - - + + + Widget + + + + 0 + 0 + 1200 + 900 + + + + Widget + + + + + 110 + 30 + 951 + 71 + + + + + + + 40 + 20 + 151 + 81 + + + + URL + + + + + + 1080 + 30 + 111 + 71 + + + + get! + + + + + + 30 + 170 + 1151 + 701 + + + + + + + 40 + 110 + 171 + 61 + + + + JSON file + + + + + +