shithub: riscv

ref: d552fed38514dc799c7fea95dfb632c8816c3f60
dir: /sys/src/cmd/ssh/msg.c/

View raw version
#include "ssh.h"

static ulong sum32(ulong, void*, int);

char *msgnames[] =
{
/* 0 */
	"SSH_MSG_NONE",
	"SSH_MSG_DISCONNECT",
	"SSH_SMSG_PUBLIC_KEY",
	"SSH_CMSG_SESSION_KEY",
	"SSH_CMSG_USER",
	"SSH_CMSG_AUTH_RHOSTS",
	"SSH_CMSG_AUTH_RSA",
	"SSH_SMSG_AUTH_RSA_CHALLENGE",
	"SSH_CMSG_AUTH_RSA_RESPONSE",
	"SSH_CMSG_AUTH_PASSWORD",

/* 10 */
	"SSH_CMSG_REQUEST_PTY",
	"SSH_CMSG_WINDOW_SIZE",
	"SSH_CMSG_EXEC_SHELL",
	"SSH_CMSG_EXEC_CMD",
	"SSH_SMSG_SUCCESS",
	"SSH_SMSG_FAILURE",
	"SSH_CMSG_STDIN_DATA",
	"SSH_SMSG_STDOUT_DATA",
	"SSH_SMSG_STDERR_DATA",
	"SSH_CMSG_EOF",

/* 20 */
	"SSH_SMSG_EXITSTATUS",
	"SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
	"SSH_MSG_CHANNEL_OPEN_FAILURE",
	"SSH_MSG_CHANNEL_DATA",
	"SSH_MSG_CHANNEL_INPUT_EOF",
	"SSH_MSG_CHANNEL_OUTPUT_CLOSED",
	"SSH_MSG_UNIX_DOMAIN_X11_FORWARDING (obsolete)",
	"SSH_SMSG_X11_OPEN",
	"SSH_CMSG_PORT_FORWARD_REQUEST",
	"SSH_MSG_PORT_OPEN",

/* 30 */
	"SSH_CMSG_AGENT_REQUEST_FORWARDING",
	"SSH_SMSG_AGENT_OPEN",
	"SSH_MSG_IGNORE",
	"SSH_CMSG_EXIT_CONFIRMATION",
	"SSH_CMSG_X11_REQUEST_FORWARDING",
	"SSH_CMSG_AUTH_RHOSTS_RSA",
	"SSH_MSG_DEBUG",
	"SSH_CMSG_REQUEST_COMPRESSION",
	"SSH_CMSG_MAX_PACKET_SIZE",
	"SSH_CMSG_AUTH_TIS",

/* 40 */
	"SSH_SMSG_AUTH_TIS_CHALLENGE",
	"SSH_CMSG_AUTH_TIS_RESPONSE",
	"SSH_CMSG_AUTH_KERBEROS",
	"SSH_SMSG_AUTH_KERBEROS_RESPONSE",
	"SSH_CMSG_HAVE_KERBEROS_TGT"
};

void
badmsg(Msg *m, int want)
{
	char *s, buf[20+ERRMAX];

	if(m==nil){
		snprint(buf, sizeof buf, "<early eof: %r>");
		s = buf;
	}else{
		snprint(buf, sizeof buf, "<unknown type %d>", m->type);
		s = buf;
		if(0 <= m->type && m->type < nelem(msgnames))
			s = msgnames[m->type];
	}
	if(want)
		error("got %s message expecting %s", s, msgnames[want]);
	error("got unexpected %s message", s);
}

Msg*
allocmsg(Conn *c, int type, int len)
{
	uchar *p;
	Msg *m;

	if(len > 256*1024)
		abort();

	m = (Msg*)emalloc(sizeof(Msg)+4+8+1+len+4);
	setmalloctag(m, getcallerpc(&c));
	p = (uchar*)&m[1];
	m->c = c;
	m->bp = p;
	m->ep = p+len;
	m->wp = p;
	m->type = type;
	return m;
}

void
unrecvmsg(Conn *c, Msg *m)
{
	debug(DBG_PROTO, "unreceived %s len %zd\n", msgnames[m->type], m->ep - m->rp);
	free(c->unget);
	c->unget = m;
}

static Msg*
recvmsg0(Conn *c)
{
	int pad;
	uchar *p, buf[4];
	ulong crc, crc0, len;
	Msg *m;

	if(c->unget){
		m = c->unget;
		c->unget = nil;
		return m;
	}

	if(readn(c->fd[0], buf, 4) != 4){
		werrstr("short net read: %r");
		return nil;
	}

	len = LONG(buf);
	if(len > 256*1024){
		werrstr("packet size far too big: %.8lux", len);
		return nil;
	}

	pad = 8 - len%8;

	m = (Msg*)emalloc(sizeof(Msg)+pad+len);
	setmalloctag(m, getcallerpc(&c));
	m->c = c;
	m->bp = (uchar*)&m[1];
	m->ep = m->bp + pad+len-4;	/* -4: don't include crc */
	m->rp = m->bp;

	if(readn(c->fd[0], m->bp, pad+len) != pad+len){
		werrstr("short net read: %r");
		free(m);
		return nil;
	}

	if(c->cipher)
		c->cipher->decrypt(c->cstate, m->bp, len+pad);

	crc = sum32(0, m->bp, pad+len-4);
	p = m->bp + pad+len-4;
	crc0 = LONG(p);
	if(crc != crc0){
		werrstr("bad crc %#lux != %#lux (packet length %lud)", crc, crc0, len);
		free(m);
		return nil;
	}

	m->rp += pad;
	m->type = *m->rp++;

	return m;
}

Msg*
recvmsg(Conn *c, int type)
{
	Msg *m;

	while((m = recvmsg0(c)) != nil){
		debug(DBG_PROTO, "received %s len %zd\n", msgnames[m->type], m->ep - m->rp);
		if(m->type != SSH_MSG_DEBUG && m->type != SSH_MSG_IGNORE)
			break;
		if(m->type == SSH_MSG_DEBUG)
			debug(DBG_PROTO, "remote DEBUG: %s\n", getstring(m));
		free(m);
	}
	if(type == 0){
		/* no checking */
	}else if(type == -1){
		/* must not be nil */
		if(m == nil)
			error(Ehangup);
	}else{
		/* must be given type */
		if(m==nil || m->type!=type)
			badmsg(m, type);
	}
	setmalloctag(m, getcallerpc(&c));
	return m;
}

int
sendmsg(Msg *m)
{
	int i, pad;
	uchar *p;
	ulong datalen, len, crc;
	Conn *c;

	datalen = m->wp - m->bp;
	len = datalen + 5;
	pad = 8 - len%8;

	debug(DBG_PROTO, "sending %s len %lud\n", msgnames[m->type], datalen);

	p = m->bp;
	memmove(m->bp+4+pad+1, m->bp, datalen);	/* slide data to correct position */

	PLONG(p, len);
	p += 4;

	if(m->c->cstate){
		for(i=0; i<pad; i++)
			*p++ = fastrand();
	}else{
		memset(p, 0, pad);
		p += pad;
	}

	*p++ = m->type;

	/* data already in position */
	p += datalen;

	crc = sum32(0, m->bp+4, pad+1+datalen);
	PLONG(p, crc);
	p += 4;

	c = m->c;
	qlock(c);
	if(c->cstate)
		c->cipher->encrypt(c->cstate, m->bp+4, len+pad);

	if(write(c->fd[1], m->bp, p - m->bp) != p-m->bp){
		qunlock(c);
		free(m);
		return -1;
	}
	qunlock(c);
	free(m);
	return 0;
}

uchar
getbyte(Msg *m)
{
	if(m->rp >= m->ep)
		error(Edecode);
	return *m->rp++;
}

ushort
getshort(Msg *m)
{
	ushort x;

	if(m->rp+2 > m->ep)
		error(Edecode);

	x = SHORT(m->rp);
	m->rp += 2;
	return x;
}

ulong
getlong(Msg *m)
{
	ulong x;

	if(m->rp+4 > m->ep)
		error(Edecode);

	x = LONG(m->rp);
	m->rp += 4;
	return x;
}

char*
getstring(Msg *m)
{
	char *p;
	ulong len;

	/* overwrites length to make room for NUL */
	len = getlong(m);
	if(m->rp+len > m->ep)
		error(Edecode);
	p = (char*)m->rp-1;
	memmove(p, m->rp, len);
	p[len] = '\0';
	return p;
}

void*
getbytes(Msg *m, int n)
{
	uchar *p;

	if(m->rp+n > m->ep)
		error(Edecode);
	p = m->rp;
	m->rp += n;
	return p;
}

mpint*
getmpint(Msg *m)
{
	int n;

	n = (getshort(m)+7)/8;	/* getshort returns # bits */
	return betomp(getbytes(m, n), n, nil);
}

RSApub*
getRSApub(Msg *m)
{
	RSApub *key;

	getlong(m);
	key = rsapuballoc();
	if(key == nil)
		error(Ememory);
	key->ek = getmpint(m);
	key->n = getmpint(m);
	setmalloctag(key, getcallerpc(&m));
	return key;
}

void
putbyte(Msg *m, uchar x)
{
	if(m->wp >= m->ep)
		error(Eencode);
	*m->wp++ = x;
}

void
putshort(Msg *m, ushort x)
{
	if(m->wp+2 > m->ep)
		error(Eencode);
	PSHORT(m->wp, x);
	m->wp += 2;
}

void
putlong(Msg *m, ulong x)
{
	if(m->wp+4 > m->ep)
		error(Eencode);
	PLONG(m->wp, x);
	m->wp += 4;
}

void
putstring(Msg *m, char *s)
{
	int len;

	len = strlen(s);
	putlong(m, len);
	putbytes(m, s, len);
}

void
putbytes(Msg *m, void *a, long n)
{
	if(m->wp+n > m->ep)
		error(Eencode);
	memmove(m->wp, a, n);
	m->wp += n;
}

void
putmpint(Msg *m, mpint *b)
{
	int bits, n;

	bits = mpsignif(b);
	putshort(m, bits);
	n = (bits+7)/8;
	if(m->wp+n > m->ep)
		error(Eencode);
	mptobe(b, m->wp, n, nil);
	m->wp += n;
}

void
putRSApub(Msg *m, RSApub *key)
{
	putlong(m, mpsignif(key->n));
	putmpint(m, key->ek);
	putmpint(m, key->n);
}

static ulong crctab[256];

static void
initsum32(void)
{
	ulong crc, poly;
	int i, j;

	poly = 0xEDB88320;
	for(i = 0; i < 256; i++){
		crc = i;
		for(j = 0; j < 8; j++){
			if(crc & 1)
				crc = (crc >> 1) ^ poly;
			else
				crc >>= 1;
		}
		crctab[i] = crc;
	}
}

static ulong
sum32(ulong lcrc, void *buf, int n)
{
	static int first=1;
	uchar *s = buf;
	ulong crc = lcrc;

	if(first){
		first=0;
		initsum32();
	}
	while(n-- > 0)
		crc = crctab[(crc^*s++)&0xff] ^ (crc>>8);
	return crc;
}

mpint*
rsapad(mpint *b, int n)
{
	int i, pad, nbuf;
	uchar buf[2560];
	mpint *c;

	if(n > sizeof buf)
		error("buffer too small in rsapad");

	nbuf = (mpsignif(b)+7)/8;
	pad = n - nbuf;
	assert(pad >= 3);
	mptobe(b, buf, nbuf, nil);
	memmove(buf+pad, buf, nbuf);

	buf[0] = 0;
	buf[1] = 2;
	for(i=2; i<pad-1; i++)
		buf[i]=1+fastrand()%255;
	buf[pad-1] = 0;
	c = betomp(buf, n, nil);
	memset(buf, 0, sizeof buf);
	return c;
}

mpint*
rsaunpad(mpint *b)
{
	int i, n;
	uchar buf[2560];

	n = (mpsignif(b)+7)/8;
	if(n > sizeof buf)
		error("buffer too small in rsaunpad");
	mptobe(b, buf, n, nil);

	/* the initial zero has been eaten by the betomp -> mptobe sequence */
	if(buf[0] != 2)
		error("bad data in rsaunpad");
	for(i=1; i<n; i++)
		if(buf[i]==0)
			break;
	return betomp(buf+i, n-i, nil);
}

void
mptoberjust(mpint *b, uchar *buf, int len)
{
	int n;

	n = mptobe(b, buf, len, nil);
	assert(n >= 0);
	if(n < len){
		len -= n;
		memmove(buf+len, buf, n);
		memset(buf, 0, len);
	}
}

mpint*
rsaencryptbuf(RSApub *key, uchar *buf, int nbuf)
{
	int n;
	mpint *a, *b, *c;

	n = (mpsignif(key->n)+7)/8;
	a = betomp(buf, nbuf, nil);
	b = rsapad(a, n);
	mpfree(a);
	c = rsaencrypt(key, b, nil);
	mpfree(b);
	return c;
}