ref: 0459a572f592fb07376c480c1ebbf04c16090211
parent: ce28695844c12f43c31b4ee739749883c8b44b17
author: Jan Buethe <jbuethe@amazon.de>
date: Fri Sep 29 11:34:59 EDT 2023
updated PitchDNN export script
--- a/dnn/torch/neural-pitch/export_neuralpitch_weights.py
+++ b/dnn/torch/neural-pitch/export_neuralpitch_weights.py
@@ -44,7 +44,7 @@
import torch
import numpy as np
-from models import large_if_ccode
+from models import PitchDNN
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
@@ -52,39 +52,51 @@
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
- enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
- enc_writer.header.write(
+ writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
+ writer.header.write(
f"""
#include "opus_types.h"
"""
)
-
- # encoder
- encoder_dense_layers = [
- ('initial' , 'initial', 'TANH'),
- ('upsample' , 'upsample', 'TANH')
+ layers = [
+ ('if_upsample.0', "dense_if_upsampler_1"),
+ ('if_upsample.2', "dense_if_upsampler_2"),
+ ('conv.1', "conv2d_1"),
+ ('conv.4', "conv2d_2"),
+ ('conv.7', "conv2d_3"),
+ ('downsample.0', "dense_downsampler"),
+ ("upsample.0", "dense_final_upsampler")
]
- for name, export_name, _ in encoder_dense_layers:
+
+ for name, export_name in layers:
layer = model.get_submodule(name)
- dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
+ dump_torch_weights(writer, layer, name=export_name, verbose=True)
- encoder_gru_layers = [
- ('gru' , 'gru', 'TANH'),
+ gru_layers = [
+ ("GRU", "gru_1"),
]
- enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
- for name, export_name, _ in encoder_gru_layers])
+ max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
+ for name, export_name in gru_layers])
- del enc_writer
+ writer.header.write(
+f"""
+#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
+"""
+ )
+
+ writer.close()
+
+
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
- model = large_if_ccode()
- checkpoint = torch.load(args.checkpoint ,map_location='cpu')
+ model = PitchDNN()
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
c_export(args, model)
--
⑨