shithub: opus

Download patch

ref: 8af5c6b4a13cb66e0f3dcd465c246d2d2e4128c7
parent: b6095cf22d501cb1950685e46b334b0a2ca7e78b
author: Jan Buethe <jbuethe@amazon.de>
date: Tue Nov 7 06:54:22 EST 2023

added transposed 1d convolutions to wexchange

--- a/dnn/torch/weight-exchange/wexchange/c_export/__init__.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/__init__.py
@@ -28,4 +28,4 @@
 */
 """
 
-from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer, print_vector
\ No newline at end of file
+from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer, print_vector
\ No newline at end of file
--- a/dnn/torch/weight-exchange/wexchange/c_export/common.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py
@@ -361,3 +361,25 @@
     writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n")
 
     return N
+
+
+def print_tconv1d_layer(writer : CWriter,
+                       name : str,
+                       weight : np.ndarray,
+                       bias : np.ndarray,
+                       stride: int,
+                       scale=1/128,
+                       quantize=False):
+
+    in_channels, out_channels, kernel_size = weight.shape
+
+
+    linear_weight = weight.transpose(2, 1, 0).reshape(kernel_size * out_channels, in_channels).transpose(1, 0)
+    linear_bias = np.repeat(bias[np.newaxis, :], kernel_size, 0).flatten()
+
+    print_linear_layer(writer, name, linear_weight, linear_bias, scale=scale, quantize=quantize)
+
+    writer.header.write(f"\n#define {name.upper()}_KERNEL_SIZE {kernel_size}\n")
+    writer.header.write(f"\n#define {name.upper()}_STRIDE {stride}\n")
+    writer.header.write(f"\n#define {name.upper()}_IN_CHANNELS {in_channels}\n")
+    writer.header.write(f"\n#define {name.upper()}_OUT_CHANNELS {out_channels}\n")
\ No newline at end of file
--- a/dnn/torch/weight-exchange/wexchange/torch/torch.py
+++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py
@@ -32,7 +32,7 @@
 import torch
 import numpy as np
 
-from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_conv2d_layer
+from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer
 
 def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
 
@@ -162,6 +162,36 @@
                 conv.bias.set_(torch.from_numpy(b))
 
 
+def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
+
+    w = conv.weight.detach().cpu().numpy().copy()
+    if conv.bias is None:
+        b = np.zeros(conv.out_channels, dtype=w.dtype)
+    else:
+        b = conv.bias.detach().cpu().numpy().copy()
+
+    if isinstance(where, CWriter):
+
+        return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize)
+    else:
+        os.makedirs(where, exist_ok=True)
+
+        np.save(os.path.join(where, 'weight_oik.npy'), w)
+
+        np.save(os.path.join(where, 'bias.npy'), b)
+
+
+def load_torch_tconv1d_weights(where, conv):
+
+    with torch.no_grad():
+        w = np.load(os.path.join(where, 'weight_oik.npy'))
+        conv.weight.set_(torch.from_numpy(w))
+        if type(conv.bias) != type(None):
+            b = np.load(os.path.join(where, 'bias.npy'))
+            if conv.bias is not None:
+                conv.bias.set_(torch.from_numpy(b))
+
+
 def dump_torch_conv2d_weights(where, conv, name='conv', scale=1/128, quantize=False):
     w = conv.weight.detach().cpu().permute(0, 1, 3, 2).numpy().copy()
     if conv.bias is None:
@@ -228,6 +258,8 @@
         return dump_torch_conv2d_weights(where, module, name, **kwargs)
     elif isinstance(module, torch.nn.Embedding):
         return dump_torch_embedding_weights(where, module)
+    elif isinstance(module, torch.nn.ConvTranspose1d):
+        return dump_torch_tconv1d_weights(where, module, name, **kwargs)
     else:
         raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
 
@@ -243,5 +275,7 @@
         load_torch_conv2d_weights(where, module)
     elif isinstance(module, torch.nn.Embedding):
         load_torch_embedding_weights(where, module)
+    elif isinstance(module, torch.nn.ConvTranspose1d):
+        return load_torch_tconv1d_weights(where, module)
     else:
-        raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
+        raise ValueError(f'load_torch_weights: layer of type {type(module)} not supported')
--