shithub: opus

Download patch

ref: 627aa7f5b3688ba787c69e55e199ba82e2013be0
parent: 7d328f5bfaa321d823ff4d11b62d5357c99e0693
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Dec 21 10:34:33 EST 2023

Packet loss generation model

--- /dev/null
+++ b/dnn/torch/lossgen/lossgen.py
@@ -1,0 +1,28 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+class LossGen(nn.Module):
+    def __init__(self, gru1_size=16, gru2_size=16):
+        super(LossGen, self).__init__()
+
+        self.gru1_size = gru1_size
+        self.gru2_size = gru2_size
+        self.gru1 = nn.GRU(2, self.gru1_size, batch_first=True)
+        self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
+        self.dense_out = nn.Linear(self.gru2_size, 1)
+
+    def forward(self, loss, perc, states=None):
+        #print(states)
+        device = loss.device
+        batch_size = loss.size(0)
+        if states is None:
+            gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
+            gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
+        else:
+            gru1_state = states[0]
+            gru2_state = states[1]
+        x = torch.cat([loss, perc], dim=-1)
+        gru1_out, gru1_state = self.gru1(x, gru1_state)
+        gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
+        return self.dense_out(gru2_out), [gru1_state, gru2_state]
--- /dev/null
+++ b/dnn/torch/lossgen/process_data.sh
@@ -1,0 +1,17 @@
+#!/bin/sh
+
+#directory containing the loss files
+datadir=$1
+
+for i in $datadir/*_is_lost.txt
+do
+	perc=`cat $i | awk '{a+=$1}END{print a/NR}'`
+	echo $perc $i
+done > percentage_list.txt
+
+sort -n percentage_list.txt | awk '{print $2}' > percentage_sorted.txt
+
+for i in `cat percentage_sorted.txt`
+do
+	cat $i
+done > loss_sorted.txt
--- /dev/null
+++ b/dnn/torch/lossgen/test_lossgen.py
@@ -1,0 +1,45 @@
+import lossgen
+import os
+import argparse
+import torch
+import numpy as np
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('model', type=str, help='CELPNet model')
+parser.add_argument('percentage', type=float, help='percentage loss')
+parser.add_argument('output', type=str, help='path to output file (ascii)')
+
+parser.add_argument('--length', type=int, help="length of sequence to generate", default=500)
+
+args = parser.parse_args()
+
+
+
+checkpoint = torch.load(args.model, map_location='cpu')
+
+model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+
+model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+states=None
+last = torch.zeros((1,1,1))
+perc = torch.tensor((args.percentage,))[None,None,:]
+seq = torch.zeros((0,1,1))
+
+one = torch.ones((1,1,1))
+zero = torch.zeros((1,1,1))
+
+if __name__ == '__main__':
+    for i in range(args.length):
+        prob, states = model(last, perc, states=states)
+        prob = torch.sigmoid(prob)
+        states[0] = states[0].detach()
+        states[1] = states[1].detach()
+        loss = one if np.random.rand() < prob else zero
+        last = loss
+        seq = torch.cat([seq, loss])
+
+np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d')
--- /dev/null
+++ b/dnn/torch/lossgen/train_lossgen.py
@@ -1,0 +1,96 @@
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+import tqdm
+from scipy.signal import lfilter
+import os
+import lossgen
+
+class LossDataset(torch.utils.data.Dataset):
+    def __init__(self,
+                loss_file,
+                sequence_length=997):
+
+        self.sequence_length = sequence_length
+
+        self.loss = np.loadtxt(loss_file, dtype='float32')
+
+        self.nb_sequences = self.loss.shape[0]//self.sequence_length
+        self.loss = self.loss[:self.nb_sequences*self.sequence_length]
+        self.perc = lfilter(np.array([.001], dtype='float32'), np.array([1., -.999], dtype='float32'), self.loss)
+
+        self.loss = np.reshape(self.loss, (self.nb_sequences, self.sequence_length, 1))
+        self.perc = np.reshape(self.perc, (self.nb_sequences, self.sequence_length, 1))
+
+    def __len__(self):
+        return self.nb_sequences
+
+    def __getitem__(self, index):
+        r0 = np.random.normal(scale=.02, size=(1,1)).astype('float32')
+        r1 = np.random.normal(scale=.02, size=(self.sequence_length,1)).astype('float32')
+        return [self.loss[index, :, :], self.perc[index, :, :]+r0+r1]
+
+
+adam_betas = [0.8, 0.99]
+adam_eps = 1e-8
+batch_size=512
+lr_decay = 0.0001
+lr = 0.001
+epsilon = 1e-5
+epochs = 20
+checkpoint_dir='checkpoint'
+os.makedirs(checkpoint_dir, exist_ok=True)
+checkpoint = dict()
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+checkpoint['model_args']    = ()
+checkpoint['model_kwargs']  = {'gru1_size': 16, 'gru2_size': 48}
+model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+dataset = LossDataset('loss_sorted.txt')
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
+
+
+optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps)
+
+
+# learning rate scheduler
+scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x))
+
+
+if __name__ == '__main__':
+    model.to(device)
+
+    for epoch in range(1, epochs + 1):
+
+        running_loss = 0
+
+        print(f"training epoch {epoch}...")
+        with tqdm.tqdm(dataloader, unit='batch') as tepoch:
+            for i, (loss, perc) in enumerate(tepoch):
+                optimizer.zero_grad()
+                loss = loss.to(device)
+                perc = perc.to(device)
+
+                out, _ = model(loss, perc)
+                out = torch.sigmoid(out[:,:-1,:])
+                target = loss[:,1:,:]
+
+                loss = torch.mean(-target*torch.log(out+epsilon) - (1-target)*torch.log(1-out+epsilon))
+
+                loss.backward()
+                optimizer.step()
+
+                scheduler.step()
+
+                running_loss += loss.detach().cpu().item()
+                tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
+                                   )
+
+        # save checkpoint
+        checkpoint_path = os.path.join(checkpoint_dir, f'lossgen_{epoch}.pth')
+        checkpoint['state_dict'] = model.state_dict()
+        checkpoint['loss'] = running_loss / len(dataloader)
+        checkpoint['epoch'] = epoch
+        torch.save(checkpoint, checkpoint_path)
--