shithub: rtmp

ref: 996f6e5c036260a18fd4c46405027ef4a23a7923
dir: /rtmp.c/

View raw version
#include <u.h>
#include <libc.h>
#include <thread.h>
#include <bio.h>
#include <libsec.h>
#include "amf0.h"
#include "ivf.h"
#include "rtmp.h"
#include "util.h"

#define min(a,b) ((a)<(b)?(a):(b))

enum {
	Port = 1935,
	CSsz = 1536,
	ChunkDefault = 128,
	ChunkDesired = 4096,

	Type0 = 0,
	Type1,
	Type2,
	Type3,

	CSUserCtl = 2,
	CSCtl = 3,

	CbCommand = 0,
	CbTransID,
	CbObject,
	CbResponse,
	NumCb,

	/* UserControl */
	CtlStreamBegin = 0,
	CtlStreamEOF,
	CtlStreamDry,
	CtlSetBufferLen,
	CtlStreamIsRecorded,
	CtlPingRequest = 6,
	CtlPingResponse,

	/* Message.type */
	SetChunkSize = 1,
	Abort,
	Ack,
	UserControl,
	WindowAckSize,
	SetBandwidth,
	Audio = 8,
	Video,
	AMF3Metadata = 15,
	AMF3SharedObject,
	AMF3Command,
	AMF0Metadata,
	AMF0SharedObject,
	AMF0Command,
	Aggregate = 22,

	/* RTMP.bwlimit */
	LimitHard = 0,
	LimitSoft,
	LimitDynamic,

	Biobufsz = 64*1024, /* FIXME don't know if it helps with anything */
	Bufsz = 64*1024,
};

typedef struct Command Command;
typedef struct Message Message;

#pragma varargck type "T" int
#pragma varargck type "M" Message*

struct Command {
	void (*cb)(RTMP *r, int ok, Amf0 *a[NumCb], void *aux);
	void *aux;
	int tid;

	Command *prev, *next;
};

struct Message {
	int type;
	int fmt;
	int cs;
	int sid;
	u32int ts;
	u8int *data;
	int sz;
	Command cmd;
};

struct RTMP {
	Biobufhdr;
	QLock;
	Channel *c;
	char *app;
	char *path;
	char *tcurl;
	Message msg;
	u8int *b, *p, *e;
	int chunkin;
	int chunkout;
	int mode;
	int bsz;
	int i;
	int winacksz;
	int bw;
	u8int bwlimit;
	struct {
		int tid;
		Command *w;
	}cmds;
	u8int biobuf[Biobufsz];
};

#define putnull() do{ r->p = amf0null(r->p, r->e); }while(0)
#define puti16(i) do{ r->p = amf0i16(r->p, r->e, i); }while(0)
#define puti24(i) do{ r->p = amf0i24(r->p, r->e, i); }while(0)
#define puti32(i) do{ r->p = amf0i32(r->p, r->e, i); }while(0)
#define putnum(v) do{ r->p = amf0num(r->p, r->e, v); }while(0)
#define putstr(s) do{ r->p = amf0str(r->p, r->e, s); }while(0)
#define putarr() do{ r->p = amf0arr(r->p, r->e); }while(0)
#define putobj() do{ r->p = amf0obj(r->p, r->e); }while(0)
#define putend() do{ r->p = amf0end(r->p, r->e); }while(0)
#define putkvnum(name, v) do{ r->p = amf0kvnum(r->p, r->e, name, v); }while(0)
#define putkvstr(name, s) do{ r->p = amf0kvstr(r->p, r->e, name, s); }while(0)
#define putkvbool(name, s) do{ r->p = amf0kvbool(r->p, r->e, name, s); }while(0)

#define putcommand(name, cb_) do { \
	putstr(name); \
	putnum(r->msg.cmd.tid); \
	putobj(); \
	r->msg.cmd.cb = cb_; \
}while(0)

static int szs[] = {
	[Type3] = 0,
	[Type2] = 3,
	[Type1] = 7,
	[Type0] = 11,
};

static char *msgtype2s[] = {
	[SetChunkSize] = "SetChunkSize",
	[Abort] = "Abort",
	[Ack] = "Ack",
	[UserControl] = "UserControl",
	[WindowAckSize] = "WindowAckSize",
	[SetBandwidth] = "SetBandwidth",
	[Audio] = "Audio",
	[Video] = "Video",
	[AMF3Metadata] = "AMF3Metadata",
	[AMF3SharedObject] = "AMF3SharedObject",
	[AMF3Command] = "AMF3Command",
	[AMF0Metadata] = "AMF0Metadata",
	[AMF0SharedObject] = "AMF0SharedObject",
	[AMF0Command] = "AMF0Command",
	[Aggregate] = "Aggregate",
};

static char *ctl2s[] = {
	[CtlStreamBegin] = "StreamBegin",
	[CtlStreamEOF] = "StreamEOF",
	[CtlStreamDry] = "StreamDry",
	[CtlSetBufferLen] = "SetBufferLen",
	[CtlStreamIsRecorded] = "StreamIsRecorded",
	[CtlPingRequest] = "PingRequest",
	[CtlPingResponse] = "PingResponse",
};

static char *bwlimit2s[] = {
	[LimitHard] = "hard",
	[LimitSoft] = "soft",
	[LimitDynamic] = "dynamic",
};

extern int debug;

static void
newmsg(RTMP *r, int type, int fmt, int cs)
{
	memset(&r->msg, 0, sizeof(r->msg));

	r->msg.type = type;
	r->msg.fmt = fmt;
	r->msg.cs = cs;
	r->p = r->b;
	if(type == AMF0Command)
		r->msg.cmd.tid = ++r->cmds.tid;
}

static void
bextend(RTMP *r, int bsz)
{
	u8int *ob;

	if(r->bsz >= bsz)
		return;
	ob = r->b;
	r->b = erealloc(r->b, bsz*2);
	if(ob != nil)
		r->p = r->b + (intptr)(ob - r->p);
	r->bsz = bsz*2;
	r->e = r->b + r->bsz;
}

static int
rtmprecv(RTMP *r)
{
	int hsz, len, n, msid;
	u8int *h, *e, byte;
	u32int ts;

	memset(&r->msg, 0, sizeof(r->msg));

	r->p = r->b;
	if(readn(r->i, r->p, 1) != 1)
		goto eof;
	r->msg.fmt = (r->p[0] & 0xc0)>>6;
	r->msg.cs = r->p[0] & 0x3f;
	n = r->msg.cs + 1;
	if(n <= 2){
		if(readn(r->i, r->p, n) != n)
			goto eof;
		r->msg.cs = 64 + r->p[0];
		if(n == 2)
			r->msg.cs += 256 * r->p[1];
	}

	hsz = szs[r->msg.fmt];
	if(readn(r->i, r->p, hsz) != hsz)
		goto eof;

	h = r->p;
	e = r->p + hsz;

	r->msg.type = -1;
	msid = 0;
	ts = 0;
	len = 0;
	if(hsz >= szs[Type2]){
		h = amf0i24get(h, e, (s32int*)&ts); /* FIXME proper timestamps? */
		if(hsz >= szs[Type1]){
			h = amf0i24get(h, e, &len);
			h = amf0byteget(h, e, &byte);
			r->msg.type = byte;
			if(hsz >= szs[Type0])
				h = amf0i32leget(h, e, &msid);
		}
	}

	if(ts == 0xffffff){ /* exntended timestamp */
		if(readn(r->i, h, 4) != 4)
			goto err;
		h = amf0i32get(h, h+4, (s32int*)&ts);
	}

	/* FIXME do all consecutive chunks use Type3? */
	bextend(r, len);
	r->msg.data = h;
	r->msg.sz = len;
	for(;;){
		n = min(len, r->chunkin);
		if(readn(r->i, h, n) != n)
			goto eof;
		len -= n;
		h += n;
		if(len < 1)
			break;
		if(readn(r->i, h, 1) != 1)
			goto err;
		if((r->msg.cs | Type3<<6) != *h){
			werrstr("cs/fmt does not match: %02x", *h);
			goto err;
		}
	}

	return 0;
eof:
	werrstr("eof");
err:
	werrstr("rtmprecv: %r");
	return -1;
}

static int
rtmpsend(RTMP *r)
{
	u8int *p, *h, *e, hdata[24];
	int len, n, hsz;
	Command *c;

	if(r->p == nil)
		goto err;

	r->msg.data = r->b;
	r->msg.sz = r->p - r->b;

	h = hdata;
	*h++ = r->msg.fmt<<6 | r->msg.cs;
	hsz = szs[r->msg.fmt];
	e = h + hsz;
	if(hsz >= szs[Type2]){
		h = amf0i24(h, e, 0); /* FIXME put actual timestamps? */
		if(hsz >= szs[Type1]){
			h = amf0i24(h, e, r->msg.sz);
			h = amf0byte(h, e, r->msg.type);
			if(hsz >= szs[Type0])
				h = amf0i32(h, e, r->msg.sid);
		}
	}
	assert(h != nil);
	memset(h, 0, e-h);
	if(Bwrite(r, hdata, h-hdata) < 0)
		goto err;

	for(p = r->msg.data, len = r->msg.sz; len > 0;){
		n = min(len, r->chunkout);
		if(Bwrite(r, p, n) < 0)
			goto err;
		p += n;
		len -= n;
		if(len > 0){
			*h = r->msg.cs | Type3<<6;
			Bputc(r, *h);
		}
	}

	if(Bflush(r) < 0)
		goto err;

	if(debug){
		fprint(2, "← %M", &r->msg);
		if(r->msg.type == AMF0Command){
			Amf0 *a;
			u8int *s, *e;
			fprint(2, ":");
			s = r->msg.data;
			e = s + r->msg.sz;
			for(; s != nil && s != e;){
				if((s = amf0parse(&a, s, e)) != nil)
					fprint(2, " %A", a);
				else
					fprint(2, " %r");
				amf0free(a);
			}
		}
		fprint(2, "\n");
	}

	if(r->msg.type == AMF0Command){
		c = emalloc(sizeof(*c));
		*c = r->msg.cmd;
		assert(c->cb != nil);
		if((c->next = r->cmds.w) != nil)
			c->next->prev = c;
		r->cmds.w = c;
	}

	return 0;
err:
	werrstr("rtmpsend: %r");
	return -1;
}

static void
rtmpfree(RTMP *r)
{
	free(r->app);
	free(r->b);
	free(r->path);
	free(r->tcurl);
	if(r->c != nil){
		sendp(r->c, "done");
		chanfree(r->c);
	}
	Bterm(r);
	free(r);
}

static int
pong(RTMP *r, s32int n)
{
	newmsg(r, UserControl, Type0, CSUserCtl);
	puti16(CtlPingResponse);
	puti32(n);

	return rtmpsend(r);
}

static int
setchunksz(RTMP *r, int sz)
{
	int n;

	newmsg(r, SetChunkSize, Type0, CSUserCtl);
	puti32(sz);
	n = rtmpsend(r);
	r->chunkout = sz;

	return n;
}

static void
loop(void *aux)
{
	int res, n, ok;
	Amf0 *a[NumCb];
	u8int *s, *e;
	s16int s16;
	Message *m;
	Command *c;
	RTMP *r;

	r = aux;
	m = &r->msg;
	res = 0;
	memset(a, 0, sizeof(a));
	for(;;){
		for(n = 0; n < nelem(a); n++)
			amf0free(a[n]);
		memset(a, 0, sizeof(a));

		qlock(r);

		if(res != 0 || (res = rtmprecv(r)) != 0){
			if(debug)
				fprint(2, "rtmp loop: %r\n");
			for(n = 0; n < nelem(a); n++)
				amf0free(a[n]);
			break;
		}

		s = r->msg.data;
		e = s + r->msg.sz;

		if(debug)
			fprint(2, "→ %M", &r->msg);

		switch(m->type){
		case AMF0Command:
			c = nil;
			ok = 1;
			for(n = 0; n < NumCb; n++){
				if((s = amf0parse(&a[n], s, e)) == nil)
					goto err;
				switch(n){
				case CbCommand:
					if(a[n]->type != Tstr){
						werrstr("command name is not a string: %A", a[n]);
						goto err;
					}
					if(strcmp(a[n]->str, "_error") == 0)
						ok = 0;
					/* other values: "_result", etc */
					break;
				case CbTransID:
					if(a[n]->type != Tnum){
						werrstr("transaction ID is not a number");
						goto err;
					}
					for(c = r->cmds.w; c != nil && c->tid != a[n]->num; c = c->next);
					if(c == nil){
						werrstr("response to non-existent transaction %d", (int)a[n]->num);
						goto err;
					}
					break;
				}
			}
			if(debug)
				fprint(2, " tid=%A: %A %A %A\n", a[CbTransID], a[CbCommand], a[CbObject], a[CbResponse]);
			if(c->prev != nil)
				c->prev->next = c->next;
			if(c->next != nil)
				c->next->prev = c->prev;
			if(r->cmds.w == c)
				r->cmds.w = c->next;
			c->cb(r, ok, a, c->aux);
			free(c);
			break;

		case SetChunkSize:
			if(amf0i32get(s, e, &r->chunkin) == nil)
				goto err;
			if(r->chunkin < 2){
				werrstr("invalid chunk size: %d", r->chunkin);
				goto err;
			}
			if(debug)
				fprint(2, ": %d\n", r->chunkin);
			break;

		case UserControl:
			if((s = amf0i16get(s, e, &s16)) == nil)
				goto err;
			if(amf0i32get(s, e, &n) == nil)
				n = -1;
			switch(s16){
			case CtlStreamBegin:
			case CtlStreamEOF:
			case CtlStreamDry:
			case CtlSetBufferLen:
			case CtlStreamIsRecorded:
				if(0){
			case CtlPingRequest:
					if(pong(r, n) != 0)
						goto err;
				}
				if(debug)
					fprint(2, ": %s %d\n", ctl2s[s16], n);
				break;
			default:
				if(debug)
					fprint(2, ": ?%d? %d\n", s16, n);
				break;
			}
			break;

		case WindowAckSize:
			if(amf0i32get(s, e, &r->winacksz) == nil)
				goto err;
			if(debug)
				fprint(2, ": %d\n", r->winacksz);
			break;

		case SetBandwidth:
			if((s = amf0i32get(s, e, &r->bw)) == nil || amf0byteget(s, e, &r->bwlimit) == nil)
				goto err;
			if(debug)
				fprint(2, ": %d (%s)\n", r->bw, r->bwlimit < nelem(bwlimit2s) ? bwlimit2s[r->bwlimit] : "???");
			break;

		/* FIXME */
		case Aggregate:
		case Abort:
		case Ack:
		case Audio:
		case Video:
		case AMF0Metadata:
		case AMF0SharedObject:
			break;

		case AMF3Metadata:
		case AMF3SharedObject:
		case AMF3Command:
			if(debug)
				fprint(2, ": ignored\n");
			break;
err:
			res = -1;
			break;
		}

		qunlock(r);
	}

	qunlock(r);
	rtmpfree(r);

	threadexitsall(res == 0 ? nil : "error");
}

static int
handshake(int f)
{
	u8int c[1+CSsz], s[1+CSsz];

	c[0] = 3; /* rtmp v3 */
	memset(c+1, 0, 4+4); /* timestamp + zero */
	prng(c+1+8, CSsz-4-4);
	if(write(f, c, sizeof(c)) != sizeof(c))
		goto err;
	if(readn(f, s, sizeof(s)) != sizeof(s))
		goto err;
	if(c[0] != s[0]){
		werrstr("expected version %d, got %d", c[0], s[0]);
		goto err;
	}
	if(write(f, s+1, CSsz) != CSsz)
		goto err;
	if(readn(f, s+1, CSsz) != CSsz)
		goto err;
	if(memcmp(c, s, sizeof(c)) != 0){
		werrstr("C1 != S2");
		goto err;
	}

	return 0;

err:
	werrstr("handshake: %r");
	return -1;
}

static void
streamcreated(RTMP *, int ok, Amf0 *a[NumCb], void *aux)
{
	Channel *sid;

	sid = aux;
	if(strcmp(a[CbCommand]->str, "_result") != 0)
		fprint(2, "createStream: expected '_result', got %#q\n", a[CbCommand]->str);
	else if(a[CbResponse]->type != Tnum)
		fprint(2, "createStream: expected stream ID, got NaN\n");
	else if(!ok)
		fprint(2, "createStream: %A\n", a[CbResponse]);
	else
		sendul(sid, (ulong)a[CbResponse]->num);

	chanclose(sid);
}

int
rtmpstream(RTMP *r, ulong *sid)
{
	Channel *c;
	int n;

	qlock(r);

	newmsg(r, AMF0Command, Type0, CSCtl);
	putstr("createStream");
	putnum(r->msg.cmd.tid);
	putnull();

	c = chancreate(sizeof(ulong), 0);
	r->msg.cmd.cb = streamcreated;
	r->msg.cmd.aux = c;
	n = rtmpsend(r);

	qunlock(r);

	n = (n == 0 && recv(c, sid) == 1) ? 0 : -1;
	chanfree(c);

	return n;
}

static void
connected(RTMP *r, int ok, Amf0 *a[NumCb], void *)
{
	if(strcmp(a[CbCommand]->str, "_result") != 0)
		sendp(r->c, smprint("expected '_result', got %#q", a[CbCommand]->str));
	else{
		sendp(r->c, ok ? nil : smprint("%A", a[CbResponse]));
		if(ok)
			setchunksz(r, ChunkDesired);
	}
}

static int
connect(RTMP *r)
{
	newmsg(r, AMF0Command, Type0, CSCtl);
	putcommand("connect", connected);
		putkvstr("app", r->app);
		putkvstr("tcUrl", r->tcurl);
		putkvbool("fpad", 0); /* no proxy */
		putkvnum("audioCodecs", 0x4 | 0x400); /* mp3 + aac */
		putkvnum("videoCodecs", 0x80); /* h.264 */
		putkvnum("videoFunction", 0); /* no frame-accurate seek */
		putkvnum("objectEncoding", 0); /* AMF0 */
	putend();

	return rtmpsend(r);
}

static int
msgtypefmt(Fmt *f)
{
	char *s;
	int t;

	if((t = va_arg(f->args, int)) >= 0 &&
	   t < nelem(msgtype2s) &&
	   (s = msgtype2s[t]) != nil)
		return fmtprint(f, "%s", s);

	return fmtprint(f, "%d", t);
}

static int
msgfmt(Fmt *f)
{
	Message *m;

	m = va_arg(f->args, Message*);
	fmtprint(f, "type=%T cs=%d ts=%ud sz=%d", m->type, m->cs, m->ts, m->sz);

	return 0;
}

RTMP *
rtmpdial(char *url)
{
	char *s, *e, *path, *app;
	int f, port, ctl;
	RTMP *r;

	fmtinstall('A', amf0fmt);
	fmtinstall('T', msgtypefmt);
	fmtinstall('M', msgfmt);
	quotefmtinstall();

	r = nil;
	f = -1;
	url = estrdup(url); /* since we're changing it in-place */
	if(memcmp(url, "rtmp://", 7) != 0){
		werrstr("invalid url");
		goto err;
	}
	s = url + 7;
	if((e = strpbrk(s, ":/")) == nil){
		werrstr("no path");
		goto err;
	}
	port = 1935;
	if(*e == ':'){
		if((port = strtol(e+1, &path, 10)) < 1 || path == e+1 || *path != '/'){
			werrstr("invalid port");
			goto err;
		}
	}else{
		path = e;
	}
	while(*(++path) == '/');

	s = smprint("tcp!%.*s!%d", (int)(e-s), s, port);
	f = dial(s, nil, nil, &ctl);
	free(s);
	if(f < 0)
		goto err;

	app = path;
	if((s = strchr(path, '/')) == nil){
		werrstr("no path");
		goto err;
	}
	if((e = strchr(s+1, '/')) != nil){
		/* at this point it can be app instance if there is another slash following */
		if((s = strchr(e+1, '/')) == nil){
			/* no, just path leftovers */
			s = e;
		}
		*s = 0;
		path = s+1;
	}else{
		path = nil;
	}

	if(handshake(f) != 0)
		goto err;

	r = ecalloc(1, sizeof(*r));
	r->i = f;
	r->chunkin = ChunkDefault;
	r->chunkout = ChunkDefault;
	r->tcurl = url;
	url = nil;
	r->c = chancreate(sizeof(void*), 0);
	r->app = estrdup(app);
	r->path = path == nil ? nil : estrdup(path);
	bextend(r, Bufsz);
	Binits(r, f, OWRITE, r->biobuf, sizeof(r->biobuf));

	if(connect(r) != 0)
		goto err;
	if(proccreate(loop, r, mainstacksize) < 0)
		goto err;

	/* wait for the connect call to finish */
	if((s = recvp(r->c)) != nil){
		rtmpclose(r);
		werrstr("rtmpdial: %s", s);
		free(s);
		r = nil;
	}

	return r;

err:
	werrstr("rtmpdial: %r");
	if(r != nil)
		rtmpfree(r);
	if(f >= 0)
		close(f);
	free(url);
	return nil;
}

void
rtmpclose(RTMP *r)
{
	if(r == nil)
		return;
	if(r->i >= 0)
		close(r->i);
	if(r->c != nil)
		chanclose(r->c);
}