shithub: riscv

Download patch

ref: 9d3bc1646915d7e7a69dcf614690fdcec6e6aa7a
parent: 079d3f4002526164e0d430cca83383af8f213a84
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Sun Jun 7 18:14:01 EDT 2015

libsec/tlshand: implement client side ECDHE (many thanks to pr!)

--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -46,6 +46,18 @@
 	int ok;
 } Algs;
 
+typedef struct Namedcurve{
+	int tlsid;
+	char *name;
+
+	char *p;
+	char *a;
+	char *b;
+	char *G;
+	char *n;
+	char *h;
+} Namedcurve;
+
 typedef struct Finished{
 	uchar verify[SSL3FinishedLen];
 	int n;
@@ -77,6 +89,7 @@
 	uchar crandom[RandomSize];	// client random
 	uchar srandom[RandomSize];	// server random
 	int clientVersion;	// version in ClientHello
+	int cipher;
 	char *digest;	// name of digest algorithm to use
 	char *enc;		// name of encryption algorithm to use
 	int nsecret;	// amount of secret data to init keys
@@ -123,6 +136,7 @@
 			Bytes *dh_g;
 			Bytes *dh_Ys;
 			Bytes *dh_signature;
+			int curve;
 		} serverKeyExchange;
 		struct {
 			Bytes *signature;
@@ -244,6 +258,11 @@
 	TLS_DHE_DSS_WITH_AES_256_CBC_SHA	= 0X0038,
 	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
 	TLS_DH_anon_WITH_AES_256_CBC_SHA	= 0X003A,
+	
+	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA	= 0xC013,
+	TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA	= 0xC014,
+	TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA  = 0xC009,
+	TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A,
 	CipherMax
 };
 
@@ -262,6 +281,11 @@
 	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
 	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
 	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
+
+	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
+	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
+	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA},
+	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
 };
 
 static uchar compressors[] = {
@@ -268,6 +292,20 @@
 	CompressionNull,
 };
 
+static Namedcurve namedcurves[] = {
+{0x0017, "secp256r1",
+	"FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
+	"FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC",
+	"5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B",
+	"046B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C2964FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
+	"FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
+	"1"}
+};
+
+static uchar pointformats[] = {
+	CompressionNull /* support of uncompressed point format is mandatory */
+};
+
 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
 static void	msgClear(Msg *m);
@@ -291,6 +329,7 @@
 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
 static Bytes*	tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
 static Bytes*	tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
+static Bytes*	tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
 static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
 static void	tlsSecOk(TlsSec *sec);
 static void	tlsSecKill(TlsSec *sec);
@@ -395,7 +434,7 @@
 tlsClientExtensions(TLSconn *conn, int *plen)
 {
 	uchar *b, *p;
-	int n, m;
+	int i, n, m;
 
 	p = b = nil;
 
@@ -402,9 +441,11 @@
 	// RFC6066 - Server Name Identification
 	if(conn->serverName != nil){
 		n = strlen(conn->serverName);
+
 		m = p - b;
-		b = erealloc(b, m+2+2+2+1+2+n);
+		b = erealloc(b, m + 2+2+2+1+2+n);
 		p = b + m;
+
 		put16(p, 0), p += 2;		/* Type: server_name */
 		put16(p, 2+1+2+n), p += 2;	/* Length */
 		put16(p, 1+2+n), p += 2;	/* Server Name list length */
@@ -414,6 +455,29 @@
 		p += n;
 	}
 
+	// ECDHE
+	if(1){
+		m = p - b;
+		b = erealloc(b, m + 2+2+2+nelem(namedcurves)*2 + 2+2+1+nelem(pointformats));
+		p = b + m;
+
+		n = nelem(namedcurves);
+		put16(p, 0x000a), p += 2;	/* Type: elliptic_curves */
+		put16(p, (n+1)*2), p += 2;	/* Length */
+		put16(p, n*2), p += 2;		/* Elliptic Curves Length */
+		for(i=0; i < n; i++){		/* Elliptic curves */
+			put16(p, namedcurves[i].tlsid);
+			p += 2;
+		}
+
+		n = nelem(pointformats);
+		put16(p, 0x000b), p += 2;	/* Type: ec_point_formats */
+		put16(p, n+1), p += 2;		/* Length */
+		*p++ = n;			/* EC point formats Length */
+		for(i=0; i < n; i++)		/* Elliptic curves point formats */
+			*p++ = pointformats[i];
+	}
+	
 	*plen = p - b;
 	return b;
 }
@@ -692,6 +756,19 @@
 	return 0;
 }
 
+static int
+isECDHE(int tlsid)
+{
+	switch(tlsid){
+	case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
+	case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
+	case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
+	case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
+		return 1;
+	}
+	return 0;
+}
+
 static Bytes*
 tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, 
 	Bytes *p, Bytes *g, Bytes *Ys)
@@ -733,6 +810,140 @@
 	return epm;
 }
 
+static ECpoint*
+bytestoec(ECdomain *dom, Bytes *bp, ECpoint *ret)
+{
+	char *hex = "0123456789ABCDEF";
+	char *s;
+	int i;
+
+	s = emalloc(2*bp->len + 1);
+	for(i=0; i < bp->len; i++){
+		s[2*i] = hex[bp->data[i]>>4 & 15];
+		s[2*i+1] = hex[bp->data[i] & 15];
+	}
+	s[2*bp->len] = '\0';
+	ret = strtoec(dom, s, nil, ret);
+	free(s);
+	return ret;
+}
+
+static Bytes*
+ectobytes(int type, ECpoint *p)
+{
+	Bytes *bx, *by, *bp;
+
+	bx = mptobytes(p->x);
+	by = mptobytes(p->y);
+	bp = newbytes(bx->len + by->len + 1);
+	bp->data[0] =  type;
+	memmove(bp->data+1, bx->data, bx->len);
+	memmove(bp->data+1+bx->len, by->data, by->len);
+	freebytes(bx);
+	freebytes(by);
+	return bp;
+}
+
+static Bytes*
+tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys)
+{
+	Namedcurve *nc, *enc;
+	Bytes *epm;
+	ECdomain dom;
+	ECpoint G, K, Y;
+	ECpriv Q;
+
+	if(Ys == nil)
+		return nil;
+
+	enc = &namedcurves[nelem(namedcurves)];
+	for(nc = namedcurves; nc != enc; nc++)
+		if(nc->tlsid == curve)
+			break;
+
+	if(nc == enc)
+		return nil;
+		
+	memmove(sec->srandom, srandom, RandomSize);
+	if(setVers(sec, vers) < 0)
+		return nil;
+	
+	epm = nil;
+
+	memset(&dom, 0, sizeof(dom));
+	dom.p = strtomp(nc->p, nil, 16, nil);
+	dom.a = strtomp(nc->a, nil, 16, nil);
+	dom.b = strtomp(nc->b, nil, 16, nil);
+	dom.n = strtomp(nc->n, nil, 16, nil);
+	dom.h = strtomp(nc->h, nil, 16, nil);
+
+	memset(&G, 0, sizeof(G));
+	G.x = mpnew(0);
+	G.y = mpnew(0);
+
+	memset(&Q, 0, sizeof(Q));
+	Q.x = mpnew(0);
+	Q.y = mpnew(0);
+	Q.d = mpnew(0);
+
+	memset(&K, 0, sizeof(K));
+	K.x = mpnew(0);
+	K.y = mpnew(0);
+
+	memset(&Y, 0, sizeof(Y));
+	Y.x = mpnew(0);
+	Y.y = mpnew(0);
+
+	if(dom.p == nil || dom.a == nil || dom.b == nil || dom.n == nil || dom.h == nil)
+		goto Out;
+	if(Q.x == nil || Q.y == nil || Q.d == nil)
+		goto Out;
+	if(G.x == nil || G.y == nil)
+		goto Out;
+	if(K.x == nil || K.y == nil)
+		goto Out;
+	if(Y.x == nil || Y.y == nil)
+		goto Out;
+
+	dom.G = strtoec(&dom, nc->G, nil, &G);
+	if(dom.G == nil)
+		goto Out;
+
+	if(bytestoec(&dom, Ys, &Y) == nil)
+		goto Out;
+
+	if(ecgen(&dom, &Q) == nil)
+		goto Out;
+
+	ecmul(&dom, &Y, Q.d, &K);
+	setMasterSecret(sec, mptobytes(K.x));
+
+	/* 0x04 = uncompressed public key */
+	epm = ectobytes(0x04, &Q);
+	
+Out:
+	mpfree(Y.x);
+	mpfree(Y.y);
+
+	mpfree(K.x);
+	mpfree(K.y);
+
+	mpfree(Q.x);
+	mpfree(Q.y);
+	mpfree(Q.d);
+
+	mpfree(G.x);
+	mpfree(G.y);
+
+	mpfree(dom.p);
+	mpfree(dom.a);
+	mpfree(dom.b);
+	mpfree(dom.n);
+	mpfree(dom.h);
+
+	return epm;
+}
+
 static TlsConnection *
 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
 	int (*trace)(char*fmt, ...))
@@ -814,7 +1025,7 @@
 	msgClear(&m);
 
 	/* server key exchange */
-	dhx = isDHE(cipher);
+	dhx = isDHE(cipher) || isECDHE(cipher);
 	if(!msgRecv(c, &m))
 		goto Err;
 	if(m.tag == HServerKeyExchange) {
@@ -822,10 +1033,15 @@
 			tlsError(c, EUnexpectedMessage, "got an server key exchange");
 			goto Err;
 		}
-		epm = tlsSecDHEc(c->sec, c->srandom, c->version,
-			m.u.serverKeyExchange.dh_p, 
-			m.u.serverKeyExchange.dh_g,
-			m.u.serverKeyExchange.dh_Ys);
+		if(isECDHE(cipher))
+			epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
+				m.u.serverKeyExchange.curve,
+				m.u.serverKeyExchange.dh_Ys);
+		else
+			epm = tlsSecDHEc(c->sec, c->srandom, c->version,
+				m.u.serverKeyExchange.dh_p, 
+				m.u.serverKeyExchange.dh_g,
+				m.u.serverKeyExchange.dh_Ys);
 		if(epm == nil)
 			goto Badcert;
 		msgClear(&m);
@@ -1116,8 +1332,10 @@
 	case HClientKeyExchange:
 		n = m->u.clientKeyExchange.key->len;
 		if(c->version != SSL3Version){
-			put16(p, n);
-			p += 2;
+			if(isECDHE(c->cipher))
+				*p++ = n;
+			else
+				put16(p, n), p += 2;
 		}
 		memmove(p, m->u.clientKeyExchange.key->data, n);
 		p += n;
@@ -1416,31 +1634,49 @@
 	case HServerKeyExchange:
 		if(n < 2)
 			goto Short;
-		nn = get16(p);
-		p += 2, n -= 2;
-		if(nn < 1 || nn > n)
-			goto Short;
-		m->u.serverKeyExchange.dh_p = makebytes(p, nn);
-		p += nn, n -= nn;
+		if(isECDHE(c->cipher)){
+			nn = *p;
+			p++, n--;
+			if(nn != 3 || nn > n) /* not a named curve */
+				goto Short;
+			nn = get16(p);
+			p += 2, n -= 2;
+			m->u.serverKeyExchange.curve = nn;
 
-		if(n < 2)
-			goto Short;
-		nn = get16(p);
-		p += 2, n -= 2;
-		if(nn < 1 || nn > n)
-			goto Short;
-		m->u.serverKeyExchange.dh_g = makebytes(p, nn);
-		p += nn, n -= nn;
-
-		if(n < 2)
-			goto Short;
-		nn = get16(p);
-		p += 2, n -= 2;
-		if(nn < 1 || nn > n)
-			goto Short;
-		m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
-		p += nn, n -= nn;
-
+			nn = *p++, n--;
+			if(nn < 1 || nn > n)
+				goto Short;
+			m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
+			p += nn, n -= nn;
+		}else if(isDHE(c->cipher)){
+			nn = get16(p);
+			p += 2, n -= 2;
+			if(nn < 1 || nn > n)
+				goto Short;
+			m->u.serverKeyExchange.dh_p = makebytes(p, nn);
+			p += nn, n -= nn;
+	
+			if(n < 2)
+				goto Short;
+			nn = get16(p);
+			p += 2, n -= 2;
+			if(nn < 1 || nn > n)
+				goto Short;
+			m->u.serverKeyExchange.dh_g = makebytes(p, nn);
+			p += nn, n -= nn;
+	
+			if(n < 2)
+				goto Short;
+			nn = get16(p);
+			p += 2, n -= 2;
+			if(nn < 1 || nn > n)
+				goto Short;
+			m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
+			p += nn, n -= nn;
+		} else {
+			/* should not happen */
+			break;
+		}
 		if(n >= 2){
 			nn = get16(p);
 			p += 2, n -= 2;
@@ -1642,8 +1878,12 @@
 		break;
 	case HServerKeyExchange:
 		bs = seprint(bs, be, "HServerKeyExchange\n");
-		bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
-		bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
+		if(m->u.serverKeyExchange.curve != 0){
+			bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
+		} else {
+			bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
+			bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
+		}
 		bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
 		bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
 		break;
@@ -1762,6 +2002,7 @@
 
 	for(i = 0; i < nelem(cipherAlgs); i++){
 		if(cipherAlgs[i].tlsid == a){
+			c->cipher = a;
 			c->enc = cipherAlgs[i].enc;
 			c->digest = cipherAlgs[i].digest;
 			c->nsecret = cipherAlgs[i].nsecret;
@@ -1785,8 +2026,8 @@
 			weak = 0;
 		else
 			weak &= weakCipher[c];
-		if(isDHE(c))
-			continue;	/* TODO: dhe not implemented for server */
+		if(isDHE(c) || isECDHE(c))
+			continue;	/* TODO: not implemented for server */
 		for(j = 0; j < nelem(cipherAlgs); j++)
 			if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
 				return c;