shithub: opus

Download patch

ref: d5b6087f4862e8001beaa39ccdb1a0562e9d3473
parent: b24e53fdfaf6a81585e45488c9c8b7cf7e909db3
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Oct 12 22:53:03 EDT 2021

Add tensorboard logging

--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -52,6 +52,7 @@
 parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
 parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
 parser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)')
+parser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
 
 
 args = parser.parse_args()
@@ -185,6 +186,13 @@
     grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density)
 
 model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
-csv_logger = CSVLogger('training_vals.log')
+
 loader = LPCNetLoader(data, features, periods, batch_size, lpc_out=flag_e2e)
-model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
+
+callbacks = [checkpoint, sparsify, grub_sparsify]
+if args.logdir is not None:
+    logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.grua_size)
+    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
+    callbacks.append(tensorboard_callback)
+
+model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
--