Commit fa68280d authored by Ian Craggs's avatar Ian Craggs

Fix buffering of connect packet

parent bc11ecc0
......@@ -47,8 +47,8 @@
#define URI_TCP "tcp://"
#define BUILD_TIMESTAMP "##MQTTCLIENT_BUILD_TAG##"
#define CLIENT_VERSION "##MQTTCLIENT_VERSION_TAG##"
#define BUILD_TIMESTAMP "201403121114"
#define CLIENT_VERSION "1.0.0.2"
char* client_timestamp_eye = "MQTTAsyncV3_Timestamp " BUILD_TIMESTAMP;
char* client_version_eye = "MQTTAsyncV3_Version " CLIENT_VERSION;
......@@ -249,7 +249,7 @@ typedef struct
int serverURIcount;
char** serverURIs;
int currentURI;
int MQTTVersion;
int MQTTVersion;
} conn;
} details;
} MQTTAsync_command;
......@@ -2527,7 +2527,7 @@ int MQTTAsync_connecting(MQTTAsyncs* m)
goto exit;
}
else if (rc == 1)
{
{
rc = MQTTCLIENT_SUCCESS;
m->c->connect_state = 3;
if (MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion) == SOCKET_ERROR)
......@@ -2535,7 +2535,7 @@ int MQTTAsync_connecting(MQTTAsyncs* m)
rc = SOCKET_ERROR;
goto exit;
}
if(!m->c->cleansession && m->c->session == NULL)
if (!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl);
}
}
......@@ -2570,7 +2570,9 @@ int MQTTAsync_connecting(MQTTAsyncs* m)
#endif
exit:
if ((rc != 0 && m->c->connect_state != 2) || (rc == SSL_FATAL))
if (rc == TCPSOCKET_INTERRUPTED)
printf("Interrupted connect encountered****\n");
if ((rc != 0 && rc != TCPSOCKET_INTERRUPTED && m->c->connect_state != 2) || (rc == SSL_FATAL))
{
if (MQTTAsync_checkConn(&m->connect))
{
......@@ -2649,7 +2651,7 @@ MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc)
pack = MQTTPacket_Factory(&m->c->net, rc);
if ((m->c->connect_state == 3) && (*rc == SOCKET_ERROR))
{
Log(TRACE_MINIMUM, -1, "CONNECT sent but MQTTPacket_Factory has returned SOCKET_ERROR");
Log(LOG_ERROR, -1, "CONNECT sent but MQTTPacket_Factory has returned SOCKET_ERROR");
if (MQTTAsync_checkConn(&m->connect))
{
MQTTAsync_queuedCommand* conn;
......
......@@ -168,7 +168,7 @@ exit:
* @param buflen the length of the data in buffer to be written
* @return the completion code (TCPSOCKET_COMPLETE etc)
*/
int MQTTPacket_send(networkHandles* net, Header header, char* buffer, int buflen)
int MQTTPacket_send(networkHandles* net, Header header, char* buffer, int buflen, int free)
{
int rc, buf0len;
char *buf;
......@@ -189,10 +189,10 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, int buflen
#if defined(OPENSSL)
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, 1, &buffer, &buflen);
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, 1, &buffer, &buflen, &free);
else
#endif
rc = Socket_putdatas(net->socket, buf, buf0len, 1, &buffer, &buflen);
rc = Socket_putdatas(net->socket, buf, buf0len, 1, &buffer, &buflen, &free);
if (rc == TCPSOCKET_COMPLETE)
time(&(net->lastContact));
......@@ -214,7 +214,7 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, int buflen
* @param buflens the lengths of the data in the array of buffers to be written
* @return the completion code (TCPSOCKET_COMPLETE etc)
*/
int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffers, int* buflens)
int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffers, int* buflens, int* frees)
{
int i, rc, buf0len, total = 0;
char *buf;
......@@ -236,10 +236,10 @@ int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffe
#endif
#if defined(OPENSSL)
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, count, buffers, buflens);
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, count, buffers, buflens, frees);
else
#endif
rc = Socket_putdatas(net->socket, buf, buf0len, count, buffers, buflens);
rc = Socket_putdatas(net->socket, buf, buf0len, count, buffers, buflens, frees);
if (rc == TCPSOCKET_COMPLETE)
time(&(net->lastContact));
......@@ -460,7 +460,7 @@ int MQTTPacket_send_disconnect(networkHandles *net, char* clientID)
FUNC_ENTRY;
header.byte = 0;
header.bits.type = DISCONNECT;
rc = MQTTPacket_send(net, header, NULL, 0);
rc = MQTTPacket_send(net, header, NULL, 0, 0);
Log(LOG_PROTOCOL, 28, NULL, net->socket, clientID, rc);
FUNC_EXIT_RC(rc);
return rc;
......@@ -536,7 +536,7 @@ int MQTTPacket_send_ack(int type, int msgid, int dup, networkHandles *net)
if (type == PUBREL)
header.bits.qos = 1;
writeInt(&ptr, msgid);
if ((rc = MQTTPacket_send(net, header, buf, 2)) != TCPSOCKET_INTERRUPTED)
if ((rc = MQTTPacket_send(net, header, buf, 2, 1)) != TCPSOCKET_INTERRUPTED)
free(buf);
FUNC_EXIT_RC(rc);
return rc;
......@@ -683,10 +683,12 @@ int MQTTPacket_send_publish(Publish* pack, int dup, int qos, int retained, netwo
char *ptr = buf;
char* bufs[4] = {topiclen, pack->topic, buf, pack->payload};
int lens[4] = {2, strlen(pack->topic), 2, pack->payloadlen};
int frees[4] = {1, 0, 1, 0};
writeInt(&ptr, pack->msgId);
ptr = topiclen;
writeInt(&ptr, lens[1]);
rc = MQTTPacket_sends(net, header, 4, bufs, lens);
rc = MQTTPacket_sends(net, header, 4, bufs, lens, frees);
if (rc != TCPSOCKET_INTERRUPTED)
free(buf);
}
......@@ -695,8 +697,10 @@ int MQTTPacket_send_publish(Publish* pack, int dup, int qos, int retained, netwo
char* ptr = topiclen;
char* bufs[3] = {topiclen, pack->topic, pack->payload};
int lens[3] = {2, strlen(pack->topic), pack->payloadlen};
int frees[3] = {1, 0, 0};
writeInt(&ptr, lens[1]);
rc = MQTTPacket_sends(net, header, 3, bufs, lens);
rc = MQTTPacket_sends(net, header, 3, bufs, lens, frees);
}
if (rc != TCPSOCKET_INTERRUPTED)
free(topiclen);
......
......@@ -210,8 +210,8 @@ void writeUTF(char** pptr, char* string);
char* MQTTPacket_name(int ptype);
void* MQTTPacket_Factory(networkHandles* net, int* error);
int MQTTPacket_send(networkHandles* net, Header header, char* buffer, int buflen);
int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffers, int* buflens);
int MQTTPacket_send(networkHandles* net, Header header, char* buffer, int buflen, int free);
int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffers, int* buflens, int* frees);
void* MQTTPacket_header_only(unsigned char aHeader, char* data, int datalen);
int MQTTPacket_send_disconnect(networkHandles* net, char* clientID);
......
......@@ -99,10 +99,11 @@ int MQTTPacket_send_connect(Clients* client, int MQTTVersion)
if (client->password)
writeUTF(&ptr, client->password);
rc = MQTTPacket_send(&client->net, packet.header, buf, len);
rc = MQTTPacket_send(&client->net, packet.header, buf, len, 1);
Log(LOG_PROTOCOL, 0, NULL, client->net.socket, client->clientID, client->cleansession, rc);
exit:
free(buf);
if (rc != TCPSOCKET_INTERRUPTED)
free(buf);
FUNC_EXIT_RC(rc);
return rc;
}
......@@ -143,7 +144,7 @@ int MQTTPacket_send_pingreq(networkHandles* net, char* clientID)
FUNC_ENTRY;
header.byte = 0;
header.bits.type = PINGREQ;
rc = MQTTPacket_send(net, header, NULL, 0);
rc = MQTTPacket_send(net, header, NULL, 0, 0);
Log(LOG_PROTOCOL, 20, NULL, net->socket, clientID, rc);
FUNC_EXIT_RC(rc);
return rc;
......@@ -187,9 +188,10 @@ int MQTTPacket_send_subscribe(List* topics, List* qoss, int msgid, int dup, netw
writeUTF(&ptr, (char*)(elem->content));
writeChar(&ptr, *(int*)(qosElem->content));
}
rc = MQTTPacket_send(net, header, data, datalen);
rc = MQTTPacket_send(net, header, data, datalen, 1);
Log(LOG_PROTOCOL, 22, NULL, net->socket, clientID, msgid, rc);
free(data);
if (rc != TCPSOCKET_INTERRUPTED)
free(data);
FUNC_EXIT_RC(rc);
return rc;
}
......@@ -255,9 +257,10 @@ int MQTTPacket_send_unsubscribe(List* topics, int msgid, int dup, networkHandles
elem = NULL;
while (ListNextElement(topics, &elem))
writeUTF(&ptr, (char*)(elem->content));
rc = MQTTPacket_send(net, header, data, datalen);
rc = MQTTPacket_send(net, header, data, datalen, 1);
Log(LOG_PROTOCOL, 25, NULL, net->socket, clientID, msgid, rc);
free(data);
if (rc != TCPSOCKET_INTERRUPTED)
free(data);
FUNC_EXIT_RC(rc);
return rc;
}
......@@ -703,7 +703,7 @@ int SSLSocket_close(networkHandles* net)
/* No SSL_writev() provided by OpenSSL. Boo. */
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens)
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens, int* frees)
{
int rc = 0;
int i;
......@@ -735,22 +735,33 @@ int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, int buf0len, int count,
if (sslerror == SSL_ERROR_WANT_WRITE)
{
int* sockmem = (int*)malloc(sizeof(int));
int free = 1;
Log(TRACE_MIN, -1, "Partial write: incomplete write of %d bytes on SSL socket %d",
iovec.iov_len, socket);
SocketBuffer_pendingWrite(socket, ssl, 1, &iovec, iovec.iov_len, 0);
SocketBuffer_pendingWrite(socket, ssl, 1, &iovec, &free, iovec.iov_len, 0);
*sockmem = socket;
ListAppend(s.write_pending, sockmem, sizeof(int));
FD_SET(socket, &(s.pending_wset));
rc = TCPSOCKET_INTERRUPTED;
iovec.iov_base = NULL; /* don't free it because it hasn't been completely written yet */
//iovec.iov_base = NULL; /* don't free it because it hasn't been completely written yet */
}
else
rc = SOCKET_ERROR;
}
SSL_unlock_mutex(&sslCoreMutex);
if (iovec.iov_base)
if (rc != TCPSOCKET_INTERRUPTED)
free(iovec.iov_base);
if (rc == TCPSOCKET_INTERRUPTED)
{
int i;
for (i = 0; i < count; ++i)
{
if (frees[i])
free(buffers[i]);
}
}
FUNC_EXIT_RC(rc);
return rc;
}
......@@ -795,12 +806,6 @@ int SSLSocket_continueWrite(pending_writes* pw)
{
/* topic and payload buffers are freed elsewhere, when all references to them have been removed */
free(pw->iovecs[0].iov_base);
if (pw->count > 1)
{
free(pw->iovecs[1].iov_base);
if (pw->count == 5)
free(pw->iovecs[3].iov_base);
}
Log(TRACE_MIN, -1, "SSL continueWrite: partial write now complete for socket %d", pw->socket);
rc = 1;
}
......
......@@ -37,7 +37,7 @@ int SSLSocket_getch(SSL* ssl, int socket, char* c);
char *SSLSocket_getdata(SSL* ssl, int socket, int bytes, int* actual_len);
int SSLSocket_close(networkHandles* net);
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens);
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens, int* frees);
int SSLSocket_connect(SSL* ssl, int socket);
int SSLSocket_getPendingRead();
......
......@@ -433,10 +433,11 @@ int Socket_writev(int socket, iobuf* iovecs, int count, unsigned long* bytes)
* @param buflens an array of corresponding buffer lengths
* @return completion code, especially TCPSOCKET_INTERRUPTED
*/
int Socket_putdatas(int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens)
int Socket_putdatas(int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens, int* frees)
{
unsigned long bytes = 0L;
iobuf iovecs[5];
int frees1[5];
int rc = TCPSOCKET_INTERRUPTED, i, total = buf0len;
FUNC_ENTRY;
......@@ -452,10 +453,12 @@ int Socket_putdatas(int socket, char* buf0, int buf0len, int count, char** buffe
iovecs[0].iov_base = buf0;
iovecs[0].iov_len = buf0len;
frees1[0] = 1;
for (i = 0; i < count; i++)
{
iovecs[i+1].iov_base = buffers[i];
iovecs[i+1].iov_len = buflens[i];
frees1[i+1] = frees[i];
}
if ((rc = Socket_writev(socket, iovecs, count+1, &bytes)) != SOCKET_ERROR)
......@@ -468,9 +471,9 @@ int Socket_putdatas(int socket, char* buf0, int buf0len, int count, char** buffe
Log(TRACE_MIN, -1, "Partial write: %ld bytes of %d actually written on socket %d",
bytes, total, socket);
#if defined(OPENSSL)
SocketBuffer_pendingWrite(socket, NULL, count+1, iovecs, total, bytes);
SocketBuffer_pendingWrite(socket, NULL, count+1, iovecs, frees1, total, bytes);
#else
SocketBuffer_pendingWrite(socket, count+1, iovecs, total, bytes);
SocketBuffer_pendingWrite(socket, count+1, iovecs, frees1, total, bytes);
#endif
*sockmem = socket;
ListAppend(s.write_pending, sockmem, sizeof(int));
......@@ -740,10 +743,11 @@ int Socket_continueWrite(int socket)
pw->bytes += bytes;
if ((rc = (pw->bytes == pw->total)))
{ /* topic and payload buffers are freed elsewhere, when all references to them have been removed */
free(pw->iovecs[0].iov_base);
free(pw->iovecs[1].iov_base);
if (pw->count == 5)
free(pw->iovecs[3].iov_base);
for (i = 0; i < pw->count; i++)
{
if (pw->frees[i])
free(pw->iovecs[i].iov_base);
}
Log(TRACE_MIN, -1, "ContinueWrite: partial write now complete for socket %d", socket);
}
else
......
......@@ -112,7 +112,7 @@ void Socket_outTerminate(void);
int Socket_getReadySocket(int more_work, struct timeval *tp);
int Socket_getch(int socket, char* c);
char *Socket_getdata(int socket, int bytes, int* actual_len);
int Socket_putdatas(int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens);
int Socket_putdatas(int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens, int* frees);
void Socket_close(int socket);
int Socket_new(char* addr, int port, int* socket);
......
......@@ -302,9 +302,9 @@ void SocketBuffer_queueChar(int socket, char c)
* @param bytes actual data length that was written
*/
#if defined(OPENSSL)
void SocketBuffer_pendingWrite(int socket, SSL* ssl, int count, iobuf* iovecs, int total, int bytes)
void SocketBuffer_pendingWrite(int socket, SSL* ssl, int count, iobuf* iovecs, int* frees, int total, int bytes)
#else
void SocketBuffer_pendingWrite(int socket, int count, iobuf* iovecs, int total, int bytes)
void SocketBuffer_pendingWrite(int socket, int count, iobuf* iovecs, int* frees, int total, int bytes)
#endif
{
int i = 0;
......@@ -321,7 +321,10 @@ void SocketBuffer_pendingWrite(int socket, int count, iobuf* iovecs, int total,
pw->total = total;
pw->count = count;
for (i = 0; i < count; i++)
{
pw->iovecs[i] = iovecs[i];
pw->frees[i] = frees[i];
}
ListAppend(&writes, pw, sizeof(pw) + total);
FUNC_EXIT;
}
......
......@@ -52,6 +52,7 @@ typedef struct
#endif
unsigned long bytes;
iobuf iovecs[5];
int frees[5];
} pending_writes;
#define SOCKETBUFFER_COMPLETE 0
......@@ -70,9 +71,9 @@ char* SocketBuffer_complete(int socket);
void SocketBuffer_queueChar(int socket, char c);
#if defined(OPENSSL)
void SocketBuffer_pendingWrite(int socket, SSL* ssl, int count, iobuf* iovecs, int total, int bytes);
void SocketBuffer_pendingWrite(int socket, SSL* ssl, int count, iobuf* iovecs, int* frees, int total, int bytes);
#else
void SocketBuffer_pendingWrite(int socket, int count, iobuf* iovecs, int total, int bytes);
void SocketBuffer_pendingWrite(int socket, int count, iobuf* iovecs, int* frees, int total, int bytes);
#endif
pending_writes* SocketBuffer_getWrite(int socket);
int SocketBuffer_writeComplete(int socket);
......
......@@ -627,9 +627,9 @@ void test3_onFailure(void* context, MQTTAsync_failureData* response)
MQTTAsync_responseOptions opts = MQTTAsync_responseOptions_initializer;
int rc;
assert("Should have connected", 0, "failed to connect", NULL);
MyLog(LOGA_DEBUG, "In connect onFailure callback, \"%s\" rc %d\n", cd->clientid, response->code);
if (response->message)
assert("Should have connected", 0, "%s failed to connect\n", cd->clientid);
MyLog(LOGA_DEBUG, "In connect onFailure callback, \"%s\" rc %d\n", cd->clientid, response ? response->code : -999);
if (response && response->message)
MyLog(LOGA_DEBUG, "In connect onFailure callback, \"%s\"\n", response->message);
test_finished++;
......@@ -1144,7 +1144,7 @@ int main(int argc, char** argv)
for (options.test_no = 1; options.test_no < ARRAY_SIZE(tests); ++options.test_no)
{
failures = 0;
//MQTTAsync_setTraceLevel(MQTTASYNC_TRACE_ERROR);
MQTTAsync_setTraceLevel(MQTTASYNC_TRACE_ERROR);
rc += tests[options.test_no](options); /* return number of failures. 0 = test succeeded */
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment