shithub: opus

Download patch

ref: 0fa7150454740a0c2157d33b7ccf80d217684841
parent: f50058f3e3221d858a021c64296f9efaf329a06e
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Fri Jul 13 10:19:27 EDT 2018

Implement FFTNet too

--- a/dnn/train_wavenet.py
+++ b/dnn/train_wavenet.py
@@ -16,9 +16,9 @@
 #set_session(tf.Session(config=config))
 
 nb_epochs = 40
-batch_size = 32
+batch_size = 64
 
-model = wavenet.new_wavenet_model()
+model = wavenet.new_wavenet_model(fftnet=True)
 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
 model.summary()
 
@@ -62,7 +62,7 @@
 # f.create_dataset('data', data=in_data[:50000, :, :])
 # f.create_dataset('feat', data=features[:50000, :, :])
 
-checkpoint = ModelCheckpoint('wavenet3a_{epoch:02d}.h5')
+checkpoint = ModelCheckpoint('wavenet3c_{epoch:02d}.h5')
 
 #model.load_weights('wavernn1c_01.h5')
 model.compile(optimizer=Adam(0.001, amsgrad=True, decay=2e-4), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
--- a/dnn/wavenet.py
+++ b/dnn/wavenet.py
@@ -10,13 +10,13 @@
 import sys
 from causalconv import CausalConv
 
-units=128
+units=256
 pcm_bits = 8
 pcm_levels = 2**pcm_bits
 nb_used_features = 38
 
 
-def new_wavenet_model():
+def new_wavenet_model(fftnet=False):
     pcm = Input(shape=(None, 1))
     pitch = Input(shape=(None, 1))
     feat = Input(shape=(None, nb_used_features))
@@ -36,8 +36,9 @@
     for k in range(10):
         res = tmp
         tmp = Concatenate()([tmp, rfeat])
-        c1 = CausalConv(units, 2, dilation_rate=2**k, activation='tanh')
-        c2 = CausalConv(units, 2, dilation_rate=2**k, activation='sigmoid')
+        dilation = 9-k if fftnet else k
+        c1 = CausalConv(units, 2, dilation_rate=2**dilation, activation='tanh')
+        c2 = CausalConv(units, 2, dilation_rate=2**dilation, activation='sigmoid')
         tmp = Multiply()([c1(tmp), c2(tmp)])
         tmp = Dense(units, activation='relu')(tmp)
         if k != 0:
--