shithub: aubio

ref: ada5bafca49cca061545a85ec525fe180cbce1df
dir: /python/bench-onset/

View raw version
#! /usr/bin/python

from aubio.bench.node import *
from aubio.tasks import *

class benchonset(bench):
	
	def dir_eval(self):
		self.P = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.bad+self.doubled)
		self.R = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.missed+self.merged)
		if self.R < 0: self.R = 0
		self.F = 2* self.P*self.R / (self.P+self.R)

		self.values = [self.params.onsetmode, 
		"%2.3f" % self.params.threshold,
		self.orig,
		self.expc,
		self.missed,
		self.merged,
		self.bad,
		self.doubled,
		(self.orig-self.missed-self.merged),
		"%2.3f" % (100*float(self.orig-self.missed-self.merged)/(self.orig)),
		"%2.3f" % (100*float(self.bad+self.doubled)/(self.orig)), 
		"%2.3f" % (100*float(self.orig-self.missed)/(self.orig)), 
		"%2.3f" % (100*float(self.bad)/(self.orig)),
		"%2.3f" % self.P,
		"%2.3f" % self.R,
		"%2.3f" % self.F  ]

	def file_exec(self,input,output):
		filetask = self.task(input,params=self.params)
		computed_data = filetask.compute_all()
		results = filetask.eval(computed_data)
		self.orig    += filetask.orig
		self.missed  += filetask.missed
		self.merged  += filetask.merged
		self.expc    += filetask.expc
		self.bad     += filetask.bad
		self.doubled += filetask.doubled


	def run_bench(self,modes=['dual'],thresholds=[0.5]):
		self.modes = modes
		self.thresholds = thresholds

		self.pretty_print(self.titles)
		for mode in self.modes:
			self.params.onsetmode = mode
			for threshold in self.thresholds:
				self.params.threshold = threshold
				self.dir_exec()
				self.dir_eval()
				self.pretty_print(self.values)

	def auto_learn(self,modes=['dual'],thresholds=[0.1,1.5]):
		""" simple dichotomia like algorithm to optimise threshold """
		self.modes = modes
		self.pretty_print(self.titles)
		for mode in self.modes:
			steps = 10 
			lesst = thresholds[0] 
			topt = thresholds[1]
			self.params.onsetmode = mode

			self.params.threshold = topt 
			self.dir_exec()
			self.dir_eval()
			self.pretty_print(self.values)
			topF = self.F 

			self.params.threshold = lesst 
			self.dir_exec()
			self.dir_eval()
			self.pretty_print(self.values)
			lessF = self.F 

			for i in range(steps):
				self.params.threshold = ( lesst + topt ) * .5 
				self.dir_exec()
				self.dir_eval()
				self.pretty_print(self.values)
				if self.F == 100.0 or self.F == topF: 
					print "assuming we converged, stopping" 
					break
				#elif abs(self.F - topF) < 0.01 :
				#	print "done converging"
				#	break
				if topF < self.F:
					#lessF = topF
					#lesst = topt 
					topF = self.F
					topt = self.params.threshold
				elif lessF < self.F:
					lessF = self.F
					lesst = self.params.threshold
				if topt == lesst:
					lesst /= 2.

	def auto_learn2(self,modes=['dual'],thresholds=[0.1,1.0]):
		""" simple dichotomia like algorithm to optimise threshold """
		self.modes = modes
		self.pretty_print(self.titles)
		for mode in self.modes:
			steps = 10 
			step = thresholds[1]
			curt = thresholds[0] 
			self.params.onsetmode = mode

			self.params.threshold = curt 
			self.dir_exec()
			self.dir_eval()
			self.pretty_print(self.values)
			curexp = self.expc

			for i in range(steps):
				if curexp < self.orig:
					#print "we found at most less onsets than annotated"
					self.params.threshold -= step 
					step /= 2
				elif curexp > self.orig:
					#print "we found more onsets than annotated"
					self.params.threshold += step 
					step /= 2
				self.dir_exec()
				self.dir_eval()
				curexp = self.expc
				self.pretty_print(self.values)
				if self.orig == 100.0 or self.orig == self.expc: 
					print "assuming we converged, stopping" 
					break

if __name__ == "__main__":
	import sys
	if len(sys.argv) > 1: datapath = sys.argv[1]
	else: print "ERR: a path is required"; sys.exit(1)
	modes = ['complex', 'energy', 'phase', 'specdiff', 'kl', 'mkl', 'dual']
	#modes = [ 'complex' ]
	thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
	#thresholds = [1.5]

	#datapath = "%s%s" % (DATADIR,'/onset/DB/*/')
	respath = '/var/tmp/DB-testings'

	benchonset = benchonset(datapath,respath,checkres=True,checkanno=True)
	benchonset.params = taskparams()
	benchonset.task = taskonset

	benchonset.titles = [ 'mode', 'thres', 'orig', 'expc', 'missd', 'mergd',
	'bad', 'doubl', 'corrt', 'GD', 'FP', 'GD-merged', 'FP-pruned',
	'prec', 'recl', 'dist' ]
	benchonset.formats = ["%12s" , "| %6s", "| %6s", "| %6s", "| %6s", "| %6s", 
	"| %6s", "| %6s", "| %6s", "| %8s", "| %8s", "| %8s", "| %8s",
	"| %6s", "| %6s", "| %6s"] 

	try:
		benchonset.auto_learn2(modes=modes)
		#benchonset.run_bench(modes=modes)
	except KeyboardInterrupt:
		sys.exit(1)