shithub: riscv

Download patch

ref: a278545e3c492d97d32192e20de7c62c8b0ef4a2
parent: 3bb180463108545274592b53940274c52b1d9186
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Wed Apr 3 06:49:47 EDT 2019

sshnet: fix eof and close handling, use proper packet size, cleanup

--- a/sys/src/cmd/sshnet.c
+++ b/sys/src/cmd/sshnet.c
@@ -61,6 +61,7 @@
 	int state;
 	int num;
 	int servernum;
+	int sentclose;
 	char *connect;
 
 	int sendpkt;
@@ -68,6 +69,8 @@
 	int recvwin;
 	int recvacc;
 
+	int eof;
+
 	Req *wq;
 	Req **ewq;
 
@@ -91,7 +94,8 @@
 	MSG_CHANNEL_SUCCESS,
 	MSG_CHANNEL_FAILURE,
 
-	MaxPacket = 1<<15,
+	Overhead = 256,
+	MaxPacket = (1<<15)-256,	/* 32K is maxatomic for pipe */
 	WinPackets = 8,
 
 	SESSIONCHAN = 1<<24,
@@ -104,7 +108,7 @@
 	uchar	*rp;
 	uchar	*wp;
 	uchar	*ep;
-	uchar	buf[MaxPacket];
+	uchar	buf[MaxPacket + Overhead];
 };
 
 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
@@ -116,6 +120,7 @@
 int sshfd;
 int localport;
 char localip[] = "::";
+char Ehangup[] = "hangup on network connection";
 
 int
 vpack(uchar *p, int n, char *fmt, va_list a)
@@ -341,12 +346,10 @@
 	Msg *m;
 	int n, rm;
 
-	while(c->rq != nil && c->mq != nil){
-		r = c->rq;
+	while((r = c->rq) != nil && (m = c->mq) != nil){
 		c->rq = r->aux;
-
+		r->aux = nil;
 		rm = 0;
-		m = c->mq;
 		n = r->ifcall.count;
 		if(n >= m->wp - m->rp){
 			n = m->wp - m->rp;
@@ -362,6 +365,15 @@
 		respond(r, nil);
 		adjustwin(c, n);
 	}
+
+	if(c->eof){
+		while((r = c->rq) != nil){
+			c->rq = r->aux;
+			r->aux = nil;
+			r->ofcall.count = 0;
+			respond(r, nil);
+		}
+	}
 }
 
 void
@@ -438,57 +450,48 @@
 }
 
 void
-teardownclient(Client *c)
-{
-	c->state = Teardown;
-	sendmsg(pack(nil, "bu", MSG_CHANNEL_EOF, c->servernum));
-}
-
-void
 hangupclient(Client *c)
 {
-	Req *r, *next;
-	Msg *m, *mnext;
+	Req *r;
 
-	c->state = Closed;
-	for(m=c->mq; m; m=mnext){
-		mnext = m->link;
-		free(m);
+	c->eof = 1;
+	c->recvwin = 0;
+	c->sendwin = 0;
+	while((r = c->wq) != nil){
+		c->wq = r->aux;
+		r->aux = nil;
+		respond(r, Ehangup);
 	}
-	c->mq = nil;
-	for(r=c->rq; r; r=next){
-		next = r->aux;
-		respond(r, "hangup on network connection");
+	if(c->state == Established){
+		c->state = Teardown;
+		matchrmsgs(c);
+		return;
 	}
-	c->rq = nil;
-	for(r=c->wq; r; r=next){
-		next = r->aux;
-		respond(r, "hangup on network connection");
-	}
-	c->wq = nil;
+	c->state = Closed;
 }
 
 void
+teardownclient(Client *c)
+{
+	hangupclient(c);
+	if(c->sentclose++ == 0)
+		sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
+}
+
+void
 closeclient(Client *c)
 {
-	Msg *m, *next;
+	Msg *m;
 
 	if(--c->ref)
 		return;
-
-	if(c->rq != nil || c->wq != nil)
-		sysfatal("ref count reached zero with requests pending (BUG)");
-
-	for(m=c->mq; m; m=next){
-		next = m->link;
+	if(c->state >= Established)
+		teardownclient(c);
+	while((m = c->mq) != nil){
+		c->mq = m->link;
 		free(m);
 	}
-	c->mq = nil;
-
-	if(c->state != Closed)
-		teardownclient(c);
 }
-
 	
 void
 sshreadproc(void*)
@@ -810,6 +813,7 @@
 		nf = getfields(f[1], f, nelem(f), 0, "!");
 		if(nf != 2)
 			goto Badarg;
+		c->eof = 0;
 		c->sendwin = MaxPacket;
 		c->recvwin = WinPackets * MaxPacket;
 		c->recvacc = 0;
@@ -831,7 +835,7 @@
 static void
 dataread(Req *r, Client *c)
 {
-	if(c->state != Established){
+	if(c->state < Established){
 		respond(r, "not connected");
 		return;
 	}
@@ -1028,7 +1032,7 @@
 static void
 handlemsg(Msg *m)
 {
-	int chan, win, pkt, n, l;
+	int chan, win, pkt, n;
 	Client *c;
 	char *s;
 
@@ -1037,7 +1041,7 @@
 		if(unpack(m, "_uu", &chan, &n) < 0)
 			break;
 		c = getclient(chan);
-		if(c != nil && c->state==Established){
+		if(c != nil && c->state == Established){
 			c->sendwin += n;
 			procwreqs(c);
 		}
@@ -1046,7 +1050,9 @@
 		if(unpack(m, "_us", &chan, &s, &n) < 0)
 			break;
 		c = getclient(chan);
-		if(c != nil && c->state==Established){
+		if(c != nil && c->state == Established){
+			if(c->recvwin <= 0)
+				break;
 			c->recvwin -= n;
 			m->rp = (uchar*)s;
 			queuermsg(c, m);
@@ -1058,11 +1064,10 @@
 		if(unpack(m, "_u", &chan) < 0)
 			break;
 		c = getclient(chan);
-		if(c != nil){
-			hangupclient(c);
-			m->rp = m->wp = m->buf;
-			sendmsg(pack(m, "bu", MSG_CHANNEL_CLOSE, c->servernum));
-			return;
+		if(c != nil && c->state == Established){
+			c->eof = 1;
+			c->recvwin = 0;
+			matchrmsgs(c);
 		}
 		break;
 	case MSG_CHANNEL_CLOSE:
@@ -1069,7 +1074,7 @@
 		if(unpack(m, "_u", &chan) < 0)
 			break;
 		c = getclient(chan);
-		if(c != nil)
+		if(c != nil && c->state >= Established)
 			hangupclient(c);
 		break;
 	case MSG_CHANNEL_OPEN_CONFIRMATION:
@@ -1087,20 +1092,20 @@
 		c->sendpkt = pkt;
 		c->sendwin = win;
 		c->servernum = n;
+		c->sentclose = 0;
 		c->state = Established;
 		dialedclient(c);
 		break;
 	case MSG_CHANNEL_OPEN_FAILURE:
-		if(unpack(m, "_uus", &chan, &n, &s, &l) < 0)
+		if(unpack(m, "_u____s", &chan, &s, &n) < 0)
 			break;
 		if(chan == SESSIONCHAN){
-			sendp(ssherrchan, smprint("%.*s", utfnlen(s, l), s));
+			sendp(ssherrchan, smprint("%.*s", utfnlen(s, n), s));
 			break;
 		}
 		c = getclient(chan);
 		if(c == nil || c->state != Dialing)
 			break;
-		c->servernum = n;
 		c->state = Closed;
 		dialedclient(c);
 		break;