diff --git a/connection.c b/connection.c index b70658a78e3bbc03aa278084d4e696ffe1d74182..112a56fabaef5b9e9ba03447b8cbc5952a7970e7 100644 --- a/connection.c +++ b/connection.c @@ -66,6 +66,8 @@ * at a time */ #define STMT_INCREMENT 16 +#define MAX_PARTS 16 +#define EXTRA_ROOM 128 /* Extra room for string additional splicing */ #define MAX_CN 128 /* the maximum number of CN is 128 */ #define EMPTY 0 @@ -104,7 +106,8 @@ CnEntry orig_entry; CnEntry pgxc_entry; static int LIBPQ_connect(ConnectionClass *self); - +int split_host_or_port_with_limit(const char *str, char *result_array[MAX_PARTS], int* total_length); +char* generate_conninfo_URL_by_ConnInfo(ConnInfo* ci, int* host_number, int* port_number); #ifdef WIN32 DWORD WINAPI read_pgxc_node(LPVOID arg) #else @@ -3322,6 +3325,88 @@ static void CC_getOSUser(char *username, int usernameLen) #define PROTOCOL3_OPTS_MAX 30 +int split_host_or_port_with_limit(const char *str, char *result_array[MAX_PARTS], int* total_length) +{ + int count = 0; + char *lasts = NULL; + char *dup_str = strdup(str); + + char *token = strtok_r(dup_str, ",", &lasts); + while (token && count < MAX_PARTS) { + result_array[count] = strdup(token); + if (!result_array[count]) { + for (int i = 0; i < count; i++) { + free(result_array[i]); + } + free(dup_str); + return -1; + } + *total_length += strlen(result_array[count]); + count++; + token = strtok_r(NULL, ",", &lasts); + } + + free(dup_str); + for (int i = count;i < MAX_PARTS;i++) { + result_array[i] = NULL; + } + return count; +} + +char* generate_conninfo_URL_by_ConnInfo(ConnInfo* ci, int* host_number, int* port_number) +{ + char *host_array[MAX_PARTS] = {0}; + char *port_array[MAX_PARTS] = {0}; + int host_length = 0; + int port_length = 0; + if (!ci->server || !ci->port) { + return NULL; + } + *host_number = split_host_or_port_with_limit(ci->server, host_array, &host_length); + *port_number = split_host_or_port_with_limit(ci->port, port_array, &port_length); + + if ((-1 == *host_number || -1 == *port_number) || + ((*host_number != *port_number) && *host_number != 1 && *port_number != 1)) { + for (int i = 0; i < MAX_PARTS; i++) { + if (host_array[i] != NULL) { + free(host_array[i]); + } + if (port_array[i] != NULL) { + free(port_array[i]); + } + } + return NULL; + } + + size_t total_length = host_length + port_length + *host_number * 2 + *port_number * 2 + + (ci->username ? strlen(ci->username) : 0) + + (ci->password.name ? strlen(ci->password.name) : 0) + + (ci->database ? strlen(ci->database) : 0) + EXTRA_ROOM; + char* temp_URL = (char*)malloc(total_length); + if (!temp_URL) { + return NULL; + } + memset(temp_URL, 0, total_length); + int valid_count = *host_number > *port_number ? *host_number : *port_number; + + (void)snprintf(temp_URL, -1, "postgres://%s@", ci->username); + for (int i = 0; i < valid_count; i++) { + strcat(temp_URL, (*host_number == 1) ? host_array[0] : host_array[i]); + strcat(temp_URL, ":"); + strcat(temp_URL, (*port_number == 1) ? port_array[0] : port_array[i]); + if (i != valid_count - 1) { + strcat(temp_URL, ","); + } + } + strcat(temp_URL, "/"); + strcat(temp_URL, ci->database); + (*host_number == 1 && *port_number == 1) ? + strcat(temp_URL, "?target_session_attrs=any") : strcat(temp_URL, "?target_session_attrs=read-write"); + strcat(temp_URL, "&password="); + strcat(temp_URL, ci->password.name); + return temp_URL; +} + static int LIBPQ_connect(ConnectionClass *self) { @@ -3339,6 +3424,9 @@ LIBPQ_connect(ConnectionClass *self) char keepalive_interval_str[20]; char *errmsg = NULL; char local_conninfo[8192]; + int host_number = 1; + int port_number = 1; + char* URL = NULL; MYLOG(0, "connecting to the database using %s as the server and pqopt={%s}\n", self->connInfo.server, SAFE_NAME(ci->pqopt)); @@ -3353,145 +3441,145 @@ LIBPQ_connect(ConnectionClass *self) CC_set_error(self, CONN_OPENDB_ERROR, emsg, func); goto cleanup; } - /* Build arrays of keywords & values, for PQconnectDBParams */ - cnt = 0; - if (ci->server[0]) - { - opts[cnt] = "host"; vals[cnt++] = ci->server; - } - if (ci->port[0]) - { - opts[cnt] = "port"; vals[cnt++] = ci->port; - } - if (ci->database[0]) - { - opts[cnt] = "dbname"; vals[cnt++] = ci->database; - } - if (ci->username[0]) - { - opts[cnt] = "user"; vals[cnt++] = ci->username; - } - switch (ci->sslmode[0]) - { - case '\0': - break; - case SSLLBYTE_VERIFY: - opts[cnt] = "sslmode"; - switch (ci->sslmode[1]) - { - case 'f': - vals[cnt++] = SSLMODE_VERIFY_FULL; - break; - case 'c': - vals[cnt++] = SSLMODE_VERIFY_CA; - break; - default: - vals[cnt++] = ci->sslmode; - } - break; - default: - opts[cnt] = "sslmode"; - vals[cnt++] = ci->sslmode; - } - if (NAME_IS_VALID(ci->password)) - { - opts[cnt] = "password"; vals[cnt++] = SAFE_NAME(ci->password); - } - if (ci->disable_keepalive) - { - opts[cnt] = "keepalives"; vals[cnt++] = "0"; - } - if (self->login_timeout > 0) - { - SPRINTF_FIXED(login_timeout_str, "%u", (unsigned int) self->login_timeout); - opts[cnt] = "connect_timeout"; vals[cnt++] = login_timeout_str; - } - if (self->connInfo.keepalive_idle > 0) - { - ITOA_FIXED(keepalive_idle_str, self->connInfo.keepalive_idle); - opts[cnt] = "keepalives_idle"; vals[cnt++] = keepalive_idle_str; - } - if (self->connInfo.keepalive_interval > 0) - { - ITOA_FIXED(keepalive_interval_str, self->connInfo.keepalive_interval); - opts[cnt] = "keepalives_interval"; vals[cnt++] = keepalive_interval_str; - } - if ((odbcVersionString != NULL) && (odbcVersionString[0] != '\0')) - { - if (self->connInfo.connection_extra_info > 0) - { - char libpath[4096] = {'\0'}; - char username[128] = {'\0'}; - - (void)CC_getLibpath(libpath, sizeof(libpath)); - (void)CC_getOSUser(username, sizeof(username)); - - snprintf(local_conninfo, sizeof(local_conninfo), - "{\"driver_name\":\"ODBC\",\"driver_version\":\"%s\",\"driver_path\":\"%s\",\"os_user\":\"%s\"}", - odbcVersionString, libpath, username); - } - else - { - snprintf(local_conninfo, sizeof(local_conninfo), - "{\"driver_name\":\"ODBC\",\"driver_version\":\"%s\"}", - odbcVersionString); - } - opts[cnt] = "connection_info"; vals[cnt++] = local_conninfo; - } - - opts[cnt] = "target_session_attrs"; - vals[cnt++] = "primary"; - - if (conninfoOption != NULL) - { - const char *keyword, *val; - int j; - - for (i = 0, pqopt = conninfoOption; (keyword = pqopt->keyword) != NULL; i++, pqopt++) - { - if ((val = pqopt->val) != NULL) - { - for (j = 0; j < cnt; j++) - { - if (stricmp(opts[j], keyword) == 0) - { - char emsg[100]; - - if (vals[j] != NULL && strcmp(vals[j], val) == 0) - continue; - SPRINTF_FIXED(emsg, "%s parameter in pqopt option conflicts with other ordinary option", keyword); - CC_set_error(self, CONN_OPENDB_ERROR, emsg, func); - goto cleanup; - } - } - if (j >= cnt && cnt < PROTOCOL3_OPTS_MAX - 1) - { - opts[cnt] = keyword; vals[cnt++] = val; - } - } - } - } - - opts[cnt] = vals[cnt] = NULL; - /* Ok, we're all set to connect */ - - if (get_qlog() > 0 || get_mylog() > 0) - { - const char **popt, **pval; - const char* pwdKey = "password"; - - QLOG(0, "PQconnectdbParams:"); - - for (popt = opts, pval = vals; *popt; popt++, pval++) { - if (strcmp(pwdKey, *popt) == 0) { - QPRINTF(0, " %s='xxxxx'", *popt); - } else { - QPRINTF(0, " %s='%s'", *popt, *pval); + /* multiple_hostip or multiple_port from DSN */ + URL = generate_conninfo_URL_by_ConnInfo(ci, &host_number, &port_number); + if (!URL) { + if (!ci->server || !ci->port) { + CC_set_error(self, CONN_INVALID_ARGUMENT_NO, "The server or port should not be empty.", func); + } else if (-1 == host_number || -1 == port_number) { + CC_set_error(self, CONN_NO_MEMORY_ERROR, "Memory allocation failure when resolving address.", func); + } else if ((host_number != port_number) && host_number != 1 && port_number != 1) { + CC_set_error(self, CONN_VALUE_OUT_OF_RANGE, + "The number of hosts should be the same as the number of ports when both are multiple.", func); + } else { + CC_set_error(self, CONN_NO_MEMORY_ERROR, + "Memory allocation failure when Trying to splice strings.", func); + } + } + if (host_number > 1 || port_number > 1) { + MYLOG(0, "connecting to the database using URL: %s\n", URL); + pqconn = PQconnectdb(URL); + } else { + /* Build arrays of keywords & values, for PQconnectDBParams */ + cnt = 0; + if (ci->server[0]) { + opts[cnt] = "host"; + vals[cnt++] = ci->server; + } + if (ci->port[0]) { + opts[cnt] = "port"; + vals[cnt++] = ci->port; + } + if (ci->database[0]) { + opts[cnt] = "dbname"; + vals[cnt++] = ci->database; + } + if (ci->username[0]) { + opts[cnt] = "user"; + vals[cnt++] = ci->username; + } + switch (ci->sslmode[0]) { + case '\0': + break; + case SSLLBYTE_VERIFY: + opts[cnt] = "sslmode"; + switch (ci->sslmode[1]) { + case 'f': + vals[cnt++] = SSLMODE_VERIFY_FULL; + break; + case 'c': + vals[cnt++] = SSLMODE_VERIFY_CA; + break; + default: + vals[cnt++] = ci->sslmode; + } + break; + default: + opts[cnt] = "sslmode"; + vals[cnt++] = ci->sslmode; + } + if (NAME_IS_VALID(ci->password)) { + opts[cnt] = "password"; + vals[cnt++] = SAFE_NAME(ci->password); + } + if (ci->disable_keepalive) { + opts[cnt] = "keepalives"; + vals[cnt++] = "0"; + } + if (self->login_timeout > 0) { + SPRINTF_FIXED(login_timeout_str, "%u", (unsigned int) self->login_timeout); + opts[cnt] = "connect_timeout"; + vals[cnt++] = login_timeout_str; + } + if (self->connInfo.keepalive_idle > 0) { + ITOA_FIXED(keepalive_idle_str, self->connInfo.keepalive_idle); + opts[cnt] = "keepalives_idle"; + vals[cnt++] = keepalive_idle_str; + } + if (self->connInfo.keepalive_interval > 0) { + ITOA_FIXED(keepalive_interval_str, self->connInfo.keepalive_interval); + opts[cnt] = "keepalives_interval"; + vals[cnt++] = keepalive_interval_str; + } + if ((odbcVersionString != NULL) && (odbcVersionString[0] != '\0')) { + if (self->connInfo.connection_extra_info > 0) { + char libpath[4096] = {'\0'}; + char username[128] = {'\0'}; + + (void)CC_getLibpath(libpath, sizeof(libpath)); + (void)CC_getOSUser(username, sizeof(username)); + + snprintf(local_conninfo, sizeof(local_conninfo), + "{\"driver_name\":\"ODBC\", \"driver_version\":\"%s\", \"driver_path\":\"%s\", \"os_user\":\"%s\"}", + odbcVersionString, libpath, username); + } else { + snprintf(local_conninfo, sizeof(local_conninfo), + "{\"driver_name\":\"ODBC\",\"driver_version\":\"%s\"}", + odbcVersionString); + } + opts[cnt] = "connection_info"; + vals[cnt++] = local_conninfo; + } + if (conninfoOption != NULL) { + const char *keyword, *val; + int j; + + for (i = 0, pqopt = conninfoOption; (keyword = pqopt->keyword) != NULL; i++, pqopt++) { + if ((val = pqopt->val) != NULL) { + for (j = 0; j < cnt; j++) { + if (stricmp(opts[j], keyword) == 0) { + char emsg[100]; + if (vals[j] != NULL && strcmp(vals[j], val) == 0) { + continue; + } + SPRINTF_FIXED(emsg, + "%s parameter in pqopt option conflicts with other ordinary option", keyword); + CC_set_error(self, CONN_OPENDB_ERROR, emsg, func); + goto cleanup; + } + } + if (j >= cnt && cnt < PROTOCOL3_OPTS_MAX - 1) { + opts[cnt] = keyword; + vals[cnt++] = val; } } - QPRINTF(0, "\n"); - } - pqconn = PQconnectdbParams(opts, vals, FALSE); + } + } + opts[cnt] = vals[cnt] = NULL; + /* Ok, we're all set to connect */ + + if (get_qlog() > 0 || get_mylog() > 0) { + const char **popt, **pval; + QLOG(0, "PQconnectdbParams:"); + for (popt = opts, pval = vals; *popt; popt++, pval++) { + QPRINTF(0, " %s='%s'", *popt, *pval); + } + QPRINTF(0, "\n"); + } + pqconn = PQconnectdbParams(opts, vals, FALSE); + } + free(URL); + if (!pqconn) { CC_set_error(self, CONN_OPENDB_ERROR, "PQconnectdb error", func);