shithub: opus

ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/engine/lpcnet_engine.py/

View raw version
import torch
from tqdm import tqdm
import sys

def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, log_interval=10):

    model.to(device)
    model.train()

    running_loss = 0
    previous_running_loss = 0

    # gru states
    gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
    gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
    gru_states = [gru_a_state, gru_b_state]

    with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:

        for i, batch in enumerate(tepoch):

            # set gradients to zero
            optimizer.zero_grad()

            # zero out initial gru states
            gru_a_state.zero_()
            gru_b_state.zero_()

            # push batch to device
            for key in batch:
                batch[key] = batch[key].to(device)

            target = batch['target']

            # calculate model output
            output = model(batch['features'], batch['periods'], batch['signals'], gru_states)

            # calculate loss
            loss = criterion(output.permute(0, 2, 1), target)

            # calculate gradients
            loss.backward()

            # update weights
            optimizer.step()

            # update learning rate
            scheduler.step()

            # call sparsifier
            model.sparsify()

            # update running loss
            running_loss += float(loss.cpu())

            # update status bar
            if i % log_interval == 0:
                tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
                previous_running_loss = running_loss


    running_loss /= len(dataloader)

    return running_loss

def evaluate(model, criterion, dataloader, device, log_interval=10):

    model.to(device)
    model.eval()

    running_loss = 0
    previous_running_loss = 0

    # gru states
    gru_a_state = torch.zeros(1, dataloader.batch_size, model.gru_a_units, device=device).to(device)
    gru_b_state = torch.zeros(1, dataloader.batch_size, model.gru_b_units, device=device).to(device)
    gru_states = [gru_a_state, gru_b_state]

    with torch.no_grad():
        with tqdm(dataloader, unit='batch', file=sys.stdout) as tepoch:

            for i, batch in enumerate(tepoch):


                # zero out initial gru states
                gru_a_state.zero_()
                gru_b_state.zero_()

                # push batch to device
                for key in batch:
                    batch[key] = batch[key].to(device)

                target = batch['target']

                # calculate model output
                output = model(batch['features'], batch['periods'], batch['signals'], gru_states)

                # calculate loss
                loss = criterion(output.permute(0, 2, 1), target)

                # update running loss
                running_loss += float(loss.cpu())

                # update status bar
                if i % log_interval == 0:
                    tepoch.set_postfix(running_loss=f"{running_loss/(i + 1):8.7f}", current_loss=f"{(running_loss - previous_running_loss)/log_interval:8.7f}")
                    previous_running_loss = running_loss


        running_loss /= len(dataloader)

        return running_loss