shithub: riscv

Download patch

ref: 6b0574e27e6abffb2328be29e2cc9f3e67e2655b
parent: 917d0fa9b42845d2b918d13ab4e73761e866409a
author: Jacob Moody <moody@posixcafe.org>
date: Sat Jan 27 19:01:56 EST 2024

ndb/dns: DoT support

--- a/sys/man/6/ndb
+++ b/sys/man/6/ndb
@@ -232,7 +232,13 @@
 pairs.
 .TP
 .B dns
-a DNS server to use (for DNS and DHCP)
+a DNS server to use for resolving (for DNS and DHCP)
+.TP
+.B dot
+a DNS over TLS server to use for resolving (for DNS).
+If found,
+.B dns
+entries are ignored.
 .TP
 .B ntp
 an NTP server to use (for DHCP)
--- a/sys/man/8/ndb
+++ b/sys/man/8/ndb
@@ -62,6 +62,9 @@
 .B -a
 .I maxage
 ] [
+.B -c
+.I cert.pem
+] [
 .B -f
 .I dbfile
 ] [
@@ -393,11 +396,18 @@
 to complete lookups.
 If present,
 .B /env/DNSSERVER
-must be a space-separated list of such DNS servers' IP addresses,
+or
+.B /env/DOTSERVER
+must be a space-separated list of such DNS (or DoT) servers' IP addresses,
 otherwise optional
 .IR ndb (6)
 .B dns
 attributes name DNS servers to forward queries to.
+Note that when
+.B DOTSERVER
+is specified,
+.B DNSSERVER
+are ignored.
 .TP
 .B -R
 ignore the `recursive' bit on all incoming requests.
@@ -421,6 +431,12 @@
 are given,
 listen on any interface on network mount point
 .IR netmtpt .
+.TP
+.B -c
+When a certificate
+.I cert.pem
+is specified, also listen on TCP port 853 and handle
+DNS requests over TLS.
 .TP
 .B -x
 specifies the mount point of the network.
--- a/sys/src/cmd/ndb/dblookup.c
+++ b/sys/src/cmd/ndb/dblookup.c
@@ -916,7 +916,7 @@
 
 	/* check duplicate ip */
 	for(n = 0; n < i; n++){
-		snprint(buf, sizeof buf, "local#dns#server%d", n);
+		snprint(buf, sizeof buf, "%s#%d", dp->name, n);
 		nsdp = dnlookup(buf, class, 0);
 		if(nsdp == nil)
 			continue;
@@ -931,7 +931,7 @@
 		rrfreelist(rp);
 	}
 
-	snprint(buf, sizeof buf, "local#dns#server%d", i);
+	snprint(buf, sizeof buf, "%s#%d", dp->name, i);
 	nsdp = dnlookup(buf, class, 1);
 
 	/* ns record for name server, make up an impossible name */
@@ -967,6 +967,33 @@
 	RR *nsrp;
 	DN *dp;
 
+	/* try first DoT servers */
+	dp = dnlookup("local#dot#servers", class, 1);
+	nsrp = rrlookup(dp, Tns, NOneg);
+	if(nsrp != nil)
+		return nsrp;
+
+	p = getenv("DOTSERVER");		/* list of ip addresses */
+	if(p != nil && (n = tokenize(p, args, nelem(args))) > 0){
+		for(i = 0; i < n; i++)
+			addlocaldnsserver(dp, class, args[i], i);
+	} else {
+		t = lookupinfo("@dot");		/* @dot=ip1 ... */
+		if(t == nil)
+			return nil;
+		i = 0;
+		for(nt = t; nt != nil; nt = nt->entry){
+			addlocaldnsserver(dp, class, nt->val, i);
+			i++;
+		}
+		ndbfree(t);
+	}
+
+	nsrp = rrlookup(dp, Tns, NOneg);
+	if(nsrp != nil)
+		return nsrp;
+
+	/* try regular local DNS servers */
 	dp = dnlookup("local#dns#servers", class, 1);
 	nsrp = rrlookup(dp, Tns, NOneg);
 	if(nsrp != nil)
--- a/sys/src/cmd/ndb/dnresolve.c
+++ b/sys/src/cmd/ndb/dnresolve.c
@@ -5,6 +5,8 @@
 #include <libc.h>
 #include <ip.h>
 #include <bio.h>
+#include <mp.h>
+#include <libsec.h>
 #include <ndb.h>
 #include "dns.h"
 
@@ -79,7 +81,7 @@
 	return strdup(lp+1);
 }
 
-void
+static void
 rrfreelistptr(RR **rpp)
 {
 	RR *rp;
@@ -267,9 +269,13 @@
 	 */
 	if(cfg.resolver){
 		nsrp = randomize(getdnsservers(class));
-		if(nsrp != nil)
+		if(nsrp != nil){
+			int dot = strncmp(nsrp->owner->name, "local#dot#server", 16) == 0;
 			if(netqueryns(qp, nsrp) > Answnone)
 				return rrlookup(qp->dp, qp->type, OKneg);
+			else if(dot)
+				return nil;	/* do not fall-back for DoT */
+		}
 	}
 
 	/*
@@ -733,7 +739,7 @@
 /*
  *	return non-0 if first list includes second list
  */
-int
+static int
 contains(RR *rp1, RR *rp2)
 {
 	RR *trp1, *trp2;
@@ -753,7 +759,7 @@
 /*
  *  return multicast version if any
  */
-int
+static int
 ipisbm(uchar *ip)
 {
 	if(isv4(ip)){
@@ -1166,16 +1172,31 @@
 	return rv;
 }
 
+enum {
+	Maxfree = 4,
+};
+
+struct {
+	QLock lk;
+	struct {
+		uvlong when;
+		char *dest;
+		int fd;
+	} l[Maxfree];
+} tcpfree;
+
 /*
  * send a query via tcp to a single address
  * and read the answer(s) into mp->an.
  */
 static int
-tcpquery(Query *qp, uchar *pkt, int len, Dest *p, uvlong endms, DNSmsg *mp)
+tcpquery(Query *qp, uchar *pkt, int len, Dest *p, uvlong endms, DNSmsg *mp, int tls)
 {
 	char buf[NETPATHLEN];
-	int fd, rv;
+	int fd, rv, i, retry;
 	long ms;
+	TLSconn conn;
+	Thumbprint *thumb;
 
 	memset(mp, 0, sizeof *mp);
 
@@ -1185,28 +1206,125 @@
 	if(ms > Maxtcpdialtm)
 		ms = Maxtcpdialtm;
 
-	procsetname("tcp query to %I/%s for %s %s", p->a, p->s->name,
+	procsetname("%s query to %I/%s for %s %s", tls ? "tls" : "tcp", p->a, p->s->name,
 		qp->dp->name, rrname(qp->type, buf, sizeof buf));
 
-	snprint(buf, sizeof buf, "%s/tcp!%I!53", mntpt, p->a);
+	snprint(buf, sizeof buf, "%s/tcp!%I!%s", mntpt, p->a, tls ? "853" : "53");
 
+	fd = -1;
+	retry = 0;
+	qlock(&tcpfree.lk);
+	for(i = 0; i < nelem(tcpfree.l); i++){
+		if(tcpfree.l[i].dest == nil || tcpfree.l[i].fd == -1)
+			continue;
+		if(strcmp(tcpfree.l[i].dest, buf) != 0)
+			continue;
+		/* RFC does not specify connection reuse timeout */
+		if(nowms - tcpfree.l[i].when < 5000){
+			fd = tcpfree.l[i].fd;
+			tcpfree.l[i].fd = -1;
+			retry++;
+			break;
+		}
+	}
+	qunlock(&tcpfree.lk);
+	if(fd != -1)
+		goto Found;
+
+Retry:
 	alarm(ms);
 	fd = dial(buf, nil, nil, nil);
-	alarm(0);
-	if (fd < 0) {
+	if(fd < 0){
+		alarm(0);
 		dnslog("%d: can't dial %s for %I/%s: %r",
 			qp->req->id, buf, p->a, p->s->name);
 		return -1;
 	}
+	if(tls){
+		memset(&conn, 0, sizeof conn);
+		rv = tlsClient(fd, &conn);
+		alarm(0);
+		if(rv >= 0){
+			fd = rv;
+			thumb = initThumbprints("/sys/lib/tls/dns", nil, "x509");
+			if(thumb == nil || !okCertificate(conn.cert, conn.certlen, thumb)){
+				dnslog("%d: invalid fingerprint for %s; echo 'x509 %r' >>/sys/lib/tls/dns",
+					qp->req->id, buf);
+				rv = -1;
+			}
+			free(conn.cert);
+			free(conn.sessionID);
+			freeThumbprints(thumb);
+		}
+		if(rv < 0){
+			close(fd);
+			return -1;
+		}
+	} else {
+		alarm(0);
+	}
+
+Found:
 	rv = writenet(qp, Tcp, fd, pkt, len, p);
 	if(rv == 0){
 		timems();	/* account for time dialing and sending */
 		rv = readreply(qp, Tcp, fd, endms, mp, pkt);
 	}
-	close(fd);
+
+	if(rv < 0){
+		close(fd);
+		if(retry){
+			retry = 0;
+			goto Retry;
+		}
+		return rv;
+	}
+
+	qlock(&tcpfree.lk);
+	if(tcpfree.l[nelem(tcpfree.l)-1].dest != nil){
+		close(tcpfree.l[nelem(tcpfree.l)-1].fd);
+		free(tcpfree.l[nelem(tcpfree.l)-1].dest);
+	}
+	memmove(tcpfree.l + 1, tcpfree.l, sizeof(tcpfree.l[0])*(nelem(tcpfree.l)-1));
+	tcpfree.l[0].when = nowms;
+	tcpfree.l[0].fd = fd;
+	tcpfree.l[0].dest = estrdup(buf);
+	qunlock(&tcpfree.lk);
+
 	return rv;
 }
 
+static int
+tlsqueryns(Query *qp, uchar *pkt, int len)
+{
+	Dest dest[Maxdest], *p;
+	int rv, n;
+	uvlong endms;
+	DNSmsg m;
+
+	/* populates dest with v4 and v6 addresses. */
+	n = 0;
+	n = serveraddrs(qp, dest, n, Ta);
+	n = serveraddrs(qp, dest, n, Taaaa);
+	endms = nowms + 500;
+	for(p = dest; p < dest+n; p++){
+		if(tcpquery(qp, pkt, len, p, endms, &m, 1) == 0){
+			/* free or incorporate RRs in m */
+			rv = procansw(qp, p, &m);
+			if(rv > Answnone)
+				return rv;
+		}
+	}
+
+	/* if all servers returned failure, propagate it */
+	qp->dp->respcode = Rserver;
+	for(p = dest; p < dest+n; p++)
+		if(p->code != Rserver)
+			qp->dp->respcode = Rok;
+
+	return Answnone;
+}
+
 /*
  *  query name servers.  fill in pkt with on-the-wire representation of a
  *  DNSmsg derived from qp. if the name server returns a pointer to another
@@ -1213,30 +1331,15 @@
  *  name server, recurse.
  */
 static int
-udpqueryns(Query *qp, int fd, uchar *pkt)
+udpqueryns(Query *qp, int fd, uchar *pkt, int len)
 {
 	Dest dest[Maxdest], *edest, *p, *np;
-	int ndest, replywaits, len, flag, rv, n;
+	int ndest, replywaits, rv, n;
 	uchar srcip[IPaddrlen];
 	char buf[32];
 	uvlong endms;
 	DNSmsg m;
-	RR *rp;
 
-	/* prepare server RR's for incremental lookup */
-	for(rp = qp->nsrp; rp; rp = rp->next)
-		rp->marker = 0;
-
-	/* request recursion only for local/override dns servers */
-	flag = Oquery;
-	if(strncmp(qp->nsrp->owner->name, "local#", 6) == 0
-	|| strncmp(qp->nsrp->owner->name, "override#", 9) == 0)
-		flag |= Frecurse;
-
-	/* pack request into a udp message */
-	qp->id = rand();
-	len = mkreq(qp->dp, qp->type, pkt, flag, qp->id);
-
 	/* no destination yet */
 	edest = dest;
 
@@ -1307,7 +1410,7 @@
 			/* if response was truncated, try tcp */
 			if(m.flags & Ftrunc){
 				freeanswers(&m);
-				if(tcpquery(qp, pkt, len, p, endms, &m) < 0)
+				if(tcpquery(qp, pkt, len, p, endms, &m, 0) < 0)
 					break;	/* failed via tcp too */
 				if(m.flags & Ftrunc){
 					freeanswers(&m);
@@ -1336,27 +1439,44 @@
 	return Answnone;
 }
 
-/*
- * in principle we could use a single descriptor for a udp port
- * to send all queries and receive all the answers to them,
- * but we'd have to sort out the answers by dns-query id.
- */
 static int
-udpquery(Query *qp)
+doquery(Query *qp)
 {
-	int fd, rv;
+	int fd, rv, len, flag;
 	uchar *pkt;
+	RR *rp;
 
 	pkt = emalloc(Maxudp+Udphdrsize);
-	fd = udpport(mntpt);
-	if (fd < 0) {
-		dnslog("%d: can't get udpport for %s query of name %s: %r",
-			qp->req->id, mntpt, qp->dp->name);
-		rv = -1;
-		goto Out;
+	/* prepare server RR's for incremental lookup */
+	for(rp = qp->nsrp; rp; rp = rp->next)
+		rp->marker = 0;
+	/* request recursion only for local/override dns servers */
+	flag = Oquery;
+	if(strncmp(qp->nsrp->owner->name, "local#", 6) == 0
+	|| strncmp(qp->nsrp->owner->name, "override#", 9) == 0)
+		flag |= Frecurse;
+	/* pack request into a udp message */
+	qp->id = rand();
+	len = mkreq(qp->dp, qp->type, pkt, flag, qp->id);
+	if(strncmp(qp->nsrp->owner->name, "local#dot#server", 16) == 0
+	|| strncmp(qp->nsrp->owner->name, "override#dot#server", 16) == 0){
+		rv = tlsqueryns(qp, pkt, len);
+	} else {
+		/*
+		 * in principle we could use a single descriptor for a udp port
+		 * to send all queries and receive all the answers to them,
+		 * but we'd have to sort out the answers by dns-query id.
+		 */
+		fd = udpport(mntpt);
+		if (fd < 0) {
+			dnslog("%d: can't get udpport for %s query of name %s: %r",
+				qp->req->id, mntpt, qp->dp->name);
+			rv = -1;
+			goto Out;
+		}
+		rv = udpqueryns(qp, fd, pkt, len);
+		close(fd);
 	}
-	rv = udpqueryns(qp, fd, pkt);
-	close(fd);
 Out:
 	free(pkt);
 	return rv;
@@ -1383,5 +1503,5 @@
 	if(!qp->req->isslave && strcmp(qp->req->from, "9p") == 0)
 		return Answnone;
 
-	return udpquery(qp);
+	return doquery(qp);
 }
--- a/sys/src/cmd/ndb/dns.c
+++ b/sys/src/cmd/ndb/dns.c
@@ -92,7 +92,7 @@
 void
 usage(void)
 {
-	fprint(2, "usage: %s [-FnrLR] [-a maxage] [-f ndb-file] [-N target] "
+	fprint(2, "usage: %s [-FnrLR] [-a maxage] [-c cert.pem] [-f ndb-file] [-N target] "
 		"[-x netmtpt] [-s [addrs...]]\n", argv0);
 	exits("usage");
 }
@@ -101,10 +101,12 @@
 main(int argc, char *argv[])
 {
 	char ext[Maxpath], servefile[Maxpath];
+	char *cert;
 	Dir *dir;
 
 	setnetmtpt(mntpt, sizeof mntpt, nil);
 	ext[0] = 0;
+	cert = nil;
 	ARGBEGIN{
 	case 'a':
 		maxage = atol(EARGF(usage()));
@@ -141,6 +143,9 @@
 		cfg.serve = 1;		/* serve network */
 		cfg.cachedb = 1;
 		break;
+	case 'c':
+		cert = EARGF(usage());
+		break;
 	case 'x':
 		setnetmtpt(mntpt, sizeof mntpt, EARGF(usage()));
 		setext(ext, sizeof ext, mntpt);
@@ -181,11 +186,15 @@
 	if(cfg.serve){
 		if(argc == 0) {
 			dnudpserver(mntpt, "*");
-			dntcpserver(mntpt, "*");
+			dntcpserver(mntpt, "*", nil);
+			if(cert != nil)
+				dntcpserver(mntpt, "*", cert);
 		} else {
 			while(argc-- > 0){
 				dnudpserver(mntpt, *argv);
-				dntcpserver(mntpt, *argv);
+				dntcpserver(mntpt, *argv, nil);
+				if(cert != nil)
+					dntcpserver(mntpt, *argv, cert);
 				argv++;
 			}
 		}
--- a/sys/src/cmd/ndb/dns.h
+++ b/sys/src/cmd/ndb/dns.h
@@ -522,7 +522,7 @@
 void	dnudpserver(char*, char*);
 
 /* dntcpserver.c */
-void	dntcpserver(char*, char*);
+void	dntcpserver(char*, char*, char*);
 
 /* dnnotify.c */
 void	dnnotify(DNSmsg*, DNSmsg*, Request*);
--- a/sys/src/cmd/ndb/dnsdebug.c
+++ b/sys/src/cmd/ndb/dnsdebug.c
@@ -172,14 +172,18 @@
 {
 	uchar ip[IPaddrlen];
 	DN *nsdp;
-	RR *rp;
+	RR *rp, *ns;
+	char name[64];
 
 	if(servername == nil)
 		return dnsservers(class);
-	if(parseip(ip, servername) == -1){
-		nsdp = idnlookup(servername, class, 1);
+
+	snprint(name, sizeof name, "override#%s#server", servername[0] == '!' ? "dot" : "dns");
+	ns = rralloc(Tns);
+	if(parseip(ip, servername+1) == -1){
+		nsdp = idnlookup(servername+1, class, 1);
 	} else {
-		nsdp = dnlookup("local#dns#server", class, 1);
+		nsdp = dnlookup(name, class, 1);
 		rp = rralloc(isv4(ip) ? Ta : Taaaa);
 		rp->owner = nsdp;
 		rp->ip = ipalookup(ip, class, 1);
@@ -187,10 +191,9 @@
 		rp->ttl = 10*Min;
 		rrattach(rp, Authoritative);
 	}
-	rp = rralloc(Tns);
-	rp->owner = dnlookup("override#dns#servers", class, 1);
-	rp->host = nsdp;
-	return rp;
+	ns->owner = dnlookup(name, class, 1);
+	ns->host = nsdp;
+	return ns;
 }
 
 int
@@ -201,7 +204,7 @@
 		servername = nil;
 		cfg.resolver = 0;
 	}
-	if(server == nil || *server == 0)
+	if(server == nil || server[0] == 0 || server[1] == 0)
 		return 0;
 	servername = estrdup(server);
 	cfg.resolver = 1;
@@ -276,8 +279,8 @@
 	name = type = nil;
 	tmpsrv = 0;
 
-	if(*f[0] == '@') {
-		if(setserver(f[0]+1) < 0)
+	if(*f[0] == '@' || *f[0] == '!') {
+		if(setserver(f[0]) < 0)
 			return;
 
 		switch(n){
@@ -306,5 +309,5 @@
 	doquery(name, type);
 
 	if(tmpsrv)
-		setserver("");
+		setserver("@");
 }
--- a/sys/src/cmd/ndb/dntcpserver.c
+++ b/sys/src/cmd/ndb/dntcpserver.c
@@ -3,6 +3,8 @@
 #include <bio.h>
 #include <ndb.h>
 #include <ip.h>
+#include <mp.h>
+#include <libsec.h>
 #include "dns.h"
 
 enum {
@@ -12,10 +14,10 @@
 static int	readmsg(int, uchar*, int);
 static int	reply(int, uchar *, DNSmsg*, Request*, uchar*);
 static int	dnzone(int, uchar *, DNSmsg*, DNSmsg*, Request*, uchar*);
-static int	tcpannounce(char *mntpt, char *addr, char caller[128]);
+static int	tcpannounce(char *mntpt, char *addr, char caller[128], char *cert);
 
 void
-dntcpserver(char *mntpt, char *addr)
+dntcpserver(char *mntpt, char *addr, char *cert)
 {
 	volatile int fd, len, rcode, rv;
 	volatile long ms;
@@ -40,7 +42,7 @@
 	}
 
 	procsetname("%s: tcp server %s", mntpt, addr);
-	if((fd = tcpannounce(mntpt, addr, caller)) < 0){
+	if((fd = tcpannounce(mntpt, addr, caller, cert)) < 0){
 		warning("can't announce %s on %s: %r", addr, mntpt);
 		_exits(0);
 	}
@@ -259,13 +261,20 @@
 }
 
 static int
-tcpannounce(char *mntpt, char *addr, char caller[128])
+tcpannounce(char *mntpt, char *addr, char caller[128], char *cert)
 {
 	char adir[NETPATHLEN], ldir[NETPATHLEN], buf[128];
 	int acfd, lcfd, dfd, wfd, rfd, procs;
+	PEMChain *chain = nil;
 
+	if(cert != nil){
+		chain = readcertchain(cert);
+		if(chain == nil)
+			return -1;
+	}
+
 	/* announce tcp dns port */
-	snprint(buf, sizeof(buf), "%s/tcp!%s!53", mntpt, addr);
+	snprint(buf, sizeof(buf), "%s/tcp!%s!%s", mntpt, addr, cert == nil ? "53" : "853");
 	acfd = announce(buf, adir);
 	if(acfd < 0)
 		return -1;
@@ -277,7 +286,6 @@
 		close(acfd);
 		return -1;
 	}
-
 	procs = 0;
 	for(;;) {
 		if(procs >= Maxprocs || (procs % 8) == 0){
@@ -314,7 +322,23 @@
 			close(lcfd);
 			if(dfd < 0)
 				_exits(0);
+			if(chain != nil){
+				TLSconn conn;
+				int fd;
 
+				memset(&conn, 0, sizeof conn);
+				conn.cert = emalloc(conn.certlen = chain->pemlen);
+				memmove(conn.cert, chain->pem, conn.certlen);
+				conn.chain = chain->next;
+				fd = tlsServer(dfd, &conn);
+				if(fd < 0){
+					close(dfd);
+					_exits(0);
+				}
+				free(conn.cert);
+				free(conn.sessionID);
+				dfd = fd;
+			}
 			/* get the callers ip!port */
 			memset(caller, 0, 128);
 			snprint(buf, sizeof(buf), "%s/remote", ldir);
@@ -322,7 +346,6 @@
 				read(rfd, caller, 128-1);
 				close(rfd);
 			}
-
 			/* child returns */
 			return dfd;
 		default: