shithub: riscv

Download patch

ref: d566a5ca6b3d105b2aa5778dc5cb08113b48bd50
parent: 778e2af7befb1d9071995b104f8b35476b0d2091
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Wed Apr 19 17:03:26 EDT 2017

ssh: fix locking, and key reexchange handling

when we initiate re-key exchange we data packets can still
come in. so instead we have everything that can come in all
the time in dispatch() function (including KEXINIT) and have
the receiver process just call that in a loop. exclude dispatch
and the sender proc from corrupting each others sendpkt() calls
with the QLock sl.

--- a/sys/src/cmd/ssh.c
+++ b/sys/src/cmd/ssh.c
@@ -44,6 +44,7 @@
 
 typedef struct
 {
+	int		pid;
 	u32int		seq;
 	u32int		kex;
 	Chachastate	cs1;
@@ -59,19 +60,18 @@
 int nsid;
 uchar sid[256];
 
-int fd, pid1, pid2, intr, raw, debug;
+int fd, intr, raw, debug;
 char *user, *status, *host, *cmd;
 
 Oneway recv, send;
+void dispatch(void);
 
 void
 shutdown(void)
 {
-	int pid = getpid();
-	if(pid1 && pid1 != pid)
-		postnote(PNPROC, pid1, "shutdown");
-	if(pid2 && pid2 != pid)
-		postnote(PNPROC, pid2, "shutdown");
+	recv.eof = send.eof = 1;
+	if(send.pid > 0)
+		postnote(PNPROC, send.pid, "shutdown");
 }
 
 void
@@ -353,35 +353,6 @@
 	return recv.r[0];
 }
 
-void
-unexpected(char *info)
-{
-	char *s;
-	int n, c;
-
-	switch(recv.r[0]){
-	case MSG_DISCONNECT:
-		if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
-			break;
-		sysfatal("disconnect: (%d) %.*s", c, n, s);
-		break;
-	case MSG_IGNORE:
-	case MSG_GLOBAL_REQUEST:
-		return;
-	case MSG_DEBUG:
-		if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
-			break;
-		if(c != 0) fprint(2, "%s: %.*s\n", argv0, n, s);
-		return;
-	case MSG_USERAUTH_BANNER:
-		if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
-			break;
-		if(raw) write(2, s, n);
-		return;
-	}
-	sysfatal("%s got: %.*H", info, (int)(recv.w - recv.r), recv.r);
-}
-
 static char sshrsa[] = "ssh-rsa";
 
 int
@@ -538,7 +509,7 @@
 	if(!gotkexinit){
 	Next0:	switch(recvpkt()){
 		default:
-			unexpected("KEXINIT");
+			dispatch();
 			goto Next0;
 		case MSG_KEXINIT:
 			break;
@@ -570,8 +541,10 @@
 	sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc));
 Next1:	switch(recvpkt()){
 	default:
-		unexpected("ECDH_INIT");
+		dispatch();
 		goto Next1;
+	case MSG_KEXINIT:
+		sysfatal("inception");
 	case MSG_ECDH_REPLY:
 		if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0)
 			sysfatal("bad ECDH_REPLY");
@@ -607,8 +580,10 @@
 	sendpkt("b", MSG_NEWKEYS);
 Next2:	switch(recvpkt()){
 	default:
-		unexpected("NEWKEYS");
+		dispatch();
 		goto Next2;
+	case MSG_KEXINIT:
+		sysfatal("inception");
 	case MSG_NEWKEYS:
 		break;
 	}
@@ -647,7 +622,7 @@
 	sendpkt("bs", MSG_SERVICE_REQUEST, sshuserauth, sizeof(sshuserauth)-1);
 Next0:	switch(recvpkt()){
 	default:
-		unexpected("SERVICE_REQUEST");
+		dispatch();
 		goto Next0;
 	case MSG_SERVICE_ACCEPT:
 		break;
@@ -690,7 +665,7 @@
 			pk, npk);
 Next1:		switch(recvpkt()){
 		default:
-			unexpected("USERAUTH_REQUEST");
+			dispatch();
 			goto Next1;
 		case MSG_USERAUTH_FAILURE:
 			continue;
@@ -733,7 +708,7 @@
 			sig, nsig);
 Next2:		switch(recvpkt()){
 		default:
-			unexpected("USERAUTH_REQUEST");
+			dispatch();
 			goto Next2;
 		case MSG_USERAUTH_FAILURE:
 			continue;
@@ -751,6 +726,83 @@
 	return -1;	
 }
 
+void
+dispatch(void)
+{
+	char *s;
+	uchar *p;
+	int n, b, c;
+
+	switch(recv.r[0]){
+	case MSG_IGNORE:
+	case MSG_GLOBAL_REQUEST:
+	case MSG_CHANNEL_WINDOW_ADJUST:
+		return;
+	case MSG_DISCONNECT:
+		if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
+			break;
+		sysfatal("disconnect: (%d) %.*s", c, n, s);
+		return;
+	case MSG_DEBUG:
+		if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
+			break;
+		if(c != 0 || debug) fprint(2, "%s: %.*s\n", argv0, n, s);
+		return;
+	case MSG_USERAUTH_BANNER:
+		if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
+			break;
+		if(raw) write(2, s, n);
+		return;
+	case MSG_CHANNEL_DATA:
+		if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
+			break;
+		if(c != 0)
+			break;
+		if(write(1, s, n) != n)
+			sysfatal("write out: %r");
+	Winadjust:
+		sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n);
+		return;
+	case MSG_CHANNEL_EXTENDED_DATA:
+		if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
+			break;
+		if(c != 0)
+			break;
+		if(b == 1) write(2, s, n);
+		goto Winadjust;
+	case MSG_CHANNEL_REQUEST:
+		if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
+			break;
+		if(c != 0)
+			break;
+		if(n == 11 && memcmp(s, "exit-signal", n) == 0){
+			if(unpack(p, recv.w-p, "s", &s, &n) < 0)
+				break;
+			if(n != 0 && status == nil)
+				status = smprint("%.*s", n, s);
+		} else if(n == 11 && memcmp(s, "exit-status", n) == 0){
+			if(unpack(p, recv.w-p, "u", &n) < 0)
+				break;
+			if(n != 0 && status == nil)
+				status = smprint("%d", n);
+		} else if(debug) {
+			fprint(2, "%s: channel request: %.*s\n", argv0, n, s);
+		}
+		return;
+	case MSG_CHANNEL_EOF:
+		recv.eof = 1;
+		if(!raw) write(1, "", 0);
+		return;
+	case MSG_CHANNEL_CLOSE:
+		shutdown();
+		return;
+	case MSG_KEXINIT:
+		kex(1);
+		return;
+	}
+	sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r);
+}
+
 char*
 readline(void)
 {
@@ -830,7 +882,6 @@
 	static QLock sl;
 	int b, n, c;
 	char *s;
-	uchar *p;
 
 	quotefmtinstall();
 	fmtinstall('B', mpfmt);
@@ -889,7 +940,6 @@
 	recv.v = strdup(recv.v);
 
 	kex(0);
-
 	if(user == nil)
 		user = getuser();
 	if(auth(user, "ssh-connection") < 0)
@@ -902,125 +952,92 @@
 		sizeof(buf),
 		sizeof(buf));
 
-	while((send.eof | recv.eof) == 0){
-		if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0){
-			qlock(&sl);
-			kex(0);
+Next0:	switch(recvpkt()){
+	default:
+		dispatch();
+		goto Next0;
+	case MSG_CHANNEL_OPEN_FAILURE:
+		if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
+			n = strlen(s = "???");
+		sysfatal("channel open failure: (%d) %.*s", b, n, s);
+	case MSG_CHANNEL_OPEN_CONFIRMATION:
+		break;
+	}
+
+	notify(catch);
+	atexit(shutdown);
+
+	recv.pid = getpid();
+	n = rfork(RFPROC|RFMEM);
+	if(n < 0)
+		sysfatal("fork: %r");
+
+	/* parent reads and dispatches packets */
+	if(n > 0) {
+		send.pid = n;
+		while((send.eof|recv.eof) == 0){
+			recvpkt();
+			qlock(&sl);					
+			dispatch();
+			if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
+				kex(0);
 			qunlock(&sl);
 		}
-		switch(recvpkt()){
-		default:
-			unexpected("CHANNEL");
-			continue;
-		case MSG_KEXINIT:
-			qlock(&sl);
-			kex(1);
-			qunlock(&sl);
-			continue;
-		case MSG_CHANNEL_WINDOW_ADJUST:
-			continue;
-		case MSG_CHANNEL_EXTENDED_DATA:
-			if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
-				unexpected("CHANNEL_EXTENDED_DATA");
-			if(b == 1) write(2, s, n);
-			sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n);
-			continue;
-		case MSG_CHANNEL_DATA:
-			if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
-				unexpected("CHANNEL_DATA");
-			write(1, s, n);
-			sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n);
-			continue;
-		case MSG_CHANNEL_EOF:
-			recv.eof = 1;
-			if(!raw) write(1, "", 0);
-			continue;
-		case MSG_CHANNEL_OPEN_FAILURE:
-			if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
-				unexpected("CHANNEL_OPEN_FAILURE");
-			sysfatal("channel open failure: (%d) %.*s", b, n, s);
+		exits(status);
+	}
+
+	/* child reads input and sends packets */
+	qlock(&sl);
+	if(raw) {
+		rawon();
+		sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
+			0,
+			"pty-req", 7,
+			0,
+			tty.term, strlen(tty.term),
+			tty.cols,
+			tty.lines,
+			tty.xpixels,
+			tty.ypixels,
+			"", 0);
+	}
+	if(cmd == nil){
+		sendpkt("busb", MSG_CHANNEL_REQUEST,
+			0,
+			"shell", 5,
+			0);
+	} else {
+		sendpkt("busbs", MSG_CHANNEL_REQUEST,
+			0,
+			"exec", 4,
+			0,
+			cmd, strlen(cmd));
+	}
+	for(;;){
+		qunlock(&sl);
+		n = read(0, buf, sizeof(buf));
+		qlock(&sl);
+		if(send.eof)
 			break;
-		case MSG_CHANNEL_OPEN_CONFIRMATION:
-			if(raw) {
-				rawon();
-				sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
-					0,
-					"pty-req", 7,
-					0,
-					tty.term, strlen(tty.term),
-					tty.cols,
-					tty.lines,
-					tty.xpixels,
-					tty.ypixels,
-					"", 0);
-			}
-			if(cmd == nil){
-				sendpkt("busb", MSG_CHANNEL_REQUEST,
-					0,
-					"shell", 5,
-					0);
-			} else {
-				sendpkt("busbs", MSG_CHANNEL_REQUEST,
-					0,
-					"exec", 4,
-					0,
-					cmd, strlen(cmd));
-			}
-			if(pid2)
-				continue;
-			pid1 = getpid();
-			notify(catch);
-			atexit(shutdown);
-			n = rfork(RFPROC|RFMEM);
-			if(n){
-				pid2 = n;
-				continue;
-			}
-			qlock(&sl);
-			for(;;){
-				qunlock(&sl);
-				n = read(0, buf, sizeof(buf));
-				qlock(&sl);
-				if(n < 0 && wasintr()){
-					sendpkt("busbs", MSG_CHANNEL_REQUEST,
-						0,
-						"signal", 6,
-						0,
-						"INT", 3);
-					intr = 0;
-					continue;
-				}
-				if(n <= 0)
-					break;
-				sendpkt("bus", MSG_CHANNEL_DATA,
-					0,
-					buf, n);
-			}
-			send.eof = 1;
-			sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0);
-			qunlock(&sl);
-			break;
-		case MSG_CHANNEL_REQUEST:
-			if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
-				unexpected("CHANNEL_REQUEST");
-			if(n == 11 && memcmp(s, "exit-signal", n) == 0){
-				if(unpack(p, recv.w-p, "s", &s, &n) < 0)
-					continue;
-				if(n != 0 && status == nil)
-					status = smprint("%.*s", n, s);
-			} else if(n == 11 && memcmp(s, "exit-status", n) == 0){
-				if(unpack(p, recv.w-p, "u", &n) < 0)
-					continue;
-				if(n != 0 && status == nil)
-					status = smprint("%d", n);
-			} else {
-				fprint(2, "%s: channel request: %.*s\n", argv0, n, s);
-			}
+		if(n < 0 && wasintr()){
+			if(!raw) break;
+			sendpkt("busbs", MSG_CHANNEL_REQUEST,
+				0,
+				"signal", 6,
+				0,
+				"INT", 3);
+			intr = 0;
 			continue;
-		case MSG_CHANNEL_CLOSE:
-			break;
 		}
-		break;
+		if(n <= 0)
+			break;
+		sendpkt("bus", MSG_CHANNEL_DATA,
+			0,
+			buf, n);
 	}
-	exits(status);
+	if(send.eof++ == 0)
+		sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0);
+	qunlock(&sl);
+
+	exits(nil);
 }