shithub: riscv

Download patch

ref: 7d7650dffc20c3853eb1acd551c9a02f73751ae3
parent: 873850c33d77cdd3ed20db94af670ec631ea4aa1
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Mon Aug 17 17:16:58 EDT 2015

libsec: TLS1.2 client support

--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -63,6 +63,12 @@
 	int n;
 } Finished;
 
+typedef struct HandshakeHash {
+	MD5state	md5;
+	SHAstate	sha1;
+	SHA2_256state	sha2_256;
+} HandshakeHash;
+
 typedef struct TlsConnection{
 	TlsSec *sec;	// security management goo
 	int hand, ctl;	// record layer file descriptors
@@ -95,8 +101,7 @@
 	int nsecret;	// amount of secret data to init keys
 
 	// for finished messages
-	MD5state	hsmd5;	// handshake hash
-	SHAstate	hssha1;	// handshake hash
+	HandshakeHash	handhash;
 	Finished	finished;
 } TlsConnection;
 
@@ -157,7 +162,7 @@
 	int vers;			// final version
 	// byte generation and handshake checksum
 	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
-	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
+	void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int);
 	int nfin;
 } TlsSec;
 
@@ -166,7 +171,8 @@
 	SSL3Version	= 0x0300,
 	TLS10Version	= 0x0301,
 	TLS11Version	= 0x0302,
-	ProtocolVersion	= TLS11Version,	// maximum version we speak
+	TLS12Version	= 0x0303,
+	ProtocolVersion	= TLS11Version,	// maximum version we speak (server)
 	MinProtoVersion	= 0x0300,	// limits on version we accept
 	MaxProtoVersion	= 0x03ff,
 };
@@ -331,7 +337,7 @@
 static Bytes*	tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
 static Bytes*	tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
 static Bytes*	tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
-static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
+static int	tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
 static void	tlsSecOk(TlsSec *sec);
 static void	tlsSecKill(TlsSec *sec);
 static void	tlsSecClose(TlsSec *sec);
@@ -341,8 +347,9 @@
 static Bytes*	clientMasterSecret(TlsSec *sec, RSApub *pub);
 static Bytes*	pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
 static Bytes*	pkcs1_decrypt(TlsSec *sec, Bytes *cipher);
-static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
-static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
+static void	tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
+static void	tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
+static void	sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
 			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
 static int setVers(TlsSec *sec, int version);
@@ -693,7 +700,7 @@
 	msgClear(&m);
 
 	/* no CertificateVerify; skip to Finished */
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
 		tlsError(c, EInternalError, "can't set finished: %r");
 		goto Err;
 	}
@@ -715,7 +722,7 @@
 		goto Err;
 	}
 
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
 		tlsError(c, EInternalError, "can't set finished: %r");
 		goto Err;
 	}
@@ -961,7 +968,12 @@
 		return nil;
 	epm = nil;
 	c = emalloc(sizeof(TlsConnection));
-	c->version = ProtocolVersion;
+	c->version = TLS12Version;
+
+	// client certificate signature not implemented for TLS1.2
+	if(cert != nil && certlen > 0)
+		c->version = TLS11Version;
+
 	c->ctl = ctl;
 	c->hand = hand;
 	c->trace = trace;
@@ -1114,26 +1126,17 @@
 		goto Err;
 	msgClear(&m);
 
-	/* CertificateVerify */
-	/*XXX I should only send this when it is not DH right? 
-		Also we need to know which TLS key 
-		we have to use in case there are more than one*/
-	if(cert){
-		m.tag = HCertificateVerify;
+	/* certificate verify */
+	if(creq && cert != nil && certlen > 0) {
 		uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */
-		MD5state	hsmd5_save;
-		SHAstate	hssha1_save;
-	
+		HandshakeHash hsave;
+
 		/* save the state for the Finish message */
+		hsave = c->handhash;
+		md5(nil, 0, hshashes, &c->handhash.md5);
+		sha1(nil, 0, hshashes+MD5dlen, &c->handhash.sha1);
+		c->handhash = hsave;
 
-		hsmd5_save = c->hsmd5;
-		hssha1_save = c->hssha1;
-		md5(nil, 0, hshashes, &c->hsmd5);
-		sha1(nil, 0, hshashes+MD5dlen, &c->hssha1);
-	
-		c->hsmd5 = hsmd5_save;
-		c->hssha1 = hssha1_save;
-
 		c->sec->rpc = factotum_rsa_open(cert, certlen);
 		if(c->sec->rpc == nil){
 			tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
@@ -1154,6 +1157,7 @@
 		m.u.certificateVerify.signature = mptobytes(signedMP);
 		mpfree(signedMP);
 
+		m.tag = HCertificateVerify;
 		if(!msgSend(c, &m, AFlush))
 			goto Err;
 		msgClear(&m);
@@ -1167,7 +1171,7 @@
 
 	// Cipherchange must occur immediately before Finished to avoid
 	// potential hole;  see section 4.3 of Wagner Schneier 1996.
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
 		tlsError(c, EInternalError, "can't set finished 1: %r");
 		goto Err;
 	}
@@ -1179,7 +1183,7 @@
 	}
 	msgClear(&m);
 
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
 		tlsError(c, EInternalError, "can't set finished 0: %r");
 		goto Err;
 	}
@@ -1216,6 +1220,15 @@
 
 //================= message functions ========================
 
+static void
+msgHash(TlsConnection *c, uchar *p, int n)
+{
+	md5(p, n, 0, &c->handhash.md5);
+	sha1(p, n, 0, &c->handhash.sha1);
+	if(c->version >= TLS12Version)
+		sha2_256(p, n, 0, &c->handhash.sha2_256);
+}
+
 static int
 msgSend(TlsConnection *c, Msg *m, int act)
 {
@@ -1352,10 +1365,8 @@
 	put24(c->sendp+1, n-4);
 
 	// remember hash of Handshake messages
-	if(m->tag != HHelloRequest) {
-		md5(c->sendp, n, 0, &c->hsmd5);
-		sha1(c->sendp, n, 0, &c->hssha1);
-	}
+	if(m->tag != HHelloRequest)
+		msgHash(c, c->sendp, n);
 
 	c->sendp = p;
 	if(act == AFlush){
@@ -1430,8 +1441,7 @@
 		p = tlsReadN(c, n);
 		if(p == nil)
 			return 0;
-		md5(p, n, 0, &c->hsmd5);
-		sha1(p, n, 0, &c->hssha1);
+		msgHash(c, p, n);
 		m->tag = HClientHello;
 		if(n < 22)
 			goto Short;
@@ -1468,15 +1478,13 @@
 		m->u.clientHello.compressors->data[0] = CompressionNull;
 		goto Ok;
 	}
-	md5(p, 4, 0, &c->hsmd5);
-	sha1(p, 4, 0, &c->hssha1);
+	msgHash(c, p, 4);
 
 	p = tlsReadN(c, n);
 	if(p == nil)
 		return 0;
 
-	md5(p, n, 0, &c->hsmd5);
-	sha1(p, n, 0, &c->hssha1);
+	msgHash(c, p, n);
 
 	m->tag = type;
 
@@ -1678,6 +1686,12 @@
 			break;
 		}
 		if(n >= 2){
+			if(c->version >= TLS12Version){
+				/* signature hash algorithm */
+				p += 2, n -= 2;
+				if(n < 2)
+					goto Short;
+			}
 			nn = get16(p);
 			p += 2, n -= 2;
 			if(nn > 0 && nn <= n){
@@ -2265,20 +2279,55 @@
 	}
 }
 
+static void
+p_sha256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
+{
+	uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
+	SHAstate *s;
+	int n;
+
+	// generate a1
+	s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
+	hmac_sha2_256(seed, nseed, key, nkey, ai, s);
+
+	while(nbuf > 0) {
+		s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
+		s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
+		hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
+		n = SHA2_256dlen;
+		if(n > nbuf)
+			n = nbuf;
+		memmove(buf, tmp, n);
+		buf += n;
+		nbuf -= n;
+		hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
+		memmove(ai, tmp, SHA2_256dlen);
+	}
+}
+
 // fill buf with md5(args)^sha1(args)
 static void
-tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
+tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
 {
-	int i;
 	int nlabel = strlen(label);
 	int n = (nkey + 1) >> 1;
 
-	for(i = 0; i < nbuf; i++)
-		buf[i] = 0;
+	memset(buf, 0, nbuf);
 	tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
 	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
 }
 
+static void
+tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
+{
+	uchar seed[2*RandomSize];
+
+	assert(nseed0+nseed1 <= sizeof(seed));
+	memmove(seed, seed0, nseed0);
+	memmove(seed+nseed0, seed1, nseed1);
+	p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1);
+}
+
 /*
  * for setting server session id's
  */
@@ -2369,7 +2418,7 @@
 }
 
 static int
-tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
+tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient)
 {
 	if(sec->nfin != nfin){
 		sec->ok = -1;
@@ -2376,9 +2425,10 @@
 		werrstr("invalid finished exchange");
 		return -1;
 	}
-	md5.malloced = 0;
-	sha1.malloced = 0;
-	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
+	hsh.md5.malloced = 0;
+	hsh.sha1.malloced = 0;
+	hsh.sha2_256.malloced = 0;
+	(*sec->setFinished)(sec, hsh, fin, isclient);
 	return 1;
 }
 
@@ -2415,10 +2465,14 @@
 		sec->setFinished = sslSetFinished;
 		sec->nfin = SSL3FinishedLen;
 		sec->prf = sslPRF;
-	}else{
-		sec->setFinished = tlsSetFinished;
+	}else if(v < TLS12Version) {
+		sec->setFinished = tls10SetFinished;
 		sec->nfin = TLSFinishedLen;
-		sec->prf = tlsPRF;
+		sec->prf = tls10PRF;
+	}else {
+		sec->setFinished = tls12SetFinished;
+		sec->nfin = TLSFinishedLen;
+		sec->prf = tls12PRF;
 	}
 	sec->vers = v;
 	return 0;
@@ -2488,7 +2542,7 @@
 }
 
 static void
-sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
 {
 	DigestState *s;
 	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
@@ -2499,21 +2553,21 @@
 	else
 		label = "SRVR";
 
-	md5((uchar*)label, 4, nil, &hsmd5);
-	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
+	md5((uchar*)label, 4, nil, &hsh.md5);
+	md5(sec->sec, MasterSecretSize, nil, &hsh.md5);
 	memset(pad, 0x36, 48);
-	md5(pad, 48, nil, &hsmd5);
-	md5(nil, 0, h0, &hsmd5);
+	md5(pad, 48, nil, &hsh.md5);
+	md5(nil, 0, h0, &hsh.md5);
 	memset(pad, 0x5C, 48);
 	s = md5(sec->sec, MasterSecretSize, nil, nil);
 	s = md5(pad, 48, nil, s);
 	md5(h0, MD5dlen, finished, s);
 
-	sha1((uchar*)label, 4, nil, &hssha1);
-	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
+	sha1((uchar*)label, 4, nil, &hsh.sha1);
+	sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1);
 	memset(pad, 0x36, 40);
-	sha1(pad, 40, nil, &hssha1);
-	sha1(nil, 0, h1, &hssha1);
+	sha1(pad, 40, nil, &hsh.sha1);
+	sha1(nil, 0, h1, &hsh.sha1);
 	memset(pad, 0x5C, 40);
 	s = sha1(sec->sec, MasterSecretSize, nil, nil);
 	s = sha1(pad, 40, nil, s);
@@ -2522,27 +2576,43 @@
 
 // fill "finished" arg with md5(args)^sha1(args)
 static void
-tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
 {
 	uchar h0[MD5dlen], h1[SHA1dlen];
 	char *label;
 
 	// get current hash value, but allow further messages to be hashed in
-	md5(nil, 0, h0, &hsmd5);
-	sha1(nil, 0, h1, &hssha1);
+	md5(nil, 0, h0, &hsh.md5);
+	sha1(nil, 0, h1, &hsh.sha1);
 
 	if(isClient)
 		label = "client finished";
 	else
 		label = "server finished";
-	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
+	tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
 }
 
 static void
+tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
+{
+	uchar seed[SHA2_256dlen];
+	char *label;
+
+	// get current hash value, but allow further messages to be hashed in
+	sha2_256(nil, 0, seed, &hsh.sha2_256);
+
+	if(isClient)
+		label = "client finished";
+	else
+		label = "server finished";
+	p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen);
+}
+
+static void
 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
 {
-	DigestState *s;
 	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
+	DigestState *s;
 	int i, n, len;
 
 	USED(label);