shithub: riscv

ref: 168077885110cb40635ffe6afe5ecab2e4000a87
dir: /sys/src/libsec/port/tlshand.c/

View raw version
#include <u.h>
#include <libc.h>
#include <auth.h>
#include <mp.h>
#include <libsec.h>

// The main groups of functions are:
//		client/server - main handshake protocol definition
//		message functions - formating handshake messages
//		cipher choices - catalog of digest and encrypt algorithms
//		security functions - PKCS#1, sslHMAC, session keygen
//		general utility functions - malloc, serialization
// The handshake protocol builds on the TLS/SSL3 record layer protocol,
// which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.

enum {
	TLSFinishedLen = 12,
	SSL3FinishedLen = MD5dlen+SHA1dlen,
	MaxKeyData = 160,	// amount of secret we may need
	MAXdlen = SHA2_512dlen,
	RandomSize = 32,
	MasterSecretSize = 48,
	AQueue = 0,
	AFlush = 1,
};

typedef struct Bytes{
	int len;
	uchar data[];
} Bytes;

typedef struct Ints{
	int len;
	int data[];
} Ints;

typedef struct Algs{
	char *enc;
	char *digest;
	int nsecret;
	int tlsid;
	int ok;
} Algs;

typedef struct Namedcurve{
	int tlsid;
	void (*init)(mpint *p, mpint *a, mpint *b, mpint *x, mpint *y, mpint *n, mpint *h);
} Namedcurve;

typedef struct Finished{
	uchar verify[SSL3FinishedLen];
	int n;
} Finished;

typedef struct HandshakeHash {
	MD5state	md5;
	SHAstate	sha1;
	SHA2_256state	sha2_256;
} HandshakeHash;

typedef struct TlsSec TlsSec;
struct TlsSec {
	RSApub *rsapub;
	AuthRpc *rpc;	// factotum for rsa private key
	uchar *psk;	// pre-shared key
	int psklen;
	int clientVers;			// version in ClientHello
	uchar sec[MasterSecretSize];	// master secret
	uchar srandom[RandomSize];	// server random
	uchar crandom[RandomSize];	// client random

	Namedcurve *nc; // selected curve for ECDHE
	// diffie hellman state
	DHstate dh;
	struct {
		ECdomain dom;
		ECpriv Q;
	} ec;
	uchar X[32];

	// byte generation and handshake checksum
	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int);
	void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int);
	int nfin;
};

typedef struct TlsConnection{
	TlsSec sec[1];	// security management goo
	int hand, ctl;	// record layer file descriptors
	int erred;		// set when tlsError called
	int (*trace)(char*fmt, ...); // for debugging
	int version;	// protocol we are speaking
	Bytes *cert;	// server certificate; only last - no chain

	int cipher;
	int nsecret;	// amount of secret data to init keys
	char *digest;	// name of digest algorithm to use
	char *enc;	// name of encryption algorithm to use

	// for finished messages
	HandshakeHash	handhash;
	Finished	finished;

	uchar *sendp;
	uchar buf[1<<16];
} TlsConnection;

typedef struct Msg{
	int tag;
	union {
		struct {
			int	version;
			uchar 	random[RandomSize];
			Bytes*	sid;
			Ints*	ciphers;
			Bytes*	compressors;
			Bytes*	extensions;
		} clientHello;
		struct {
			int	version;
			uchar	random[RandomSize];
			Bytes*	sid;
			int	cipher;
			int	compressor;
			Bytes*	extensions;
		} serverHello;
		struct {
			int ncert;
			Bytes **certs;
		} certificate;
		struct {
			Bytes *types;
			Ints *sigalgs;
			int nca;
			Bytes **cas;
		} certificateRequest;
		struct {
			Bytes *pskid;
			Bytes *key;
		} clientKeyExchange;
		struct {
			Bytes *pskid;
			Bytes *dh_p;
			Bytes *dh_g;
			Bytes *dh_Ys;
			Bytes *dh_parameters;
			Bytes *dh_signature;
			int sigalg;
			int curve;
		} serverKeyExchange;
		struct {
			int sigalg;
			Bytes *signature;
		} certificateVerify;		
		Finished finished;
	} u;
} Msg;


enum {
	SSL3Version	= 0x0300,
	TLS10Version	= 0x0301,
	TLS11Version	= 0x0302,
	TLS12Version	= 0x0303,
	ProtocolVersion	= TLS12Version,	// maximum version we speak
	MinProtoVersion	= 0x0300,	// limits on version we accept
	MaxProtoVersion	= 0x03ff,
};

// handshake type
enum {
	HHelloRequest,
	HClientHello,
	HServerHello,
	HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
	HCertificate = 11,
	HServerKeyExchange,
	HCertificateRequest,
	HServerHelloDone,
	HCertificateVerify,
	HClientKeyExchange,
	HFinished = 20,
	HMax
};

// alerts
enum {
	ECloseNotify = 0,
	EUnexpectedMessage = 10,
	EBadRecordMac = 20,
	EDecryptionFailed = 21,
	ERecordOverflow = 22,
	EDecompressionFailure = 30,
	EHandshakeFailure = 40,
	ENoCertificate = 41,
	EBadCertificate = 42,
	EUnsupportedCertificate = 43,
	ECertificateRevoked = 44,
	ECertificateExpired = 45,
	ECertificateUnknown = 46,
	EIllegalParameter = 47,
	EUnknownCa = 48,
	EAccessDenied = 49,
	EDecodeError = 50,
	EDecryptError = 51,
	EExportRestriction = 60,
	EProtocolVersion = 70,
	EInsufficientSecurity = 71,
	EInternalError = 80,
	EInappropriateFallback = 86,
	EUserCanceled = 90,
	ENoRenegotiation = 100,
	EUnknownPSKidentity = 115,
	EMax = 256
};

// cipher suites
enum {
	TLS_RSA_WITH_3DES_EDE_CBC_SHA		= 0X000A,
	TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA	= 0X0016,

	TLS_RSA_WITH_AES_128_CBC_SHA		= 0X002F,
	TLS_DHE_RSA_WITH_AES_128_CBC_SHA	= 0X0033,
	TLS_RSA_WITH_AES_256_CBC_SHA		= 0X0035,
	TLS_DHE_RSA_WITH_AES_256_CBC_SHA	= 0X0039,
	TLS_RSA_WITH_AES_128_CBC_SHA256		= 0X003C,
	TLS_RSA_WITH_AES_256_CBC_SHA256		= 0X003D,
	TLS_DHE_RSA_WITH_AES_128_CBC_SHA256	= 0X0067,

	TLS_RSA_WITH_AES_128_GCM_SHA256		= 0x009C,
	TLS_DHE_RSA_WITH_AES_128_GCM_SHA256	= 0x009E,

	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA	= 0xC013,
	TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA	= 0xC014,
	TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256	= 0xC023,
	TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256	= 0xC027,

	TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B,
	TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256	= 0xC02F,

	GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305		= 0xCC13,
	GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305	= 0xCC14,
	GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305		= 0xCC15,

	TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305	= 0xCCA8,
	TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305	= 0xCCA9,
	TLS_DHE_RSA_WITH_CHACHA20_POLY1305	= 0xCCAA,

	TLS_PSK_WITH_CHACHA20_POLY1305		= 0xCCAB,
	TLS_PSK_WITH_AES_128_CBC_SHA256		= 0x00AE,
	TLS_PSK_WITH_AES_128_CBC_SHA		= 0x008C,

	TLS_FALLBACK_SCSV = 0x5600,
};

// compression methods
enum {
	CompressionNull = 0,
	CompressionMax
};


// curves
enum {
	X25519 = 0x001d,
};

// extensions
enum {
	Extsni = 0x0000,
	Extec = 0x000a,
	Extecp = 0x000b,
	Extsigalgs = 0x000d,
};

static Algs cipherAlgs[] = {
	// ECDHE-ECDSA
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305},
	{"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305},
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256},

	// ECDHE-RSA
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305},
	{"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305},
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256},
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},

	// DHE-RSA
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_DHE_RSA_WITH_CHACHA20_POLY1305},
	{"ccpoly64_aead", "clear", 2*32, GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305},
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_DHE_RSA_WITH_AES_128_GCM_SHA256},
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA256},
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
	{"3des_ede_cbc","sha1",	2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},

	// RSA
	{"aes_128_gcm_aead", "clear", 2*(16+4), TLS_RSA_WITH_AES_128_GCM_SHA256},
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_RSA_WITH_AES_128_CBC_SHA256},
	{"aes_256_cbc", "sha256", 2*(32+16+SHA2_256dlen), TLS_RSA_WITH_AES_256_CBC_SHA256},
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
	{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
	{"3des_ede_cbc","sha1",	2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},

	// PSK
	{"ccpoly96_aead", "clear", 2*(32+12), TLS_PSK_WITH_CHACHA20_POLY1305},
	{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_PSK_WITH_AES_128_CBC_SHA256},
	{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_PSK_WITH_AES_128_CBC_SHA},
};

static uchar compressors[] = {
	CompressionNull,
};

static Namedcurve namedcurves[] = {
	X25519, nil,
	0x0017, secp256r1,
	0x0018, secp384r1,
};

static uchar pointformats[] = {
	CompressionNull /* support of uncompressed point format is mandatory */
};

static struct {
	DigestState* (*fun)(uchar*, ulong, uchar*, DigestState*);
	int len;
} hashfun[] = {
/*	[0x00]  is reserved for MD5+SHA1 for < TLS1.2 */
	[0x01]	{md5,		MD5dlen},
	[0x02]	{sha1,		SHA1dlen},
	[0x03]	{sha2_224,	SHA2_224dlen},
	[0x04]	{sha2_256,	SHA2_256dlen},
	[0x05]	{sha2_384,	SHA2_384dlen},
	[0x06]	{sha2_512,	SHA2_512dlen},
};

// signature algorithms (only RSA and ECDSA at the moment)
static int sigalgs[] = {
	0x0603,		/* SHA512 ECDSA */
	0x0503,		/* SHA384 ECDSA */
	0x0403,		/* SHA256 ECDSA */
	0x0203,		/* SHA1 ECDSA */

	0x0601,		/* SHA512 RSA */
	0x0501,		/* SHA384 RSA */
	0x0401,		/* SHA256 RSA */
	0x0201,		/* SHA1 RSA */
};

static TlsConnection *tlsServer2(int ctl, int hand,
	uchar *cert, int certlen,
	char *pskid, uchar *psk, int psklen,
	int (*trace)(char*fmt, ...), PEMChain *chain);
static TlsConnection *tlsClient2(int ctl, int hand,
	uchar *cert, int certlen,
	char *pskid, uchar *psk, int psklen,
	uchar *ext, int extlen, int (*trace)(char*fmt, ...));
static void	msgClear(Msg *m);
static char* msgPrint(char *buf, int n, Msg *m);
static int	msgRecv(TlsConnection *c, Msg *m);
static int	msgSend(TlsConnection *c, Msg *m, int act);
static void	tlsError(TlsConnection *c, int err, char *msg, ...);
#pragma	varargck argpos	tlsError 3
static int setVersion(TlsConnection *c, int version);
static int setSecrets(TlsConnection *c, int isclient);
static int finishedMatch(TlsConnection *c, Finished *f);
static void tlsConnectionFree(TlsConnection *c);

static int isDHE(int tlsid);
static int isECDHE(int tlsid);
static int isPSK(int tlsid);
static int isECDSA(int tlsid);

static int setAlgs(TlsConnection *c, int a);
static int okCipher(Ints *cv, int ispsk, int canec);
static int okCompression(Bytes *cv);
static int initCiphers(void);
static Ints* makeciphers(int ispsk);

static AuthRpc* factotum_rsa_open(RSApub *rsapub);
static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
static void factotum_rsa_close(AuthRpc *rpc);

static void	tlsSecInits(TlsSec *sec, int cvers, uchar *crandom);
static int	tlsSecRSAs(TlsSec *sec, Bytes *epm);
static Bytes*	tlsSecECDHEs1(TlsSec *sec);
static int	tlsSecECDHEs2(TlsSec *sec, Bytes *Yc);
static void	tlsSecInitc(TlsSec *sec, int cvers);
static Bytes*	tlsSecRSAc(TlsSec *sec, uchar *cert, int ncert);
static Bytes*	tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys);
static Bytes*	tlsSecECDHEc(TlsSec *sec, int curve, Bytes *Ys);
static void	tlsSecVers(TlsSec *sec, int v);
static int	tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
static void	setMasterSecret(TlsSec *sec, Bytes *pm);
static int	digestDHparams(TlsSec *sec, Bytes *par, uchar digest[MAXdlen], int sigalg);
static char*	verifyDHparams(TlsSec *sec, Bytes *par, Bytes *cert, Bytes *sig, int sigalg);

static Bytes*	pkcs1_encrypt(Bytes* data, RSApub* key);
static Bytes*	pkcs1_decrypt(TlsSec *sec, Bytes *data);
static Bytes*	pkcs1_sign(TlsSec *sec, uchar *digest, int digestlen, int sigalg);

static void* emalloc(int);
static void* erealloc(void*, int);
static void put32(uchar *p, u32int);
static void put24(uchar *p, int);
static void put16(uchar *p, int);
static int get24(uchar *p);
static int get16(uchar *p);
static Bytes* newbytes(int len);
static Bytes* makebytes(uchar* buf, int len);
static Bytes* mptobytes(mpint* big, int len);
static mpint* bytestomp(Bytes* bytes);
static void freebytes(Bytes* b);
static Ints* newints(int len);
static void freeints(Ints* b);
static int lookupid(Ints* b, int id);

//================= client/server ========================

//	push TLS onto fd, returning new (application) file descriptor
//		or -1 if error.
int
tlsServer(int fd, TLSconn *conn)
{
	char buf[8];
	char dname[64];
	uchar seed[2*RandomSize];
	int n, data, ctl, hand;
	TlsConnection *tls;

	if(conn == nil)
		return -1;
	ctl = open("#a/tls/clone", ORDWR|OCEXEC);
	if(ctl < 0)
		return -1;
	n = read(ctl, buf, sizeof(buf)-1);
	if(n < 0){
		close(ctl);
		return -1;
	}
	buf[n] = 0;
	snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
	snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
	hand = open(dname, ORDWR|OCEXEC);
	if(hand < 0){
		close(ctl);
		return -1;
	}
	data = -1;
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
	tls = tlsServer2(ctl, hand,
		conn->cert, conn->certlen,
		conn->pskID, conn->psk, conn->psklen,
		conn->trace, conn->chain);
	if(tls != nil){
		snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
		data = open(dname, ORDWR);
	}
	close(hand);
	close(ctl);
	if(data < 0){
		tlsConnectionFree(tls);
		return -1;
	}
	free(conn->cert);
	conn->cert = nil;  // client certificates are not yet implemented
	conn->certlen = 0;
	conn->sessionIDlen = 0;
	conn->sessionID = nil;
	if(conn->sessionKey != nil
	&& conn->sessionType != nil
	&& strcmp(conn->sessionType, "ttls") == 0){
		memmove(seed, tls->sec->crandom, RandomSize);
		memmove(seed+RandomSize, tls->sec->srandom, RandomSize);
		tls->sec->prf(
			conn->sessionKey, conn->sessionKeylen,
			tls->sec->sec, MasterSecretSize,
			conn->sessionConst, 
			seed, sizeof(seed));
	}
	tlsConnectionFree(tls);
	close(fd);
	return data;
}

static uchar*
tlsClientExtensions(TLSconn *conn, int *plen)
{
	uchar *b, *p;
	int i, n, m;

	p = b = nil;

	// RFC6066 - Server Name Identification
	if(conn->serverName != nil && (n = strlen(conn->serverName)) > 0){
		m = p - b;
		b = erealloc(b, m + 2+2+2+1+2+n);
		p = b + m;

		put16(p, Extsni), p += 2;	/* Type: server_name */
		put16(p, 2+1+2+n), p += 2;	/* Length */
		put16(p, 1+2+n), p += 2;	/* Server Name list length */
		*p++ = 0;			/* Server Name Type: host_name */
		put16(p, n), p += 2;		/* Server Name length */
		memmove(p, conn->serverName, n);
		p += n;
	}

	// Elliptic Curves (also called Supported Groups)
	if(ProtocolVersion >= TLS10Version){
		m = p - b;
		b = erealloc(b, m + 2+2+2+nelem(namedcurves)*2 + 2+2+1+nelem(pointformats));
		p = b + m;

		n = nelem(namedcurves);
		put16(p, Extec), p += 2;	/* Type: elliptic_curves / supported_groups */
		put16(p, (n+1)*2), p += 2;	/* Length */
		put16(p, n*2), p += 2;		/* Elliptic Curves Length */
		for(i=0; i < n; i++){		/* Elliptic Curves */
			put16(p, namedcurves[i].tlsid);
			p += 2;
		}

		n = nelem(pointformats);
		put16(p, Extecp), p += 2;	/* Type: ec_point_formats */
		put16(p, n+1), p += 2;		/* Length */
		*p++ = n;			/* EC point formats Length */
		for(i=0; i < n; i++)		/* EC point formats */
			*p++ = pointformats[i];
	}

	// signature algorithms
	if(ProtocolVersion >= TLS12Version){
		n = nelem(sigalgs);

		m = p - b;
		b = erealloc(b, m + 2+2+2+n*2);
		p = b + m;

		put16(p, Extsigalgs), p += 2;
		put16(p, n*2 + 2), p += 2;
		put16(p, n*2), p += 2;
		for(i=0; i < n; i++){
			put16(p, sigalgs[i]);
			p += 2;
		}
	}
	
	*plen = p - b;
	return b;
}

//	push TLS onto fd, returning new (application) file descriptor
//		or -1 if error.
int
tlsClient(int fd, TLSconn *conn)
{
	char buf[8];
	char dname[64];
	uchar seed[2*RandomSize];
	int n, data, ctl, hand;
	TlsConnection *tls;
	uchar *ext;

	if(conn == nil)
		return -1;
	ctl = open("#a/tls/clone", ORDWR|OCEXEC);
	if(ctl < 0)
		return -1;
	n = read(ctl, buf, sizeof(buf)-1);
	if(n < 0){
		close(ctl);
		return -1;
	}
	buf[n] = 0;
	snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
	snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
	hand = open(dname, ORDWR|OCEXEC);
	if(hand < 0){
		close(ctl);
		return -1;
	}
	snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
	data = open(dname, ORDWR);
	if(data < 0){
		close(hand);
		close(ctl);
		return -1;
	}
	fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
	ext = tlsClientExtensions(conn, &n);
	tls = tlsClient2(ctl, hand,
		conn->cert, conn->certlen, 
		conn->pskID, conn->psk, conn->psklen,
		ext, n, conn->trace);
	free(ext);
	close(hand);
	close(ctl);
	if(tls == nil){
		close(data);
		return -1;
	}
	free(conn->cert);
	if(tls->cert != nil){
		conn->certlen = tls->cert->len;
		conn->cert = emalloc(conn->certlen);
		memcpy(conn->cert, tls->cert->data, conn->certlen);
	} else {
		conn->certlen = 0;
		conn->cert = nil;
	}
	conn->sessionIDlen = 0;
	conn->sessionID = nil;
	if(conn->sessionKey != nil
	&& conn->sessionType != nil
	&& strcmp(conn->sessionType, "ttls") == 0){
		memmove(seed, tls->sec->crandom, RandomSize);
		memmove(seed+RandomSize, tls->sec->srandom, RandomSize);
		tls->sec->prf(
			conn->sessionKey, conn->sessionKeylen,
			tls->sec->sec, MasterSecretSize,
			conn->sessionConst, 
			seed, sizeof(seed));
	}
	tlsConnectionFree(tls);
	close(fd);
	return data;
}

static int
countchain(PEMChain *p)
{
	int i = 0;

	while (p) {
		i++;
		p = p->next;
	}
	return i;
}

static int
checkClientExtensions(TlsConnection *c, Bytes *ext)
{
	uchar *p, *e;
	int i, j, n;

	if(ext == nil)
		return 0;

	for(p = ext->data, e = p+ext->len; p < e; p += n){
		if(e-p < 4)
			goto Short;
		p += 4;
		if(e-p < (n = get16(p-2)))
			goto Short;
		switch(get16(p-4)){
		case Extec:
			if(n < 4 || n % 2 || get16(p) != (n -= 2))
				goto Short;
			p += 2;
			for(i = 0; i < nelem(namedcurves) && c->sec->nc == nil; i++)
				for(j = 0; j < n; j += 2)
					if(namedcurves[i].tlsid == get16(p+j)){
						c->sec->nc = &namedcurves[i];
						break;
					}
			break;
		}
	}

	return 0;
Short:
	tlsError(c, EDecodeError, "clienthello extensions has invalid length");
	return -1; 
} 

static TlsConnection *
tlsServer2(int ctl, int hand,
	uchar *cert, int certlen,
	char *pskid, uchar *psk, int psklen,
	int (*trace)(char*fmt, ...), PEMChain *chp)
{
	int cipher, compressor, numcerts, i;
	TlsConnection *c;
	Msg m;

	if(trace)
		trace("tlsServer2\n");
	if(!initCiphers())
		return nil;

	c = emalloc(sizeof(TlsConnection));
	c->ctl = ctl;
	c->hand = hand;
	c->trace = trace;
	c->version = ProtocolVersion;
	c->sendp = c->buf;

	memset(&m, 0, sizeof(m));
	if(!msgRecv(c, &m)){
		if(trace)
			trace("initial msgRecv failed\n");
		goto Err;
	}
	if(m.tag != HClientHello) {
		tlsError(c, EUnexpectedMessage, "expected a client hello");
		goto Err;
	}
	if(trace)
		trace("ClientHello version %x\n", m.u.clientHello.version);
	if(setVersion(c, m.u.clientHello.version) < 0) {
		tlsError(c, EIllegalParameter, "incompatible version");
		goto Err;
	}
	if(c->version < ProtocolVersion
	&& lookupid(m.u.clientHello.ciphers, TLS_FALLBACK_SCSV) >= 0){
		tlsError(c, EInappropriateFallback, "inappropriate fallback");
		goto Err;
	}
	tlsSecInits(c->sec, m.u.clientHello.version, m.u.clientHello.random);
	tlsSecVers(c->sec, c->version);
	if(psklen > 0){
		c->sec->psk = psk;
		c->sec->psklen = psklen;
	}
	if(certlen > 0){
		/* server certificate */
		c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
		if(c->sec->rsapub == nil){
			tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
			goto Err;
		}
		c->sec->rpc = factotum_rsa_open(c->sec->rsapub);
		if(c->sec->rpc == nil){
			tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
			goto Err;
		}
	}
	if(checkClientExtensions(c, m.u.clientHello.extensions) < 0)
		goto Err;
	cipher = okCipher(m.u.clientHello.ciphers, psklen > 0, c->sec->nc != nil);
	if(cipher < 0 || !setAlgs(c, cipher)) {
		tlsError(c, EHandshakeFailure, "no matching cipher suite");
		goto Err;
	}
	compressor = okCompression(m.u.clientHello.compressors);
	if(compressor < 0) {
		tlsError(c, EHandshakeFailure, "no matching compressor");
		goto Err;
	}
	if(trace)
		trace("  cipher %x, compressor %x\n", cipher, compressor);
	msgClear(&m);

	m.tag = HServerHello;
	m.u.serverHello.version = c->version;
	memmove(m.u.serverHello.random, c->sec->srandom, RandomSize);
	m.u.serverHello.cipher = cipher;
	m.u.serverHello.compressor = compressor;
	m.u.serverHello.sid = makebytes(nil, 0);
	if(!msgSend(c, &m, AQueue))
		goto Err;

	if(certlen > 0){
		m.tag = HCertificate;
		numcerts = countchain(chp);
		m.u.certificate.ncert = 1 + numcerts;
		m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
		m.u.certificate.certs[0] = makebytes(cert, certlen);
		for (i = 0; i < numcerts && chp; i++, chp = chp->next)
			m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
		if(!msgSend(c, &m, AQueue))
			goto Err;
	}

	if(isECDHE(cipher)){
		m.tag = HServerKeyExchange;
		m.u.serverKeyExchange.curve = c->sec->nc->tlsid;
		m.u.serverKeyExchange.dh_parameters = tlsSecECDHEs1(c->sec);
		if(m.u.serverKeyExchange.dh_parameters == nil){
			tlsError(c, EInternalError, "can't set DH parameters");
			goto Err;
		}

		/* sign the DH parameters */
		if(certlen > 0){
			uchar digest[MAXdlen];
			int digestlen;

			if(c->version >= TLS12Version)
				m.u.serverKeyExchange.sigalg = 0x0401;	/* RSA SHA256 */
			digestlen = digestDHparams(c->sec, m.u.serverKeyExchange.dh_parameters,
				digest, m.u.serverKeyExchange.sigalg);
			if((m.u.serverKeyExchange.dh_signature = pkcs1_sign(c->sec, digest, digestlen,
				m.u.serverKeyExchange.sigalg)) == nil){
				tlsError(c, EHandshakeFailure, "pkcs1_sign: %r");
				goto Err;
			}
		}
		if(!msgSend(c, &m, AQueue))
			goto Err;
	}

	m.tag = HServerHelloDone;
	if(!msgSend(c, &m, AFlush))
		goto Err;

	if(!msgRecv(c, &m))
		goto Err;
	if(m.tag != HClientKeyExchange) {
		tlsError(c, EUnexpectedMessage, "expected a client key exchange");
		goto Err;
	}
	if(pskid != nil){
		if(m.u.clientKeyExchange.pskid == nil
		|| m.u.clientKeyExchange.pskid->len != strlen(pskid)
		|| memcmp(pskid, m.u.clientKeyExchange.pskid->data, m.u.clientKeyExchange.pskid->len) != 0){
			tlsError(c, EUnknownPSKidentity, "unknown or missing pskid");
			goto Err;
		}
	}
	if(isECDHE(cipher)){
		if(tlsSecECDHEs2(c->sec, m.u.clientKeyExchange.key) < 0){
			tlsError(c, EHandshakeFailure, "couldn't set keys: %r");
			goto Err;
		}
	} else if(certlen > 0){
		if(tlsSecRSAs(c->sec, m.u.clientKeyExchange.key) < 0){
			tlsError(c, EHandshakeFailure, "couldn't set keys: %r");
			goto Err;
		}
	} else if(psklen > 0){
		setMasterSecret(c->sec, newbytes(psklen));
	} else {
		tlsError(c, EInternalError, "no psk or certificate");
		goto Err;
	}

	if(trace)
		trace("tls secrets\n");
	if(setSecrets(c, 0) < 0){
		tlsError(c, EHandshakeFailure, "can't set secrets: %r");
		goto Err;
	}

	/* no CertificateVerify; skip to Finished */
	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
		tlsError(c, EInternalError, "can't set finished: %r");
		goto Err;
	}
	if(!msgRecv(c, &m))
		goto Err;
	if(m.tag != HFinished) {
		tlsError(c, EUnexpectedMessage, "expected a finished");
		goto Err;
	}
	if(!finishedMatch(c, &m.u.finished)) {
		tlsError(c, EHandshakeFailure, "finished verification failed");
		goto Err;
	}
	msgClear(&m);

	/* change cipher spec */
	if(fprint(c->ctl, "changecipher") < 0){
		tlsError(c, EInternalError, "can't enable cipher: %r");
		goto Err;
	}

	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
		tlsError(c, EInternalError, "can't set finished: %r");
		goto Err;
	}
	m.tag = HFinished;
	m.u.finished = c->finished;
	if(!msgSend(c, &m, AFlush))
		goto Err;
	if(trace)
		trace("tls finished\n");

	if(fprint(c->ctl, "opened") < 0)
		goto Err;
	return c;

Err:
	msgClear(&m);
	tlsConnectionFree(c);
	return nil;
}

static Bytes*
tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys)
{
	DHstate *dh = &sec->dh;
	mpint *G, *P, *Y, *K;
	Bytes *Yc;
	int n;

	if(p == nil || g == nil || Ys == nil)
		return nil;
	// reject dh primes that is susceptible to logjam
	if(p->len <= 1024/8)
		return nil;
	Yc = nil;
	P = bytestomp(p);
	G = bytestomp(g);
	Y = bytestomp(Ys);
	K = nil;

	if(dh_new(dh, P, nil, G) == nil)
		goto Out;
	n = (mpsignif(P)+7)/8;
	Yc = mptobytes(dh->y, n);
	K = dh_finish(dh, Y);	/* zeros dh */
	if(K == nil){
		freebytes(Yc);
		Yc = nil;
		goto Out;
	}
	setMasterSecret(sec, mptobytes(K, n));

Out:
	mpfree(K);
	mpfree(Y);
	mpfree(G);
	mpfree(P);

	return Yc;
}

static Bytes*
tlsSecECDHEc(TlsSec *sec, int curve, Bytes *Ys)
{
	ECdomain *dom = &sec->ec.dom;
	ECpriv *Q = &sec->ec.Q;
	ECpub *pub;
	ECpoint K;
	Namedcurve *nc;
	Bytes *Yc;
	Bytes *Z;
	int n;

	if(Ys == nil)
		return nil;

	if(curve == X25519){
		if(Ys->len != 32)
			return nil;
		Yc = newbytes(32);
		curve25519_dh_new(sec->X, Yc->data);
		Z = newbytes(32);
		if(!curve25519_dh_finish(sec->X, Ys->data, Z->data)){
			freebytes(Yc);
			freebytes(Z);
			return nil;
		}
		setMasterSecret(sec, Z);
	}else{
		for(nc = namedcurves; nc->tlsid != curve; nc++)
			if(nc == &namedcurves[nelem(namedcurves)])
				return nil;
		ecdominit(dom, nc->init);
		pub = ecdecodepub(dom, Ys->data, Ys->len);
		if(pub == nil)
			return nil;

		memset(Q, 0, sizeof(*Q));
		Q->x = mpnew(0);
		Q->y = mpnew(0);
		Q->d = mpnew(0);

		memset(&K, 0, sizeof(K));
		K.x = mpnew(0);
		K.y = mpnew(0);

		ecgen(dom, Q);
		ecmul(dom, pub, Q->d, &K);

		n = (mpsignif(dom->p)+7)/8;
		setMasterSecret(sec, mptobytes(K.x, n));
		Yc = newbytes(1 + 2*n);
		Yc->len = ecencodepub(dom, Q, Yc->data, Yc->len);

		mpfree(K.x);
		mpfree(K.y);

		ecpubfree(pub);
	}
	return Yc;
}

static TlsConnection *
tlsClient2(int ctl, int hand,
	uchar *cert, int certlen,
	char *pskid, uchar *psk, int psklen,
	uchar *ext, int extlen,
	int (*trace)(char*fmt, ...))
{
	int creq, dhx, cipher;
	TlsConnection *c;
	Bytes *epm;
	Msg m;

	if(!initCiphers())
		return nil;

	epm = nil;
	memset(&m, 0, sizeof(m));
	c = emalloc(sizeof(TlsConnection));

	c->ctl = ctl;
	c->hand = hand;
	c->trace = trace;
	c->cert = nil;
	c->sendp = c->buf;

	c->version = ProtocolVersion;
	tlsSecInitc(c->sec, c->version);
	if(psklen > 0){
		c->sec->psk = psk;
		c->sec->psklen = psklen;
	}
	if(certlen > 0){
		/* client certificate */
		c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
		if(c->sec->rsapub == nil){
			tlsError(c, EInternalError, "invalid X509/rsa certificate");
			goto Err;
		}
		c->sec->rpc = factotum_rsa_open(c->sec->rsapub);
		if(c->sec->rpc == nil){
			tlsError(c, EInternalError, "factotum_rsa_open: %r");
			goto Err;
		}
	}

	/* client hello */
	m.tag = HClientHello;
	m.u.clientHello.version = c->version;
	memmove(m.u.clientHello.random, c->sec->crandom, RandomSize);
	m.u.clientHello.sid = makebytes(nil, 0);
	m.u.clientHello.ciphers = makeciphers(psklen > 0);
	m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
	m.u.clientHello.extensions = makebytes(ext, extlen);
	if(!msgSend(c, &m, AFlush))
		goto Err;

	/* server hello */
	if(!msgRecv(c, &m))
		goto Err;
	if(m.tag != HServerHello) {
		tlsError(c, EUnexpectedMessage, "expected a server hello");
		goto Err;
	}
	if(setVersion(c, m.u.serverHello.version) < 0) {
		tlsError(c, EIllegalParameter, "incompatible version: %r");
		goto Err;
	}
	tlsSecVers(c->sec, c->version);
	memmove(c->sec->srandom, m.u.serverHello.random, RandomSize);

	cipher = m.u.serverHello.cipher;
	if((psklen > 0) != isPSK(cipher) || !setAlgs(c, cipher)) {
		tlsError(c, EIllegalParameter, "invalid cipher suite");
		goto Err;
	}
	if(m.u.serverHello.compressor != CompressionNull) {
		tlsError(c, EIllegalParameter, "invalid compression");
		goto Err;
	}
	dhx = isDHE(cipher) || isECDHE(cipher);
	if(!msgRecv(c, &m))
		goto Err;
	if(m.tag == HCertificate){
		if(m.u.certificate.ncert < 1) {
			tlsError(c, EIllegalParameter, "runt certificate");
			goto Err;
		}
		c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
		if(!msgRecv(c, &m))
			goto Err;
	} else if(psklen == 0) {
		tlsError(c, EUnexpectedMessage, "expected a certificate");
		goto Err;
	}
	if(m.tag == HServerKeyExchange) {
		if(dhx){
			char *err = verifyDHparams(c->sec,
				m.u.serverKeyExchange.dh_parameters,
				c->cert,
				m.u.serverKeyExchange.dh_signature,
				c->version<TLS12Version ? 0x01 : m.u.serverKeyExchange.sigalg);
			if(err != nil){
				tlsError(c, EBadCertificate, "can't verify DH parameters: %s", err);
				goto Err;
			}
			if(isECDHE(cipher))
				epm = tlsSecECDHEc(c->sec,
					m.u.serverKeyExchange.curve,
					m.u.serverKeyExchange.dh_Ys);
			else
				epm = tlsSecDHEc(c->sec,
					m.u.serverKeyExchange.dh_p, 
					m.u.serverKeyExchange.dh_g,
					m.u.serverKeyExchange.dh_Ys);
			if(epm == nil){
				tlsError(c, EHandshakeFailure, "bad DH parameters");
				goto Err;
			}
		} else if(psklen == 0){
			tlsError(c, EUnexpectedMessage, "got an server key exchange");
			goto Err;
		}
		if(!msgRecv(c, &m))
			goto Err;
	} else if(dhx){
		tlsError(c, EUnexpectedMessage, "expected server key exchange");
		goto Err;
	}

	/* certificate request (optional) */
	creq = 0;
	if(m.tag == HCertificateRequest) {
		creq = 1;
		if(!msgRecv(c, &m))
			goto Err;
	}

	if(m.tag != HServerHelloDone) {
		tlsError(c, EUnexpectedMessage, "expected a server hello done");
		goto Err;
	}
	msgClear(&m);

	if(!dhx){
		if(c->cert != nil){
			epm = tlsSecRSAc(c->sec, c->cert->data, c->cert->len);
			if(epm == nil){
				tlsError(c, EBadCertificate, "bad certificate: %r");
				goto Err;
			}
		} else if(psklen > 0){
			setMasterSecret(c->sec, newbytes(psklen));
		} else {
			tlsError(c, EInternalError, "no psk or certificate");
			goto Err;
		}
	}

	if(trace)
		trace("tls secrets\n");
	if(setSecrets(c, 1) < 0){
		tlsError(c, EHandshakeFailure, "can't set secrets: %r");
		goto Err;
	}

	if(creq) {
		m.tag = HCertificate;
		if(certlen > 0){
			m.u.certificate.ncert = 1;
			m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
			m.u.certificate.certs[0] = makebytes(cert, certlen);
		}		
		if(!msgSend(c, &m, AFlush))
			goto Err;
	}

	/* client key exchange */
	m.tag = HClientKeyExchange;
	if(psklen > 0){
		if(pskid == nil)
			pskid = "";
		m.u.clientKeyExchange.pskid = makebytes((uchar*)pskid, strlen(pskid));
	}
	m.u.clientKeyExchange.key = epm;
	epm = nil;
	 
	if(!msgSend(c, &m, AFlush))
		goto Err;

	/* certificate verify */
	if(creq && certlen > 0) {
		HandshakeHash hsave;
		uchar digest[MAXdlen];
		int digestlen;

		/* save the state for the Finish message */
		hsave = c->handhash;
		if(c->version < TLS12Version){
			md5(nil, 0, digest, &c->handhash.md5);
			sha1(nil, 0, digest+MD5dlen, &c->handhash.sha1);
			digestlen = MD5dlen+SHA1dlen;
		} else {
			m.u.certificateVerify.sigalg = 0x0401;	/* RSA SHA256 */
			sha2_256(nil, 0, digest, &c->handhash.sha2_256);
			digestlen = SHA2_256dlen;
		}
		c->handhash = hsave;

		if((m.u.certificateVerify.signature = pkcs1_sign(c->sec, digest, digestlen,
			m.u.certificateVerify.sigalg)) == nil){
			tlsError(c, EHandshakeFailure, "pkcs1_sign: %r");
			goto Err;
		}

		m.tag = HCertificateVerify;
		if(!msgSend(c, &m, AFlush))
			goto Err;
	} 

	/* change cipher spec */
	if(fprint(c->ctl, "changecipher") < 0){
		tlsError(c, EInternalError, "can't enable cipher: %r");
		goto Err;
	}

	// Cipherchange must occur immediately before Finished to avoid
	// potential hole;  see section 4.3 of Wagner Schneier 1996.
	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
		tlsError(c, EInternalError, "can't set finished 1: %r");
		goto Err;
	}
	m.tag = HFinished;
	m.u.finished = c->finished;
	if(!msgSend(c, &m, AFlush)) {
		tlsError(c, EInternalError, "can't flush after client Finished: %r");
		goto Err;
	}

	if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
		tlsError(c, EInternalError, "can't set finished 0: %r");
		goto Err;
	}
	if(!msgRecv(c, &m)) {
		tlsError(c, EInternalError, "can't read server Finished: %r");
		goto Err;
	}
	if(m.tag != HFinished) {
		tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
		goto Err;
	}

	if(!finishedMatch(c, &m.u.finished)) {
		tlsError(c, EHandshakeFailure, "finished verification failed");
		goto Err;
	}
	msgClear(&m);

	if(fprint(c->ctl, "opened") < 0){
		if(trace)
			trace("unable to do final open: %r\n");
		goto Err;
	}
	return c;

Err:
	free(epm);
	msgClear(&m);
	tlsConnectionFree(c);
	return nil;
}


//================= message functions ========================

static void
msgHash(TlsConnection *c, uchar *p, int n)
{
	md5(p, n, 0, &c->handhash.md5);
	sha1(p, n, 0, &c->handhash.sha1);
	if(c->version >= TLS12Version)
		sha2_256(p, n, 0, &c->handhash.sha2_256);
}

static int
msgSend(TlsConnection *c, Msg *m, int act)
{
	uchar *p, *e; // sendp = start of new message;  p = write pointer; e = end pointer
	int n, i;

	p = c->sendp;
	e = &c->buf[sizeof(c->buf)];
	if(c->trace)
		c->trace("send %s", msgPrint((char*)p, e - p, m));

	p[0] = m->tag;	// header - fill in size later
	p += 4;

	switch(m->tag) {
	default:
		tlsError(c, EInternalError, "can't encode a %d", m->tag);
		goto Err;
	case HClientHello:
		if(p+2+RandomSize > e)
			goto Overflow;
		put16(p, m->u.clientHello.version), p += 2;
		memmove(p, m->u.clientHello.random, RandomSize);
		p += RandomSize;

		if(p+1+(n = m->u.clientHello.sid->len) > e)
			goto Overflow;
		*p++ = n;
		memmove(p, m->u.clientHello.sid->data, n);
		p += n;

		if(p+2+(n = m->u.clientHello.ciphers->len) > e)
			goto Overflow;
		put16(p, n*2), p += 2;
		for(i=0; i<n; i++)
			put16(p, m->u.clientHello.ciphers->data[i]), p += 2;

		if(p+1+(n = m->u.clientHello.compressors->len) > e)
			goto Overflow;
		*p++ = n;
		memmove(p, m->u.clientHello.compressors->data, n);
		p += n;

		if(m->u.clientHello.extensions == nil
		|| (n = m->u.clientHello.extensions->len) == 0)
			break;
		if(p+2+n > e)
			goto Overflow;
		put16(p, n), p += 2;
		memmove(p, m->u.clientHello.extensions->data, n);
		p += n;
		break;
	case HServerHello:
		if(p+2+RandomSize > e)
			goto Overflow;
		put16(p, m->u.serverHello.version), p += 2;
		memmove(p, m->u.serverHello.random, RandomSize);
		p += RandomSize;

		if(p+1+(n = m->u.serverHello.sid->len) > e)
			goto Overflow;
		*p++ = n;
		memmove(p, m->u.serverHello.sid->data, n);
		p += n;

		if(p+2+1 > e)
			goto Overflow;
		put16(p, m->u.serverHello.cipher), p += 2;
		*p++ = m->u.serverHello.compressor;

		if(m->u.serverHello.extensions == nil
		|| (n = m->u.serverHello.extensions->len) == 0)
			break;
		if(p+2+n > e)
			goto Overflow;
		put16(p, n), p += 2;
		memmove(p, m->u.serverHello.extensions->data, n);
		p += n;
		break;
	case HServerHelloDone:
		break;
	case HCertificate:
		n = 0;
		for(i = 0; i < m->u.certificate.ncert; i++)
			n += 3 + m->u.certificate.certs[i]->len;
		if(p+3+n > e)
			goto Overflow;
		put24(p, n), p += 3;
		for(i = 0; i < m->u.certificate.ncert; i++){
			n = m->u.certificate.certs[i]->len;
			put24(p, n), p += 3;
			memmove(p, m->u.certificate.certs[i]->data, n);
			p += n;
		}
		break;
	case HCertificateVerify:
		if(p+2+2+(n = m->u.certificateVerify.signature->len) > e)
			goto Overflow;
		if(m->u.certificateVerify.sigalg != 0)
			put16(p, m->u.certificateVerify.sigalg), p += 2;
		put16(p, n), p += 2;
		memmove(p, m->u.certificateVerify.signature->data, n);
		p += n;
		break;
	case HServerKeyExchange:
		if(m->u.serverKeyExchange.pskid != nil){
			if(p+2+(n = m->u.serverKeyExchange.pskid->len) > e)
				goto Overflow;
			put16(p, n), p += 2;
			memmove(p, m->u.serverKeyExchange.pskid->data, n);
			p += n;
		}
		if(m->u.serverKeyExchange.dh_parameters == nil)
			break;
		if(p+(n = m->u.serverKeyExchange.dh_parameters->len) > e)
			goto Overflow;
		memmove(p, m->u.serverKeyExchange.dh_parameters->data, n);
		p += n;
		if(m->u.serverKeyExchange.dh_signature == nil)
			break;
		if(p+2+2+(n = m->u.serverKeyExchange.dh_signature->len) > e)
			goto Overflow;
		if(c->version >= TLS12Version)
			put16(p, m->u.serverKeyExchange.sigalg), p += 2;
		put16(p, n), p += 2;
		memmove(p, m->u.serverKeyExchange.dh_signature->data, n);
		p += n;
		break;
	case HClientKeyExchange:
		if(m->u.clientKeyExchange.pskid != nil){
			if(p+2+(n = m->u.clientKeyExchange.pskid->len) > e)
				goto Overflow;
			put16(p, n), p += 2;
			memmove(p, m->u.clientKeyExchange.pskid->data, n);
			p += n;
		}
		if(m->u.clientKeyExchange.key == nil)
			break;
		if(p+2+(n = m->u.clientKeyExchange.key->len) > e)
			goto Overflow;
		if(isECDHE(c->cipher))
			*p++ = n;
		else if(isDHE(c->cipher) || c->version != SSL3Version)
			put16(p, n), p += 2;
		memmove(p, m->u.clientKeyExchange.key->data, n);
		p += n;
		break;
	case HFinished:
		if(p+m->u.finished.n > e)
			goto Overflow;
		memmove(p, m->u.finished.verify, m->u.finished.n);
		p += m->u.finished.n;
		break;
	}

	// go back and fill in size
	n = p - c->sendp;
	put24(c->sendp+1, n-4);

	// remember hash of Handshake messages
	if(m->tag != HHelloRequest)
		msgHash(c, c->sendp, n);

	c->sendp = p;
	if(act == AFlush){
		c->sendp = c->buf;
		if(write(c->hand, c->buf, p - c->buf) < 0){
			fprint(2, "write error: %r\n");
			goto Err;
		}
	}
	msgClear(m);
	return 1;
Overflow:
	tlsError(c, EInternalError, "not enougth send buffer for message (%d)", m->tag);
Err:
	msgClear(m);
	return 0;
}

static uchar*
tlsReadN(TlsConnection *c, int n)
{
	uchar *p, *w, *e;

	e = &c->buf[sizeof(c->buf)];
	p = e - n;
	if(n > sizeof(c->buf) || p < c->sendp){
		tlsError(c, EDecodeError, "handshake message too long %d", n);
		return nil;
	}
	for(w = p; w < e; w += n)
		if((n = read(c->hand, w, e - w)) <= 0)
			return nil;
	return p;
}

static int
msgRecv(TlsConnection *c, Msg *m)
{
	uchar *p, *s;
	int type, n, nn, i;

	msgClear(m);
	for(;;) {
		p = tlsReadN(c, 4);
		if(p == nil)
			return 0;
		type = p[0];
		n = get24(p+1);

		if(type != HHelloRequest)
			break;
		if(n != 0) {
			tlsError(c, EDecodeError, "invalid hello request during handshake");
			return 0;
		}
	}

	if(type == HSSL2ClientHello){
		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
			This is sent by some clients that we must interoperate
			with, such as Java's JSSE and Microsoft's Internet Explorer. */
		int nsid, nrandom, nciph;

		p = tlsReadN(c, n);
		if(p == nil)
			return 0;
		msgHash(c, p, n);
		m->tag = HClientHello;
		if(n < 22)
			goto Short;
		m->u.clientHello.version = get16(p+1);
		p += 3;
		n -= 3;
		nn = get16(p); /* cipher_spec_len */
		nsid = get16(p + 2);
		nrandom = get16(p + 4);
		p += 6;
		n -= 6;
		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
		|| nrandom < 16 || nn % 3 || n - nrandom < nn)
			goto Err;
		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
		nciph = 0;
		for(i = 0; i < nn; i += 3)
			if(p[i] == 0)
				nciph++;
		m->u.clientHello.ciphers = newints(nciph);
		nciph = 0;
		for(i = 0; i < nn; i += 3)
			if(p[i] == 0)
				m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
		p += nn;
		m->u.clientHello.sid = makebytes(nil, 0);
		if(nrandom > RandomSize)
			nrandom = RandomSize;
		memset(m->u.clientHello.random, 0, RandomSize - nrandom);
		memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
		m->u.clientHello.compressors = newbytes(1);
		m->u.clientHello.compressors->data[0] = CompressionNull;
		goto Ok;
	}
	msgHash(c, p, 4);

	p = tlsReadN(c, n);
	if(p == nil)
		return 0;

	msgHash(c, p, n);

	m->tag = type;

	switch(type) {
	default:
		tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
		goto Err;
	case HClientHello:
		if(n < 2)
			goto Short;
		m->u.clientHello.version = get16(p);
		p += 2, n -= 2;

		if(n < RandomSize)
			goto Short;
		memmove(m->u.clientHello.random, p, RandomSize);
		p += RandomSize, n -= RandomSize;
		if(n < 1 || n < p[0]+1)
			goto Short;
		m->u.clientHello.sid = makebytes(p+1, p[0]);
		p += m->u.clientHello.sid->len+1;
		n -= m->u.clientHello.sid->len+1;

		if(n < 2)
			goto Short;
		nn = get16(p);
		p += 2, n -= 2;

		if(nn % 2 || n < nn || nn < 2)
			goto Short;
		m->u.clientHello.ciphers = newints(nn >> 1);
		for(i = 0; i < nn; i += 2)
			m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
		p += nn, n -= nn;

		if(n < 1 || n < p[0]+1 || p[0] == 0)
			goto Short;
		nn = p[0];
		m->u.clientHello.compressors = makebytes(p+1, nn);
		p += nn + 1, n -= nn + 1;

		if(n < 2)
			break;
		nn = get16(p);
		if(nn > n-2)
			goto Short;
		m->u.clientHello.extensions = makebytes(p+2, nn);
		n -= nn + 2;
		break;
	case HServerHello:
		if(n < 2)
			goto Short;
		m->u.serverHello.version = get16(p);
		p += 2, n -= 2;

		if(n < RandomSize)
			goto Short;
		memmove(m->u.serverHello.random, p, RandomSize);
		p += RandomSize, n -= RandomSize;

		if(n < 1 || n < p[0]+1)
			goto Short;
		m->u.serverHello.sid = makebytes(p+1, p[0]);
		p += m->u.serverHello.sid->len+1;
		n -= m->u.serverHello.sid->len+1;

		if(n < 3)
			goto Short;
		m->u.serverHello.cipher = get16(p);
		m->u.serverHello.compressor = p[2];
		p += 3, n -= 3;

		if(n < 2)
			break;
		nn = get16(p);
		if(nn > n-2)
			goto Short;
		m->u.serverHello.extensions = makebytes(p+2, nn);
		n -= nn + 2;
		break;
	case HCertificate:
		if(n < 3)
			goto Short;
		nn = get24(p);
		p += 3, n -= 3;
		if(nn == 0 && n > 0)
			goto Short;
		/* certs */
		i = 0;
		while(n > 0) {
			if(n < 3)
				goto Short;
			nn = get24(p);
			p += 3, n -= 3;
			if(nn > n)
				goto Short;
			m->u.certificate.ncert = i+1;
			m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes*));
			m->u.certificate.certs[i] = makebytes(p, nn);
			p += nn, n -= nn;
			i++;
		}
		break;
	case HCertificateRequest:
		if(n < 1)
			goto Short;
		nn = p[0];
		p++, n--;
		if(nn > n)
			goto Short;
		m->u.certificateRequest.types = makebytes(p, nn);
		p += nn, n -= nn;
		if(c->version >= TLS12Version){
			if(n < 2)
				goto Short;
			nn = get16(p);
			p += 2, n -= 2;
			if(nn % 2)
				goto Short;
			m->u.certificateRequest.sigalgs = newints(nn>>1);
			for(i = 0; i < nn; i += 2)
				m->u.certificateRequest.sigalgs->data[i >> 1] = get16(&p[i]);
			p += nn, n -= nn;

		}
		if(n < 2)
			goto Short;
		nn = get16(p);
		p += 2, n -= 2;
		/* nn == 0 can happen; yahoo's servers do it */
		if(nn != n)
			goto Short;
		/* cas */
		i = 0;
		while(n > 0) {
			if(n < 2)
				goto Short;
			nn = get16(p);
			p += 2, n -= 2;
			if(nn < 1 || nn > n)
				goto Short;
			m->u.certificateRequest.nca = i+1;
			m->u.certificateRequest.cas = erealloc(
				m->u.certificateRequest.cas, (i+1)*sizeof(Bytes*));
			m->u.certificateRequest.cas[i] = makebytes(p, nn);
			p += nn, n -= nn;
			i++;
		}
		break;
	case HServerHelloDone:
		break;
	case HServerKeyExchange:
		if(isPSK(c->cipher)){
			if(n < 2)
				goto Short;
			nn = get16(p);
			p += 2, n -= 2;
			if(nn > n)
				goto Short;
			m->u.serverKeyExchange.pskid = makebytes(p, nn);
			p += nn, n -= nn;
			if(n == 0)
				break;
		}
		if(n < 2)
			goto Short;
		s = p;
		if(isECDHE(c->cipher)){
			nn = *p;
			p++, n--;
			if(nn != 3 || nn > n) /* not a named curve */
				goto Short;
			nn = get16(p);
			p += 2, n -= 2;
			m->u.serverKeyExchange.curve = nn;

			nn = *p++, n--;
			if(nn < 1 || nn > n)
				goto Short;
			m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
			p += nn, n -= nn;
		}else if(isDHE(c->cipher)){
			nn = get16(p);
			p += 2, n -= 2;
			if(nn < 1 || nn > n)
				goto Short;
			m->u.serverKeyExchange.dh_p = makebytes(p, nn);
			p += nn, n -= nn;
	
			if(n < 2)
				goto Short;
			nn = get16(p);
			p += 2, n -= 2;
			if(nn < 1 || nn > n)
				goto Short;
			m->u.serverKeyExchange.dh_g = makebytes(p, nn);
			p += nn, n -= nn;
	
			if(n < 2)
				goto Short;
			nn = get16(p);
			p += 2, n -= 2;
			if(nn < 1 || nn > n)
				goto Short;
			m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
			p += nn, n -= nn;
		} else {
			/* should not happen */
			goto Short;
		}
		m->u.serverKeyExchange.dh_parameters = makebytes(s, p - s);
		if(n >= 2){
			m->u.serverKeyExchange.sigalg = 0;
			if(c->version >= TLS12Version){
				m->u.serverKeyExchange.sigalg = get16(p);
				p += 2, n -= 2;
				if(n < 2)
					goto Short;
			}
			nn = get16(p);
			p += 2, n -= 2;
			if(nn > 0 && nn <= n){
				m->u.serverKeyExchange.dh_signature = makebytes(p, nn);
				n -= nn;
			}
		}
		break;		
	case HClientKeyExchange:
		if(isPSK(c->cipher)){
			if(n < 2)
				goto Short;
			nn = get16(p);
			p += 2, n -= 2;
			if(nn > n)
				goto Short;
			m->u.clientKeyExchange.pskid = makebytes(p, nn);
			p += nn, n -= nn;
			if(n == 0)
				break;
		}
		if(n < 2)
			goto Short;
		if(isECDHE(c->cipher))
			nn = *p++, n--;
		else if(isDHE(c->cipher) || c->version != SSL3Version)
			nn = get16(p), p += 2, n -= 2;
		else
			nn = n;
		if(n < nn)
			goto Short;
		m->u.clientKeyExchange.key = makebytes(p, nn);
		n -= nn;
		break;
	case HFinished:
		m->u.finished.n = c->finished.n;
		if(n < m->u.finished.n)
			goto Short;
		memmove(m->u.finished.verify, p, m->u.finished.n);
		n -= m->u.finished.n;
		break;
	}

	if(n != 0 && type != HClientHello && type != HServerHello)
		goto Short;
Ok:
	if(c->trace)
		c->trace("recv %s", msgPrint((char*)c->sendp, &c->buf[sizeof(c->buf)] - c->sendp, m));
	return 1;
Short:
	tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
Err:
	msgClear(m);
	return 0;
}

static void
msgClear(Msg *m)
{
	int i;

	switch(m->tag) {
	case HHelloRequest:
		break;
	case HClientHello:
		freebytes(m->u.clientHello.sid);
		freeints(m->u.clientHello.ciphers);
		freebytes(m->u.clientHello.compressors);
		freebytes(m->u.clientHello.extensions);
		break;
	case HServerHello:
		freebytes(m->u.serverHello.sid);
		freebytes(m->u.serverHello.extensions);
		break;
	case HCertificate:
		for(i=0; i<m->u.certificate.ncert; i++)
			freebytes(m->u.certificate.certs[i]);
		free(m->u.certificate.certs);
		break;
	case HCertificateRequest:
		freebytes(m->u.certificateRequest.types);
		freeints(m->u.certificateRequest.sigalgs);
		for(i=0; i<m->u.certificateRequest.nca; i++)
			freebytes(m->u.certificateRequest.cas[i]);
		free(m->u.certificateRequest.cas);
		break;
	case HCertificateVerify:
		freebytes(m->u.certificateVerify.signature);
		break;
	case HServerHelloDone:
		break;
	case HServerKeyExchange:
		freebytes(m->u.serverKeyExchange.pskid);
		freebytes(m->u.serverKeyExchange.dh_p);
		freebytes(m->u.serverKeyExchange.dh_g);
		freebytes(m->u.serverKeyExchange.dh_Ys);
		freebytes(m->u.serverKeyExchange.dh_parameters);
		freebytes(m->u.serverKeyExchange.dh_signature);
		break;
	case HClientKeyExchange:
		freebytes(m->u.clientKeyExchange.pskid);
		freebytes(m->u.clientKeyExchange.key);
		break;
	case HFinished:
		break;
	}
	memset(m, 0, sizeof(Msg));
}

static char *
bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
{
	int i;

	if(s0)
		bs = seprint(bs, be, "%s", s0);
	if(b == nil)
		bs = seprint(bs, be, "nil");
	else {
		bs = seprint(bs, be, "<%d> [ ", b->len);
		for(i=0; i<b->len; i++)
			bs = seprint(bs, be, "%.2x ", b->data[i]);
		bs = seprint(bs, be, "]");
	}
	if(s1)
		bs = seprint(bs, be, "%s", s1);
	return bs;
}

static char *
intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
{
	int i;

	if(s0)
		bs = seprint(bs, be, "%s", s0);
	if(b == nil)
		bs = seprint(bs, be, "nil");
	else {
		bs = seprint(bs, be, "[ ");
		for(i=0; i<b->len; i++)
			bs = seprint(bs, be, "%x ", b->data[i]);
		bs = seprint(bs, be, "]");
	}
	if(s1)
		bs = seprint(bs, be, "%s", s1);
	return bs;
}

static char*
msgPrint(char *buf, int n, Msg *m)
{
	int i;
	char *bs = buf, *be = buf+n;

	switch(m->tag) {
	default:
		bs = seprint(bs, be, "unknown %d\n", m->tag);
		break;
	case HClientHello:
		bs = seprint(bs, be, "ClientHello\n");
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
		bs = seprint(bs, be, "\trandom: ");
		for(i=0; i<RandomSize; i++)
			bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
		bs = seprint(bs, be, "\n");
		bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
		bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
		bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
		if(m->u.clientHello.extensions != nil)
			bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
		break;
	case HServerHello:
		bs = seprint(bs, be, "ServerHello\n");
		bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
		bs = seprint(bs, be, "\trandom: ");
		for(i=0; i<RandomSize; i++)
			bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
		bs = seprint(bs, be, "\n");
		bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
		bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
		bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
		if(m->u.serverHello.extensions != nil)
			bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
		break;
	case HCertificate:
		bs = seprint(bs, be, "Certificate\n");
		for(i=0; i<m->u.certificate.ncert; i++)
			bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
		break;
	case HCertificateRequest:
		bs = seprint(bs, be, "CertificateRequest\n");
		bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
		if(m->u.certificateRequest.sigalgs != nil)
			bs = intsPrint(bs, be, "\tsigalgs: ", m->u.certificateRequest.sigalgs, "\n");
		bs = seprint(bs, be, "\tcertificateauthorities\n");
		for(i=0; i<m->u.certificateRequest.nca; i++)
			bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
		break;
	case HCertificateVerify:
		bs = seprint(bs, be, "HCertificateVerify\n");
		if(m->u.certificateVerify.sigalg != 0)
			bs = seprint(bs, be, "\tsigalg: %.4x\n", m->u.certificateVerify.sigalg);
		bs = bytesPrint(bs, be, "\tsignature: ", m->u.certificateVerify.signature,"\n");
		break;	
	case HServerHelloDone:
		bs = seprint(bs, be, "ServerHelloDone\n");
		break;
	case HServerKeyExchange:
		bs = seprint(bs, be, "HServerKeyExchange\n");
		if(m->u.serverKeyExchange.pskid != nil)
			bs = bytesPrint(bs, be, "\tpskid: ", m->u.serverKeyExchange.pskid, "\n");
		if(m->u.serverKeyExchange.dh_parameters == nil)
			break;
		if(m->u.serverKeyExchange.curve != 0){
			bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
		} else {
			bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
			bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
		}
		bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
		if(m->u.serverKeyExchange.sigalg != 0)
			bs = seprint(bs, be, "\tsigalg: %.4x\n", m->u.serverKeyExchange.sigalg);
		bs = bytesPrint(bs, be, "\tdh_parameters: ", m->u.serverKeyExchange.dh_parameters, "\n");
		bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
		break;
	case HClientKeyExchange:
		bs = seprint(bs, be, "HClientKeyExchange\n");
		if(m->u.clientKeyExchange.pskid != nil)
			bs = bytesPrint(bs, be, "\tpskid: ", m->u.clientKeyExchange.pskid, "\n");
		if(m->u.clientKeyExchange.key != nil)
			bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
		break;
	case HFinished:
		bs = seprint(bs, be, "HFinished\n");
		for(i=0; i<m->u.finished.n; i++)
			bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
		bs = seprint(bs, be, "\n");
		break;
	}
	USED(bs);
	return buf;
}

static void
tlsError(TlsConnection *c, int err, char *fmt, ...)
{
	char msg[512];
	va_list arg;

	va_start(arg, fmt);
	vseprint(msg, msg+sizeof(msg), fmt, arg);
	va_end(arg);
	if(c->trace)
		c->trace("tlsError: %s\n", msg);
	if(c->erred)
		fprint(2, "double error: %r, %s", msg);
	else
		errstr(msg, sizeof(msg));
	c->erred = 1;
	fprint(c->ctl, "alert %d", err);
}

// commit to specific version number
static int
setVersion(TlsConnection *c, int version)
{
	if(version > MaxProtoVersion || version < MinProtoVersion)
		return -1;
	if(version > c->version)
		version = c->version;
	if(version == SSL3Version) {
		c->version = version;
		c->finished.n = SSL3FinishedLen;
	}else {
		c->version = version;
		c->finished.n = TLSFinishedLen;
	}
	return fprint(c->ctl, "version 0x%x", version);
}

// confirm that received Finished message matches the expected value
static int
finishedMatch(TlsConnection *c, Finished *f)
{
	return tsmemcmp(f->verify, c->finished.verify, f->n) == 0;
}

// free memory associated with TlsConnection struct
//		(but don't close the TLS channel itself)
static void
tlsConnectionFree(TlsConnection *c)
{
	if(c == nil)
		return;

	dh_finish(&c->sec->dh, nil);

	mpfree(c->sec->ec.Q.x);
	mpfree(c->sec->ec.Q.y);
	mpfree(c->sec->ec.Q.d);
	ecdomfree(&c->sec->ec.dom);

	factotum_rsa_close(c->sec->rpc);
	rsapubfree(c->sec->rsapub);
	freebytes(c->cert);

	memset(c, 0, sizeof(*c));
	free(c);
}


//================= cipher choices ========================

static int
isDHE(int tlsid)
{
	switch(tlsid){
	case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256:
	case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256:
 	case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
 	case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
 	case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
	case TLS_DHE_RSA_WITH_CHACHA20_POLY1305:
	case GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305:
		return 1;
	}
	return 0;
}

static int
isECDHE(int tlsid)
{
	switch(tlsid){
	case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
	case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:

	case GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
	case GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305:

	case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
	case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:

	case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256:
	case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
	case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
	case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
		return 1;
	}
	return 0;
}

static int
isPSK(int tlsid)
{
	switch(tlsid){
	case TLS_PSK_WITH_CHACHA20_POLY1305:
	case TLS_PSK_WITH_AES_128_CBC_SHA256:
	case TLS_PSK_WITH_AES_128_CBC_SHA:
		return 1;
	}
	return 0;
}

static int
isECDSA(int tlsid)
{
	switch(tlsid){
	case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
	case GOOGLE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
	case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
	case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256:
		return 1;
	}
	return 0;
}

static int
setAlgs(TlsConnection *c, int a)
{
	int i;

	for(i = 0; i < nelem(cipherAlgs); i++){
		if(cipherAlgs[i].tlsid == a){
			c->cipher = a;
			c->enc = cipherAlgs[i].enc;
			c->digest = cipherAlgs[i].digest;
			c->nsecret = cipherAlgs[i].nsecret;
			if(c->nsecret > MaxKeyData)
				return 0;
			return 1;
		}
	}
	return 0;
}

static int
okCipher(Ints *cv, int ispsk, int canec)
{
	int i, c;

	for(i = 0; i < nelem(cipherAlgs); i++) {
		c = cipherAlgs[i].tlsid;
		if(!cipherAlgs[i].ok || isECDSA(c) || isDHE(c))
			continue;
		if(isPSK(c) != ispsk)
			continue;
		if(isECDHE(c) && !canec)
			continue;
		if(lookupid(cv, c) >= 0)
			return c;
	}
	return -1;
}

static int
okCompression(Bytes *cv)
{
	int i, c;

	for(i = 0; i < nelem(compressors); i++) {
		c = compressors[i];
		if(memchr(cv->data, c, cv->len) != nil)
			return c;
	}
	return -1;
}

static Lock	ciphLock;
static int	nciphers;

static int
initCiphers(void)
{
	enum {MaxAlgF = 1024, MaxAlgs = 10};
	char s[MaxAlgF], *flds[MaxAlgs];
	int i, j, n, ok;

	lock(&ciphLock);
	if(nciphers){
		unlock(&ciphLock);
		return nciphers;
	}
	j = open("#a/tls/encalgs", OREAD|OCEXEC);
	if(j < 0){
		werrstr("can't open #a/tls/encalgs: %r");
		goto out;
	}
	n = read(j, s, MaxAlgF-1);
	close(j);
	if(n <= 0){
		werrstr("nothing in #a/tls/encalgs: %r");
		goto out;
	}
	s[n] = 0;
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
	for(i = 0; i < nelem(cipherAlgs); i++){
		ok = 0;
		for(j = 0; j < n; j++){
			if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
				ok = 1;
				break;
			}
		}
		cipherAlgs[i].ok = ok;
	}

	j = open("#a/tls/hashalgs", OREAD|OCEXEC);
	if(j < 0){
		werrstr("can't open #a/tls/hashalgs: %r");
		goto out;
	}
	n = read(j, s, MaxAlgF-1);
	close(j);
	if(n <= 0){
		werrstr("nothing in #a/tls/hashalgs: %r");
		goto out;
	}
	s[n] = 0;
	n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
	for(i = 0; i < nelem(cipherAlgs); i++){
		ok = 0;
		for(j = 0; j < n; j++){
			if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
				ok = 1;
				break;
			}
		}
		cipherAlgs[i].ok &= ok;
		if(cipherAlgs[i].ok)
			nciphers++;
	}
out:
	unlock(&ciphLock);
	return nciphers;
}

static Ints*
makeciphers(int ispsk)
{
	Ints *is;
	int i, j;

	is = newints(nciphers);
	j = 0;
	for(i = 0; i < nelem(cipherAlgs); i++)
		if(cipherAlgs[i].ok && isPSK(cipherAlgs[i].tlsid) == ispsk)
			is->data[j++] = cipherAlgs[i].tlsid;
	is->len = j;
	return is;
}


//================= security functions ========================

// given a public key, set up connection to factotum
// for using corresponding private key
static AuthRpc*
factotum_rsa_open(RSApub *rsapub)
{
	int afd;
	char *s;
	mpint *n;
	AuthRpc *rpc;

	// start talking to factotum
	if((afd = open("/mnt/factotum/rpc", ORDWR|OCEXEC)) < 0)
		return nil;
	if((rpc = auth_allocrpc(afd)) == nil){
		close(afd);
		return nil;
	}
	s = "proto=rsa service=tls role=client";
	if(auth_rpc(rpc, "start", s, strlen(s)) == ARok){
		// roll factotum keyring around to match public key
		n = mpnew(0);
		while(auth_rpc(rpc, "read", nil, 0) == ARok){
			if(strtomp(rpc->arg, nil, 16, n) != nil
			&& mpcmp(n, rsapub->n) == 0){
				mpfree(n);
				return rpc;
			}
		}
		mpfree(n);
	}
	factotum_rsa_close(rpc);
	return nil;
}

static mpint*
factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
{
	char *p;
	int rv;

	if(cipher == nil)
		return nil;
	p = mptoa(cipher, 16, nil, 0);
	mpfree(cipher);
	if(p == nil)
		return nil;
	rv = auth_rpc(rpc, "write", p, strlen(p));
	free(p);
	if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
		return nil;
	return strtomp(rpc->arg, nil, 16, nil);
}

static void
factotum_rsa_close(AuthRpc *rpc)
{
	if(rpc == nil)
		return;
	close(rpc->afd);
	auth_freerpc(rpc);
}

// buf ^= prf
static void
tlsP(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed,
	DigestState* (*x)(uchar*, ulong, uchar*, ulong, uchar*, DigestState*), int xlen)
{
	uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
	DigestState *s;
	int n, i;

	assert(xlen <= sizeof(ai) && xlen <= sizeof(tmp));
	// generate a1
	s = x(label, nlabel, key, nkey, nil, nil);
	x(seed, nseed, key, nkey, ai, s);

	while(nbuf > 0) {
		s = x(ai, xlen, key, nkey, nil, nil);
		s = x(label, nlabel, key, nkey, nil, s);
		x(seed, nseed, key, nkey, tmp, s);
		n = xlen;
		if(n > nbuf)
			n = nbuf;
		for(i = 0; i < n; i++)
			buf[i] ^= tmp[i];
		buf += n;
		nbuf -= n;
		x(ai, xlen, key, nkey, tmp, nil);
		memmove(ai, tmp, xlen);
	}
}


// fill buf with md5(args)^sha1(args)
static void
tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed, int nseed)
{
	int nlabel = strlen(label);
	int n = (nkey + 1) >> 1;

	memset(buf, 0, nbuf);
	tlsP(buf, nbuf, key, n, (uchar*)label, nlabel, seed, nseed,
		hmac_md5, MD5dlen);
	tlsP(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed, nseed,
		hmac_sha1, SHA1dlen);
}

static void
tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed, int nseed)
{
	memset(buf, 0, nbuf);
	tlsP(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed,
		hmac_sha2_256, SHA2_256dlen);
}

static void
sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed, int nseed)
{
	uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
	DigestState *s;
	int i, n, len;

	USED(label);
	len = 1;
	while(nbuf > 0){
		if(len > 26)
			return;
		for(i = 0; i < len; i++)
			tmp[i] = 'A' - 1 + len;
		s = sha1(tmp, len, nil, nil);
		s = sha1(key, nkey, nil, s);
		sha1(seed, nseed, sha1dig, s);
		s = md5(key, nkey, nil, nil);
		md5(sha1dig, SHA1dlen, md5dig, s);
		n = MD5dlen;
		if(n > nbuf)
			n = nbuf;
		memmove(buf, md5dig, n);
		buf += n;
		nbuf -= n;
		len++;
	}
}

static void
sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient)
{
	DigestState *s;
	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
	char *label;

	if(isclient)
		label = "CLNT";
	else
		label = "SRVR";

	md5((uchar*)label, 4, nil, &hsh.md5);
	md5(sec->sec, MasterSecretSize, nil, &hsh.md5);
	memset(pad, 0x36, 48);
	md5(pad, 48, nil, &hsh.md5);
	md5(nil, 0, h0, &hsh.md5);
	memset(pad, 0x5C, 48);
	s = md5(sec->sec, MasterSecretSize, nil, nil);
	s = md5(pad, 48, nil, s);
	md5(h0, MD5dlen, finished, s);

	sha1((uchar*)label, 4, nil, &hsh.sha1);
	sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1);
	memset(pad, 0x36, 40);
	sha1(pad, 40, nil, &hsh.sha1);
	sha1(nil, 0, h1, &hsh.sha1);
	memset(pad, 0x5C, 40);
	s = sha1(sec->sec, MasterSecretSize, nil, nil);
	s = sha1(pad, 40, nil, s);
	sha1(h1, SHA1dlen, finished + MD5dlen, s);
}

// fill "finished" arg with md5(args)^sha1(args)
static void
tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient)
{
	uchar h[MD5dlen+SHA1dlen];
	char *label;

	// get current hash value, but allow further messages to be hashed in
	md5(nil, 0, h, &hsh.md5);
	sha1(nil, 0, h+MD5dlen, &hsh.sha1);

	if(isclient)
		label = "client finished";
	else
		label = "server finished";
	tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h, sizeof(h));
}

static void
tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient)
{
	uchar seed[SHA2_256dlen];
	char *label;

	// get current hash value, but allow further messages to be hashed in
	sha2_256(nil, 0, seed, &hsh.sha2_256);

	if(isclient)
		label = "client finished";
	else
		label = "server finished";
	tls12PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, seed, SHA2_256dlen);
}

static void
tlsSecInits(TlsSec *sec, int cvers, uchar *crandom)
{
	memset(sec, 0, sizeof(*sec));
	sec->clientVers = cvers;
	memmove(sec->crandom, crandom, RandomSize);

	// putting time()'s output to the first 4 bytes is no
	// longer recommended and is not useful
	genrandom(sec->srandom, RandomSize);
}

static int
tlsSecRSAs(TlsSec *sec, Bytes *epm)
{
	Bytes *pm;

	if(epm == nil){
		werrstr("no encrypted premaster secret");
		return -1;
	}
	// if the client messed up, just continue as if everything is ok,
	// to prevent attacks to check for correctly formatted messages.
	pm = pkcs1_decrypt(sec, epm);
	if(pm == nil || pm->len != MasterSecretSize || get16(pm->data) != sec->clientVers){
		freebytes(pm);
		pm = newbytes(MasterSecretSize);
		genrandom(pm->data, pm->len);
	}
	setMasterSecret(sec, pm);
	return 0;
}

static Bytes*
tlsSecECDHEs1(TlsSec *sec)
{
	ECdomain *dom = &sec->ec.dom;
	ECpriv *Q = &sec->ec.Q;
	Bytes *par;
	int n;

	if(sec->nc == nil)
		return nil;
	if(sec->nc->tlsid == X25519){
		par = newbytes(1+2+1+32);
		par->data[0] = 3;
		put16(par->data+1, X25519);
		par->data[3] = 32;
		curve25519_dh_new(sec->X, par->data+4);
	}else{
		ecdominit(dom, sec->nc->init);
		memset(Q, 0, sizeof(*Q));
		Q->x = mpnew(0);
		Q->y = mpnew(0);
		Q->d = mpnew(0);
		ecgen(dom, Q);
		n = 1 + 2*((mpsignif(dom->p)+7)/8);
		par = newbytes(1+2+1+n);
		par->data[0] = 3;
		put16(par->data+1, sec->nc->tlsid);
		n = ecencodepub(dom, Q, par->data+4, par->len-4);
		par->data[3] = n;
		par->len = 1+2+1+n;
	}
	return par;
}

static int
tlsSecECDHEs2(TlsSec *sec, Bytes *Yc)
{
	ECdomain *dom = &sec->ec.dom;
	ECpriv *Q = &sec->ec.Q;
	ECpoint K;
	ECpub *Y;
	Bytes *Z;

	if(Yc == nil){
		werrstr("no public key");
		return -1;
	}

	if(sec->nc->tlsid == X25519){
		if(Yc->len != 32){
			werrstr("bad public key");
			return -1;
		}
		Z = newbytes(32);
		if(!curve25519_dh_finish(sec->X, Yc->data, Z->data)){
			werrstr("unlucky shared key");
			freebytes(Z);
			return -1;
		}
		setMasterSecret(sec, Z);
	}else{
		if((Y = ecdecodepub(dom, Yc->data, Yc->len)) == nil){
			werrstr("bad public key");
			return -1;
		}

		memset(&K, 0, sizeof(K));
		K.x = mpnew(0);
		K.y = mpnew(0);

		ecmul(dom, Y, Q->d, &K);

		setMasterSecret(sec, mptobytes(K.x, (mpsignif(dom->p)+7)/8));

		mpfree(K.x);
		mpfree(K.y);

		ecpubfree(Y);
	}
	return 0;
}

static void
tlsSecInitc(TlsSec *sec, int cvers)
{
	memset(sec, 0, sizeof(*sec));
	sec->clientVers = cvers;
	// see the comment on tlsSecInits
	genrandom(sec->crandom, RandomSize);
}

static Bytes*
tlsSecRSAc(TlsSec *sec, uchar *cert, int ncert)
{
	RSApub *pub;
	Bytes *pm, *epm;

	pub = X509toRSApub(cert, ncert, nil, 0);
	if(pub == nil){
		werrstr("invalid x509/rsa certificate");
		return nil;
	}
	pm = newbytes(MasterSecretSize);
	put16(pm->data, sec->clientVers);
	genrandom(pm->data+2, MasterSecretSize - 2);
	epm = pkcs1_encrypt(pm, pub);
	setMasterSecret(sec, pm);
	rsapubfree(pub);
	return epm;
}

static int
tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient)
{
	if(sec->nfin != nfin){
		werrstr("invalid finished exchange");
		return -1;
	}
	hsh.md5.malloced = 0;
	hsh.sha1.malloced = 0;
	hsh.sha2_256.malloced = 0;
	(*sec->setFinished)(sec, hsh, fin, isclient);
	return 0;
}

static void
tlsSecVers(TlsSec *sec, int v)
{
	if(v == SSL3Version){
		sec->setFinished = sslSetFinished;
		sec->nfin = SSL3FinishedLen;
		sec->prf = sslPRF;
	}else if(v < TLS12Version) {
		sec->setFinished = tls10SetFinished;
		sec->nfin = TLSFinishedLen;
		sec->prf = tls10PRF;
	}else {
		sec->setFinished = tls12SetFinished;
		sec->nfin = TLSFinishedLen;
		sec->prf = tls12PRF;
	}
}

static int
setSecrets(TlsConnection *c, int isclient)
{
	uchar kd[MaxKeyData], seed[2*RandomSize];
	char *secrets;
	int rv;

	assert(c->nsecret <= sizeof(kd));
	secrets = emalloc(2*c->nsecret);

	memmove(seed, c->sec->srandom, RandomSize);
	memmove(seed+RandomSize, c->sec->crandom, RandomSize);
	/*
	 * generate secret keys from the master secret.
	 *
	 * different cipher selections will require different amounts
	 * of key expansion and use of key expansion data,
	 * but it's all generated using the same function.
	 */
	(*c->sec->prf)(kd, c->nsecret, c->sec->sec, MasterSecretSize, "key expansion",
			seed, sizeof(seed));

	enc64(secrets, 2*c->nsecret, kd, c->nsecret);
	memset(kd, 0, c->nsecret);

	rv = fprint(c->ctl, "secret %s %s %d %s", c->digest, c->enc, isclient, secrets);
	memset(secrets, 0, 2*c->nsecret);
	free(secrets);

	return rv;
}

/*
 * set the master secret from the pre-master secret,
 * destroys premaster.
 */
static void
setMasterSecret(TlsSec *sec, Bytes *pm)
{
	uchar seed[2*RandomSize];

	if(sec->psklen > 0){
		Bytes *opm = pm;
		uchar *p;

		/* concatenate psk to pre-master secret */
		pm = newbytes(4 + opm->len + sec->psklen);
		p = pm->data;
		put16(p, opm->len), p += 2;
		memmove(p, opm->data, opm->len), p += opm->len;
		put16(p, sec->psklen), p += 2;
		memmove(p, sec->psk, sec->psklen);

		memset(opm->data, 0, opm->len);
		freebytes(opm);
	}

	memmove(seed, sec->crandom, RandomSize);
	memmove(seed+RandomSize, sec->srandom, RandomSize);
	(*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
			seed, sizeof(seed));

	memset(pm->data, 0, pm->len);	
	freebytes(pm);
}

static int
digestDHparams(TlsSec *sec, Bytes *par, uchar digest[MAXdlen], int sigalg)
{
	int hashalg = (sigalg>>8) & 0xFF;
	int digestlen;
	Bytes *blob;

	blob = newbytes(2*RandomSize + par->len);
	memmove(blob->data+0*RandomSize, sec->crandom, RandomSize);
	memmove(blob->data+1*RandomSize, sec->srandom, RandomSize);
	memmove(blob->data+2*RandomSize, par->data, par->len);
	if(hashalg == 0){
		digestlen = MD5dlen+SHA1dlen;
		md5(blob->data, blob->len, digest, nil);
		sha1(blob->data, blob->len, digest+MD5dlen, nil);
	} else {
		digestlen = -1;
		if(hashalg < nelem(hashfun) && hashfun[hashalg].fun != nil){
			digestlen = hashfun[hashalg].len;
			(*hashfun[hashalg].fun)(blob->data, blob->len, digest, nil);
		}
	}
	freebytes(blob);
	return digestlen;
}

static char*
verifyDHparams(TlsSec *sec, Bytes *par, Bytes *cert, Bytes *sig, int sigalg)
{
	uchar digest[MAXdlen];
	int digestlen;
	ECdomain dom;
	ECpub *ecpk;
	RSApub *rsapk;
	char *err;

	if(par == nil || par->len <= 0)
		return "no DH parameters";

	if(sig == nil || sig->len <= 0){
		if(sec->psklen > 0)
			return nil;
		return "no signature";
	}

	if(cert == nil)
		return "no certificate";

	digestlen = digestDHparams(sec, par, digest, sigalg);
	if(digestlen <= 0)
		return "unknown signature digest algorithm";
	
	switch(sigalg & 0xFF){
	case 0x01:
		rsapk = X509toRSApub(cert->data, cert->len, nil, 0);
		if(rsapk == nil)
			return "bad certificate";
		err = X509rsaverifydigest(sig->data, sig->len, digest, digestlen, rsapk);
		rsapubfree(rsapk);
		break;
	case 0x03:
		ecpk = X509toECpub(cert->data, cert->len, nil, 0, &dom);
		if(ecpk == nil)
			return "bad certificate";
		err = X509ecdsaverifydigest(sig->data, sig->len, digest, digestlen, &dom, ecpk);
		ecdomfree(&dom);
		ecpubfree(ecpk);
		break;
	default:
		err = "signaure algorithm not RSA or ECDSA";
	}

	return err;
}

// encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
static Bytes*
pkcs1_encrypt(Bytes* data, RSApub* key)
{
	mpint *x, *y;

	x = pkcs1padbuf(data->data, data->len, key->n, 2);
	if(x == nil)
		return nil;
	y = rsaencrypt(key, x, nil);
	mpfree(x);
	data = newbytes((mpsignif(key->n)+7)/8);
	mptober(y, data->data, data->len);
	mpfree(y);
	return data;
}

// decrypt data according to PKCS#1, with given key.
static Bytes*
pkcs1_decrypt(TlsSec *sec, Bytes *data)
{
	mpint *y;

	if(data->len != (mpsignif(sec->rsapub->n)+7)/8)
		return nil;
	y = factotum_rsa_decrypt(sec->rpc, bytestomp(data));
	if(y == nil)
		return nil;
	data = mptobytes(y, (mpsignif(y)+7)/8);
	mpfree(y);
	if((data->len = pkcs1unpadbuf(data->data, data->len, sec->rsapub->n, 2)) < 0){
		freebytes(data);
		return nil;
	}
	return data;
}

static Bytes*
pkcs1_sign(TlsSec *sec, uchar *digest, int digestlen, int sigalg)
{
	int hashalg = (sigalg>>8)&0xFF;
	mpint *signedMP;
	Bytes *signature;
	uchar buf[128];

	if(hashalg > 0 && hashalg < nelem(hashfun) && hashfun[hashalg].len == digestlen)
		digestlen = asn1encodedigest(hashfun[hashalg].fun, digest, buf, sizeof(buf));
	else if(digestlen == MD5dlen+SHA1dlen)
		memmove(buf, digest, digestlen);
	else
		digestlen = -1;
	if(digestlen <= 0){
		werrstr("bad digest algorithm");
		return nil;
	}

	signedMP = factotum_rsa_decrypt(sec->rpc, pkcs1padbuf(buf, digestlen, sec->rsapub->n, 1));
	if(signedMP == nil)
		return nil;
	signature = mptobytes(signedMP, (mpsignif(sec->rsapub->n)+7)/8);
	mpfree(signedMP);
	return signature;
}


//================= general utility functions ========================

static void *
emalloc(int n)
{
	void *p;
	if(n==0)
		n=1;
	p = malloc(n);
	if(p == nil)
		sysfatal("out of memory");
	memset(p, 0, n);
	setmalloctag(p, getcallerpc(&n));
	return p;
}

static void *
erealloc(void *ReallocP, int ReallocN)
{
	if(ReallocN == 0)
		ReallocN = 1;
	if(ReallocP == nil)
		ReallocP = emalloc(ReallocN);
	else if((ReallocP = realloc(ReallocP, ReallocN)) == nil)
		sysfatal("out of memory");
	setrealloctag(ReallocP, getcallerpc(&ReallocP));
	return(ReallocP);
}

static void
put32(uchar *p, u32int x)
{
	p[0] = x>>24;
	p[1] = x>>16;
	p[2] = x>>8;
	p[3] = x;
}

static void
put24(uchar *p, int x)
{
	p[0] = x>>16;
	p[1] = x>>8;
	p[2] = x;
}

static void
put16(uchar *p, int x)
{
	p[0] = x>>8;
	p[1] = x;
}

static int
get24(uchar *p)
{
	return (p[0]<<16)|(p[1]<<8)|p[2];
}

static int
get16(uchar *p)
{
	return (p[0]<<8)|p[1];
}

static Bytes*
newbytes(int len)
{
	Bytes* ans;

	if(len < 0)
		abort();
	ans = emalloc(sizeof(Bytes) + len);
	ans->len = len;
	return ans;
}

/*
 * newbytes(len), with data initialized from buf
 */
static Bytes*
makebytes(uchar* buf, int len)
{
	Bytes* ans;

	ans = newbytes(len);
	memmove(ans->data, buf, len);
	return ans;
}

static void
freebytes(Bytes* b)
{
	free(b);
}

static mpint*
bytestomp(Bytes* bytes)
{
	return betomp(bytes->data, bytes->len, nil);
}

/*
 * Convert mpint* to Bytes, putting high order byte first.
 */
static Bytes*
mptobytes(mpint *big, int len)
{
	Bytes* ans;

	if(len == 0) len++;
	ans = newbytes(len);
	mptober(big, ans->data, ans->len);
	return ans;
}

/* len is number of ints */
static Ints*
newints(int len)
{
	Ints* ans;

	if(len < 0 || len > ((uint)-1>>1)/sizeof(int))
		abort();
	ans = emalloc(sizeof(Ints) + len*sizeof(int));
	ans->len = len;
	return ans;
}

static void
freeints(Ints* b)
{
	free(b);
}

static int
lookupid(Ints* b, int id)
{
	int i;

	for(i=0; i<b->len; i++)
		if(b->data[i] == id)
			return i;
	return -1;
}