shithub: riscv

Download patch

ref: c8008e1ffd0bdcf8aa1152aa9a81157ef1688926
parent: 6c68876db6d25b8c646295fecc75a6363d0bdc75
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Sat Sep 13 22:30:46 EDT 2014

libsec: experimental DHE client support for tls and cleanups

--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -117,6 +117,12 @@
 			Bytes *key;
 		} clientKeyExchange;
 		struct {
+			Bytes *dh_p;
+			Bytes *dh_g;
+			Bytes *dh_Ys;
+			Bytes *dh_signature;
+		} serverKeyExchange;
+		struct {
 			Bytes *signature;
 		} certificateVerify;		
 		Finished finished;
@@ -251,6 +257,8 @@
 	{"3des_ede_cbc","sha1",	2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
 	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
 	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
+	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
+	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
 };
 
 static uchar compressors[] = {
@@ -275,20 +283,21 @@
 static int initCiphers(void);
 static Ints* makeciphers(void);
 
-static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
-static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
+static TlsSec*	tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
+static int	tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
-static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
+static Bytes*	tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
+static Bytes*	tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
 static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
 static void	tlsSecOk(TlsSec *sec);
 static void	tlsSecKill(TlsSec *sec);
 static void	tlsSecClose(TlsSec *sec);
 static void	setMasterSecret(TlsSec *sec, Bytes *pm);
-static void	serverMasterSecret(TlsSec *sec, uchar *epm, int nepm);
+static void	serverMasterSecret(TlsSec *sec, Bytes *epm);
 static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
-static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
-static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
-static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
+static Bytes*	clientMasterSecret(TlsSec *sec, RSApub *pub);
+static Bytes*	pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
+static Bytes*	pkcs1_decrypt(TlsSec *sec, Bytes *cipher);
 static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
 static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
@@ -310,6 +319,7 @@
 static Bytes* newbytes(int len);
 static Bytes* makebytes(uchar* buf, int len);
 static Bytes* mptobytes(mpint* big);
+static mpint* bytestomp(Bytes* bytes);
 static void freebytes(Bytes* b);
 static Ints* newints(int len);
 static Ints* makeints(int* buf, int len);
@@ -565,10 +575,11 @@
 		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
 		goto Err;
 	}
-	if(tlsSecSecrets(c->sec, c->version, m.u.clientKeyExchange.key->data, m.u.clientKeyExchange.key->len, kd, c->nsecret) < 0){
+	if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
 		tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
 		goto Err;
 	}
+	setSecrets(c->sec, kd, c->nsecret);
 	if(trace)
 		trace("tls secrets\n");
 	secrets = (char*)emalloc(2*c->nsecret);
@@ -629,15 +640,76 @@
 	return 0;
 }
 
+static int
+isDHE(int tlsid)
+{
+	switch(tlsid){
+	case TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA:
+	case TLS_DHE_DSS_WITH_DES_CBC_SHA:
+	case TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA:
+	case TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA:
+	case TLS_DHE_RSA_WITH_DES_CBC_SHA:
+	case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
+	case TLS_DHE_DSS_WITH_AES_128_CBC_SHA:
+	case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
+	case TLS_DHE_DSS_WITH_AES_256_CBC_SHA:
+	case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
+		return 1;
+	}
+	return 0;
+}
+
+static Bytes*
+tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, 
+	Bytes *p, Bytes *g, Bytes *Ys)
+{
+	mpint *G, *P, *Y, *K;
+	Bytes *epm;
+	DHstate dh;
+
+	if(p == nil || g == nil || Ys == nil)
+		return nil;
+
+	memmove(sec->srandom, srandom, RandomSize);
+	if(setVers(sec, vers) < 0)
+		return nil;
+
+	epm = nil;
+	P = bytestomp(p);
+	G = bytestomp(g);
+	Y = bytestomp(Ys);
+	K = nil;
+
+	if(P == nil || G == nil || Y == nil || dh_new(&dh, P, G) == nil)
+		goto Out;
+	epm = mptobytes(dh.y);
+	K = dh_finish(&dh, Y);
+	if(K == nil){
+		freebytes(epm);
+		epm = nil;
+		goto Out;
+	}
+	setMasterSecret(sec, mptobytes(K));
+
+Out:
+	mpfree(K);
+	mpfree(Y);
+	mpfree(G);
+	mpfree(P);
+
+	return epm;
+}
+
 static TlsConnection *
 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...))
 {
 	TlsConnection *c;
 	Msg m;
-	uchar kd[MaxKeyData], *epm;
+	uchar kd[MaxKeyData];
 	char *secrets;
-	int creq, nepm, rv;
-	mpint *signedMP, *paddedHashes; 
+	int creq, dhx, rv, cipher;
+	mpint *signedMP, *paddedHashes;
+	Bytes *epm;
 
 	if(!initCiphers())
 		return nil;
@@ -683,7 +755,8 @@
 		tlsError(c, EIllegalParameter, "invalid server session identifier");
 		goto Err;
 	}
-	if(!setAlgs(c, m.u.serverHello.cipher)) {
+	cipher = m.u.serverHello.cipher;
+	if(!setAlgs(c, cipher)) {
 		tlsError(c, EIllegalParameter, "invalid cipher suite");
 		goto Err;
 	}
@@ -705,14 +778,27 @@
 	c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
 	msgClear(&m);
 
-	/* server key exchange (optional) */
+	/* server key exchange */
+	dhx = isDHE(cipher);
 	if(!msgRecv(c, &m))
 		goto Err;
 	if(m.tag == HServerKeyExchange) {
-		tlsError(c, EUnexpectedMessage, "got an server key exchange");
+		if(!dhx){
+			tlsError(c, EUnexpectedMessage, "got an server key exchange");
+			goto Err;
+		}
+		epm = tlsSecDHEc(c->sec, c->srandom, c->version,
+			m.u.serverKeyExchange.dh_p, 
+			m.u.serverKeyExchange.dh_g,
+			m.u.serverKeyExchange.dh_Ys);
+		if(epm == nil)
+			goto Badcert;
+		msgClear(&m);
+		if(!msgRecv(c, &m))
+			goto Err;
+	} else if(dhx){
+		tlsError(c, EUnexpectedMessage, "expected server key exchange");
 		goto Err;
-		// If implementing this later, watch out for rollback attack
-		// described in Wagner Schneier 1996, section 4.4.
 	}
 
 	/* certificate request (optional) */
@@ -730,12 +816,17 @@
 	}
 	msgClear(&m);
 
-	if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom,
-			c->cert->data, c->cert->len, c->version, &epm, &nepm,
-			kd, c->nsecret) < 0){
+	if(!dhx)
+		epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
+			c->cert->data, c->cert->len, c->version);
+
+	if(epm == nil){
+	Badcert:
 		tlsError(c, EBadCertificate, "bad certificate: %r");
 		goto Err;
 	}
+
+	setSecrets(c->sec, kd, c->nsecret);
 	secrets = (char*)emalloc(2*c->nsecret);
 	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
 	rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
@@ -761,8 +852,7 @@
 
 	/* client key exchange */
 	m.tag = HClientKeyExchange;
-	m.u.clientKeyExchange.key = makebytes(epm, nepm);
-	free(epm);
+	m.u.clientKeyExchange.key = epm;
 	epm = nil;
 	if(m.u.clientKeyExchange.key == nil) {
 		tlsError(c, EHandshakeFailure, "can't set secret: %r");
@@ -832,9 +922,7 @@
 	}
 	m.tag = HFinished;
 	m.u.finished = c->finished;
-
 	if(!msgSend(c, &m, AFlush)) {
-		fprint(2, "tlsClient nepm=%d\n", nepm);
 		tlsError(c, EInternalError, "can't flush after client Finished: %r");
 		goto Err;
 	}
@@ -841,17 +929,14 @@
 	msgClear(&m);
 
 	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
-		fprint(2, "tlsClient nepm=%d\n", nepm);
 		tlsError(c, EInternalError, "can't set finished 0: %r");
 		goto Err;
 	}
 	if(!msgRecv(c, &m)) {
-		fprint(2, "tlsClient nepm=%d\n", nepm);
 		tlsError(c, EInternalError, "can't read server Finished: %r");
 		goto Err;
 	}
 	if(m.tag != HFinished) {
-		fprint(2, "tlsClient nepm=%d\n", nepm);
 		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
 		goto Err;
 	}
@@ -1112,7 +1197,6 @@
 		m->u.clientHello.compressors->data[0] = CompressionNull;
 		goto Ok;
 	}
-
 	md5(p, 4, 0, &c->hsmd5);
 	sha1(p, 4, 0, &c->hssha1);
 
@@ -1259,6 +1343,43 @@
 		break;
 	case HServerHelloDone:
 		break;
+	case HServerKeyExchange:
+		if(n < 2)
+			goto Short;
+		nn = get16(p);
+		p += 2, n -= 2;
+		if(nn < 1 || nn > n)
+			goto Short;
+		m->u.serverKeyExchange.dh_p = makebytes(p, nn);
+		p += nn, n -= nn;
+
+		if(n < 2)
+			goto Short;
+		nn = get16(p);
+		p += 2, n -= 2;
+		if(nn < 1 || nn > n)
+			goto Short;
+		m->u.serverKeyExchange.dh_g = makebytes(p, nn);
+		p += nn, n -= nn;
+
+		if(n < 2)
+			goto Short;
+		nn = get16(p);
+		p += 2, n -= 2;
+		if(nn < 1 || nn > n)
+			goto Short;
+		m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
+		p += nn, n -= nn;
+
+		if(n >= 2){
+			nn = get16(p);
+			p += 2, n -= 2;
+			if(nn > 0 && nn <= n){
+				m->u.serverKeyExchange.dh_signature = makebytes(p, nn);
+				n -= nn;
+			}
+		}
+		break;		
 	case HClientKeyExchange:
 		/*
 		 * this message depends upon the encryption selected
@@ -1298,7 +1419,7 @@
 	}
 	return 1;
 Short:
-	tlsError(c, EDecodeError, "handshake message has invalid length");
+	tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
 Err:
 	msgClear(m);
 	return 0;
@@ -1338,6 +1459,12 @@
 		break;
 	case HServerHelloDone:
 		break;
+	case HServerKeyExchange:
+		freebytes(m->u.serverKeyExchange.dh_p);
+		freebytes(m->u.serverKeyExchange.dh_g);
+		freebytes(m->u.serverKeyExchange.dh_Ys);
+		freebytes(m->u.serverKeyExchange.dh_signature);
+		break;
 	case HClientKeyExchange:
 		freebytes(m->u.clientKeyExchange.key);
 		break;
@@ -1354,12 +1481,13 @@
 
 	if(s0)
 		bs = seprint(bs, be, "%s", s0);
-	bs = seprint(bs, be, "[");
 	if(b == nil)
 		bs = seprint(bs, be, "nil");
-	else
+	else {
+		bs = seprint(bs, be, "<%d> [", b->len);
 		for(i=0; i<b->len; i++)
 			bs = seprint(bs, be, "%.2x ", b->data[i]);
+	}
 	bs = seprint(bs, be, "]");
 	if(s1)
 		bs = seprint(bs, be, "%s", s1);
@@ -1436,6 +1564,13 @@
 	case HServerHelloDone:
 		bs = seprint(bs, be, "ServerHelloDone\n");
 		break;
+	case HServerKeyExchange:
+		bs = seprint(bs, be, "HServerKeyExchange\n");
+		bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
+		bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
+		bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
+		bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
+		break;
 	case HClientKeyExchange:
 		bs = seprint(bs, be, "HClientKeyExchange\n");
 		bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
@@ -1574,6 +1709,8 @@
 			weak = 0;
 		else
 			weak &= weakCipher[c];
+		if(isDHE(c))
+			continue;	/* TODO: dhe not implemented for server */
 		for(j = 0; j < nelem(cipherAlgs); j++)
 			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
 				return c;
@@ -1862,17 +1999,16 @@
 }
 
 static int
-tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd)
+tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm)
 {
 	if(epm != nil){
 		if(setVers(sec, vers) < 0)
 			goto Err;
-		serverMasterSecret(sec, epm, nepm);
+		serverMasterSecret(sec, epm);
 	}else if(sec->vers != vers){
 		werrstr("mismatched session versions");
 		goto Err;
 	}
-	setSecrets(sec, kd, nkd);
 	return 0;
 Err:
 	sec->ok = -1;
@@ -1890,37 +2026,30 @@
 	return sec;
 }
 
-static int
-tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd)
+static Bytes*
+tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
 {
 	RSApub *pub;
+	Bytes *epm;
 
-	pub = nil;
-
 	USED(sid);
 	USED(nsid);
 	
 	memmove(sec->srandom, srandom, RandomSize);
-
 	if(setVers(sec, vers) < 0)
 		goto Err;
-
 	pub = X509toRSApub(cert, ncert, nil, 0);
 	if(pub == nil){
 		werrstr("invalid x509/rsa certificate");
 		goto Err;
 	}
-	if(clientMasterSecret(sec, pub, epm, nepm) < 0)
-		goto Err;
+	epm = clientMasterSecret(sec, pub);
 	rsapubfree(pub);
-	setSecrets(sec, kd, nkd);
-	return 0;
-
+	if(epm != nil)
+		return epm;
 Err:
-	if(pub != nil)
-		rsapubfree(pub);
 	sec->ok = -1;
-	return -1;
+	return nil;
 }
 
 static int
@@ -1997,21 +2126,25 @@
 }
 
 /*
- * set the master secret from the pre-master secret.
+ * set the master secret from the pre-master secret,
+ * destroys premaster.
  */
 static void
 setMasterSecret(TlsSec *sec, Bytes *pm)
 {
-	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, MasterSecretSize, "master secret",
+	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
 			sec->crandom, RandomSize, sec->srandom, RandomSize);
+
+	memset(pm->data, 0, pm->len);	
+	freebytes(pm);
 }
 
 static void
-serverMasterSecret(TlsSec *sec, uchar *epm, int nepm)
+serverMasterSecret(TlsSec *sec, Bytes *epm)
 {
 	Bytes *pm;
 
-	pm = pkcs1_decrypt(sec, epm, nepm);
+	pm = pkcs1_decrypt(sec, epm);
 
 	// if the client messed up, just continue as if everything is ok,
 	// to prevent attacks to check for correctly formatted messages.
@@ -2018,7 +2151,7 @@
 	// Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
 	if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
 		fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
-			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, nepm);
+			sec->ok, pm, pm ? get16(pm->data) : -1, sec->clientVers, epm->len);
 		sec->ok = -1;
 		if(pm != nil)
 			freebytes(pm);
@@ -2025,42 +2158,21 @@
 		pm = newbytes(MasterSecretSize);
 		genrandom(pm->data, MasterSecretSize);
 	}
+	assert(pm->len == MasterSecretSize);
 	setMasterSecret(sec, pm);
-	memset(pm->data, 0, pm->len);	
-	freebytes(pm);
 }
 
-static int
-clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
+static Bytes*
+clientMasterSecret(TlsSec *sec, RSApub *pub)
 {
-	Bytes *pm, *key;
+	Bytes *pm, *epm;
 
 	pm = newbytes(MasterSecretSize);
 	put16(pm->data, sec->clientVers);
 	genrandom(pm->data+2, MasterSecretSize - 2);
-
+	epm = pkcs1_encrypt(pm, pub, 2);
 	setMasterSecret(sec, pm);
-
-	key = pkcs1_encrypt(pm, pub, 2);
-	memset(pm->data, 0, pm->len);
-	freebytes(pm);
-	if(key == nil){
-		werrstr("tls pkcs1_encrypt failed");
-		return -1;
-	}
-
-	*nepm = key->len;
-	*epm = malloc(*nepm);
-	if(*epm == nil){
-		freebytes(key);
-		werrstr("out of memory");
-		return -1;
-	}
-	memmove(*epm, key->data, *nepm);
-
-	freebytes(key);
-
-	return 1;
+	return epm;
 }
 
 static void
@@ -2236,7 +2348,7 @@
 // decrypt data according to PKCS#1, with given key.
 // expect a block type of 2.
 static Bytes*
-pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm)
+pkcs1_decrypt(TlsSec *sec, Bytes *cipher)
 {
 	Bytes *eb, *ans = nil;
 	int i, modlen;
@@ -2243,9 +2355,9 @@
 	mpint *x, *y;
 
 	modlen = (mpsignif(sec->rsapub->n)+7)/8;
-	if(nepm != modlen)
+	if(cipher->len != modlen)
 		return nil;
-	x = betomp(epm, nepm, nil);
+	x = bytestomp(cipher);
 	y = factotum_rsa_decrypt(sec->rpc, x);
 	if(y == nil)
 		return nil;