shithub: riscv

Download patch

ref: 40360a992d03ccccf69a36fa20359ad029b3afcf
parent: a1bbf39c341e0b7ae0c999b0a34c85ab157aa6c9
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Wed May 20 22:26:57 EDT 2015

libsec: implement tlsClient support for RFC6066 server name identification (SNI)

tlsClient() now can optionally send the server_name in the ClientHello
message by setting the TLSconn.serverName. This is required for some
https sites.

--- a/sys/include/libsec.h
+++ b/sys/include/libsec.h
@@ -383,6 +383,7 @@
 	uchar	*sessionKey;
 	int	sessionKeylen;
 	char	*sessionConst;
+	char	*serverName;
 } TLSconn;
 
 /* tlshand.c */
--- a/sys/man/2/pushtls
+++ b/sys/man/2/pushtls
@@ -107,6 +107,7 @@
 	uchar *sessionKey;	/* opt IN/OUT session key */
 	int	sessionKeylen;	/* opt IN  session key length */
 	char	*sessionConst;	/* opt IN  session constant */
+	char	*serverName;	/* opt IN  server name */
 } TLSconn;
 .EE
 .PP
--- a/sys/src/libsec/port/tlshand.c
+++ b/sys/src/libsec/port/tlshand.c
@@ -96,13 +96,15 @@
 			Bytes*	sid;
 			Ints*	ciphers;
 			Bytes*	compressors;
+			Bytes*	extensions;
 		} clientHello;
 		struct {
 			int version;
-			uchar 	random[RandomSize];
+			uchar	random[RandomSize];
 			Bytes*	sid;
-			int cipher;
-			int compressor;
+			int	cipher;
+			int	compressor;
+			Bytes*	extensions;
 		} serverHello;
 		struct {
 			int ncert;
@@ -266,8 +268,8 @@
 	CompressionNull,
 };
 
-static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain);
-static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,  int (*trace)(char*fmt, ...));
+static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
+static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
 static void	msgClear(Msg *m);
 static char* msgPrint(char *buf, int n, Msg *m);
 static int	msgRecv(TlsConnection *c, Msg *m);
@@ -390,6 +392,33 @@
 	return data;
 }
 
+static uchar*
+tlsClientExtensions(TLSconn *conn, int *plen)
+{
+	uchar *b, *p;
+	int n, m;
+
+	p = b = nil;
+
+	// RFC6066 - Server Name Identification
+	if(conn->serverName != nil){
+		n = strlen(conn->serverName);
+		m = p - b;
+		b = erealloc(b, m+2+2+2+1+2+n);
+		p = b + m;
+		put16(p, 0), p += 2;		/* Type: server_name */
+		put16(p, 2+1+2+n), p += 2;	/* Length */
+		put16(p, 1+2+n), p += 2;	/* Server Name list length */
+		*p++ = 0;			/* Server Name Type: host_name */
+		put16(p, n), p += 2;		/* Server Name length */
+		memmove(p, conn->serverName, n);
+		p += n;
+	}
+
+	*plen = p - b;
+	return b;
+}
+
 //	push TLS onto fd, returning new (application) file descriptor
 //		or -1 if error.
 int
@@ -399,6 +428,7 @@
 	char dname[64];
 	int n, data, ctl, hand;
 	TlsConnection *tls;
+	uchar *ext;
 
 	if(conn == nil)
 		return -1;
@@ -426,7 +456,10 @@
 		return -1;
 	}
 	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
-	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, conn->trace);
+	ext = tlsClientExtensions(conn, &n);
+	tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, 
+		ext, n, conn->trace);
+	free(ext);
 	close(hand);
 	close(ctl);
 	if(tls == nil){
@@ -466,7 +499,7 @@
 }
 
 static TlsConnection *
-tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp)
+tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
 {
 	TlsConnection *c;
 	Msg m;
@@ -531,12 +564,12 @@
 		tlsError(c, EHandshakeFailure, "can't initialize security: %r");
 		goto Err;
 	}
-	c->sec->rpc = factotum_rsa_open(cert, ncert);
+	c->sec->rpc = factotum_rsa_open(cert, certlen);
 	if(c->sec->rpc == nil){
 		tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
 		goto Err;
 	}
-	c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0);
+	c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
 	if(c->sec->rsapub == nil){
 		tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
 		goto Err;
@@ -558,7 +591,7 @@
 	numcerts = countchain(chp);
 	m.u.certificate.ncert = 1 + numcerts;
 	m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
-	m.u.certificate.certs[0] = makebytes(cert, ncert);
+	m.u.certificate.certs[0] = makebytes(cert, certlen);
 	for (i = 0; i < numcerts && chp; i++, chp = chp->next)
 		m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
 	if(!msgSend(c, &m, AQueue))
@@ -702,7 +735,8 @@
 }
 
 static TlsConnection *
-tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...))
+tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
+	int (*trace)(char*fmt, ...))
 {
 	TlsConnection *c;
 	Msg m;
@@ -735,6 +769,7 @@
 	m.u.clientHello.sid = makebytes(csid, ncsid);
 	m.u.clientHello.ciphers = makeciphers();
 	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
+	m.u.clientHello.extensions = makebytes(ext, extlen);
 	if(!msgSend(c, &m, AFlush))
 		goto Err;
 	msgClear(&m);
@@ -1015,6 +1050,15 @@
 		p[0] = n;
 		memmove(p+1, m->u.clientHello.compressors->data, n);
 		p += n+1;
+
+		if(m->u.clientHello.extensions == nil)
+			break;
+		n = m->u.clientHello.extensions->len;
+		if(n == 0)
+			break;
+		put16(p, n);
+		memmove(p+2, m->u.clientHello.extensions->data, n);
+		p += n+2;
 		break;
 	case HServerHello:
 		put16(p, m->u.serverHello.version);
@@ -1035,6 +1079,15 @@
 		p += 2;
 		p[0] = m->u.serverHello.compressor;
 		p += 1;
+
+		if(m->u.serverHello.extensions == nil)
+			break;
+		n = m->u.serverHello.extensions->len;
+		if(n == 0)
+			break;
+		put16(p, n);
+		memmove(p+2, m->u.serverHello.extensions->data, n);
+		p += n+2;
 		break;
 	case HServerHelloDone:
 		break;
@@ -1249,9 +1302,17 @@
 		if(n < 1 || n < p[0]+1 || p[0] == 0)
 			goto Short;
 		nn = p[0];
-		m->u.clientHello.compressors = newbytes(nn);
-		memmove(m->u.clientHello.compressors->data, p+1, nn);
+		m->u.clientHello.compressors = makebytes(p+1, nn);
+		p += nn + 1;
 		n -= nn + 1;
+
+		if(n < 2)
+			break;
+		nn = get16(p);
+		if(nn > n-2)
+			goto Short;
+		m->u.clientHello.extensions = makebytes(p+2, nn);
+		n -= nn + 2;
 		break;
 	case HServerHello:
 		if(n < 2)
@@ -1276,7 +1337,16 @@
 			goto Short;
 		m->u.serverHello.cipher = get16(p);
 		m->u.serverHello.compressor = p[2];
+		p += 3;
 		n -= 3;
+
+		if(n < 2)
+			break;
+		nn = get16(p);
+		if(nn > n-2)
+			goto Short;
+		m->u.serverHello.extensions = makebytes(p+2, nn);
+		n -= nn + 2;
 		break;
 	case HCertificate:
 		if(n < 3)
@@ -1409,7 +1479,7 @@
 		break;
 	}
 
-	if(type != HClientHello && n != 0)
+	if(type != HClientHello && type != HServerHello && n != 0)
 		goto Short;
 Ok:
 	if(c->trace){
@@ -1440,9 +1510,11 @@
 		freebytes(m->u.clientHello.sid);
 		freeints(m->u.clientHello.ciphers);
 		freebytes(m->u.clientHello.compressors);
+		freebytes(m->u.clientHello.extensions);
 		break;
 	case HServerHello:
-		freebytes(m->u.clientHello.sid);
+		freebytes(m->u.serverHello.sid);
+		freebytes(m->u.serverHello.extensions);
 		break;
 	case HCertificate:
 		for(i=0; i<m->u.certificate.ncert; i++)
@@ -1534,6 +1606,8 @@
 		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
 		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
 		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
+		if(m->u.clientHello.extensions != nil)
+			bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
 		break;
 	case HServerHello:
 		bs = seprint(bs, be, "ServerHello\n");
@@ -1545,6 +1619,8 @@
 		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
 		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
 		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
+		if(m->u.serverHello.extensions != nil)
+			bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
 		break;
 	case HCertificate:
 		bs = seprint(bs, be, "Certificate\n");
--