shithub: musw

Download patch

ref: fd5dc301e4a69d7b7c1293aafe5b069b4ff400a4
parent: 62e75d8830eb56ab03bd4689d51ffd6d4150f461
author: rodri <rgl@antares-labs.eu>
date: Thu Feb 16 08:35:24 EST 2023

implemented a keep alive mechanism.

also changed the newframe fn to take a Udphdr* instead of a Frame*.
the verifyframe fn now returns 1 if correct 0 otherwise.

--- a/dat.h
+++ b/dat.h
@@ -47,7 +47,8 @@
 enum {
 	ProtocolID	= 0x5753554d,	/* MUSW */
 	Framehdrsize	= 4+1+4+4+2+MD5dlen,
-	MTU		= 1024
+	MTU		= 1024,
+	ConnTimeout	= 10000		/* in ms */
 };
 
 typedef struct VModel VModel;
@@ -163,12 +164,14 @@
 	NCState state;
 	u32int lastseq;
 	u32int lastack;
+	ulong lastrecvts;	/* last time a packet was received (in ms) */
+	ulong lastnudgets;	/* last time a nudge was sent (in ms) */
 };
 
 struct Player
 {
 	char *name;
-	NetConn conn;
+	NetConn *conn;
 	ulong okdown, kdown;
 };
 
--- a/fns.h
+++ b/fns.h
@@ -52,7 +52,7 @@
 ulong dhgenkey(ulong, ulong, ulong);
 NetConn *newnetconn(NCState, Udphdr*);
 void delnetconn(NetConn*);
-Frame *newframe(Frame*, u8int, u32int, u32int, u16int, uchar*);
+Frame *newframe(Udphdr*, u8int, u32int, u32int, u16int, uchar*);
 void signframe(Frame*, ulong);
 int verifyframe(Frame*, ulong);
 void delframe(Frame*);
--- a/musw.c
+++ b/musw.c
@@ -268,7 +268,7 @@
 			case NShi:
 				unpack(frame->data, frame->len, "kk", &netconn.dh.p, &netconn.dh.g);
 
-				newf = newframe(frame, NCdhx, 0, 0, sizeof(ulong), nil);
+				newf = newframe(nil, NCdhx, frame->seq+1, frame->seq, sizeof(ulong), nil);
 
 				netconn.dh.sec = truerand();
 				pack(newf->data, newf->len, "k", dhgenkey(netconn.dh.g, netconn.dh.sec, netconn.dh.p));
@@ -290,7 +290,7 @@
 			}
 			break;
 		case NCSConnected:
-			if(verifyframe(frame, netconn.dh.priv) != 0){
+			if(!verifyframe(frame, netconn.dh.priv)){
 				if(debug)
 					fprint(2, "\tbad signature\n");
 				goto discard;
@@ -304,7 +304,7 @@
 					&universe->star.p);
 				break;
 			case NSnudge:
-				newf = newframe(frame, NCnudge, 0, 0, 0, nil);
+				newf = newframe(nil, NCnudge, frame->seq+1, frame->seq, 0, nil);
 				signframe(newf, netconn.dh.priv);
 
 				sendp(egress, newf);
--- a/muswd.c
+++ b/muswd.c
@@ -58,6 +58,32 @@
 }
 
 void
+nudgeconns(ulong curts)
+{
+	NetConn **ncp, **ncpe;
+	Frame *f;
+	ulong elapsed, elapsednudge;
+
+	ncpe = conns+nconns;
+
+	for(ncp = conns; ncp < ncpe; ncp++){
+		elapsed = curts - (*ncp)->lastrecvts;
+		elapsednudge = curts - (*ncp)->lastnudgets;
+
+		if((*ncp)->state == NCSConnected && elapsed > ConnTimeout){
+			popconn(*ncp);
+			delnetconn(*ncp);
+		}else if((*ncp)->state == NCSConnected && elapsednudge > 1000){ /* every second */
+			f = newframe(&(*ncp)->udp, NSnudge, 0, 0, 0, nil);
+			signframe(f, (*ncp)->dh.priv);
+			sendp(egress, f);
+
+			(*ncp)->lastnudgets = curts;
+		}
+	}
+}
+
+void
 threadnetrecv(void *arg)
 {
 	uchar buf[MTU];
@@ -105,7 +131,7 @@
 				nc = newnetconn(NCSConnecting, &frame->udp);
 				putconn(nc);
 
-				newf = newframe(frame, NShi, 0, 0, 2*sizeof(ulong), nil);
+				newf = newframe(&frame->udp, NShi, frame->seq+1, frame->seq, 2*sizeof(ulong), nil);
 
 				dhgenpg(&nc->dh.p, &nc->dh.g);
 				pack(newf->data, newf->len, "kk", nc->dh.p, nc->dh.g);
@@ -117,6 +143,8 @@
 				goto discard;
 		}
 
+		nc->lastrecvts = nanosec()/1e6;
+
 		switch(nc->state){
 		case NCSConnecting:
 			switch(frame->type){
@@ -127,7 +155,7 @@
 				if(debug)
 					fprint(2, "\trcvd pubkey %ld\n", nc->dh.pub);
 
-				newf = newframe(frame, NSdhx, 0, 0, sizeof(ulong), nil);
+				newf = newframe(&frame->udp, NSdhx, frame->seq+1, frame->seq, sizeof(ulong), nil);
 
 				nc->dh.sec = truerand();
 				nc->dh.priv = dhgenkey(nc->dh.pub, nc->dh.sec, nc->dh.p);
@@ -141,7 +169,7 @@
 			}
 			break;
 		case NCSConnected:
-			if(verifyframe(frame, nc->dh.priv) != 0){
+			if(!verifyframe(frame, nc->dh.priv)){
 				if(debug)
 					fprint(2, "\tbad signature\n");
 				goto discard;
@@ -250,6 +278,7 @@
 		}
 
 		broadcaststate();
+		nudgeconns(now/1e6);
 
 		iosleep(io, HZ2MS(70));
 	}
--- a/net.c
+++ b/net.c
@@ -66,23 +66,18 @@
 /* Frame */
 
 Frame *
-newframe(Frame *pf, u8int type, u32int seq, u32int ack, u16int len, uchar *data)
+newframe(Udphdr *hdr, u8int type, u32int seq, u32int ack, u16int len, uchar *data)
 {
 	Frame *f;
 
 	f = emalloc(sizeof(Frame)+len);
 	memset(f, 0, sizeof(Frame));
+	if(hdr != nil)
+		memmove(&f->udp, hdr, Udphdrsize);
 	f->id = ProtocolID;
 	f->type = type;
-	if(pf != nil){
-		memmove(&f->udp, &pf->udp, Udphdrsize);
-		f->seq = pf->seq+1;
-		f->ack = pf->seq;
-	}else{
-		memset(&f->udp, 0, Udphdrsize);
-		f->seq = seq;
-		f->ack = ack;
-	}
+	f->seq = seq;
+	f->ack = ack;
 	f->len = len;
 	if(data != nil)
 		memmove(f->data, data, f->len);
@@ -120,8 +115,7 @@
 	memset(f->sig, 0, MD5dlen);
 	n = pack(msg, sizeof msg, "f", f);
 	hmac_md5(msg, n, k, sizeof k, h1, nil);
-	memmove(f->sig, h0, MD5dlen);
-	return memcmp(h0, h1, MD5dlen);
+	return memcmp(h0, h1, MD5dlen) == 0;
 }
 
 void