ref: 4ccc52b8ed3c89262a3ee72bfdd05de1a3d0441c
dir: /libstd/resolve.myr/
use "alloc.use"
use "chartype.use"
use "die.use"
use "endian.use"
use "error.use"
use "extremum.use"
use "fmt.use"
use "hashfuncs.use"
use "htab.use"
use "ipparse.use"
use "option.use"
use "slcp.use"
use "sleq.use"
use "slpush.use"
use "slurp.use"
use "strfind.use"
use "strsplit.use"
use "strstrip.use"
use "sys.use"
use "types.use"
use "utf.use"
pkg std =
type resolveerr = union
`Badhost
`Badsrv
`Badquery
`Badresp
;;
type hostinfo = struct
fam : sockfam
stype : socktype
ttl : uint32
addr : netaddr
/*
proto : uint32
flags : uint32
addr : sockaddr[:]
canon : byte[:]
next : hostinfo#
*/
;;
const resolve : (host : byte[:] -> error(hostinfo[:], resolveerr))
;;
const Hostfile = "/etc/hosts"
const Resolvfile = "/etc/resolv.conf"
var hostmap : htab(byte[:], hostinfo)#
var search : byte[:][:]
var nameservers : netaddr[:]
var inited : bool = false
const resolve = {host : byte[:] -> error(hostinfo[:], resolveerr)
match hostfind(host)
| `Some hinf:
-> `Success slpush([][:], hinf)
| `None:
-> dnsresolve(host)
;;
}
const hostfind = {host
if !inited
hostmap = mkht(strhash, streq)
loadhosts()
loadresolv()
inited = true
;;
-> htget(hostmap, host)
}
const loadhosts = {
var h
var lines
match slurp(Hostfile)
| `Success d: h = d
| `Failure m: ->
;;
lines = strsplit(h, "\n")
for l in lines
/* trim comment */
match strfind(l, "#")
| `Some idx: l = l[:idx]
;;
match word(l)
| `Some (ip, rest):
match ipparse(ip)
| `Some addr:
addhosts(addr, ip, rest)
;;
| `None:
;;
;;
slfree(lines)
}
const addhosts = {addr, as, str
var hinf
var fam
match addr
| `Ipv4 _: fam = Afinet
| `Ipv6 _: fam = Afinet6
;;
while true
match word(str)
| `Some (name, rest):
if !hthas(hostmap, name)
hinf = [
.fam=fam,
.stype = 0,
.ttl = 0,
.addr = addr
]
htput(hostmap, name, hinf)
;;
str = rest
| `None:
->
;;
;;
}
const loadresolv = {
var h
var lines
match slurp(Resolvfile)
| `Success d: h = d
| `Failure m: ->
;;
lines = strsplit(h, "\n")
for l in lines
match strfind(l, "#")
| `Some idx: l = l[:idx]
| `None:
;;
match word(l)
| `Some (cmd, rest):
if sleq(cmd, "nameserver")
addns(rest)
;;
;;
;;
slfree(lines)
}
const addns = {rest
match word(rest)
| `Some (name, _):
match ipparse(name)
| `Some addr:
nameservers = slpush(nameservers, addr)
| `None:
;;
;;
}
const word = {s
var c, len
len = 0
s = strfstrip(s)
for c = decode(s[len:]); c != Badchar && !isblank(c); c = decode(s[len:])
len += charlen(c)
;;
if len == 0
-> `None
else
-> `Some (s[:len], s[len:])
;;
}
const dnsresolve = {host : byte[:]
/*var hosts*/
var nsrv
if !valid(host)
-> `Failure (`Badhost)
;;
for ns in nameservers
nsrv = dnsconnect(ns)
if nsrv >= 0
-> dnsquery(nsrv, host)
;;
;;
-> `Failure (`Badsrv)
}
const dnsconnect = {ns
match ns
| `Ipv4 addr: -> dnsconnectv4(ns)
| `Ipv6 addr: die("don't support ipv6 yet\n")
;;
}
const dnsconnectv4 = {addr
var sa : sockaddr_in
var s
var status
s = socket(Afinet, Sockdgram, 0)
if s < 0
-> -1
;;
/* hardcode Google DNS for now.
FIXME: parse /etc/resolv.conf */
sa.fam = Afinet
sa.port = hosttonet(53) /* port 53 */
sa.addr = [8,8,8,8] /* 8.8.8.8 */
status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in))
if status < 0
-> -1
;;
-> s
}
const dnsquery = {srv, host
var id
var r
id = tquery(srv, host)
r = rquery(srv, id)
-> r
}
const Qr : uint16 = 1 << 0
const Aa : uint16 = 1 << 5
const Tc : uint16 = 1 << 6
const Rd : uint16 = 1 << 7
const Ra : uint16 = 1 << 8
var nextid : uint16 = 42
const tquery = {srv, host
var pkt : byte[512] /* big enough */
var off : size
/* header */
off = 0
off += pack16(pkt[:], off, nextid) /* id */
off += pack16(pkt[:], off, Ra) /* flags */
off += pack16(pkt[:], off, 1) /* qdcount */
off += pack16(pkt[:], off, 0) /* ancount */
off += pack16(pkt[:], off, 0) /* nscount */
off += pack16(pkt[:], off, 0) /* arcount */
/* query */
off += packname(pkt[:], off, host) /* host */
off += pack16(pkt[:], off, 0x1) /* qtype: a record */
off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */
write(srv, pkt[:off])
-> nextid++
}
const rquery = {srv, id
var pktbuf : byte[1024]
var pkt
var n
n = read(srv, pktbuf[:])
if n < 0
;;
pkt = pktbuf[:n]
dumpresponse(pkt)
-> hosts(pkt, id)
}
const hosts = {pkt, id : uint16
var off
var v, q, a
var i
var hinf : hostinfo[:]
off = 0
/* parse header */
(v, off) = unpack16(pkt, off) /* id */
if v != id
-> `Failure (`Badresp)
;;
(v, off) = unpack16(pkt, off) /* flags */
(q, off) = unpack16(pkt, off) /* qdcount */
(a, off) = unpack16(pkt, off) /* ancount */
(v, off) = unpack16(pkt, off) /* nscount */
(v, off) = unpack16(pkt, off) /* arcount */
/* skip past query records */
for i = 0; i < q; i++
off = skipname(pkt, off) /* name */
(v, off) = unpack16(pkt, off) /* type */
(v, off) = unpack16(pkt, off) /* class */
;;
/* parse answer records */
hinf = slalloc(a castto(size))
for i = 0; i < a; i++
off = skipname(pkt, off) /* name */
(v, off) = unpack16(pkt, off) /* type */
(v, off) = unpack16(pkt, off) /* class */
(hinf[i].ttl, off) = unpack32(pkt, off) /* ttl */
(v, off) = unpack16(pkt, off) /* rdatalen */
/* the thing we're interested in: our IP address */
hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]]
off += 4;
;;
-> `Success hinf
}
const dumpresponse = {pkt
var nquery, nans
var off
var v
var i
(v, off) = unpack16(pkt, 0)
(v, off) = unpack16(pkt, off)
(nquery, off) = unpack16(pkt, off)
(nans, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
for i = 0; i < nquery; i++
off = dumpquery(pkt, off)
;;
for i = 0; i < nans; i++
off = dumpans(pkt, off)
;;
}
const dumpquery = {pkt, off
var v
off = printname(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
-> off
}
const dumpans = {pkt, off
var v
off = printname(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
(v, off) = unpack16(pkt, off)
-> off
}
const skipname = {pkt, off
var sz
for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
/* ptr is 2 bytes */
if sz & 0xC0 == 0xC0
-> off + 2
else
off += sz + 1
;;
;;
-> off + 1
}
const printname = {pkt, off
var sz
for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
if sz & 0xC0 == 0xC0
printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size)))
-> off + 2
else
off += sz + 1
;;
;;
-> off + 1
}
const pack16 = {buf, off, v
buf[off] = (v & 0xff00) >> 8 castto(byte)
buf[off+1] = (v & 0x00ff) castto(byte)
-> sizeof(uint16) /* we always write one uint16 */
}
const unpack16 = {buf, off
var v
v = (buf[off] castto(uint16)) << 8
v |= (buf[off + 1] castto(uint16))
-> (v, off+sizeof(uint16))
}
const unpack32 = {buf, off
var v
v = (buf[off] castto(uint32)) << 24
v |= (buf[off+1] castto(uint32)) << 32
v |= (buf[off+2] castto(uint32)) << 8
v |= (buf[off+3] castto(uint32))
-> (v, off+sizeof(uint32))
}
const packname = {buf, off : size, host
var i
var start
var seglen, lastseg
start = off
seglen = 0
lastseg = 0
for i = 0; i < host.len; i++
seglen++
if host[i] == ('.' castto(byte))
off += addseg(buf, off, host[lastseg:lastseg+seglen-1])
lastseg = seglen
seglen = 0
;;
;;
if host[host.len - 1] != ('.' castto(byte))
off += addseg(buf, off, host[lastseg:lastseg + seglen])
;;
off += addseg(buf, off, "") /* null terminating segment */
-> off - start
}
const addseg = {buf, off, str
buf[off] = str.len castto(byte)
slcp(buf[off + 1 : off + str.len + 1], str)
-> str.len + 1
}
const valid = {host : byte[:]
var i
var seglen
/* maximum length: 255 chars */
if host.len > 255
-> false
;;
seglen = 0
for i = 0; i < host.len; i++
if host[i] == ('.' castto(byte))
seglen = 0
;;
if seglen > 63
-> false
;;
if host[i] & 0x80
-> false
;;
;;
-> true
}