ref: 0fa7150454740a0c2157d33b7ccf80d217684841
parent: f50058f3e3221d858a021c64296f9efaf329a06e
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Fri Jul 13 10:19:27 EDT 2018
Implement FFTNet too
--- a/dnn/train_wavenet.py
+++ b/dnn/train_wavenet.py
@@ -16,9 +16,9 @@
#set_session(tf.Session(config=config))
nb_epochs = 40
-batch_size = 32
+batch_size = 64
-model = wavenet.new_wavenet_model()
+model = wavenet.new_wavenet_model(fftnet=True)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
model.summary()
@@ -62,7 +62,7 @@
# f.create_dataset('data', data=in_data[:50000, :, :]) # f.create_dataset('feat', data=features[:50000, :, :])-checkpoint = ModelCheckpoint('wavenet3a_{epoch:02d}.h5')+checkpoint = ModelCheckpoint('wavenet3c_{epoch:02d}.h5') #model.load_weights('wavernn1c_01.h5')model.compile(optimizer=Adam(0.001, amsgrad=True, decay=2e-4), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
--- a/dnn/wavenet.py
+++ b/dnn/wavenet.py
@@ -10,13 +10,13 @@
import sys
from causalconv import CausalConv
-units=128
+units=256
pcm_bits = 8
pcm_levels = 2**pcm_bits
nb_used_features = 38
-def new_wavenet_model():
+def new_wavenet_model(fftnet=False):
pcm = Input(shape=(None, 1))
pitch = Input(shape=(None, 1))
feat = Input(shape=(None, nb_used_features))
@@ -36,8 +36,9 @@
for k in range(10):
res = tmp
tmp = Concatenate()([tmp, rfeat])
- c1 = CausalConv(units, 2, dilation_rate=2**k, activation='tanh')
- c2 = CausalConv(units, 2, dilation_rate=2**k, activation='sigmoid')
+ dilation = 9-k if fftnet else k
+ c1 = CausalConv(units, 2, dilation_rate=2**dilation, activation='tanh')
+ c2 = CausalConv(units, 2, dilation_rate=2**dilation, activation='sigmoid')
tmp = Multiply()([c1(tmp), c2(tmp)])
tmp = Dense(units, activation='relu')(tmp)
if k != 0:
--
⑨