ref: 6cb7a53ec0671722fc0f95f728f6f0556b644f75
dir: /libstd/resolve.myr/
use "alloc.use"
use "chartype.use"
use "die.use"
use "endian.use"
use "error.use"
use "extremum.use"
use "fmt.use"
use "hashfuncs.use"
use "htab.use"
use "ipparse.use"
use "option.use"
use "slcp.use"
use "sleq.use"
use "slpush.use"
use "slurp.use"
use "strfind.use"
use "strsplit.use"
use "strstrip.use"
use "sys.use"
use "types.use"
use "utf.use"
pkg std =
type resolveerr = union
`Badhost
`Badsrv
`Badquery
`Badresp
;;
type hostinfo = struct
fam : sockfam
stype : socktype
ttl : uint32
addr : netaddr
/*
proto : uint32
flags : uint32
addr : sockaddr[:]
canon : byte[:]
next : hostinfo#
*/
;;
const resolve : (host : byte[:] -> error(hostinfo[:], resolveerr))
;;
const Hostfile = "/etc/hosts"
const Resolvfile = "/etc/resolv.conf"
var hostmap : htab(byte[:], hostinfo)#
var search : byte[:][:]
var nameservers : netaddr[:]
var inited : bool = false
const resolve = {host : byte[:] -> error(hostinfo[:], resolveerr)
match hostfind(host)
| `Some hinf:
-> `Success slpush([][:], hinf)
| `None:
put("********** Couldn't find host %s in hosts\n", host)
-> dnsresolve(host)
;;
}
const hostfind = {host
if !inited
hostmap = mkht(strhash, streq)
loadhosts()
loadresolv()
inited = true
;;
-> htget(hostmap, host)
}
const loadhosts = {
var h
var lines
match slurp(Hostfile)
| `Success d: h = d
| `Failure m: ->
;;
lines = strsplit(h, "\n")
for l in lines
/* trim comment */
match strfind(l, "#")
| `Some idx: l = l[:idx]
;;
match word(l)
| `Some (ip, rest):
match ipparse(ip)
| `Some addr:
addhosts(addr, ip, rest)
;;
| `None:
;;
;;
slfree(lines)
}
const addhosts = {addr, as, str
var hinf
var fam
match addr
| `Ipv4 _: fam = Afinet
| `Ipv6 _: fam = Afinet6
;;
while true
match word(str)
| `Some (name, rest):
if !hthas(hostmap, name)
hinf = [
.fam=fam,
.stype = 0,
.ttl = 0,
.addr = addr
]
htput(hostmap, name, hinf)
;;
str = rest
| `None:
->
;;
;;
}
const loadresolv = {
var h
var lines
match slurp(Resolvfile)
| `Success d: h = d
| `Failure m: ->
;;
lines = strsplit(h, "\n")
for l in lines
match strfind(l, "#")
| `Some idx: l = l[:idx]
| `None:
;;
match word(l)
| `Some (cmd, rest):
if sleq(cmd, "nameserver")
addns(rest)
;;
;;
;;
slfree(lines)
}
const addns = {rest
match word(rest)
| `Some (name, _):
match ipparse(name)
| `Some addr:
put("Adding nameserver %s\n", name)
nameservers = slpush(nameservers, addr)
| `None:
;;
;;
}
const word = {s
var c, len
len = 0
s = strfstrip(s)
for c = decode(s[len:]); c != Badchar && !isblank(c); c = decode(s[len:])
len += charlen(c)
;;
if len == 0
-> `None
else
-> `Some (s[:len], s[len:])
;;
}
const dnsresolve = {host : byte[:]
/*var hosts*/
var nsrv
if !valid(host)
-> `Failure (`Badhost)
;;
for ns in nameservers
put("trying ns\n")
nsrv = dnsconnect(ns)
if nsrv >= 0
-> dnsquery(nsrv, host)
;;
;;
-> `Failure (`Badsrv)
}
const dnsconnect = {ns
match ns
| `Ipv4 addr: -> dnsconnectv4(ns)
| `Ipv6 addr: die("don't support ipv6 yet\n")
;;
}
const dnsconnectv4 = {addr
var sa : sockaddr_in
var s
var status
s = socket(Afinet, Sockdgram, 0)
if s < 0
put("Warning: Failed to open socket: %l\n", s)
-> -1
;;
/* hardcode Google DNS for now.
FIXME: parse /etc/resolv.conf */
sa.fam = Afinet
sa.port = hosttonet(53) /* port 53 */
sa.addr = [8,8,8,8] /* 8.8.8.8 */
status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in))
if status < 0
put("Warning: Failed to connect to server: %l\n", status)
-> -1
;;
-> s
}
const dnsquery = {srv, host
var id
var r
id = tquery(srv, host)
r = rquery(srv, id)
put("Got hosts. Returning\n")
-> r
}
const Qr : uint16 = 1 << 0
const Aa : uint16 = 1 << 5
const Tc : uint16 = 1 << 6
const Rd : uint16 = 1 << 7
const Ra : uint16 = 1 << 8
var nextid : uint16 = 42
const tquery = {srv, host
var pkt : byte[512] /* big enough */
var off : size
put("Sending request for %s\n", host)
/* header */
off = 0
off += pack16(pkt[:], off, nextid) /* id */
off += pack16(pkt[:], off, Ra) /* flags */
off += pack16(pkt[:], off, 1) /* qdcount */
off += pack16(pkt[:], off, 0) /* ancount */
off += pack16(pkt[:], off, 0) /* nscount */
off += pack16(pkt[:], off, 0) /* arcount */
/* query */
off += packname(pkt[:], off, host) /* host */
off += pack16(pkt[:], off, 0x1) /* qtype: a record */
off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */
write(srv, pkt[:off])
-> nextid++
}
const rquery = {srv, id
var pktbuf : byte[1024]
var pkt
var n
put("Waiting for response...\n")
n = read(srv, pktbuf[:])
if n < 0
put("Warning: Failed to read from %z: %i\n", srv, n)
;;
pkt = pktbuf[:n]
put("Got response:\n");
dumpresponse(pkt)
-> hosts(pkt, id)
}
const hosts = {pkt, id : uint16
var off
var v, q, a
var i
var hinf : hostinfo[:]
off = 0
/* parse header */
(v, off) = unpack16(pkt, off) /* id */
if v != id
-> `Failure (`Badresp)
;;
put("Unpacking flags")
(v, off) = unpack16(pkt, off) /* flags */
(q, off) = unpack16(pkt, off) /* qdcount */
(a, off) = unpack16(pkt, off) /* ancount */
(v, off) = unpack16(pkt, off) /* nscount */
(v, off) = unpack16(pkt, off) /* arcount */
/* skip past query records */
for i = 0; i < q; i++
put("Skipping query record")
off = skipname(pkt, off) /* name */
(v, off) = unpack16(pkt, off) /* type */
(v, off) = unpack16(pkt, off) /* class */
;;
/* parse answer records */
hinf = slalloc(a castto(size))
for i = 0; i < a; i++
off = skipname(pkt, off) /* name */
(v, off) = unpack16(pkt, off) /* type */
(v, off) = unpack16(pkt, off) /* class */
(hinf[i].ttl, off) = unpack32(pkt, off) /* ttl */
(v, off) = unpack16(pkt, off) /* rdatalen */
/* the thing we're interested in: our IP address */
hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]]
off += 4;
;;
-> `Success hinf
}
const dumpresponse = {pkt
var nquery, nans
var off
var v
var i
(v, off) = unpack16(pkt, 0)
(v, off) = unpack16(pkt, off)
(nquery, off) = unpack16(pkt, off)
put("hdr.qdcount = %w\n", nquery)
(nans, off) = unpack16(pkt, off)
put("hdr.ancount = %w\n", nans)
(v, off) = unpack16(pkt, off)
put("hdr.nscount = %w\n", v)
(v, off) = unpack16(pkt, off)
put("hdr.arcount = %w\n", v)
put("Queries:\n")
for i = 0; i < nquery; i++
put("i: %w\n", i)
off = dumpquery(pkt, off)
;;
put("Answers:")
for i = 0; i < nans; i++
put("i: %w\n", i)
off = dumpans(pkt, off)
;;
}
const dumpquery = {pkt, off
var v
put("\tname = ");
off = printname(pkt, off)
(v, off) = unpack16(pkt, off)
put("\tbody.type = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.class = %w\n", v)
-> off
}
const dumpans = {pkt, off
var v
put("\tname = ");
off = printname(pkt, off)
(v, off) = unpack16(pkt, off)
put("\tbody.type = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.class = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.ttl_lo = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.ttl_hi = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.rdlength = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.rdata_lo = %w\n", v)
(v, off) = unpack16(pkt, off)
put("\tbody.rdata_hi = %w\n", v)
-> off
}
const skipname = {pkt, off
var sz
for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
/* ptr is 2 bytes */
if sz & 0xC0 == 0xC0
-> off + 2
else
off += sz + 1
;;
;;
-> off + 1
}
const printname = {pkt, off
var sz
for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
if sz & 0xC0 == 0xC0
put("PTR: ")
printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size)))
-> off + 2
else
put("%s.", pkt[off+1:off+sz+1])
off += sz + 1
;;
;;
-> off + 1
}
const pack16 = {buf, off, v
buf[off] = (v & 0xff00) >> 8 castto(byte)
buf[off+1] = (v & 0x00ff) castto(byte)
-> sizeof(uint16) /* we always write one uint16 */
}
const unpack16 = {buf, off
var v
v = (buf[off] castto(uint16)) << 8
v |= (buf[off + 1] castto(uint16))
-> (v, off+sizeof(uint16))
}
const unpack32 = {buf, off
var v
v = (buf[off] castto(uint32)) << 24
v |= (buf[off+1] castto(uint32)) << 32
v |= (buf[off+2] castto(uint32)) << 8
v |= (buf[off+3] castto(uint32))
-> (v, off+sizeof(uint32))
}
const packname = {buf, off : size, host
var i
var start
var seglen, lastseg
start = off
seglen = 0
lastseg = 0
for i = 0; i < host.len; i++
seglen++
if host[i] == ('.' castto(byte))
off += addseg(buf, off, host[lastseg:lastseg+seglen-1])
lastseg = seglen
seglen = 0
;;
;;
if host[host.len - 1] != ('.' castto(byte))
off += addseg(buf, off, host[lastseg:lastseg + seglen])
;;
off += addseg(buf, off, "") /* null terminating segment */
-> off - start
}
const addseg = {buf, off, str
buf[off] = str.len castto(byte)
slcp(buf[off + 1 : off + str.len + 1], str)
-> str.len + 1
}
const valid = {host : byte[:]
var i
var seglen
/* maximum length: 255 chars */
if host.len > 255
-> false
;;
seglen = 0
for i = 0; i < host.len; i++
if host[i] == ('.' castto(byte))
seglen = 0
;;
if seglen > 63
-> false
;;
if host[i] & 0x80
-> false
;;
;;
-> true
}