shithub: aubio

ref: 50e99cc1567a617ea03de844a3928637e38a8e39
dir: /python/bench-onset/

View raw version
#! /usr/bin/python

from aubio.bench.config import *
from aubio.bench.node import *

class onset_parameters:
	def __init__(self):
		""" set default parameters """
		self.silence = -70
		self.derivate = False
		self.localmin = False
		self.bufsize = 512
		self.hopsize = 256
		self.samplerate = 44100
		self.tol = 0.05
		self.step = float(self.hopsize)/float(self.samplerate)
		self.threshold = 0.1
		self.mode = 'dual'

class benchonset(bench):
	
	def compute_results(self):
		self.P = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.bad+self.doubled)
		self.R = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.missed+self.merged)
		if self.R < 0: self.R = 0
		self.F = 2* self.P*self.R / (self.P+self.R)

		self.values = [self.params.mode, 
		"%2.3f" % self.params.threshold,
		self.orig,
		self.expc,
		self.missed,
		self.merged,
		self.bad,
		self.doubled,
		(self.orig-self.missed-self.merged),
		"%2.3f" % (100*float(self.orig-self.missed-self.merged)/(self.orig)),
		"%2.3f" % (100*float(self.bad+self.doubled)/(self.orig)), 
		"%2.3f" % (100*float(self.orig-self.missed)/(self.orig)), 
		"%2.3f" % (100*float(self.bad)/(self.orig)),
		"%2.3f" % self.P,
		"%2.3f" % self.R,
		"%2.3f" % self.F  ]

	def compute_onset(self,input,output):
		from aubio.tasks import getonsets, get_onset_mode
		from aubio.onsetcompare import onset_roc, onset_diffs
		from aubio.txtfile import read_datafile
		amode = 'roc'
		vmode = 'verbose'
		vmode = ''
		lres, ofunc = getonsets(input,
			self.params.threshold,
			self.params.silence,
			mode=get_onset_mode(self.params.mode),
			localmin=self.params.localmin,
			derivate=self.params.derivate,
			bufsize=self.params.bufsize,
			hopsize=self.params.hopsize,
			storefunc=False)

		for i in range(len(lres)): lres[i] = lres[i]*self.params.step
		ltru = read_datafile(input.replace('.wav','.txt'),depth=0)
		if vmode=='verbose':
			print "Running with mode %s" % self.params.mode, 
			print " and threshold %f" % self.params.threshold, 
			print " on file", input
		#print ltru; print lres
		if amode == 'localisation':
			l = onset_diffs(ltru,lres,self.params.tol)
			mean = 0
			for i in l: mean += i
			if len(l): print "%.3f" % (mean/len(l))
			else: print "?0"
		elif amode == 'roc':
			orig, missed, merged, expc, bad, doubled = onset_roc(ltru,lres,self.params.tol)
			self.orig    += orig
			self.missed  += missed
			self.merged  += merged
			self.expc    += expc
			self.bad     += bad
			self.doubled += doubled
		self.compute_results()
			
	def compute_data(self):
		self.orig, self.missed, self.merged, self.expc, \
			self.bad, self.doubled = 0, 0, 0, 0, 0, 0
		act_on_data(self.compute_onset,self.datadir,self.resdir, \
			suffix='',filter='f -name \'*.wav\'')

	def run_bench(self,modes=['dual'],thresholds=[0.5]):
		self.modes = modes
		self.thresholds = thresholds

		self.pretty_print(self.titles)
		for mode in self.modes:
			self.params.mode = mode
			for threshold in self.thresholds:
				self.params.threshold = threshold
				self.compute_data()
				self.compute_results()
				self.pretty_print(self.values)

	def auto_learn(self,modes=['dual'],thresholds=[0.1,1.5]):
		""" simple dichotomia like algorithm to optimise threshold """
		self.modes = modes
		self.pretty_print(self.titles)
		for mode in self.modes:
			steps = 10 
			lesst = thresholds[0] 
			topt = thresholds[1]
			self.params.mode = mode

			self.params.threshold = topt 
			self.compute_data()
			self.pretty_print(self.values)
			topF = self.F 

			self.params.threshold = lesst 
			self.compute_data()
			self.pretty_print(self.values)
			lessF = self.F 

			for i in range(steps):
				self.params.threshold = ( lesst + topt ) * .5 
				self.compute_data()
				self.pretty_print(self.values)
				if self.F == 100.0 or self.F == topF: 
					print "assuming we converged, stopping" 
					break
				#elif abs(self.F - topF) < 0.01 :
				#	print "done converging"
				#	break
				if topF < self.F:
					#lessF = topF
					#lesst = topt 
					topF = self.F
					topt = self.params.threshold
				elif lessF < self.F:
					lessF = self.F
					lesst = self.params.threshold
				if topt == lesst:
					lesst /= 2.


#modes = [ 'complex' ]
modes = ['complex', 'energy', 'phase', 'specdiff', 'kl', 'mkl', 'dual']
#thresholds = [1.5]
thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]

#datapath = "%s%s" % (DATADIR,'/onset/DB/*/')
datapath = "%s%s" % (DATADIR,'/onset/DB/PercussivePhrases/RobertRich')
respath = '/var/tmp/DB-testings'

benchonset = benchonset(datapath,respath,checkres=True,checkanno=True)

benchonset.params = onset_parameters()

benchonset.titles = [ 'mode', 'thres', 'orig', 'expc', 'missd', 'mergd',
'bad', 'doubl', 'corrt', 'GD', 'FP', 'GD-merged', 'FP-pruned',
'prec', 'recl', 'dist' ]
benchonset.formats = ["%12s" , "| %6s", "| %6s", "| %6s", "| %6s", "| %6s", 
"| %6s", "| %6s", "| %6s", "| %8s", "| %8s", "| %8s", "| %8s",
"| %6s", "| %6s", "| %6s"] 

#benchonset.run_bench(modes=modes,thresholds=thresholds)
benchonset.auto_learn(modes=modes)

#        gatherdata
#act_on_data(my_print,datapath,respath,suffix='.txt',filter='f -name \'*.wav\'')
#        gatherthreshold
#        gathermodes
#        comparediffs
#        gatherdiffs