shithub: musw

Download patch

ref: 775309861b51dd1f340d82074b7e9234f3e0675e
parent: c1cdf0f980b92193953f05bf444b7c78f369a122
author: rodri <rgl@antares-labs.eu>
date: Fri Feb 10 19:26:30 EST 2023

implemented connection establishment with per client DHX.

--- a/dat.h
+++ b/dat.h
@@ -14,6 +14,12 @@
 	WEDGE
 } Kind;
 
+typedef enum {
+	NCSDisconnected,
+	NCSConnecting,
+	NCSConnected
+} NCState;
+
 enum {
 	SCRW	= 640,
 	SCRH	= 480,
@@ -22,24 +28,25 @@
 };
 
 enum {
-	NChi	= 10,	/* C wants to connect */
-	NShi,		/* S accepts */
-	NCdhx0	= 12,	/* C asks for p and g */
-	NSdhx0,		/* S sends them. it's not a negotiation */
-	NCdhx1	= 14,	/* C shares pubkey */
-	NSdhx1,		/* S shares pubkey */
-	NCnudge	= 16,
-	NSnudge,	/* check the pulse of the line */
+	NChi		= 10,	/* C wants to connect */
+	NShi,			/* S accepts. sends P and G for DHX */
+	NCdhx		= 12,	/* C shares pubkey */
+	NSdhx,			/* S shares pubkey */
+	NCnudge		= 16,
+	NSnudge,		/* check the pulse of the line */
 
-	NCinput	= 20,	/* C sends player input state */
-	NSsimstate,	/* S sends current simulation state */
+	NCinput		= 20,	/* C sends player input state */
+	NSsimstate,		/* S sends current simulation state */
 
 	NCbuhbye	= 30,
-	NSbuhbye
+	NSbuhbye,
+
+	NSerror 	= 66	/* report an error */
 };
 
 enum {
-	Framehdrsize	= 1+4+4+2,
+	ProtocolID	= 0x5753554d,	/* MUSW */
+	Framehdrsize	= 4+1+4+4+2,
 	MTU		= 1024
 };
 
@@ -53,6 +60,7 @@
 typedef struct Derivative Derivative;
 
 typedef struct Frame Frame;
+typedef struct DHparams DHparams;
 typedef struct NetConn NetConn;
 typedef struct Player Player;
 typedef struct Party Party;
@@ -63,7 +71,7 @@
 struct VModel
 {
 	Point2 *pts;
-	ulong npts;
+	usize npts;
 	/* WIP
 	 * l(ine) → takes 2 points
 	 * c(urve) → takes 3 points
@@ -134,6 +142,7 @@
 struct Frame
 {
 	Udphdr udp;
+	u32int id;	/* ProtocolID */
 	u8int type;
 	u32int seq;
 	u32int ack;
@@ -141,10 +150,18 @@
 	uchar data[];
 };
 
+struct DHparams
+{
+	ulong p, g, pub, sec, priv;
+};
+
 struct NetConn
 {
 	Udphdr udp;
-	int isconnected;
+	DHparams dh;
+	NCState state;
+	u32int lastseq;
+	u32int lastack;
 };
 
 struct Player
--- a/fns.h
+++ b/fns.h
@@ -39,8 +39,18 @@
 void inituniverse(Universe*);
 
 /*
- *	sprite
+ * sprite
  */
 Sprite *newsprite(Image*, Point, Rectangle, int, ulong);
 Sprite *readsprite(char*, Point, Rectangle, int, ulong);
 void delsprite(Sprite*);
+
+/*
+ * net
+ */
+void dhgenpg(ulong*, ulong*);
+ulong dhgenkey(ulong, ulong, ulong);
+NetConn *newnetconn(NCState, Udphdr*);
+void delnetconn(NetConn*);
+Frame *newframe(Frame*, u8int, u32int, u32int, u16int, uchar*);
+void delframe(Frame*);
--- a/mkfile
+++ b/mkfile
@@ -14,6 +14,7 @@
 	party.$O\
 	universe.$O\
 	sprite.$O\
+	net.$O\
 
 HFILES=\
 	dat.h\
--- a/musw.c
+++ b/musw.c
@@ -37,6 +37,7 @@
 Image *skymap;
 Channel *ingress;
 Channel *egress;
+NetConn netconn;
 char winspec[32];
 int debug;
 
@@ -85,7 +86,7 @@
 			continue;
 		case 'v':
 			if(tokenize(s, args, nelem(args)) != nelem(args)){
-				werrstr("syntax error: %s:%lud 'v' expects %d args",
+				werrstr("syntax error: %s:%ld 'v' expects %d args",
 					file, lineno, nelem(args));
 				free(mdl);
 				Bterm(bin);
@@ -144,15 +145,24 @@
 }
 
 void
+initconn(void)
+{
+	Frame *frame;
+
+	frame = newframe(nil, NChi, 0, 0, 0, nil);
+	sendp(egress, frame);
+	netconn.state = NCSConnecting;
+}
+
+void
 sendkeys(ulong kdown)
 {
 	Frame *frame;
 
-	frame = emalloc(sizeof(Frame)+sizeof(kdown));
-	frame->type = NCinput;
-	frame->seq = 0;
-	frame->ack = 0;
-	frame->len = sizeof(kdown);
+	if(netconn.state != NCSConnected)
+		return;
+
+	frame = newframe(nil, NCinput, 0, 0, sizeof(kdown), nil);
 	pack(frame->data, frame->len, "k", kdown);
 	sendp(egress, frame);
 }
@@ -214,6 +224,7 @@
 {
 	uchar buf[MTU];
 	int fd, n;
+	ushort rport, lport;
 	Ioproc *io;
 	Frame *frame;
 
@@ -223,9 +234,17 @@
 	io = ioproc();
 
 	while((n = ioread(io, fd, buf, sizeof buf)) > 0){
-		frame = emalloc(sizeof(Frame)+(n-Framehdrsize));
+		frame = newframe(nil, 0, 0, 0, n-Framehdrsize, nil);
 		unpack(buf, n, "f", frame);
 		sendp(ingress, frame);
+
+		if(debug){
+			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
+			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
+			fprint(2, "%I!%ud ← %I!%ud | rcvd type %ud seq %ud ack %ud len %ud\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport,
+				frame->type, frame->seq, frame->ack, frame->len);
+		}
 	}
 	closeioproc(io);
 }
@@ -233,20 +252,62 @@
 void
 threadnetppu(void *)
 {
-	Frame *frame;
+	Frame *frame, *newf;
 
 	threadsetname("threadnetppu");
 
 	while((frame = recvp(ingress)) != nil){
-		switch(frame->type){
-		case NSsimstate:
-			unpack(frame->data, frame->len, "PdPdP",
-				&universe->ships[0].p, &universe->ships[0].θ,
-				&universe->ships[1].p, &universe->ships[1].θ,
-				&universe->star.p);
+		if(frame->id != ProtocolID)
+			goto discard;
+
+		switch(netconn.state){
+		case NCSConnecting:
+			switch(frame->type){
+			case NShi:
+				unpack(frame->data, frame->len, "kk", &netconn.dh.p, &netconn.dh.g);
+
+				newf = newframe(frame, NCdhx, 0, 0, sizeof(ulong), nil);
+	
+				netconn.dh.sec = truerand();
+				pack(newf->data, newf->len, "k", dhgenkey(netconn.dh.g, netconn.dh.sec, netconn.dh.p));
+				sendp(egress, newf);
+
+				if(debug)
+					fprint(2, "\tsent pubkey %ld\n", dhgenkey(netconn.dh.g, netconn.dh.sec, netconn.dh.p));
+	
+				break;
+			case NSdhx:
+				unpack(frame->data, frame->len, "k", &netconn.dh.pub);
+				netconn.state = NCSConnected;
+
+				if(debug)
+					fprint(2, "\trecvd pubkey %ld\n", netconn.dh.pub);
+
+				netconn.dh.priv = dhgenkey(netconn.dh.pub, netconn.dh.sec, netconn.dh.p);
+				break;
+			}
 			break;
-		}
+		case NCSConnected:
+			switch(frame->type){
+			case NSsimstate:
+				unpack(frame->data, frame->len, "PdPdP",
+					&universe->ships[0].p, &universe->ships[0].θ,
+					&universe->ships[1].p, &universe->ships[1].θ,
+					&universe->star.p);
+				break;
+			case NSnudge:
+				newf = newframe(frame, NCnudge, 0, 0, 0, nil);
 
+				sendp(egress, newf);
+
+				break;
+			case NSbuhbye:
+				netconn.state = NCSDisconnected;
+				break;
+			}
+			break;
+		}
+discard:
 		free(frame);
 	}
 }
@@ -256,6 +317,7 @@
 {
 	uchar buf[MTU];
 	int fd, n;
+	ushort rport, lport;
 	Frame *frame;
 
 	threadsetname("threadnetsend");
@@ -264,9 +326,18 @@
 
 	while((frame = recvp(egress)) != nil){
 		n = pack(buf, sizeof buf, "f", frame);
-		free(frame);
 		if(write(fd, buf, n) != n)
 			sysfatal("write: %r");
+
+		if(debug){
+			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
+			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
+			fprint(2, "%I!%ud → %I!%ud | sent type %ud seq %ud ack %ud len %ud\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport,
+				frame->type, frame->seq, frame->ack, frame->len);
+		}
+
+		free(frame);
 	}
 }
 
@@ -366,6 +437,7 @@
 	Ioproc *io;
 
 	GEOMfmtinstall();
+	fmtinstall('I', eipfmt);
 	ARGBEGIN{
 	case 'd':
 		debug++;
@@ -415,6 +487,7 @@
 	ingress = chancreate(sizeof(Frame*), 8);
 	egress = chancreate(sizeof(Frame*), 8);
 	threadcreate(threadnetrecv, &fd, mainstacksize);
+	threadcreate(threadnetppu, nil, mainstacksize);
 	threadcreate(threadnetsend, &fd, mainstacksize);
 	threadcreate(threadresize, mc, mainstacksize);
 
@@ -428,6 +501,9 @@
 		universe->star.spr->step(universe->star.spr, frametime/1e6);
 
 		redraw();
+
+		if(netconn.state == NCSDisconnected)
+			initconn();
 
 		iosleep(io, HZ2MS(30));
 	}
--- a/muswd.c
+++ b/muswd.c
@@ -11,15 +11,56 @@
 int mainstacksize = 24*1024;
 
 Party theparty;
+NetConn **conns;
+usize nconns;
+usize maxconns;
 Channel *ingress;
 Channel *egress;
 
 
 void
+putconn(NetConn *nc)
+{
+	if(++nconns > maxconns){
+		conns = erealloc(conns, sizeof(NetConn*)*nconns);
+		maxconns = nconns;
+	}
+	conns[nconns-1] = nc;
+}
+
+NetConn *
+getconn(Frame *f)
+{
+	NetConn **nc;
+
+	for(nc = conns; nc < conns+nconns; nc++)
+		if(memcmp(&(*nc)->udp, &f->udp, Udphdrsize) == 0)
+			return *nc;
+	return nil;
+}
+
+int
+popconn(NetConn *nc)
+{
+	NetConn **ncp, **ncpe;
+
+	ncpe = conns+nconns;
+
+	for(ncp = conns; ncp < conns+nconns; ncp++)
+		if(*ncp == nc){
+			memmove(ncp, ncp+1, sizeof(NetConn*)*(ncpe-ncp-1));
+			nconns--;
+			return 0;
+		}
+	return -1;
+}
+
+void
 threadnetrecv(void *arg)
 {
 	uchar buf[MTU];
 	int fd, n;
+	ushort rport, lport;
 	Ioproc *io;
 	Frame *frame;
 
@@ -29,9 +70,17 @@
 	io = ioproc();
 
 	while((n = ioread(io, fd, buf, sizeof buf)) > 0){
-		frame = emalloc(sizeof(Frame)+(n-Udphdrsize-Framehdrsize));
+		frame = newframe(nil, 0, 0, 0, n-Udphdrsize-Framehdrsize, nil);
 		unpack(buf, n, "F", frame);
 		sendp(ingress, frame);
+
+		if(debug){
+			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
+			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
+			fprint(2, "%I!%ud ← %I!%ud | rcvd type %ud seq %ud ack %ud len %ud\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport,
+				frame->type, frame->seq, frame->ack, frame->len);
+		}
 	}
 	closeioproc(io);
 }
@@ -39,29 +88,74 @@
 void
 threadnetppu(void *)
 {
-	ushort rport, lport;
 	ulong kdown;
-	Frame *frame;
+	Frame *frame, *newf;
+	NetConn *nc;
 
 	threadsetname("threadnetppu");
 
 	while((frame = recvp(ingress)) != nil){
-		rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
-		lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
+		if(frame->id != ProtocolID)
+			goto discard;
 
-		switch(frame->type){
-		case NCinput:
-			unpack(frame->data, frame->len, "k", &kdown);
+		nc = getconn(frame);
+		if(nc == nil){
+			if(frame->type == NChi){
+				nc = newnetconn(NCSConnecting, &frame->udp);
+				putconn(nc);
+	
+				newf = newframe(frame, NShi, 0, 0, 2*sizeof(ulong), nil);
+	
+				dhgenpg(&nc->dh.p, &nc->dh.g);
+				pack(newf->data, newf->len, "kk", nc->dh.p, nc->dh.g);
+				sendp(egress, newf);
+	
+				if(debug)
+					fprint(2, "\tsent p %ld g %ld\n", nc->dh.p, nc->dh.g);
+			}else
+				goto discard;
+		}
 
-			if(debug){
-				fprint(2, "%I!%d ← %I!%d | rcvd type %ud seq %ud ack %ud len %ud %.*lub\n",
-					frame->udp.laddr, lport, frame->udp.raddr, rport,
-					frame->type, frame->seq, frame->ack, frame->len,
-					sizeof(kdown)*8, kdown);
+		switch(nc->state){
+		case NCSConnecting:
+			switch(frame->type){
+			case NCdhx:
+				unpack(frame->data, frame->len, "k", &nc->dh.pub);
+				nc->state = NCSConnected;
+
+				if(debug)
+					fprint(2, "\trecvd pubkey %ld\n", nc->dh.pub);
+
+				newf = newframe(frame, NSdhx, 0, 0, sizeof(ulong), nil);
+	
+				nc->dh.sec = truerand();
+				nc->dh.priv = dhgenkey(nc->dh.pub, nc->dh.sec, nc->dh.p);
+				pack(newf->data, newf->len, "k", dhgenkey(nc->dh.g, nc->dh.sec, nc->dh.p));
+				sendp(egress, newf);
+
+				if(debug)
+					fprint(2, "\tsent pubkey %ld\n", dhgenkey(nc->dh.g, nc->dh.sec, nc->dh.p));
+
+				break;
 			}
 			break;
-		}
+		case NCSConnected:
+			switch(frame->type){
+			case NCinput:
+				unpack(frame->data, frame->len, "k", &kdown);
 
+				if(debug)
+					fprint(2, "\t%.*lub\n", sizeof(kdown)*8, kdown);
+
+				break;
+			case NCbuhbye:
+				popconn(nc);
+				free(nc);
+				break;
+			}
+			break;
+		}
+discard:
 		free(frame);
 	}
 }
@@ -71,6 +165,7 @@
 {
 	uchar buf[MTU];
 	int fd, n;
+	ushort rport, lport;
 	Frame *frame;
 
 	threadsetname("threadnetsend");
@@ -79,9 +174,18 @@
 
 	while((frame = recvp(egress)) != nil){
 		n = pack(buf, sizeof buf, "F", frame);
-		free(frame);
 		if(write(fd, buf, n) != n)
 			sysfatal("write: %r");
+
+		if(debug){
+			rport = frame->udp.rport[0]<<8 | frame->udp.rport[1];
+			lport = frame->udp.lport[0]<<8 | frame->udp.lport[1];
+			fprint(2, "%I!%ud → %I!%ud | sent type %ud seq %ud ack %ud len %ud\n",
+				frame->udp.laddr, lport, frame->udp.raddr, rport,
+				frame->type, frame->seq, frame->ack, frame->len);
+		}
+
+		free(frame);
 	}
 }
 
@@ -148,19 +252,16 @@
 void
 fprintstats(int fd)
 {
-	ulong nparties = 0;
+	usize nparties = 0;
 	Party *p;
 
 	for(p = theparty.next; p != &theparty; p = p->next)
 		nparties++;
 
-//	fprint(fd, "curplayers	%lud\n"
-//		   "totplayers	%lud\n"
-//		   "maxplayers	%lud\n"
-//		   "curparties	%lud\n"
-//		   "totparties	%lud\n",
-//		lobby->nseats, 0UL, lobby->cap,
-//		nparties, 0UL);
+	fprint(fd, "curconns	%lld\n"
+		   "maxconns	%lld\n"
+		   "nparties	%lld\n",
+		nconns, maxconns, nparties);
 }
 
 void
@@ -172,10 +273,10 @@
 
 	for(p = theparty.next; p != &theparty; p = p->next, i++){
 		for(s = &p->u->ships[0]; s-p->u->ships < nelem(p->u->ships); s++){
-			fprint(fd, "%lud s%lld k%d p %v v %v θ %g ω %g m %g f %d\n",
+			fprint(fd, "%ld s%lld k%d p %v v %v θ %g ω %g m %g f %d\n",
 				i, s-p->u->ships, s->kind, s->p, s->v, s->θ, s->ω, s->mass, s->fuel);
 		}
-		fprint(fd, "%lud S p %v m %g\n", i, p->u->star.p, p->u->star.mass);
+		fprint(fd, "%ld S p %v m %g\n", i, p->u->star.p, p->u->star.mass);
 	}
 }
 
--- /dev/null
+++ b/net.c
@@ -1,0 +1,94 @@
+#include <u.h>
+#include <libc.h>
+#include <ip.h>
+#include <thread.h>
+#include <draw.h>
+#include <geometry.h>
+#include "dat.h"
+#include "fns.h"
+
+/* DHX */
+
+void
+dhgenpg(ulong *p, ulong *g)
+{
+	static ulong P = 97;
+	static ulong G = 71;
+
+	*p = P;
+	*g = G;
+}
+
+/*
+ * x = g^k mod p
+ */
+ulong
+dhgenkey(ulong g, ulong k, ulong p)
+{
+	ulong r, y;
+
+	y = 1;
+
+	while(k > 0){
+		r = k % 2;
+		if(r == 1)
+			y = y*g % p;
+		g = g*g % p;
+		k /= 2;
+	}
+	return y;
+}
+
+/* NetConn */
+
+NetConn *
+newnetconn(NCState s, Udphdr *u)
+{
+	NetConn *nc;
+
+	nc = emalloc(sizeof(NetConn));
+	memset(nc, 0, sizeof(NetConn));
+	if(u != nil)
+		memmove(&nc->udp, u, Udphdrsize);
+	nc->state = s;
+
+	return nc;
+}
+
+void
+delnetconn(NetConn *nc)
+{
+	free(nc);
+}
+
+/* Frame */
+
+Frame *
+newframe(Frame *pf, u8int type, u32int seq, u32int ack, u16int len, uchar *data)
+{
+	Frame *f;
+
+	f = emalloc(sizeof(Frame)+len);
+	f->id = ProtocolID;
+	f->type = type;
+	if(pf != nil){
+		memmove(&f->udp, &pf->udp, Udphdrsize);
+		f->seq = pf->seq+1;
+		f->ack = pf->seq;
+	}else{
+		memset(&f->udp, 0, Udphdrsize);
+		f->seq = seq;
+		f->ack = ack;
+	}
+	f->len = len;
+	if(data != nil)
+		memmove(f->data, data, f->len);
+
+	return f;
+}
+
+void
+delframe(Frame *f)
+{
+	free(f);
+}
--- a/pack.c
+++ b/pack.c
@@ -1,6 +1,5 @@
 #include <u.h>
 #include <libc.h>
-#include <pool.h>
 #include <ip.h>
 #include <draw.h>
 #include <geometry.h>
@@ -87,6 +86,7 @@
 			if(p+Framehdrsize+F->len > e)
 				goto err;
 
+			put4(p, F->id), p += 4;
 			*p++ = F->type;
 			put4(p, F->seq), p += 4;
 			put4(p, F->ack), p += 4;
@@ -156,6 +156,7 @@
 			if(F == nil)
 				F = va_arg(a, Frame*);
 
+			F->id = get4(p), p += 4;
 			F->type = *p++;
 			F->seq = get4(p), p += 4;
 			F->ack = get4(p), p += 4;