ref: 43866763e399e223a6e9b6f16e0c09bb16be19fa
parent: 15e68cc285cf082696ab68faa16f4662f50306c1
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Mon Apr 3 21:59:45 EDT 2017
tlshand: sync with 9front
--- a/libsec/tlshand.c
+++ b/libsec/tlshand.c
@@ -17,7 +17,6 @@
TLSFinishedLen = 12,
SSL3FinishedLen = MD5dlen+SHA1dlen,
MaxKeyData = 160, // amount of secret we may need
- MaxChunk = 1<<15,
MAXdlen = SHA2_512dlen,
RandomSize = 32,
MasterSecretSize = 48,
@@ -100,13 +99,8 @@
HandshakeHash handhash;
Finished finished;
- // input buffer for handshake messages
- uchar recvbuf[MaxChunk];
- uchar *rp, *ep;
-
- // output buffer
- uchar sendbuf[MaxChunk];
- uchar *sendp;
+ uchar *sendp, *recvp, *recvw;
+ uchar buf[1<<16];
} TlsConnection;
typedef struct Msg{
@@ -443,7 +437,7 @@
static int get16(uchar *p);
static Bytes* newbytes(int len);
static Bytes* makebytes(uchar* buf, int len);
-static Bytes* mptobytes(mpint* big);
+static Bytes* mptobytes(mpint* big, int len);
static mpint* bytestomp(Bytes* bytes);
static void freebytes(Bytes* b);
static Ints* newints(int len);
@@ -695,6 +689,8 @@
c->hand = hand;
c->trace = trace;
c->version = ProtocolVersion;
+ c->sendp = c->buf;
+ c->recvp = c->recvw = &c->buf[sizeof(c->buf)];
memset(&m, 0, sizeof(m));
if(!msgRecv(c, &m)){
@@ -894,6 +890,7 @@
DHstate *dh = &sec->dh;
mpint *G, *P, *Y, *K;
Bytes *Yc;
+ int n;
if(p == nil || g == nil || Ys == nil)
return nil;
@@ -906,7 +903,8 @@
if(dh_new(dh, P, nil, G) == nil)
goto Out;
- Yc = mptobytes(dh->y);
+ n = (mpsignif(P)+7)/8;
+ Yc = mptobytes(dh->y, n);
K = dh_finish(dh, Y); /* zeros dh */
if(K == nil){
freebytes(Yc);
@@ -913,7 +911,7 @@
Yc = nil;
goto Out;
}
- setMasterSecret(sec, mptobytes(K));
+ setMasterSecret(sec, mptobytes(K, n));
Out:
mpfree(K);
@@ -933,6 +931,7 @@
ECpub *pub;
ECpoint K;
Bytes *Yc;
+ int n;
if(Ys == nil)
return nil;
@@ -958,8 +957,10 @@
ecgen(dom, Q);
ecmul(dom, pub, Q->d, &K);
- setMasterSecret(sec, mptobytes(K.x));
- Yc = newbytes(1 + 2*((mpsignif(dom->p)+7)/8));
+
+ n = (mpsignif(dom->p)+7)/8;
+ setMasterSecret(sec, mptobytes(K.x, n));
+ Yc = newbytes(1 + 2*n);
Yc->len = ecencodepub(dom, (ECpub*)Q, Yc->data, Yc->len);
mpfree(K.x);
@@ -993,6 +994,8 @@
c->hand = hand;
c->trace = trace;
c->cert = nil;
+ c->sendp = c->buf;
+ c->recvp = c->recvw = &c->buf[sizeof(c->buf)];
c->version = ProtocolVersion;
tlsSecInitc(c->sec, c->version);
@@ -1256,14 +1259,13 @@
static int
msgSend(TlsConnection *c, Msg *m, int act)
{
- uchar *p; // sendp = start of new message; p = write pointer
- int nn, n, i;
+ uchar *p, *e; // sendp = start of new message; p = write pointer; e = end pointer
+ int n, i;
- if(c->sendp == nil)
- c->sendp = c->sendbuf;
p = c->sendp;
+ e = c->recvp;
if(c->trace)
- c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m));
+ c->trace("send %s", msgPrint((char*)p, e - p, m));
p[0] = m->tag; // header - fill in size later
p += 4;
@@ -1273,119 +1275,111 @@
tlsError(c, EInternalError, "can't encode a %d", m->tag);
goto Err;
case HClientHello:
- // version
- put16(p, m->u.clientHello.version);
- p += 2;
-
- // random
+ if(p+2+RandomSize > e)
+ goto Overflow;
+ put16(p, m->u.clientHello.version), p += 2;
memmove(p, m->u.clientHello.random, RandomSize);
p += RandomSize;
- // sid
- n = m->u.clientHello.sid->len;
- p[0] = n;
- memmove(p+1, m->u.clientHello.sid->data, n);
- p += n+1;
+ if(p+1+(n = m->u.clientHello.sid->len) > e)
+ goto Overflow;
+ *p++ = n;
+ memmove(p, m->u.clientHello.sid->data, n);
+ p += n;
- n = m->u.clientHello.ciphers->len;
- put16(p, n*2);
- p += 2;
- for(i=0; i<n; i++) {
- put16(p, m->u.clientHello.ciphers->data[i]);
- p += 2;
- }
+ if(p+2+(n = m->u.clientHello.ciphers->len) > e)
+ goto Overflow;
+ put16(p, n*2), p += 2;
+ for(i=0; i<n; i++)
+ put16(p, m->u.clientHello.ciphers->data[i]), p += 2;
- n = m->u.clientHello.compressors->len;
- p[0] = n;
- memmove(p+1, m->u.clientHello.compressors->data, n);
- p += n+1;
+ if(p+1+(n = m->u.clientHello.compressors->len) > e)
+ goto Overflow;
+ *p++ = n;
+ memmove(p, m->u.clientHello.compressors->data, n);
+ p += n;
- if(m->u.clientHello.extensions == nil)
+ if(m->u.clientHello.extensions == nil
+ || (n = m->u.clientHello.extensions->len) == 0)
break;
- n = m->u.clientHello.extensions->len;
- if(n == 0)
- break;
- put16(p, n);
- memmove(p+2, m->u.clientHello.extensions->data, n);
- p += n+2;
+ if(p+2+n > e)
+ goto Overflow;
+ put16(p, n), p += 2;
+ memmove(p, m->u.clientHello.extensions->data, n);
+ p += n;
break;
case HServerHello:
- put16(p, m->u.serverHello.version);
- p += 2;
-
- // random
+ if(p+2+RandomSize > e)
+ goto Overflow;
+ put16(p, m->u.serverHello.version), p += 2;
memmove(p, m->u.serverHello.random, RandomSize);
p += RandomSize;
- // sid
- n = m->u.serverHello.sid->len;
- p[0] = n;
- memmove(p+1, m->u.serverHello.sid->data, n);
- p += n+1;
+ if(p+1+(n = m->u.serverHello.sid->len) > e)
+ goto Overflow;
+ *p++ = n;
+ memmove(p, m->u.serverHello.sid->data, n);
+ p += n;
- put16(p, m->u.serverHello.cipher);
- p += 2;
- p[0] = m->u.serverHello.compressor;
- p += 1;
+ if(p+2+1 > e)
+ goto Overflow;
+ put16(p, m->u.serverHello.cipher), p += 2;
+ *p++ = m->u.serverHello.compressor;
- if(m->u.serverHello.extensions == nil)
- break;
- n = m->u.serverHello.extensions->len;
- if(n == 0)
+ if(m->u.serverHello.extensions == nil
+ || (n = m->u.serverHello.extensions->len) == 0)
break;
- put16(p, n);
- memmove(p+2, m->u.serverHello.extensions->data, n);
- p += n+2;
+ if(p+2+n > e)
+ goto Overflow;
+ put16(p, n), p += 2;
+ memmove(p, m->u.serverHello.extensions->data, n);
+ p += n;
break;
case HServerHelloDone:
break;
case HCertificate:
- nn = 0;
+ n = 0;
for(i = 0; i < m->u.certificate.ncert; i++)
- nn += 3 + m->u.certificate.certs[i]->len;
- if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) {
- tlsError(c, EInternalError, "output buffer too small for certificate");
- goto Err;
- }
- put24(p, nn);
- p += 3;
+ n += 3 + m->u.certificate.certs[i]->len;
+ if(p+3+n > e)
+ goto Overflow;
+ put24(p, n), p += 3;
for(i = 0; i < m->u.certificate.ncert; i++){
- put24(p, m->u.certificate.certs[i]->len);
- p += 3;
- memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
- p += m->u.certificate.certs[i]->len;
+ n = m->u.certificate.certs[i]->len;
+ put24(p, n), p += 3;
+ memmove(p, m->u.certificate.certs[i]->data, n);
+ p += n;
}
break;
case HCertificateVerify:
- if(m->u.certificateVerify.sigalg != 0){
- put16(p, m->u.certificateVerify.sigalg);
- p += 2;
- }
- put16(p, m->u.certificateVerify.signature->len);
- p += 2;
- memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
- p += m->u.certificateVerify.signature->len;
+ if(p+2+2+(n = m->u.certificateVerify.signature->len) > e)
+ goto Overflow;
+ if(m->u.certificateVerify.sigalg != 0)
+ put16(p, m->u.certificateVerify.sigalg), p += 2;
+ put16(p, n), p += 2;
+ memmove(p, m->u.certificateVerify.signature->data, n);
+ p += n;
break;
case HServerKeyExchange:
if(m->u.serverKeyExchange.pskid != nil){
- n = m->u.serverKeyExchange.pskid->len;
- put16(p, n);
- p += 2;
+ if(p+2+(n = m->u.serverKeyExchange.pskid->len) > e)
+ goto Overflow;
+ put16(p, n), p += 2;
memmove(p, m->u.serverKeyExchange.pskid->data, n);
p += n;
}
if(m->u.serverKeyExchange.dh_parameters == nil)
break;
- n = m->u.serverKeyExchange.dh_parameters->len;
+ if(p+(n = m->u.serverKeyExchange.dh_parameters->len) > e)
+ goto Overflow;
memmove(p, m->u.serverKeyExchange.dh_parameters->data, n);
p += n;
if(m->u.serverKeyExchange.dh_signature == nil)
break;
- if(c->version >= TLS12Version){
- put16(p, m->u.serverKeyExchange.sigalg);
- p += 2;
- }
- n = m->u.serverKeyExchange.dh_signature->len;
+ if(p+2+2+(n = m->u.serverKeyExchange.dh_signature->len) > e)
+ goto Overflow;
+ if(c->version >= TLS12Version)
+ put16(p, m->u.serverKeyExchange.sigalg), p += 2;
put16(p, n), p += 2;
memmove(p, m->u.serverKeyExchange.dh_signature->data, n);
p += n;
@@ -1392,15 +1386,16 @@
break;
case HClientKeyExchange:
if(m->u.clientKeyExchange.pskid != nil){
- n = m->u.clientKeyExchange.pskid->len;
- put16(p, n);
- p += 2;
+ if(p+2+(n = m->u.clientKeyExchange.pskid->len) > e)
+ goto Overflow;
+ put16(p, n), p += 2;
memmove(p, m->u.clientKeyExchange.pskid->data, n);
p += n;
}
if(m->u.clientKeyExchange.key == nil)
break;
- n = m->u.clientKeyExchange.key->len;
+ if(p+2+(n = m->u.clientKeyExchange.key->len) > e)
+ goto Overflow;
if(isECDHE(c->cipher))
*p++ = n;
else if(isDHE(c->cipher) || c->version != SSL3Version)
@@ -1409,6 +1404,8 @@
p += n;
break;
case HFinished:
+ if(p+m->u.finished.n > e)
+ goto Overflow;
memmove(p, m->u.finished.verify, m->u.finished.n);
p += m->u.finished.n;
break;
@@ -1416,7 +1413,6 @@
// go back and fill in size
n = p - c->sendp;
- assert(n <= sizeof(c->sendbuf));
put24(c->sendp+1, n-4);
// remember hash of Handshake messages
@@ -1425,8 +1421,8 @@
c->sendp = p;
if(act == AFlush){
- c->sendp = c->sendbuf;
- if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){
+ c->sendp = c->buf;
+ if(write(c->hand, c->buf, p - c->buf) < 0){
fprint(2, "write error: %r\n");
goto Err;
}
@@ -1433,6 +1429,8 @@
}
msgClear(m);
return 1;
+Overflow:
+ tlsError(c, EInternalError, "not enougth send buffer for message (%d)", m->tag);
Err:
msgClear(m);
return 0;
@@ -1441,25 +1439,28 @@
static uchar*
tlsReadN(TlsConnection *c, int n)
{
- uchar *p;
- int nn, nr;
+ uchar *p, *e;
- nn = c->ep - c->rp;
- if(nn < n){
- if(c->rp != c->recvbuf){
- memmove(c->recvbuf, c->rp, nn);
- c->rp = c->recvbuf;
- c->ep = &c->recvbuf[nn];
- }
- for(; nn < n; nn += nr) {
- nr = read(c->hand, &c->rp[nn], n - nn);
- if(nr <= 0)
- return nil;
- c->ep += nr;
- }
+ p = c->recvp;
+ if(n <= c->recvw - p){
+ c->recvp += n;
+ return p;
}
- p = c->rp;
- c->rp += n;
+ e = &c->buf[sizeof(c->buf)];
+ c->recvp = e - n;
+ if(c->recvp < c->sendp || n > sizeof(c->buf)){
+ tlsError(c, EDecodeError, "handshake message too long %d", n);
+ return nil;
+ }
+ memmove(c->recvp, p, c->recvw - p);
+ c->recvw -= p - c->recvp;
+ p = c->recvp;
+ c->recvp += n;
+ while(c->recvw < c->recvp){
+ if((n = read(c->hand, c->recvw, e - c->recvw)) <= 0)
+ return nil;
+ c->recvw += n;
+ }
return p;
}
@@ -1485,11 +1486,6 @@
}
}
- if(n > sizeof(c->recvbuf)) {
- tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf));
- return 0;
- }
-
if(type == HSSL2ClientHello){
/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
This is sent by some clients that we must interoperate
@@ -1512,10 +1508,8 @@
p += 6;
n -= 6;
if(nsid != 0 /* no sid's, since shouldn't restart using ssl2 header */
- || nrandom < 16 || nn % 3)
+ || nrandom < 16 || nn % 3 || n - nrandom < nn)
goto Err;
- if(c->trace && (n - nrandom != nn))
- c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
nciph = 0;
for(i = 0; i < nn; i += 3)
@@ -1805,15 +1799,11 @@
break;
}
- if(type != HClientHello && type != HServerHello && n != 0)
+ if(n != 0 && type != HClientHello && type != HServerHello)
goto Short;
Ok:
- if(c->trace){
- char *buf;
- buf = emalloc(8000);
- c->trace("recv %s", msgPrint(buf, 8000, m));
- free(buf);
- }
+ if(c->trace)
+ c->trace("recv %s", msgPrint((char*)c->sendp, c->recvp - c->sendp, m));
return 1;
Short:
tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
@@ -2623,8 +2613,9 @@
K.y = mpnew(0);
ecmul(dom, Y, Q->d, &K);
- setMasterSecret(sec, mptobytes(K.x));
+ setMasterSecret(sec, mptobytes(K.x, (mpsignif(dom->p)+7)/8));
+
mpfree(K.x);
mpfree(K.y);
@@ -2857,7 +2848,7 @@
y = factotum_rsa_decrypt(sec->rpc, bytestomp(data));
if(y == nil)
return nil;
- data = mptobytes(y);
+ data = mptobytes(y, (mpsignif(y)+7)/8);
if((data->len = pkcs1unpadbuf(data->data, data->len, sec->rsapub->n, 2)) < 0){
freebytes(data);
return nil;
@@ -2883,10 +2874,11 @@
werrstr("bad digest algorithm");
return nil;
}
+
signedMP = factotum_rsa_decrypt(sec->rpc, pkcs1padbuf(buf, digestlen, sec->rsapub->n, 1));
if(signedMP == nil)
return nil;
- signature = mptobytes(signedMP);
+ signature = mptobytes(signedMP, (mpsignif(sec->rsapub->n)+7)/8);
mpfree(signedMP);
return signature;
}
@@ -2998,14 +2990,12 @@
* Convert mpint* to Bytes, putting high order byte first.
*/
static Bytes*
-mptobytes(mpint* big)
+mptobytes(mpint *big, int len)
{
Bytes* ans;
- int n;
- n = (mpsignif(big)+7)/8;
- if(n == 0) n = 1;
- ans = newbytes(n);
+ if(len == 0) len++;
+ ans = newbytes(len);
mptober(big, ans->data, ans->len);
return ans;
}