diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/Base.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/Base.m index 498e7b0..9545574 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/Base.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/Base.m @@ -53,8 +53,9 @@ function process( obj, wavFilepath ) %% ------------------------------------------------------------------------------- % override of Core.IdProcInterface's method - function fileProcessed = hasFileAlreadyBeenProcessed( ~, ~ ) + function [fileProcessed,cacheDir] = hasFileAlreadyBeenProcessed( ~, ~ ) fileProcessed = false; + cacheDir = []; end %% ------------------------------------------------------------------------------- @@ -130,11 +131,12 @@ function delete( obj ) %% ------------------------------------------------------------------------------- function processInternal( obj, varargin ) obj.inputProc.sceneId = obj.sceneId; - in = obj.loadInputData( obj.curWavFilepath, 'afeData', 'annotations' ); if nargin < 2 || any( strcmpi( 'afeBlocks', varargin ) ) + in = obj.loadInputData( obj.curWavFilepath, 'afeData', 'annotations' ); [obj.blockAnnotations,obj.afeBlocks] = ... obj.blockify( in.afeData, in.annotations ); else + in = obj.loadInputData( obj.curWavFilepath, 'annotations' ); obj.blockAnnotations = obj.blockify( in.afeData, in.annotations ); end end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/FullFileBlockCreator.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/FullFileBlockCreator.m new file mode 100644 index 0000000..36a391a --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/FullFileBlockCreator.m @@ -0,0 +1,55 @@ +classdef FullFileBlockCreator < BlockCreators.Base + % + %% ----------------------------------------------------------------------------------- + properties (SetAccess = private) + end + + %% ----------------------------------------------------------------------------------- + methods + + function obj = FullFileBlockCreator() + obj = obj@BlockCreators.Base( inf, 0 ); + end + %% ------------------------------------------------------------------------------- + + end + + %% ----------------------------------------------------------------------------------- + methods (Access = protected) + + function outputDeps = getBlockCreatorInternOutputDependencies( obj ) + outputDeps.v = 1; + end + %% ------------------------------------------------------------------------------- + + function [blockAnnots,afeBlocks] = blockify( obj, afeData, annotations ) + currentDependencies = obj.getOutputDependencies(); + sceneConfig = currentDependencies.preceding.preceding.sceneCfg; + annotations = BlockCreators.StandardBlockCreator.extendAnnotations( ... + sceneConfig, annotations ); + anyAFEsignal = afeData(1); + if isa( anyAFEsignal, 'cell' ), anyAFEsignal = anyAFEsignal{1}; end; + streamLen_s = double( size( anyAFEsignal.Data, 1 ) ) / anyAFEsignal.FsHz; + if nargout > 1 + afeBlocks = {afeData}; + end + blockAnnots = annotations; + blockAnnots.blockOnset = 0; + blockAnnots.blockOffset = streamLen_s; + end + %% ------------------------------------------------------------------------------- + + end + %% ----------------------------------------------------------------------------------- + + methods (Static) + + %% ------------------------------------------------------------------------------- + %% ------------------------------------------------------------------------------- + + end + +end + + + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/MeanStandardBlockCreator.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/MeanStandardBlockCreator.m index ccf611e..cd2e836 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/MeanStandardBlockCreator.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/MeanStandardBlockCreator.m @@ -20,7 +20,7 @@ function outputDeps = getBlockCreatorInternOutputDependencies( obj ) outputDeps.sbc = getBlockCreatorInternOutputDependencies@... BlockCreators.StandardBlockCreator( obj ); - outputDeps.v = 1; + outputDeps.v = 3; end %% ------------------------------------------------------------------------------- @@ -32,6 +32,10 @@ blockAnnots = blockify@BlockCreators.StandardBlockCreator( ... obj, afeData, annotations ); end + [blockAnnots(:).nrj] = deal(struct('t',[],'nrj',[])); + [blockAnnots(:).nrjOthers] = deal(struct('t',[],'nrjOthers',[])); + [blockAnnots(:).srcSNRactive] = deal(struct('t',[],'srcSNRactive',[])); + [blockAnnots(:).srcSNR2] = deal(struct('t',[],'srcSNR2',[])); aFields = fieldnames( blockAnnots ); isSequenceAnnotation = cellfun( @(af)(... isstruct( blockAnnots(1).(af) ) && ... @@ -40,30 +44,74 @@ ), aFields ); sequenceAfields = aFields(isSequenceAnnotation); for ii = 1 : numel( blockAnnots ) + blockAnnots(ii) = ... + BlockCreators.MeanStandardBlockCreator.adjustPreMeanAnnotations( ... + blockAnnots(ii) ); for jj = 1 : numel( sequenceAfields ) seqAname = sequenceAfields{jj}; annot = blockAnnots(ii).(seqAname); - if length( annot.t ) == size( annot.(seqAname), 1 ) - if iscell( annot.(seqAname) ) + annotSeq = annot.(seqAname); + if length( annot.t ) == size( annotSeq, 1 ) + if iscell( annotSeq ) + as_szs = cellfun( @(c)( size( c, 2 ) ), annotSeq(1,:) ); blockAnnots(ii).(seqAname) = ... - cellSqueezeFun( @mean, annot.(seqAname), 1, true ); + mat2cell( nanmean( cell2mat( annotSeq ), 1 ), 1, as_szs ); else - blockAnnots(ii).(seqAname) = mean( annot.(seqAname), 1 ); + blockAnnots(ii).(seqAname) = nanmean( annotSeq, 1 ); end else error( 'unexpected annotations sequence structure' ); end end + blockAnnots(ii) = ... + BlockCreators.MeanStandardBlockCreator.extendMeanAnnotations( ... + blockAnnots(ii) ); end end - %% ------------------------------------------------------------------------------- + %% ------------------------------------------------------------------------------- end %% ----------------------------------------------------------------------------------- methods (Static) + %% ------------------------------------------------------------------------------- + + % TODO: this is the wrong place for the annotation computation; it + % should be done in SceneEarSignalProc -- and is now here, for the + % moment, to avoid recomputation with SceneEarSignalProc. + + function avgdBlockAnnots = extendMeanAnnotations( avgdBlockAnnots ) + srcsGlobalRefEnergyMeanChannel = cellfun( ... + @(c)(sum(c) ./ 2 ), avgdBlockAnnots.globalSrcEnergy ); + srcsGlobalRefEnergyMeanChannel_db = 10 * log10( srcsGlobalRefEnergyMeanChannel ); + haveSrcsEnergy = srcsGlobalRefEnergyMeanChannel_db > -40; + isAmbientSource = isnan( avgdBlockAnnots.srcAzms ); + haveSrcsEnergy(isAmbientSource) = []; + avgdBlockAnnots.nActivePointSrcs = single( sum( haveSrcsEnergy ) ); + avgdBlockAnnots.srcSNR2 = 10 * log10( avgdBlockAnnots.nrj ./ avgdBlockAnnots.nrjOthers ); + avgdBlockAnnots.nrj = 10 * log10( avgdBlockAnnots.nrj ); + avgdBlockAnnots.nrjOthers = 10 * log10( avgdBlockAnnots.nrjOthers ); + avgdBlockAnnots.globalSrcEnergy = cellfun( @(c)(10 * log10( c )), ... + avgdBlockAnnots.globalSrcEnergy, 'UniformOutput', false ); + end %% ------------------------------------------------------------------------------- + + % TODO: this is the wrong place for the annotation computation; it + % should be done in SceneEarSignalProc -- and is now here, for the + % moment, to avoid recomputation with SceneEarSignalProc. + + function annotations = adjustPreMeanAnnotations( annotations ) + annotations.srcSNRactive.t = annotations.globalSrcEnergy.t; + annotations.srcSNRactive.srcSNRactive = annotations.srcSNR_db.srcSNR_db; + allSrcsInactive = annotations.nActivePointSrcs.nActivePointSrcs == 0; + annotations.srcSNRactive.srcSNRactive(allSrcsInactive,:) = nan; + annotations.srcSNRactive.srcSNRactive = annotations.srcSNRactive.srcSNRactive; + annotations.nrj.t = annotations.globalSrcEnergy.t; + annotations.nrj.nrj = 10.^(annotations.nrj_db.nrj_db./10); + annotations.nrjOthers.t = annotations.globalSrcEnergy.t; + annotations.nrjOthers.nrjOthers = 10.^(annotations.nrjOthers_db.nrjOthers_db./10); + end %% ------------------------------------------------------------------------------- end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/StandardBlockCreator.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/StandardBlockCreator.m index 455b8f3..e9e40d6 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/StandardBlockCreator.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+BlockCreators/StandardBlockCreator.m @@ -18,11 +18,15 @@ methods (Access = protected) function outputDeps = getBlockCreatorInternOutputDependencies( obj ) - outputDeps.v = 2; + outputDeps.v = 4; end %% ------------------------------------------------------------------------------- function [blockAnnots,afeBlocks] = blockify( obj, afeData, annotations ) + currentDependencies = obj.getOutputDependencies(); + sceneConfig = currentDependencies.preceding.preceding.sceneCfg; + annotations = BlockCreators.StandardBlockCreator.extendAnnotations( ... + sceneConfig, annotations ); anyAFEsignal = afeData(1); if isa( anyAFEsignal, 'cell' ), anyAFEsignal = anyAFEsignal{1}; end; streamLen_s = double( size( anyAFEsignal.Data, 1 ) ) / anyAFEsignal.FsHz; @@ -80,13 +84,60 @@ blockAnnots = flipud( blockAnnots ); end %% ------------------------------------------------------------------------------- - + end %% ----------------------------------------------------------------------------------- methods (Static) %% ------------------------------------------------------------------------------- + + % TODO: this is the wrong place for the annotation computation; it + % should be done in SceneEarSignalProc -- and is now here, for the + % moment, to avoid recomputation with SceneEarSignalProc. + + function annotations = extendAnnotations( sceneConfig, annotations ) + annotations.srcSNR_db.t = annotations.globalSrcEnergy.t; + annotations.srcSNR_db.srcSNR_db = zeros( size( annotations.globalSrcEnergy.globalSrcEnergy ) ); + annotations.nrj_db.t = annotations.globalSrcEnergy.t; + annotations.nrj_db.nrj_db = zeros( size( annotations.globalSrcEnergy.globalSrcEnergy ) ); + annotations.nrjOthers_db.t = annotations.globalSrcEnergy.t; + annotations.nrjOthers_db.nrjOthers_db = zeros( size( annotations.globalSrcEnergy.globalSrcEnergy ) ); + annotations.nActivePointSrcs.t = annotations.globalSrcEnergy.t; + annotations.nActivePointSrcs.nActivePointSrcs = zeros( size( annotations.globalSrcEnergy.globalSrcEnergy ) ); + if std( sceneConfig.snrRefs ) ~= 0 + error( 'AMLTTP:usage:snrRefMustBeSame', 'different snrRefs not supported' ); + end + snrRef = sceneConfig.snrRefs(1); + nSrcs = size( annotations.globalSrcEnergy.globalSrcEnergy, 2 ); + srcsGlobalRefEnergyMeanChannel = zeros( ... + size( annotations.globalSrcEnergy.globalSrcEnergy ) ); + for ss = 1 : nSrcs + srcsGlobalRefEnergyMeanChannel(:,ss) = mean( ... + cell2mat( annotations.globalSrcEnergy.globalSrcEnergy(:,ss) ), 2 ); + end + srcsGlobalRefEnergyMeanChannel_db = 10 * log10( srcsGlobalRefEnergyMeanChannel ); + snrRefNrjOffsets = cell2mat( annotations.globalNrjOffsets.globalNrjOffsets ) ... + - annotations.globalNrjOffsets.globalNrjOffsets{snrRef}; + annotations.globalNrjOffsets = snrRefNrjOffsets; + for ss = 1 : nSrcs + otherIdxs = 1 : size( srcsGlobalRefEnergyMeanChannel, 2 ); + otherIdxs(ss) = []; + srcsCurrentSrcRefEnergy_db = srcsGlobalRefEnergyMeanChannel_db ... + - snrRefNrjOffsets(ss); + srcsCurrentSrcRefEnergy = 10.^(srcsCurrentSrcRefEnergy_db./10); + sumOtherSrcsEnergy = sum( srcsCurrentSrcRefEnergy(:,otherIdxs), 2 ); + sumOthersSrcsEnergy_db = 10 * log10( sumOtherSrcsEnergy ); + annotations.nrjOthers_db.nrjOthers_db(:,ss) = single( sumOthersSrcsEnergy_db ); + annotations.nrj_db.nrj_db(:,ss) = single( srcsCurrentSrcRefEnergy_db(:,ss) ); + annotations.srcSNR_db.srcSNR_db(:,ss) = single( ... + srcsCurrentSrcRefEnergy_db(:,ss) - sumOthersSrcsEnergy_db ); + end + haveSrcsEnergy = srcsGlobalRefEnergyMeanChannel_db > -40; + isAmbientSource = all( isnan( annotations.srcAzms.srcAzms ), 1 ); + haveSrcsEnergy(:,isAmbientSource) = []; + annotations.nActivePointSrcs.nActivePointSrcs = single( sum( haveSrcsEnergy, 2 ) ); + end %% ------------------------------------------------------------------------------- end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/DataPipeProc.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/DataPipeProc.m index a87e13d..5ed0a13 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/DataPipeProc.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/DataPipeProc.m @@ -42,7 +42,7 @@ function connectToOutputFrom( obj, outputtingProc ) end %% ---------------------------------------------------------------- - function checkDataFiles( obj, otherOverlay ) + function cacheDirs = checkDataFiles( obj, otherOverlay ) fprintf( '\nChecking file list: %s\n%s\n', ... obj.dataFileProcessor.procName, ... repmat( '=', 1, 20 + numel( obj.dataFileProcessor.procName ) ) ); @@ -64,8 +64,13 @@ function checkDataFiles( obj, otherOverlay ) obj.dataFileProcessor.getSingleProcessCacheAccess(); DataProcs.MultiSceneCfgsIdProcWrapper.doEarlyHasProcessedStop( true, false ); end - fileHasBeenProcessed = ... - obj.dataFileProcessor.hasFileAlreadyBeenProcessed( dataFile.fileName ); + if nargout > 0 && ~exist( 'cacheDirs', 'var' ) + [fileHasBeenProcessed,cacheDirs] = ... + obj.dataFileProcessor.hasFileAlreadyBeenProcessed( dataFile.fileName ); + else + fileHasBeenProcessed = ... + obj.dataFileProcessor.hasFileAlreadyBeenProcessed( dataFile.fileName ); + end if ii == 1 DataProcs.MultiSceneCfgsIdProcWrapper.doEarlyHasProcessedStop( true, true ); obj.dataFileProcessor.saveCacheDirectory(); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdCacheDirectory.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdCacheDirectory.m index abac0f3..6929315 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdCacheDirectory.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdCacheDirectory.m @@ -148,7 +148,7 @@ function loadCacheDirectory( obj, filename ) end %% ------------------------------------------------------------------------------- - function maintenance( obj ) + function maintenance( obj, deleteEmpties ) cDirs = dir( [obj.topCacheDirectory filesep 'cache.*'] ); cacheDirs = cell( 0, 3 ); fprintf( '-> read cache folders\n' ); @@ -159,7 +159,7 @@ function maintenance( obj ) pause; else cdContents = dir( [obj.topCacheDirectory filesep cDirs(ii).name filesep '*.mat'] ); - if all( strcmpi( 'cfg.mat', {cdContents.name} ) | strcmpi( 'fdesc.mat', {cdContents.name} ) ) + if deleteEmpties && all( strcmpi( 'cfg.mat', {cdContents.name} ) | strcmpi( 'fdesc.mat', {cdContents.name} ) ) rmdir( [obj.topCacheDirectory filesep cDirs(ii).name], 's' ); fprintf( 'deleting empty cache folder ' ); else @@ -310,11 +310,12 @@ function maintenance( obj ) end %% ------------------------------------------------------------------------------- - function standaloneMaintain( cacheTopDir ) + function standaloneMaintain( cacheTopDir, deleteEmpties ) + if nargin < 2, deleteEmpties = true; end cache = Core.IdCacheDirectory(); cache.setCacheTopDir( cacheTopDir ); cache.loadCacheDirectory(); - cache.maintenance(); + cache.maintenance( deleteEmpties ); end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdProcInterface.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdProcInterface.m index f60560d..8ac9dbb 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdProcInterface.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdProcInterface.m @@ -118,9 +118,17 @@ function init( obj ) end %% ------------------------------------------------------------------------------- - function fileProcessed = hasFileAlreadyBeenProcessed( obj, wavFilepath ) + function [fileProcessed,cacheDir] = hasFileAlreadyBeenProcessed( obj, wavFilepath ) if isempty( wavFilepath ), fileProcessed = false; return; end - fileProcessed = exist( obj.getOutputFilepath( wavFilepath ), 'file' ); + cacheFile = obj.getOutputFilepath( wavFilepath ); + if obj.forceCacheRewrite + fileProcessed = false; + else + fileProcessed = exist( cacheFile, 'file' ); + end + if nargout > 1 + cacheDir = fileparts( cacheFile ); + end end %% ------------------------------------------------------------------------------- @@ -222,6 +230,22 @@ function save( obj, wavFilepath, out ) end + %% -------------------------------------------------------------------- + methods (Static) + + function b = forceCacheRewrite( newValue ) + persistent fcrw; + if isempty( fcrw ) + fcrw = false; + end + if nargin > 0 + fcrw = newValue; + else + b = fcrw; + end + end + + end %% ----------------------------------------------------------------------------------- methods (Abstract) process( obj, wavFilepath ) diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeData.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeData.m index f0f9d7d..41245bc 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeData.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeData.m @@ -43,7 +43,7 @@ if isa( fileSubScript, 'char' ) if all( fileSubScript == ':' ) fIdx = 1 : length( obj.data ); - elseif strcmp( fileSubScript, 'fileLabel' ) + elseif strcmpi( fileSubScript, 'fileLabel' ) dataElemFieldIdxPos = 3; labels = S.subs{1,2}; if iscell( labels ) && cellfun( @iscell, labels ) @@ -86,9 +86,30 @@ end varargout{1:nargout} = ... vertcat( obj.data(fIdx).(dSubScript)(xyIdx,:,:,:,:) ); - elseif any( strcmp( dSubScript, {'fileName','blockAnnotsCacheFile'} ) ) + elseif any( strcmpi( dSubScript, {'fileName','blockAnnotsCacheFile'} ) ) varargout{1:nargout} = { obj.data(fIdx).(dSubScript) }'; - elseif strcmp( dSubScript, 'pointwiseFileIdxs' ) + elseif any( strcmpi( dSubScript, {'blockAnnotations'} ) ) + isEmpty_bas_bb = arrayfun( @(c)(isempty(c.blockAnnotations)), obj.data(fIdx) ); + for bb = find( isEmpty_bas_bb ) + bas_bb = []; + bacfIdxs = obj.data(bb).bacfIdxs; + bacfs = obj.data(bb).blockAnnotsCacheFile; + for mm = 1 : numel( bacfs ) + bIdxs = obj.data(bb).bIdxs(bacfIdxs==mm); + bacf = load( bacfs{mm}, 'blockAnnotations' ); + bas_ = bacf.blockAnnotations(bIdxs); + bas_ = Core.IdentTrainPipeDataElem.addPPtoBas( ... + bas_, obj.data(bb).y(bacfIdxs==mm) ); + if isempty( bas_bb ) + bas_bb = bas_; + else + bas_bb = vertcat( bas_bb, bas_ ); + end + end + obj.data(bb).blockAnnotations = bas_bb; + end + varargout{1:nargout} = vertcat( obj.data(fIdx).blockAnnotations ); + elseif strcmpi( dSubScript, 'pointwiseFileIdxs' ) out = []; for ff = fIdx out = [out; repmat( ff, size( obj.data(ff).x, 1 ), 1 )]; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeDataElem.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeDataElem.m index 493d467..f9963f3 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeDataElem.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentTrainPipeDataElem.m @@ -5,10 +5,12 @@ fileName; x; y; + ysi; % assignment of label to source index bIdxs; bacfIdxs; blockAnnotsCacheFile; fileAnnotations = struct; + blockAnnotations; end %% ----------------------------------------------------------------------------------- @@ -39,7 +41,26 @@ function readFileAnnotations( obj ) obj.fileAnnotations.type = IdEvalFrame.readEventClass( obj.fileName ); end %% ------------------------------------------------------------------------------- + + end + + %% ----------------------------------------------------------------------------------- + methods (Static) + function bas = addPPtoBas( bas, y ) + bons = cat( 1, bas.blockOnset ); + bofs = cat( 1, bas.blockOffset ); + pos_bons = bons(y == +1); + pos_bofs = bofs(y == +1); + ba_pp = zeros( size( bas ) ); + for ii = 1 : sum( y == +1 ) + ba_pp(bons == pos_bons(ii) & bofs == pos_bofs(ii)) = 1; + end + ba_pp = num2cell( ba_pp ); + [bas(:).posPresent] = deal( ba_pp{:} ); + end + %% ------------------------------------------------------------------------------- + end end \ No newline at end of file diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentificationTrainingPipeline.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentificationTrainingPipeline.m index 3bfe343..86bf368 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentificationTrainingPipeline.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+Core/IdentificationTrainingPipeline.m @@ -113,25 +113,51 @@ function splitIntoTrainAndTestSets( obj, trainSetShare ) % nGenAssessFolds: number of folds of generalization assessment through % cross validation (default: 0 - no folds) % - function modelPath = run( obj, varargin ) + function [modelPath, model, testPerfresults] = run( obj, varargin ) ip = inputParser; ip.addOptional( 'nGenAssessFolds', 0 ); ip.addOptional( 'modelPath', ['amlttpRun' buildCurrentTimeString()] ); ip.addOptional( 'modelName', 'amlttp' ); ip.addOptional( 'runOption', [] ); + ip.addOptional( 'startWithProc', 1 ); + ip.addOptional( 'filterPipeInput', [] ); ip.addOptional( 'debug', false ); ip.parse( varargin{:} ); - + cleaner = onCleanup( @() obj.finish() ); modelPath = obj.createFilesDir( ip.Results.modelPath ); + modelFilename = [ip.Results.modelName '.model.mat']; + testPerfresults = []; + model = []; - successiveProcFileFilter = []; - for ii = length( obj.dataPipeProcs ) : -1 : 1 - obj.dataPipeProcs{ii}.checkDataFiles( successiveProcFileFilter ); - successiveProcFileFilter = obj.dataPipeProcs{ii}.fileListOverlay; + successiveProcFileFilter = ip.Results.filterPipeInput; + gcpMode = strcmpi( ip.Results.runOption, 'getCachePathes' ); + rwcMode = strcmpi( ip.Results.runOption, 'rewriteCache' ); + if rwcMode + Core.IdProcInterface.forceCacheRewrite( true ); + else + Core.IdProcInterface.forceCacheRewrite( false ); + end + cacheDirs = cell( numel( obj.dataPipeProcs ), 1 ); + for ii = numel( obj.dataPipeProcs ) : -1 : ip.Results.startWithProc + if ~gcpMode + obj.dataPipeProcs{ii}.checkDataFiles( successiveProcFileFilter ); + else + gcpFileFilter = false( length( obj.data(:) ), 1 ); + gcpFileFilter(1) = true; + cacheDirs{ii} = obj.dataPipeProcs{ii}.checkDataFiles( gcpFileFilter ); + end + if ~gcpMode && ~rwcMode + successiveProcFileFilter = obj.dataPipeProcs{ii}.fileListOverlay; + end + end + if gcpMode + save( modelFilename, ... + 'cacheDirs' ); + return; end errs = {}; - for ii = 1 : length( obj.dataPipeProcs ) + for ii = ip.Results.startWithProc : numel( obj.dataPipeProcs ) if ~ip.Results.debug try obj.dataPipeProcs{ii}.run(); @@ -155,6 +181,7 @@ function splitIntoTrainAndTestSets( obj, trainSetShare ) end if strcmp(ip.Results.runOption, 'onlyGenCache'), return; end; + if rwcMode, return; end; featureCreator = obj.featureCreator; lastDataProcParams = ... @@ -174,6 +201,12 @@ function splitIntoTrainAndTestSets( obj, trainSetShare ) save( 'dataStoreUni.mat', ... 'x', 'y', 'featureNames', '-v7.3' ); return; + elseif strcmp( ip.Results.runOption, 'dataStoreGT' ) + bIdxs = obj.data(:,'bIdxs'); + y = obj.data(:,'y'); + save( 'dataStoreGT.mat', ... + 'bIdxs', 'y', '-v7.3' ); + return; end; fprintf( ['\n\n===================================\n',... @@ -194,11 +227,10 @@ function splitIntoTrainAndTestSets( obj, trainSetShare ) obj.trainer.run(); trainTime = toc; testTime = nan; - testPerfresults = []; if ~isempty( obj.testSet ) fprintf( '\n== Testing model on testSet... \n\n' ); tic; - testPerfresults = obj.trainer.getPerformance( 'datapointInfo' ); + testPerfresults = obj.trainer.getPerformance( true ); testTime = toc; if numel( testPerfresults ) == 1 fprintf( ['\n\n===================================\n',... @@ -213,7 +245,6 @@ function splitIntoTrainAndTestSets( obj, trainSetShare ) end end model = obj.trainer.getModel(); - modelFilename = [ip.Results.modelName '.model.mat']; save( modelFilename, ... 'model', 'featureCreator', 'blockCreator', ... 'testPerfresults', 'trainTime', 'testTime', 'lastDataProcParams' ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/BlackboardKsWrapper_AnnotationWriter.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/BlackboardKsWrapper_AnnotationWriter.m new file mode 100644 index 0000000..b66c401 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/BlackboardKsWrapper_AnnotationWriter.m @@ -0,0 +1,89 @@ +classdef BlackboardKsWrapper_AnnotationWriter < Core.IdProcInterface + % Abstract base class for wrapping KS into an emulated blackboard + %% ----------------------------------------------------------------------------------- + properties (SetAccess = protected) + kss; + bbs; + afeDataIndexOffset; + out; + end + + %% ----------------------------------------------------------------------------------- + methods (Abstract) + postproc( obj, afeData, blockAnnotations ) + outputDeps = getKsInternOutputDependencies( obj ) + end + + %% ----------------------------------------------------------------------------------- + methods + + function obj = BlackboardKsWrapper_AnnotationWriter( kss ) + obj = obj@Core.IdProcInterface(); + if ~iscell( kss ) + obj.kss = {kss}; + else + obj.kss = kss; + end + obj.bbs = BlackboardSystem( false ); + for ii = 1 : numel( kss ) + obj.kss{ii}.setBlackboardAccess( obj.bbs.blackboard, obj.bbs ); + end + end + %% ------------------------------------------------------------------------------- + + function [afeRequests, ksReqHashes] = getAfeRequests( obj ) + afeRequests = []; + ksReqHashes = []; + for ii = 1 : numel( obj.kss ) + afeRequests = [afeRequests obj.kss{ii}.requests]; + ksReqHashes = [ksReqHashes obj.kss{ii}.reqHashs]; + end + end + %% ------------------------------------------------------------------------------- + + function obj = setAfeDataIndexOffset( obj, afeDataIndexOffset ) + obj.afeDataIndexOffset = afeDataIndexOffset; + end + %% ------------------------------------------------------------------------------- + + function process( obj, wavFilepath ) + warning( 'off', 'BB:tNotIncreasing' ); + obj.inputProc.sceneId = obj.sceneId; + inData = obj.loadInputData( wavFilepath, 'blockAnnotations' ); + selfData = obj.loadProcessedData( wavFilepath, 'afeBlocks', 'blockAnnotations' ); + obj.out = struct( 'afeBlocks', {selfData.afeBlocks}, ... + 'blockAnnotations', {selfData.blockAnnotations} ); + obj.postproc( inData.blockAnnotations ); + end + %% ------------------------------------------------------------------------------- + + end + + %% ----------------------------------------------------------------------------------- + methods (Access = protected) + + function outputDeps = getInternOutputDependencies( obj ) + outputDeps.v = 1; + outputDeps.ksProc = obj.getKsInternOutputDependencies(); + end + %% ------------------------------------------------------------------------------- + + function out = getOutput( obj, varargin ) + out.afeBlocks = obj.out.afeBlocks; + out.blockAnnotations = obj.out.blockAnnotations; + end + %% ------------------------------------------------------------------------------- + + end + %% ----------------------------------------------------------------------------------- + + methods (Static) + + %% ------------------------------------------------------------------------------- + + end + +end + + + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/DnnLocKsWrapper.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/DnnLocKsWrapper.m index d62f0ec..402c84c 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/DnnLocKsWrapper.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/DnnLocKsWrapper.m @@ -84,9 +84,9 @@ function postproc( obj, afeData, blockAnnotations ) function afeData = addLocDecisionData( afeData, locDecisionData ) % assumes location data has been refined by LocalisationDecisionKS locFakeAFEsignal = struct(); - locFakeAFEsignal.Data = locDecisionData.sourcesPosteriors(:)'; + locFakeAFEsignal.Data = locDecisionData.sourcesDistribution(:)'; locFakeAFEsignal.Name = 'DnnLocationDistribution'; - locFakeAFEsignal.azms = locDecisionData.sourceAzimuths(:)'; + locFakeAFEsignal.azms = locDecisionData.azimuths(:)'; afeData(afeData.Count+1) = locFakeAFEsignal; end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/GatherFeaturesProc.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/GatherFeaturesProc.m index cc7a569..a372892 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/GatherFeaturesProc.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/GatherFeaturesProc.m @@ -3,6 +3,10 @@ %% ----------------------------------------------------------------------------------- properties (SetAccess = private, Transient) sceneCfgDataUseRatio = 1; + sceneCfgPrioDataUseRatio = 1; + dataSelector; + selectPrioClass = []; + loadBlockAnnotations = false; prioClass = []; end @@ -13,39 +17,62 @@ %% ----------------------------------------------------------------------------------- methods (Access = public) - function obj = GatherFeaturesProc() + function obj = GatherFeaturesProc( loadBlockAnnotations ) obj = obj@Core.IdProcInterface(); + if nargin >= 1 + obj.loadBlockAnnotations = loadBlockAnnotations; + end end %% ------------------------------------------------------------------------------- - function setSceneCfgDataUseRatio( obj, sceneCfgDataUseRatio, prioClass ) + function setSceneCfgDataUseRatio( obj, sceneCfgDataUseRatio, dataSelector, ... + sceneCfgPrioDataUseRatio, selectPrioClass ) obj.sceneCfgDataUseRatio = sceneCfgDataUseRatio; - if nargin < 3, prioClass = []; end - obj.prioClass = prioClass; + if nargin < 3, dataSelector = DataSelectors.IgnorantSelector(); end + if nargin < 4, sceneCfgPrioDataUseRatio = 1; end + if nargin < 5, selectPrioClass = []; end + obj.dataSelector = dataSelector; + obj.sceneCfgPrioDataUseRatio = sceneCfgPrioDataUseRatio; + obj.selectPrioClass = selectPrioClass; end %% ------------------------------------------------------------------------------- function process( obj, wavFilepath ) obj.inputProc.sceneId = obj.sceneId; - xy = obj.loadInputData( wavFilepath, 'x', 'y' ); + if obj.loadBlockAnnotations + xy = obj.loadInputData( wavFilepath, 'x', 'y', 'ysi', 'a' ); + xy.blockAnnotations = Core.IdentTrainPipeDataElem.addPPtoBas( xy.a, xy.y ); + xy = rmfield( xy, 'a' ); + sceneCfgDeps = obj.inputProc.getOutputDependencies(); + while ~(isstruct( sceneCfgDeps ) && isfield( sceneCfgDeps, 'sceneCfg' ) ) + sceneCfgDeps = sceneCfgDeps.preceding; + end + npssc = numel( sceneCfgDeps.sceneCfg.sources ); + [xy.blockAnnotations(:).nPointSrcsSceneConfig] = deal( npssc ); + else + xy = obj.loadInputData( wavFilepath, 'x', 'y', 'ysi' ); + end obj.inputProc.inputProc.sceneId = obj.sceneId; inDataFilepath = obj.inputProc.inputProc.getOutputFilepath( wavFilepath ); dataFile = obj.idData(wavFilepath); fprintf( '.' ); - if obj.sceneCfgDataUseRatio < 1 && ... - ~strcmp( obj.prioClass, dataFile.getFileAnnotation( 'type' ) ) - nUsePoints = round( size( xy.x, 1 ) * obj.sceneCfgDataUseRatio ); - useIdxs = randperm( size( xy.x, 1 ) ); - useIdxs(nUsePoints+1:end) = []; + if ~isempty( obj.selectPrioClass ) && any( xy.y == obj.selectPrioClass ) + nUsePoints = round( size( xy.x, 1 ) * obj.sceneCfgPrioDataUseRatio ); else - useIdxs = 1 : size( xy.x, 1 ); + nUsePoints = round( size( xy.x, 1 ) * obj.sceneCfgDataUseRatio ); end + obj.dataSelector.connectData( xy ); + useIdxs = obj.dataSelector.getDataSelection( 1:size( xy.x, 1 ), nUsePoints ); dataFile.x = [dataFile.x; xy.x(useIdxs,:)]; dataFile.y = [dataFile.y; xy.y(useIdxs,:)]; + dataFile.ysi = [dataFile.ysi; xy.ysi(useIdxs)']; dataFile.bIdxs = [dataFile.bIdxs; xy.bIdxs(useIdxs)']; dataFile.bacfIdxs = [dataFile.bacfIdxs; ... - repmat( numel(dataFile.blockAnnotsCacheFile ) + 1, numel(useIdxs), 1 )]; + repmat( numel(dataFile.blockAnnotsCacheFile ) + 1, sum(useIdxs), 1 )]; dataFile.blockAnnotsCacheFile = [dataFile.blockAnnotsCacheFile; {inDataFilepath}]; + if obj.loadBlockAnnotations + dataFile.blockAnnotations = [dataFile.blockAnnotations; xy.blockAnnotations(useIdxs)]; + end fprintf( '.' ); end %% ------------------------------------------------------------------------------- @@ -64,8 +91,9 @@ function process( obj, wavFilepath ) %% ------------------------------------------------------------------------------- % override of Core.IdProcInterface's method - function fileProcessed = hasFileAlreadyBeenProcessed( ~, ~ ) + function [fileProcessed,cacheDir] = hasFileAlreadyBeenProcessed( ~, ~ ) fileProcessed = false; + cacheDir = []; end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/IdSimConvRoomWrapper.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/IdSimConvRoomWrapper.m index 676e213..15c58fe 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/IdSimConvRoomWrapper.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/IdSimConvRoomWrapper.m @@ -9,7 +9,6 @@ earSout; annotsOut; srcAzimuth; - brirSrcPos; outFs; end @@ -195,9 +194,6 @@ function setupSceneConfig( obj, sceneConfig ) obj.IRDataset.dir = ... simulator.DirectionalIR( sceneConfig.sources(1).brirFName ); warning( 'on', 'all' ); - obj.brirSrcPos = SOFAconvertCoordinates( ... - brirSofa.EmitterPosition(1,:) - brirSofa.ListenerPosition, ... - 'cartesian','spherical' ); else warning( 'off', 'all' ); % avoid messy "SOFA experimental" warning obj.IRDataset.dir = simulator.DirectionalIR( ... @@ -205,16 +201,13 @@ function setupSceneConfig( obj, sceneConfig ) sceneConfig.sources(1).speakerId ); warning( 'on', 'all' ); obj.IRDataset.speakerId = sceneConfig.sources(1).speakerId; - obj.brirSrcPos = SOFAconvertCoordinates( ... - brirSofa.EmitterPosition(sceneConfig.sources(1).speakerId,:) ... - - brirSofa.ListenerPosition, 'cartesian','spherical' ); end obj.IRDataset.isbrir = true; obj.IRDataset.fname = sceneConfig.sources(1).brirFName; end obj.convRoomSim.Sources{1}.IRDataset = obj.IRDataset.dir; obj.convRoomSim.rotateHead( headOrientation(1), 'absolute' ); - obj.srcAzimuth = obj.brirSrcPos(1) - headOrientation(1); + obj.srcAzimuth = sceneConfig.sources(1).azimuth; else % ~is diffuse obj.convRoomSim.Sources{1} = simulator.source.Binaural(); channelMapping = [1 2]; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/MultiSceneCfgsIdProcWrapper.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/MultiSceneCfgsIdProcWrapper.m index 911e77f..5848dbd 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/MultiSceneCfgsIdProcWrapper.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/MultiSceneCfgsIdProcWrapper.m @@ -44,12 +44,19 @@ function setSceneConfig( obj, multiSceneCfgs ) %% ---------------------------------------------------------------- % override of Core.IdProcInterface's method - function fileProcessed = hasFileAlreadyBeenProcessed( obj, wavFilepath ) + function [fileProcessed,cacheDirs] = hasFileAlreadyBeenProcessed( obj, wavFilepath ) fileProcessed = true; + if nargout > 1 + cacheDirs = cell( numel( obj.sceneConfigurations ), 1 ); + end for ii = 1 : numel( obj.sceneConfigurations ) obj.sceneProc.setSceneConfig( obj.sceneConfigurations(ii) ); obj.wrappedProcs{1}.sceneId = ii; - processed = obj.wrappedProcs{1}.hasFileAlreadyBeenProcessed( wavFilepath ); + if nargout > 1 + [processed,cacheDirs{ii}] = obj.wrappedProcs{1}.hasFileAlreadyBeenProcessed( wavFilepath ); + else + processed = obj.wrappedProcs{1}.hasFileAlreadyBeenProcessed( wavFilepath ); + end fileProcessed = fileProcessed && processed; % not stopping early because hasFileAlreadyBeenProcessed triggers cache % directory creation diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/ParallelRequestsAFEmodule.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/ParallelRequestsAFEmodule.m index 49e608b..44b2718 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/ParallelRequestsAFEmodule.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/ParallelRequestsAFEmodule.m @@ -85,8 +85,12 @@ function process( obj, wavFilepath ) [tmpOut, outFilepath] = ... loadProcessedData@Core.IdProcInterface( obj, wavFilepath, 'indivFiles' ); obj.indivFiles = tmpOut.indivFiles; + if nargin == 3 && strcmpi( varargin{1}, 'indivFiles' ) + out = tmpOut; + return; + end try - out = obj.getOutput; + out = obj.getOutput( varargin{:} ); catch err if strcmp( 'AMLTTP:dataprocs:cacheFileCorrupt', err.identifier ) error( 'AMLTTP:dataprocs:cacheFileCorrupt', ... @@ -97,6 +101,17 @@ function process( obj, wavFilepath ) end end %% ------------------------------------------------------------------------------- + + % override of Core.IdProcInterface's method + function [fileProcessed,cacheDirs] = hasFileAlreadyBeenProcessed( obj, wavFilepath ) + fileProcessed = ... + hasFileAlreadyBeenProcessed@Core.IdProcInterface( obj, wavFilepath ); + if nargout > 1 + obj.loadProcessedData( wavFilepath, 'indivFiles' ); + cacheDirs = cellfun( @fileparts, obj.indivFiles, 'UniformOutput', false ); + end + end + %% ------------------------------------------------------------------------------- % override of DataProcs.IdProcInterface's method function save( obj, wavFilepath, ~ ) @@ -130,7 +145,12 @@ function save( obj, wavFilepath, ~ ) % override of Core.IdProcInterface's method function out = getOutput( obj, varargin ) out.afeData = containers.Map( 'KeyType', 'int32', 'ValueType', 'any' ); - [~,ia,ic] = unique( obj.indivFiles ); + if nargin < 2 || any( strcmpi( 'afeData', varargin ) ) + [~,ia,ic] = unique( obj.indivFiles ); + else + ia = 1; + ic = 1; + end for ii = ia' if ~exist( obj.indivFiles{ii}, 'file' ) error( 'AMLTTP:dataprocs:cacheFileCorrupt', '%s not found.', obj.indivFiles{ii} ); @@ -138,9 +158,11 @@ function save( obj, wavFilepath, ~ ) tmp = load( obj.indivFiles{ii}, 'afeData', 'annotations' ); out.afeData(ii) = tmp.afeData(1); end - for ii = 1 : numel( obj.indivFiles ) - if any( ii == ia ), continue; end - out.afeData(ii) = out.afeData(ia(ic(ii))); + if nargin < 2 || any( strcmpi( 'afeData', varargin ) ) + for ii = 1 : numel( obj.indivFiles ) + if any( ii == ia ), continue; end + out.afeData(ii) = out.afeData(ia(ic(ii))); + end end out.annotations = tmp.annotations; % if individual AFE modules produced % individual annotations, they would have diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SceneEarSignalProc.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SceneEarSignalProc.m index bf0d1c6..5697da2 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SceneEarSignalProc.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SceneEarSignalProc.m @@ -166,26 +166,22 @@ function process( obj, pipeWavFilepath ) splitAzms{ii}, obj.annotsOut.srcAzms.t, 'next', 'extrap' ) ); end - obj.annotsOut.srcEnergy = struct( 't', {[]} ); - for ss = 1 : numSrcs - [energy1, tEnergy] = DataProcs.SceneEarSignalProc.runningEnergy( ... - obj.getDataFs(), ... - double(splitEarSignals{ss}(:,1)), ... - 20e-3, 10e-3 ); - [energy2, ~] = DataProcs.SceneEarSignalProc.runningEnergy( ... - obj.getDataFs(), ... - double(splitEarSignals{ss}(:,2)), ... - 20e-3, 10e-3 ); - if numel( tEnergy ) > numel( obj.annotsOut.srcEnergy.t ) - obj.annotsOut.srcEnergy.t = single( tEnergy ); - end - obj.annotsOut.srcEnergy.srcEnergy(:,ss) = ... - arrayfun( @(e1,e2)( {single( [e1,e2] )} ), energy1', energy2' ); - end - +% obj.annotsOut.srcEnergy = struct( 't', {[]} ); +% obj.annotsOut.srcEnergy_db = struct( 't', {[]} ); +% for ss = 1 : numSrcs +% obj.annotsOut.srcEnergy = obj.annotateNrj( splitEarSignals{ss}, ... +% obj.annotsOut.srcEnergy, ss,... +% 'srcEnergy', false ); +% obj.annotsOut.srcEnergy_db = obj.annotateNrj( splitEarSignals{ss}, ... +% obj.annotsOut.srcEnergy_db, ss,... +% 'srcEnergy_db' ); +% end +% obj.earSout = zeros( mixLen, 2 ); + obj.annotsOut.globalSrcEnergy = struct( 't', {[]} ); + obj.annotsOut.globalSrcEnergy_db = struct( 't', {[]} ); + q = {[]}; for srcIdx = 1 : numel( splitEarSignals ) - srcNsignal = splitEarSignals{srcIdx}; srcSidx = obj.sceneConfig.snrRefs(srcIdx); if srcSidx == srcIdx srcNsignal = splitEarSignals{srcSidx}; @@ -195,14 +191,23 @@ function process( obj, pipeWavFilepath ) obj.getDataFs(), ... srcSsignal, ... 'energy', ... - srcNsignal, ... + splitEarSignals{srcIdx}, ... obj.sceneConfig.SNRs(srcIdx).value ); end maxSignalsLen = min( mixLen, length( srcNsignal ) ); obj.earSout(1:maxSignalsLen,:) = ... obj.earSout(1:maxSignalsLen,:) + srcNsignal(1:maxSignalsLen,:); + [obj.annotsOut.globalSrcEnergy,q{srcIdx}] = obj.annotateNrj( ... + srcNsignal(1:maxSignalsLen,:), ... + obj.annotsOut.globalSrcEnergy, ... + srcIdx,'globalSrcEnergy', false, q{1} ); + obj.annotsOut.globalSrcEnergy_db = obj.annotateNrj( ... + srcNsignal(1:maxSignalsLen,:), ... + obj.annotsOut.globalSrcEnergy_db, ... + srcIdx,'globalSrcEnergy_db', true, q{1} ); fprintf( '.' ); end + obj.annotsOut.globalNrjOffsets.globalNrjOffsets = q; if obj.sceneConfig.normalize earSoutRMS = max( rms( obj.earSout ) ); obj.earSout = obj.earSout * obj.sceneConfig.normalizeLevel / earSoutRMS; @@ -220,7 +225,33 @@ function process( obj, pipeWavFilepath ) obj.annotsOut.mixEnergy.t = single( tEnergy ); obj.annotsOut.mixEnergy.mixEnergy = single( [energy1',energy2'] ); end + %% ------------------------------------------------------------------------------- + function [nrjAnnots,qself] = annotateNrj( obj, signal, nrjAnnots, signalId, annotsName, returnDb, q ) + [energy1,tEnergy,q1] = DataProcs.SceneEarSignalProc.runningEnergy( ... + obj.getDataFs(), ... + double(signal(:,1)), ... + 20e-3, 10e-3 ); + [energy2,~,q2] = DataProcs.SceneEarSignalProc.runningEnergy( ... + obj.getDataFs(), ... + double(signal(:,2)), ... + 20e-3, 10e-3 ); + qself = 0.5*q1+0.5*q2; + if nargin < 7 || isempty( q ), q = qself; end + energy1 = energy1 + (q1 - q); + energy2 = energy2 + (q2 - q); + if nargin >= 6 && ~returnDb + energy1 = 10.^(energy1./10); + energy2 = 10.^(energy2./10); + end + if numel( tEnergy ) > numel( nrjAnnots.t ) + nrjAnnots.t = single( tEnergy ); + end + nrjAnnots.(annotsName)(:,signalId) = ... + arrayfun( @(e1,e2)( {single( [e1,e2] )} ), energy1', energy2' ); + end + %% ------------------------------------------------------------------------------- + end %% -------------------------------------------------------------------- @@ -300,7 +331,7 @@ function process( obj, pipeWavFilepath ) end %% ---------------------------------------------------------------- - function [energy, tFramesSec] = runningEnergy( fs, signal, blockSec, stepSec ) + function [energy, tFramesSec, q] = runningEnergy( fs, signal, blockSec, stepSec ) blockSize = 2 * round(fs * blockSec / 2); stepSize = round(fs * stepSec); frames = frameData(signal,blockSize,stepSize,'rectwin'); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SegmentKsWrapper.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SegmentKsWrapper.m index b7f9137..1e0a060 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SegmentKsWrapper.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SegmentKsWrapper.m @@ -2,8 +2,9 @@ % Wrapping the SegmentationKS %% ----------------------------------------------------------------------------------- properties (SetAccess = public) - varAzmPrior; - currentVarAzms; + varAzmSigma; + azmsGroundTruth; + segSrcAssignmentMethod; dnnHash; nfHash; useDnnLocKs = false; @@ -12,6 +13,12 @@ dnnLocKs; nsrcsKs; idKss; + energeticBaidxs; + nsrcsBias; + nsrcsRndPlusMinusBias; + isNsrcsFixed; + isAzmFixedUniform; + softMaskExponent = 10; end %% ----------------------------------------------------------------------------------- @@ -27,6 +34,11 @@ ip.addOptional( 'useDnnLocKs', false ); ip.addOptional( 'useNsrcsKs', false ); ip.addOptional( 'nsrcsParams', {} ); + ip.addOptional( 'segSrcAssignmentMethod', 'minDistance' ); + ip.addOptional( 'varAzmSigma', 0 ); + ip.addOptional( 'nsrcsBias', 0 ); + ip.addOptional( 'nsrcsRndPlusMinusBias', 0 ); + ip.addOptional( 'softMaskExponent', 10 ); ip.parse( varargin{:} ); segmentKs = StreamSegregationKS( paramFilepath ); fprintf( '.' ); @@ -72,45 +84,114 @@ end wrappedKss{end+1} = segmentKs; obj = obj@DataProcs.BlackboardKsWrapper( wrappedKss ); - obj.varAzmPrior = 0; + obj.varAzmSigma = ip.Results.varAzmSigma; + obj.azmsGroundTruth = []; + obj.segSrcAssignmentMethod = ip.Results.segSrcAssignmentMethod; obj.dnnHash = dnnHash; obj.nfHash = nfHash; obj.useDnnLocKs = ip.Results.useDnnLocKs; obj.useNsrcsKs = ip.Results.useNsrcsKs; + if obj.useNsrcsKs && ~obj.useDnnLocKs + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['nSrcs model employment only supported if also using ' ... + 'location model.'] ); + end obj.segmentKs = segmentKs; obj.dnnLocKs = dnnLocKs; obj.idKss = idKss; obj.nsrcsKs = nsrcsKs; + obj.energeticBaidxs = []; + obj.isNsrcsFixed = false; + obj.isAzmFixedUniform = false; + obj.nsrcsBias = ip.Results.nsrcsBias; + if obj.useNsrcsKs && (obj.nsrcsBias ~= 0) + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['nSrcs bias only supported if using ' ... + 'nSrcs ground truth.'] ); + end + if ischar( obj.nsrcsBias ) + if strfind( obj.nsrcsBias, 'fixed' ) == 1 + obj.isNsrcsFixed = true; + obj.nsrcsBias = str2double( obj.nsrcsBias(6:end) ); + else + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['unrecognized nSrcs bias flag.'] ); + end + end + obj.nsrcsRndPlusMinusBias = ip.Results.nsrcsRndPlusMinusBias; + if obj.useNsrcsKs && (obj.nsrcsRndPlusMinusBias ~= 0) + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['nSrcs random bias only supported if using ' ... + 'nSrcs ground truth.'] ); + end + if ischar( obj.varAzmSigma ) + if strfind( obj.varAzmSigma, 'fixedUniform' ) == 1 + obj.isAzmFixedUniform = true; + else + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['unrecognized azm bias flag.'] ); + end + end + obj.softMaskExponent = ip.Results.softMaskExponent; fprintf( '.\n' ); end %% ------------------------------------------------------------------------------- function procBlock = preproc( obj, blockAnnotations ) procBlock = true; - absAzms = blockAnnotations.srcAzms; - if isstruct( absAzms ) || size( absAzms, 1 ) > 1 + obj.azmsGroundTruth = blockAnnotations.srcAzms; + if isstruct( obj.azmsGroundTruth ) || size( obj.azmsGroundTruth, 1 ) > 1 error( 'AMLTTP:procBinding:singleValueBlockAnnotationsNeeded', ... 'SegmentKsWrapper can only handle one azm value per source per block.' ); end - absAzms(isnan(absAzms)) = []; - if isempty( absAzms ) - procBlock = false; - return; - end - if ~obj.useDnnLocKs - azmVar = obj.varAzmPrior * (2*rand( size( absAzms ) ) - 1); - obj.currentVarAzms = wrapTo180( absAzms + azmVar ); - obj.segmentKs.setFixedAzimuths( obj.currentVarAzms ); + srcsGlobalRefEnergyMeanChannel = cellfun( ... + @(c)(sum(10.^(c./10)) ./ 2 ), blockAnnotations.globalSrcEnergy ); + srcsGlobalRefEnergyMeanChannel_db = 10 * log10( srcsGlobalRefEnergyMeanChannel ); + srcsHaveEnergy = srcsGlobalRefEnergyMeanChannel_db > -40; + obj.energeticBaidxs = 1 : numel( blockAnnotations.globalSrcEnergy ); + obj.energeticBaidxs(isnan(obj.azmsGroundTruth)) = []; + srcsHaveEnergy(isnan(obj.azmsGroundTruth)) = []; + obj.azmsGroundTruth(isnan(obj.azmsGroundTruth)) = []; + if any( srcsHaveEnergy ) + obj.azmsGroundTruth(~srcsHaveEnergy) = []; + obj.energeticBaidxs(~srcsHaveEnergy) = []; else - obj.currentVarAzms = wrapTo180( absAzms ); - obj.segmentKs.setFixedAzimuths( [] ); - warning( 'off', 'BBS:badBlockTimeRequest' ); + rndIdx = randi( numel( obj.azmsGroundTruth ) ); + obj.azmsGroundTruth = obj.azmsGroundTruth(rndIdx); + obj.energeticBaidxs = obj.energeticBaidxs(rndIdx); end if ~obj.useNsrcsKs - obj.segmentKs.setFixedNoSrcs( numel( absAzms ) ); + rndNbias = randi( obj.nsrcsRndPlusMinusBias*2 + 1 ) ... + - obj.nsrcsRndPlusMinusBias - 1; + if obj.isNsrcsFixed + setNsrcs = max( 1, obj.nsrcsBias + rndNbias ); + else + setNsrcs = max( 1, sum( srcsHaveEnergy ) + obj.nsrcsBias + rndNbias ); + end + obj.segmentKs.setFixedNoSrcs( setNsrcs ); else obj.segmentKs.setFixedNoSrcs( [] ); end + if ~obj.useDnnLocKs + if obj.isAzmFixedUniform + azmStep = round( 360 / setNsrcs ); + currentVarAzms = round( azmStep/2 ) : azmStep : 360; + else + azmVar = obj.varAzmSigma * randn( size( obj.azmsGroundTruth ) ); + currentVarAzms = wrapTo180( obj.azmsGroundTruth + azmVar ); + setNsrcsDiff = setNsrcs - numel( currentVarAzms ); + if setNsrcsDiff > 0 + currentVarAzms = [currentVarAzms 360*rand( 1, setNsrcsDiff )]; + elseif setNsrcsDiff < 0 + rndidxs = randperm( numel( currentVarAzms ) ); + currentVarAzms(rndidxs(1:abs(setNsrcsDiff))) = []; + end + end + obj.segmentKs.setFixedAzimuths( wrapTo180( currentVarAzms ) ); + else + obj.segmentKs.setFixedAzimuths( [] ); + warning( 'off', 'BBS:badBlockTimeRequest' ); + end obj.segmentKs.setBlocksize( blockAnnotations.blockOffset ... - blockAnnotations.blockOnset ); end @@ -118,26 +199,44 @@ function postproc( obj, afeData, blockAnnotations ) segHypos = obj.bbs.blackboard.getLastData( 'segmentationHypotheses' ); - nMasks = numel( segHypos.data ); - nTrue = numel( obj.currentVarAzms ); - hypCurAzmDists = zeros( nMasks, nTrue ); - for ii = 1 : nMasks - hypAzm = wrapTo180( segHypos.data(ii).refAzm ); - hypCurAzmDists(ii,:) = ... - abs( wrapTo180( obj.currentVarAzms - hypAzm ) ); - end - [~,estObjMinAzmDistIdx] = min( hypCurAzmDists, [], 1 ); -% [~,trueObjMinAzmDistIdx] = min( hypCurAzmDists, [], 2 ); -% trueObjMinAzmDistIdx = num2cell( trueObjMinAzmDistIdx ); -% for ii = 1 : numel( estObjMinAzmDistIdx ) -% trueObjMinAzmDistIdx{estObjMinAzmDistIdx(ii)}(end+1) = ii; -% end - for ii = 1 : nMasks - obj.out.afeBlocks{end+1,1} = obj.softmaskAFE( afeData, segHypos, ii ); -% baIdxs = unique( trueObjMinAzmDistIdx{ii} ); - baIdxs = find( estObjMinAzmDistIdx == ii ); - maskedBlockAnnotations = obj.maskBA( blockAnnotations, baIdxs ); - maskedBlockAnnotations.estAzm = segHypos.data(ii).refAzm; + nSegments = numel( segHypos.data ); + nTrue = numel( obj.azmsGroundTruth ); + hypAzms = repmat( wrapTo180( [segHypos.data.refAzm]' ), 1, nTrue ); + gtAzms = repmat( wrapTo180( obj.azmsGroundTruth ), nSegments, 1 ); + hypAzmGtDists = abs( wrapTo180( gtAzms - hypAzms ) ); + switch obj.segSrcAssignmentMethod + case 'minPermutedDistance' + segIdxs = []; + while numel( segIdxs ) < nTrue + segIdxs = [segIdxs 1:nSegments]; %#ok + end + distCombinations = nchoosek( segIdxs, nTrue ); + distPermutations = []; + for ii = 1 : size( distCombinations, 1 ) + distPermutations = [distPermutations; ... + perms( distCombinations(ii,:) )]; %#ok + end + distPermutations = unique( distPermutations, 'rows' ); + distances = zeros( size( distPermutations ) ); + for tt = 1 : size( distPermutations, 2 ) + distances(:,tt) = hypAzmGtDists(distPermutations(:,tt),tt); + end + permutedDistances = sum( distances, 2 ); + [~,minPermutedDistanceIdx] = min( permutedDistances ); + segSrcAssignment = distPermutations(minPermutedDistanceIdx,:); + case 'minDistance' + [~,segSrcAssignment] = min( hypAzmGtDists, [], 1 ); + end + for ss = 1 : nSegments + segData = segHypos.data(ss); + softmask = (segData.softMask) .^ obj.softMaskExponent; + obj.out.afeBlocks{end+1,1} = SegmentIdentityKS.maskAFEData( ... + afeData, softmask, segData.cfHz, segData.hopSize ); + srcIdxs = find( segSrcAssignment == ss ); + srcIdxs = obj.energeticBaidxs(srcIdxs); %#ok + maskedBlockAnnotations = obj.maskBA( blockAnnotations, srcIdxs ); + maskedBlockAnnotations.estAzm = segData.refAzm; + maskedBlockAnnotations.nSrcs_estimationError = nSegments - nTrue; if isempty(obj.out.blockAnnotations) obj.out.blockAnnotations = maskedBlockAnnotations; else @@ -149,13 +248,17 @@ function postproc( obj, afeData, blockAnnotations ) %% ------------------------------------------------------------------------------- function outputDeps = getKsInternOutputDependencies( obj ) - outputDeps.v = 7; + outputDeps.v = 16; outputDeps.useDnnLocKs = obj.useDnnLocKs; outputDeps.useNsrcsKs = obj.useNsrcsKs; outputDeps.useIdModels = ~isempty( obj.idKss ); outputDeps.params = obj.kss{end}.observationModel.trainingParameters; [~,outputDeps.afeHashs] = obj.getAfeRequests(); - outputDeps.varAzmPrior = obj.varAzmPrior; + outputDeps.varAzmSigma = obj.varAzmSigma; + outputDeps.segSrcAssignmentMethod = obj.segSrcAssignmentMethod; + outputDeps.nsrcsBias = obj.nsrcsBias; + outputDeps.nsrcsRndPlusMinusBias = obj.nsrcsRndPlusMinusBias; + outputDeps.softMaskExponent = obj.softMaskExponent; end %% ------------------------------------------------------------------------------- @@ -166,15 +269,7 @@ function postproc( obj, afeData, blockAnnotations ) %% ------------------------------------------------------------------------------- - function afeBlock = softmaskAFE( obj, afeBlock, segHypos, idx_mask ) - afeBlock = SegmentIdentityKS.maskAFEData( afeBlock, ... - segHypos.data(idx_mask).softMask, ... - segHypos.data(idx_mask).cfHz, ... - segHypos.data(idx_mask).hopSize ); - end - %% ------------------------------------------------------------------------------- - - function blockAnnotations = maskBA( obj, blockAnnotations, srcIdxs ) + function blockAnnotations = maskBA( ~, blockAnnotations, srcIdxs ) rSrcIdxs = 1:max( srcIdxs ); rSrcIdxs(srcIdxs) = 1:numel(srcIdxs); baFields = fieldnames( blockAnnotations ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SegmentKsWrapper_AnnotationWriter.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SegmentKsWrapper_AnnotationWriter.m new file mode 100644 index 0000000..8b4d412 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataProcs/SegmentKsWrapper_AnnotationWriter.m @@ -0,0 +1,224 @@ +classdef SegmentKsWrapper_AnnotationWriter < DataProcs.BlackboardKsWrapper_AnnotationWriter + % Wrapping the SegmentationKS + %% ----------------------------------------------------------------------------------- + properties (SetAccess = public) + varAzmSigma; + azmsGroundTruth; + segSrcAssignmentMethod; + dnnHash; + nfHash; + useDnnLocKs = false; + useNsrcsKs = false; + segmentKs; + dnnLocKs; + nsrcsKs; + idKss; + energeticBaidxs; + nsrcsBias; + nsrcsRndPlusMinusBias; + isNsrcsFixed; + isAzmFixedUniform; + end + + %% ----------------------------------------------------------------------------------- + methods (Abstract) + end + + %% ----------------------------------------------------------------------------------- + methods + + function obj = SegmentKsWrapper_AnnotationWriter( paramFilepath, varargin ) + fprintf( 'Building SegmentKsWrapper...' ); + ip = inputParser(); + ip.addOptional( 'useDnnLocKs', false ); + ip.addOptional( 'useNsrcsKs', false ); + ip.addOptional( 'nsrcsParams', {} ); + ip.addOptional( 'segSrcAssignmentMethod', 'minDistance' ); + ip.addOptional( 'varAzmSigma', 0 ); + ip.addOptional( 'nsrcsBias', 0 ); + ip.addOptional( 'nsrcsRndPlusMinusBias', 0 ); + ip.parse( varargin{:} ); + segmentKs = StreamSegregationKS( paramFilepath ); + fprintf( '.' ); + wrappedKss = {}; + if ip.Results.useDnnLocKs + dnnLocKs = DnnLocationKS(); + dnnHash = calcDataHash( dnnLocKs.DNNs ); + nfHash = calcDataHash( dnnLocKs.normFactors ); + wrappedKss{end+1} = dnnLocKs; + else + dnnLocKs = []; + dnnHash = []; + nfHash = []; + end + fprintf( '.' ); + idKss = []; + if ip.Results.useDnnLocKs && ip.Results.useNsrcsKs + ipns = inputParser(); + ipns.addOptional( 'modelPath', './nsrcs.model.mat' ); + ipns.addOptional( 'useIdModels', false ); + ipns.addOptional( 'idModelpathes', {} ); + ipns.parse( ip.Results.nsrcsParams{:} ); + if ipns.Results.useIdModels + idKss = {}; + mnames = {}; + for ii = 1 : numel( ipns.Results.idModelpathes ) + [mdir, mname] = fileparts( ipns.Results.idModelpathes{ii} ); + [~, mnames{ii}] = fileparts( mname ); + idKss{ii} = IdentityKS( mnames{ii}, mdir, false ); + fprintf( '.' ); + end + [~,idSort] = sort( mnames ); + idKss = idKss(idSort); + wrappedKss = [wrappedKss idKss]; + end + [mdir, mname] = fileparts( ipns.Results.modelPath ); + [~, mname] = fileparts( mname ); + nsrcsKs = NumberOfSourcesKS( mname, mdir, false, 'useIdModels', ipns.Results.useIdModels ); + fprintf( '.' ); + wrappedKss{end+1} = nsrcsKs; + else + nsrcsKs = []; + end + wrappedKss{end+1} = segmentKs; + obj = obj@DataProcs.BlackboardKsWrapper_AnnotationWriter( wrappedKss ); + obj.varAzmSigma = ip.Results.varAzmSigma; + obj.azmsGroundTruth = []; + obj.segSrcAssignmentMethod = ip.Results.segSrcAssignmentMethod; + obj.dnnHash = dnnHash; + obj.nfHash = nfHash; + obj.useDnnLocKs = ip.Results.useDnnLocKs; + obj.useNsrcsKs = ip.Results.useNsrcsKs; + if obj.useNsrcsKs && ~obj.useDnnLocKs + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['nSrcs model employment only supported if also using ' ... + 'location model.'] ); + end + obj.segmentKs = segmentKs; + obj.dnnLocKs = dnnLocKs; + obj.idKss = idKss; + obj.nsrcsKs = nsrcsKs; + obj.energeticBaidxs = []; + obj.isNsrcsFixed = false; + obj.isAzmFixedUniform = false; + obj.nsrcsBias = ip.Results.nsrcsBias; + if obj.useNsrcsKs && (obj.nsrcsBias ~= 0) + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['nSrcs bias only supported if using ' ... + 'nSrcs ground truth.'] ); + end + if ischar( obj.nsrcsBias ) + if strfind( obj.nsrcsBias, 'fixed' ) == 1 + obj.isNsrcsFixed = true; + obj.nsrcsBias = str2double( obj.nsrcsBias(6:end) ); + else + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['unrecognized nSrcs bias flag.'] ); + end + end + obj.nsrcsRndPlusMinusBias = ip.Results.nsrcsRndPlusMinusBias; + if obj.useNsrcsKs && (obj.nsrcsRndPlusMinusBias ~= 0) + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['nSrcs random bias only supported if using ' ... + 'nSrcs ground truth.'] ); + end + if ischar( obj.varAzmSigma ) + if strfind( obj.varAzmSigma, 'fixedUniform' ) == 1 + obj.isAzmFixedUniform = true; + else + error( 'AMLTTP:usage:unsupportedOptionSetting', ... + ['unrecognized azm bias flag.'] ); + end + end + fprintf( '.\n' ); + end + %% ------------------------------------------------------------------------------- + + function postproc( obj, blockAnnotations ) + blockAnnotations = rmfield( blockAnnotations, {'srcType','srcFile','mixEnergy','oneVsAllAvgSnrs','nSrcs_sceneConfig','nActivePointSrcs'} ); + newBAfields = fieldnames( blockAnnotations ); + for ii = 1 : numel( obj.out.blockAnnotations ) + if isempty( obj.out.blockAnnotations(ii).srcAzms ), continue; end + bon = obj.out.blockAnnotations(ii).blockOnset; + bof = obj.out.blockAnnotations(ii).blockOffset; + newBAidx = find( abs( [blockAnnotations.blockOnset] - bon ) <= 0.05 & ... + abs( [blockAnnotations.blockOffset] - bof ) <= 0.05 ); + if numel( newBAidx ) ~= 1 + error( ['newBAidx == ' num2str( newBAidx ) ', bon == ' num2str( bon ) ... + ', bof == ' num2str( bof )] ); + end + bazms = obj.out.blockAnnotations(ii).srcAzms; + newBA_srcIdxs = arrayfun( ... + @(a)( find( a == blockAnnotations(newBAidx).srcAzms ) ), ... + bazms, 'UniformOutput', false ); + newBA_srcIdxs = [newBA_srcIdxs{:}]; + assert( numel( newBA_srcIdxs ) == numel( bazms ) ); + maskedNewBA = obj.maskBA( blockAnnotations(newBAidx), newBA_srcIdxs ); + for jj = 1 : numel( newBAfields ) + obj.out.blockAnnotations(ii).(newBAfields{jj}) = maskedNewBA.(newBAfields{jj}); + end + end + end + %% ------------------------------------------------------------------------------- + + function outputDeps = getKsInternOutputDependencies( obj ) + outputDeps.v = 15; + outputDeps.useDnnLocKs = obj.useDnnLocKs; + outputDeps.useNsrcsKs = obj.useNsrcsKs; + outputDeps.useIdModels = ~isempty( obj.idKss ); + outputDeps.params = obj.kss{end}.observationModel.trainingParameters; + [~,outputDeps.afeHashs] = obj.getAfeRequests(); + outputDeps.varAzmSigma = obj.varAzmSigma; + outputDeps.segSrcAssignmentMethod = obj.segSrcAssignmentMethod; + outputDeps.nsrcsBias = obj.nsrcsBias; + outputDeps.nsrcsRndPlusMinusBias = obj.nsrcsRndPlusMinusBias; + end + %% ------------------------------------------------------------------------------- + + end + + %% ----------------------------------------------------------------------------------- + methods (Access = protected) + + %% ------------------------------------------------------------------------------- + + function blockAnnotations = maskBA( ~, blockAnnotations, srcIdxs ) + rSrcIdxs = 1:max( srcIdxs ); + rSrcIdxs(srcIdxs) = 1:numel(srcIdxs); + baFields = fieldnames( blockAnnotations ); + for ff = 1 : numel( baFields ) + if isstruct( blockAnnotations.(baFields{ff}) ) + baSrcs = blockAnnotations.(baFields{ff}).(baFields{ff})(:,2); + baIsSrcIdEq = cellfun( @(x)( any( x == srcIdxs) ), baSrcs ); + blockAnnotations.(baFields{ff}).t.onset(~baIsSrcIdEq) = []; + blockAnnotations.(baFields{ff}).t.offset(~baIsSrcIdEq) = []; + blockAnnotations.(baFields{ff}).(baFields{ff})(~baIsSrcIdEq,:) = []; + blockAnnotations.(baFields{ff}).(baFields{ff})(:,2) = ... + cellfun( @(x)(rSrcIdxs(x)), ... + blockAnnotations.(baFields{ff}).(baFields{ff})(:,2), ... + 'UniformOutput', false ); + elseif ~strcmpi('mixEnergy',baFields{ff}) && ... + (iscell( blockAnnotations.(baFields{ff}) ) ... + || numel( blockAnnotations.(baFields{ff}) ) > 1) + baIsSrcIdEq = false( size( blockAnnotations.(baFields{ff}) ) ); + baIsSrcIdEq(srcIdxs) = true; + blockAnnotations.(baFields{ff})(~baIsSrcIdEq) = []; + end + end +% blockAnnotations.mixEnergy = blockAnnotations.srcEnergy{1}; + end + %% ------------------------------------------------------------------------------- + + end + %% ----------------------------------------------------------------------------------- + + methods (Static) + + %% ------------------------------------------------------------------------------- + + end + +end + + + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/BAC_NPP_NS_Selector.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/BAC_NPP_NS_Selector.m new file mode 100644 index 0000000..7d13b43 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/BAC_NPP_NS_Selector.m @@ -0,0 +1,68 @@ +classdef BAC_NPP_NS_Selector < DataSelectors.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + discardNsNotNa = true; + end + + %% -------------------------------------------------------------------- + methods + + function obj = BAC_NPP_NS_Selector( discardNsNotNa ) + obj = obj@DataSelectors.Base(); + if nargin >= 1 + obj.discardNsNotNa = discardNsNotNa; + end + end + % ----------------------------------------------------------------- + + function [selectFilter] = getDataSelection( obj, sampleIdsIn, maxDataSize ) + selectFilter = true( size( sampleIdsIn ) ); + ba = obj.getData( 'blockAnnotations' ); + ba = ba(sampleIdsIn); + ba_ns = cat( 1, ba.nActivePointSrcs ); + if obj.discardNsNotNa + % nPointSrcsSceneConfig is only in blockAnnotations if they + % are loaded through GatherFeaturesProc + ba_ns_scp = cat( 1, ba.nPointSrcsSceneConfig ); + nsNotNa = (ba_ns ~= ba_ns_scp) & ~(ba_ns == 0 & ba_ns_scp == 1); + selectFilter(nsNotNa) = false; + sampleIdsIn(nsNotNa) = []; + ba(nsNotNa) = []; + ba_ns(nsNotNa) = []; + end + ba_pp = cat( 1, ba.posPresent ); + clear ba; + y = obj.getData( 'y' ); + y = y(sampleIdsIn); + y_ = y .* (ba_ns+1) .* (1 + ~ba_pp * 9); + y_Idxs = find( selectFilter ); + shouldNotExistPos = (y_ == 1); % pos although ba_ns==0 + shouldNotExistPos_l = false( size( selectFilter ) ); + shouldNotExistPos_l(y_Idxs) = shouldNotExistPos; + selectFilter = selectFilter & ~shouldNotExistPos_l; + y_(shouldNotExistPos) = []; + [throwoutIdxs,nClassSamples,nPerLabel,labels] = ... + DataSelectors.BAC_Selector.getBalThrowoutIdxs( y_, maxDataSize ); + selectFilter(y_Idxs(throwoutIdxs)) = false; + obj.verboseOutput = sprintf( ['\nOut of a pool of %d samples,\n' ... + 'discard %d where na ~= ns, and\n'], ... + numel( nsNotNa ), sum( nsNotNa ) ); + for ii = 1 : numel( nClassSamples ) + trueLabel = unique( y(y_==labels(ii)) ); + obj.verboseOutput = sprintf( ['%s' ... + 'randomly select %d/%d of class %d (%d)\n'], ... + obj.verboseOutput, ... + nClassSamples(ii), nPerLabel(ii), labels(ii), trueLabel ); + end + end + % ----------------------------------------------------------------- + + end + % --------------------------------------------------------------------- + + methods (Static) + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/BAC_Selector.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/BAC_Selector.m new file mode 100644 index 0000000..27eb76e --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/BAC_Selector.m @@ -0,0 +1,61 @@ +classdef BAC_Selector < DataSelectors.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + end + + %% -------------------------------------------------------------------- + methods + + function obj = BAC_Selector() + obj = obj@DataSelectors.Base(); + end + % ----------------------------------------------------------------- + + function [selectFilter] = getDataSelection( obj, sampleIdsIn, maxDataSize ) + selectFilter = true( size( sampleIdsIn ) ); + y = obj.getData( 'y' ); + y = y(sampleIdsIn); + [throwoutIdxs,nClassSamples,nPerLabel,labels] = ... + DataSelectors.BAC_Selector.getBalThrowoutIdxs( y, maxDataSize ); + selectFilter(throwoutIdxs) = false; + obj.verboseOutput = sprintf( '\nOut of a pool of %d samples,\n', ... + numel( sampleIdsIn ) ); + for ii = 1 : numel( nClassSamples ) + obj.verboseOutput = sprintf( ['%s' ... + 'randomly select %d/%d of class %d\n'], ... + obj.verboseOutput, ... + nClassSamples(ii), nPerLabel(ii), labels(ii) ); + end + end + % ----------------------------------------------------------------- + + end + % --------------------------------------------------------------------- + + methods (Static) + + function [throwoutIdxs,nClassSamples,nPerLabel,labels] = getBalThrowoutIdxs( y, maxDataSize ) + labels = unique( y ); + nPerLabel = arrayfun( @(l)(sum( l == y )), labels ); + [~, labelOrder] = sort( nPerLabel ); + nLabels = numel( labels ); + nClassSamples = zeros( nLabels, 1 ); + nRemaining = maxDataSize; + throwoutIdxs = []; + for ii = labelOrder' + nKeep = min( int32( nRemaining/nLabels ), nPerLabel(ii) ); + nClassSamples(ii) = nKeep; + nRemaining = nRemaining - nKeep; + nLabels = nLabels - 1; + lIdxs = find( y == labels(ii) ); + lIdxs = lIdxs(randperm(nPerLabel(ii))); + throwoutIdxs = [throwoutIdxs; lIdxs(nKeep+1:end)]; + end + end + % ----------------------------------------------------------------- + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/Base.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/Base.m new file mode 100644 index 0000000..7e5b88b --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/Base.m @@ -0,0 +1,36 @@ +classdef (Abstract) Base < handle + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + data; + verboseOutput = ''; + end + + %% -------------------------------------------------------------------- + methods + + function obj = connectData( obj, data ) + obj.data = data; + end + % ----------------------------------------------------------------- + + function d = getData( obj, dataField ) + if isa( obj.data, 'Core.IdentTrainPipeData' ) + d = obj.data(:,dataField); + elseif isstruct( obj.data ) && isfield( obj.data, dataField ) + d = obj.data.(dataField); + else + error( 'AMLTTP:ApiUsage', 'improper usage of DataSelectors API' ); + end + end + % ----------------------------------------------------------------- + + end + + %% -------------------------------------------------------------------- + methods (Abstract) + [selectFilter] = getDataSelection( obj, sampleIdsIn, maxDataSize ) + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/IgnorantSelector.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/IgnorantSelector.m new file mode 100644 index 0000000..e26632f --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+DataSelectors/IgnorantSelector.m @@ -0,0 +1,28 @@ +classdef IgnorantSelector < DataSelectors.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + end + + %% -------------------------------------------------------------------- + methods + + function obj = IgnorantSelector() + obj = obj@DataSelectors.Base(); + end + % ----------------------------------------------------------------- + + function [selectFilter] = getDataSelection( obj, sampleIdsIn, maxDataSize ) + selectFilter = true( size( sampleIdsIn ) ); + rndIdxs = randperm( numel( sampleIdsIn ) ); + selectFilter(rndIdxs(maxDataSize+1:end)) = false; + obj.verboseOutput = sprintf( ['Out of a pool of %d samples, ' ... + 'randomly select %d...\n'], ... + numel( sampleIdsIn ), maxDataSize ); + end + % ----------------------------------------------------------------- + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/Base.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/Base.m index 2e32f49..798b861 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/Base.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/Base.m @@ -39,7 +39,7 @@ function process( obj, wavFilepath ) for afeBlock = inData.afeBlocks' obj.afeData = afeBlock{1}; xd = obj.constructVector(); - obj.x(end+1,:) = xd{1}; + obj.x(end+1,:,:) = xd{1}; fprintf( '.' ); if obj.descriptionBuilt, continue; end obj.description = xd{2}; @@ -70,8 +70,8 @@ function process( obj, wavFilepath ) if ~obj.descriptionBuilt if exist( fdescFilepath, 'file' ) fdescFileSema = setfilesemaphore( fdescFilepath, 'semaphoreOldTime', 30 ); - load( fdescFilepath, 'description' ); - obj.description = description; + ld = load( fdescFilepath, 'description' ); + obj.description = ld.description; removefilesemaphore( fdescFileSema ); obj.descriptionBuilt = true; else @@ -89,7 +89,7 @@ function save( obj, wavFilepath, ~ ) save@Core.IdProcInterface( obj, wavFilepath, out ); fdescFilepath = [obj.getCurrentFolder() filesep 'fdesc.mat']; if obj.descriptionBuilt && ~exist( fdescFilepath, 'file' ) - description = obj.description; + description = obj.description; %#ok fdescFileSema = setfilesemaphore( fdescFilepath, 'semaphoreOldTime', 30 ); save( fdescFilepath, 'description' ); removefilesemaphore( fdescFileSema ); @@ -221,7 +221,31 @@ function save( obj, wavFilepath, ~ ) for ii = 1 : size( grps, 2 ) grps{1,ii} = cat( 2, grps{:,ii} ); end - grps(2,:) = []; + grps(2:end,:) = []; + grps = FeatureCreators.Base.removeGrpDuplicates( grps ); + b{2} = grps; + end + %% ------------------------------------------------------------------------------- + + function b = reshape2timeSeriesFeatVec( obj, bl ) + b{1} = reshape( bl{1}, size( bl{1}, 1), [] ); + if obj.descriptionBuilt, return; end + for ii = 2 : size( bl, 2 ) - 1 + bl{ii} = bl{ii+1}; + end + bl(end) = []; + for ii = 1 : size( bl, 2 ) - 1 + blszii = size( bl{1} ); + blszii(ii+1) = 1; + blszii(1) = []; + dgprs{ii} = repmat( shiftdim( bl{ii+1}, 2-ii ), blszii ); + dgprs{ii} = reshape( dgprs{ii}, 1, [] ); + end + grps = cat( 1, dgprs{:} ); + for ii = 1 : size( grps, 2 ) + grps{1,ii} = cat( 2, grps{:,ii} ); + end + grps(2:end,:) = []; grps = FeatureCreators.Base.removeGrpDuplicates( grps ); b{2} = grps; end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/Base_AnnotationWriter.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/Base_AnnotationWriter.m new file mode 100644 index 0000000..8307db1 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/Base_AnnotationWriter.m @@ -0,0 +1,96 @@ +classdef Base_AnnotationWriter < Core.IdProcInterface + % Base Abstract base class for specifying features sets with which features + % are extracted. + %% ----------------------------------------------------------------------------------- + properties (SetAccess = protected) + x; + blockAnnotations; + afeData; % current AFE signals used for vector construction + end + + %% ----------------------------------------------------------------------------------- + methods (Abstract) + afeRequests = getAFErequests( obj ) + outputDeps = getFeatureInternOutputDependencies( obj ) + x = constructVector( obj ) % has to return a cell, first item the feature vector, + % second item the features description. + end + + %% ----------------------------------------------------------------------------------- + methods + + function obj = Base_AnnotationWriter() + obj = obj@Core.IdProcInterface(); + end + %% ------------------------------------------------------------------------------- + + function setAfeData( obj, afeData ) + obj.afeData = afeData; + end + %% ------------------------------------------------------------------------------- + + function process( obj, wavFilepath ) + obj.inputProc.sceneId = obj.sceneId; + inData = obj.loadInputData( wavFilepath, 'blockAnnotations' ); + obj.blockAnnotations = inData.blockAnnotations; + selfData = obj.loadProcessedData( wavFilepath, 'x' ); + obj.x = selfData.x; + end + %% ------------------------------------------------------------------------------- + + %% ------------------------------------------------------------------------------- + + % override of Core.IdProcInterface's method + function [out, outFilepath] = loadProcessedData( obj, wavFilepath, varargin ) + [tmpOut, outFilepath] = loadProcessedData@Core.IdProcInterface( ... + obj, wavFilepath ); + obj.x = tmpOut.x; + if nargin < 3 || any( strcmpi( 'blockAnnotations', varargin ) ) + if isfield( tmpOut, 'blockAnnotations' ) % new version + obj.blockAnnotations = tmpOut.blockAnnotations; + else % old version; ba was saved in blockCreator cache + obj.inputProc.sceneId = obj.sceneId; + inData = obj.loadInputData( wavFilepath, 'blockAnnotations' ); + obj.blockAnnotations = inData.blockAnnotations; + obj.save( wavFilepath ); + end + end + out = obj.getOutput( varargin{:} ); + end + %% ------------------------------------------------------------------------------- + + % override of Core.IdProcInterface's method + function save( obj, wavFilepath, ~ ) + out.x = obj.x; + out.blockAnnotations = obj.blockAnnotations; + save@Core.IdProcInterface( obj, wavFilepath, out ); + end + %% ------------------------------------------------------------------------------- + + end + + %% ----------------------------------------------------------------------------------- + methods (Access = protected) + + function outputDeps = getInternOutputDependencies( obj ) + outputDeps.v = 4; + outputDeps.featureProc = obj.getFeatureInternOutputDependencies(); + end + %% ------------------------------------------------------------------------------- + + function out = getOutput( obj, varargin ) + if nargin < 2 || any( strcmpi( 'blockAnnotations', varargin ) ) + out.blockAnnotations = obj.blockAnnotations; + end + if nargin < 2 || any( strcmpi( 'x', varargin ) ) + out.x = obj.x; + end + end + + end + %% ----------------------------------------------------------------------------------- + +end + + + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet4Blockmean.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet4Blockmean.m index ad50948..f804026 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet4Blockmean.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet4Blockmean.m @@ -107,13 +107,13 @@ % afeIdx 2: spectralFeatures spfR = obj.makeBlockFromAfe( 2, 1, ... @(a)(compressAndScale( a.Data, 0.33 )), ... - {@(a)(a.Name),'24-ch', ... + {@(a)(a.Name),'32-ch', ... @(a)(a.Channel)}, ... {'t'}, ... {@(a)(a.fList)} ); spfL = obj.makeBlockFromAfe( 2, 2, ... @(a)(compressAndScale( a.Data, 0.33 )), ... - {@(a)(a.Name),'24-ch',... + {@(a)(a.Name),'32-ch',... @(a)(a.Channel)}, ... {'t'}, ... {@(a)(a.fList)} ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5Blockmean.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5Blockmean.m index 0725d90..5c7993b 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5Blockmean.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5Blockmean.m @@ -70,13 +70,13 @@ % afeIdx 2: spectralFeatures spfR = obj.makeBlockFromAfe( 2, 1, ... @(a)(compressAndScale( a.Data, 0.33 )), ... - {@(a)(a.Name),'24-ch', ... + {@(a)(a.Name),'32-ch', ... @(a)(a.Channel)}, ... {'t'}, ... {@(a)(a.fList)} ); spfL = obj.makeBlockFromAfe( 2, 2, ... @(a)(compressAndScale( a.Data, 0.33 )), ... - {@(a)(a.Name),'24-ch',... + {@(a)(a.Name),'32-ch',... @(a)(a.Channel)}, ... {'t'}, ... {@(a)(a.fList)} ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5Blockmean_AnnotationWriter.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5Blockmean_AnnotationWriter.m new file mode 100644 index 0000000..fa509d6 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5Blockmean_AnnotationWriter.m @@ -0,0 +1,69 @@ +classdef FeatureSet5Blockmean_AnnotationWriter < FeatureCreators.Base_AnnotationWriter +% FeatureSet5Blockmean Specifies a feature set consisting of: +% see FeatureSet5Blockmean.getAFErequests() + + %% -------------------------------------------------------------------- + properties (SetAccess = private) + deltasLevels; + end + + %% -------------------------------------------------------------------- + methods (Static) + end + + %% -------------------------------------------------------------------- + methods (Access = public) + + function obj = FeatureSet5Blockmean_AnnotationWriter( ) + obj = obj@FeatureCreators.Base_AnnotationWriter(); + obj.deltasLevels = 2; + end + %% ---------------------------------------------------------------- + + function afeRequests = getAFErequests( obj ) + commonParams = FeatureCreators.LCDFeatureSet.getCommonAFEParams(); + afeRequests{1}.name = 'amsFeatures'; + afeRequests{1}.params = genParStruct( ... + commonParams{:}, ... + 'fb_nChannels', 16, ... + 'ams_fbType', 'log', ... + 'ams_nFilters', 8, ... + 'ams_lowFreqHz', 2, ... + 'ams_highFreqHz', 256', ... + 'ams_wSizeSec', 128e-3, ... + 'ams_hSizeSec', 32e-3 ... + ); + afeRequests{2}.name = 'spectralFeatures'; + afeRequests{2}.params = genParStruct( ... + commonParams{:}, ... + 'fb_nChannels', 32 ... + ); + afeRequests{3}.name = 'ratemap'; + afeRequests{3}.params = genParStruct( ... + commonParams{:}, ... + 'fb_nChannels', 32 ... + ); + end + %% ---------------------------------------------------------------- + + function x = constructVector( obj ) + % noop + error( 'Not to get called' ); + end + %% ---------------------------------------------------------------- + + function outputDeps = getFeatureInternOutputDependencies( obj ) + outputDeps.deltasLevels = obj.deltasLevels; + outputDeps.featureProc = 'FeatureSet5Blockmean'; + outputDeps.v = 1; + end + %% ---------------------------------------------------------------- + + end + + %% -------------------------------------------------------------------- + methods (Access = protected) + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5aBlockmean.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5aBlockmean.m new file mode 100644 index 0000000..f2118e9 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5aBlockmean.m @@ -0,0 +1,136 @@ +classdef FeatureSet5aBlockmean < FeatureCreators.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = private) + deltasLevels; + compressor = 10; + end + + %% -------------------------------------------------------------------- + methods (Static) + end + + %% -------------------------------------------------------------------- + methods (Access = public) + + function obj = FeatureSet5aBlockmean( ) + obj = obj@FeatureCreators.Base(); + obj.deltasLevels = 2; + end + %% ---------------------------------------------------------------- + + function afeRequests = getAFErequests( obj ) + commonParams = FeatureCreators.LCDFeatureSet.getCommonAFEParams(); + afeRequests{1}.name = 'amsFeatures'; + afeRequests{1}.params = genParStruct( ... + commonParams{:}, ... + 'fb_nChannels', 16, ... + 'ams_fbType', 'log', ... + 'ams_nFilters', 8, ... + 'ams_lowFreqHz', 2, ... + 'ams_highFreqHz', 256', ... + 'ams_wSizeSec', 128e-3, ... + 'ams_hSizeSec', 32e-3 ... + ); + afeRequests{2}.name = 'ratemap'; + afeRequests{2}.params = genParStruct( ... + commonParams{:}, ... + 'fb_nChannels', 32 ... + ); + end + %% ---------------------------------------------------------------- + + function x = constructVector( obj ) + % constructVector for each feature: compress, scale, average + % over left and right channels, construct individual feature names + % returned flattened feature vector for entire block + % The AFE data is indexed according to the order in which the requests + % where made + % + % See getAFErequests + + % afeIdx 2: rm + rmR = obj.makeBlockFromAfe( 2, 1, ... + @(a)(compressAndScale( a.Data, 1/obj.compressor, @(x)(median( x(x>0.01) )), 0 )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... + {'t'}, ... + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz, 'UniformOutput', false)))} ); + rmL = obj.makeBlockFromAfe( 2, 2, ... + @(a)(compressAndScale( a.Data, 1/obj.compressor, @(x)(median( x(x>0.01) )), 0 )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... + {'t'}, ... + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz, 'UniformOutput', false)))} ); + rm = obj.combineBlocks( @(b1,b2)(0.5*b1+0.5*b2), 'LRmean', rmR, rmL ); + x = obj.block2feat( rm, ... + @(b)(lMomentAlongDim( b, [1,2,3], 1, true )), ... + 2, @(idxs)(sort([idxs idxs idxs])),... + {{'1.LMom',@(idxs)(idxs(1:3:end))},... + {'2.LMom',@(idxs)(idxs(2:3:end))},... + {'3.LMom',@(idxs)(idxs(3:3:end))}} ); + for ii = 1:obj.deltasLevels + rm = obj.transformBlock( rm, 1, ... + @(b)(b(2:end,:) - b(1:end-1,:)), ... + @(idxs)(idxs(1:end-1)),... + {[num2str(ii) '.delta']} ); + xtmp = obj.block2feat( rm, ... + @(b)(lMomentAlongDim( b, [1,2], 1, true )), ... + 2, @(idxs)(sort([idxs idxs])),... + {{'1.LMom',@(idxs)(idxs(1:2:end))},... + {'2.LMom',@(idxs)(idxs(2:2:end))}} ); + x = obj.concatFeats( x, xtmp ); + end + % afeIdx 1: amsFeatures and generate corresponding feature names + modR = obj.makeBlockFromAfe( 1, 1, ... + @(a)(compressAndScale( a.Data, 1/obj.compressor )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... + {'t'}, ..., + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz,'UniformOutput', false)))}, ... + {@(a)(strcat('mf', arrayfun(@(f)(num2str(f)), a.modCfHz,'UniformOutput', false)))} ); + modL = obj.makeBlockFromAfe( 1, 2, ... + @(a)(compressAndScale( a.Data, 1/obj.compressor )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... % groups + {'t'}, ... % varargin: time index + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz,'UniformOutput', false)))}, ... % varargin: freq. bins + {@(a)(strcat('mf', arrayfun(@(f)(num2str(f)), a.modCfHz,'UniformOutput', false)))} ); % vararing: modulation frequencies + % average between right and left channels + mod = obj.combineBlocks( @(b1,b2)(0.5*b1+0.5*b2), 'LRmean', modR, modL ); + mod = obj.reshapeBlock( mod, 1 ); % flatten + % append l-moments + x = obj.concatFeats( x, obj.block2feat( mod, ... + @(b)(lMomentAlongDim( b, [1,2], 1, true )), ... + 2, @(idxs)(sort([idxs idxs])),... + {{'1.LMom', @(idxs)(idxs(1:2:end))},... + {'2.LMom', @(idxs)(idxs(2:2:end))}} ) ); + % append first derivative + for ii = 1:obj.deltasLevels + mod = obj.transformBlock( mod, 1, ... + @(b)(b(2:end,:) - b(1:end-1,:)), ... + @(idxs)(idxs(1:end-1)),... + {[num2str(ii) '.delta']} ); + x = obj.concatFeats( x, obj.block2feat( mod, ... + @(b)(lMomentAlongDim( b, [1,2], 1, true )), ... + 2, @(idxs)(sort([idxs idxs])),... + {{'1.LMom', @(idxs)(idxs(1:2:end))},... + {'2.LMom', @(idxs)(idxs(2:2:end))}} ) ); + end + end + %% ---------------------------------------------------------------- + + function outputDeps = getFeatureInternOutputDependencies( obj ) + outputDeps.deltasLevels = obj.deltasLevels; + classInfo = metaclass( obj ); + [classname1, classname2] = strtok( classInfo.Name, '.' ); + if isempty( classname2 ), outputDeps.featureProc = classname1; + else outputDeps.featureProc = classname2(2:end); end + outputDeps.v = 1; + end + %% ---------------------------------------------------------------- + + end + + %% -------------------------------------------------------------------- + methods (Access = protected) + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5aRawTimeSeries.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5aRawTimeSeries.m new file mode 100644 index 0000000..10ea6bd --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/FeatureSet5aRawTimeSeries.m @@ -0,0 +1,102 @@ +classdef FeatureSet5aRawTimeSeries < FeatureCreators.TimeSeriesFeatureCreator + + %% -------------------------------------------------------------------- + properties (SetAccess = private) + compressor = 10; + end + + %% -------------------------------------------------------------------- + methods (Access = public) + + function obj = FeatureSet5aRawTimeSeries( ) + afeRequests = FeatureCreators.LCDFeatureSet.getCommonAFEParams(); + afeRequests = genParStruct( afeRequests{:} ); + targetFsHz = 1 / afeRequests.rm_hSizeSec; + obj = obj@FeatureCreators.TimeSeriesFeatureCreator( targetFsHz ); + end + %% ---------------------------------------------------------------- + + function afeRequests = getAFErequests( obj ) + commonParams = FeatureCreators.LCDFeatureSet.getCommonAFEParams(); + afeRequests{1}.name = 'amsFeatures'; + afeRequests{1}.params = genParStruct( ... + commonParams{:}, ... + 'fb_nChannels', 16, ... + 'ams_fbType', 'log', ... + 'ams_nFilters', 8, ... + 'ams_lowFreqHz', 2, ... + 'ams_highFreqHz', 256', ... + 'ams_wSizeSec', 128e-3, ... + 'ams_hSizeSec', 32e-3 ... + ); + afeRequests{2}.name = 'ratemap'; + afeRequests{2}.params = genParStruct( ... + commonParams{:}, ... + 'fb_nChannels', 32 ... + ); + end + %% ---------------------------------------------------------------- + + function x = constructTSvector( obj ) + rmAfeData = obj.afeData(2); + modAfeData = obj.afeData(1); + modFsHz = modAfeData{1}.FsHz; + % afeIdx 1: ams + modR = obj.makeBlockFromAfe( 1, 1, ... + @(a)(compressAndScale( ... + FeatureCreators.TimeSeriesFeatureCreator.resampleDataBlock(a.Data,modFsHz,obj.targetFsHz,size(rmAfeData{1}.Data,1)), ... + 1/obj.compressor )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... + {@(a)('t1')}, ... + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz,'UniformOutput', false)))}, ... + {@(a)(strcat('mf', arrayfun(@(f)(num2str(f)), a.modCfHz,'UniformOutput', false)))} ); + fprintf( '.' ); + modL = obj.makeBlockFromAfe( 1, 2, ... + @(a)(compressAndScale( ... + FeatureCreators.TimeSeriesFeatureCreator.resampleDataBlock(a.Data,modFsHz,obj.targetFsHz,size(rmAfeData{1}.Data,1)), ... + 1/obj.compressor )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... % groups + {@(a)('t1')}, ... + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz,'UniformOutput', false)))}, ... % varargin: freq. bins + {@(a)(strcat('mf', arrayfun(@(f)(num2str(f)), a.modCfHz,'UniformOutput', false)))} ); % vararing: modulation frequencies + fprintf( '.' ); + % afeIdx 2: rm + rmR = obj.makeBlockFromAfe( 2, 1, ... + @(a)(compressAndScale( a.Data, 1/obj.compressor, @(x)(median( x(x>0.01) )), 0 )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... + {@(a)(strcat('t', arrayfun(@(t)(num2str(t)),1:size(a.Data,1),'UniformOutput',false)))}, ... + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz, 'UniformOutput', false)))} ); + fprintf( '.' ); + rmL = obj.makeBlockFromAfe( 2, 2, ... + @(a)(compressAndScale( a.Data, 1/obj.compressor, @(x)(median( x(x>0.01) )), 0 )), ... + {@(a)(a.Name), @(a)([num2str(numel(a.cfHz)) '-ch']), @(a)(a.Channel)}, ... + {@(a)(strcat('t', arrayfun(@(t)(num2str(t)),1:size(a.Data,1),'UniformOutput',false)))}, ... + {@(a)(strcat('f', arrayfun(@(f)(num2str(f)), a.cfHz, 'UniformOutput', false)))} ); + fprintf( '.' ); + % average between right and left channels + rm = obj.combineBlocks( @(b1,b2)(0.5*b1+0.5*b2), 'LRmean', rmR, rmL ); + mod = obj.combineBlocks( @(b1,b2)(0.5*b1+0.5*b2), 'LRmean', modR, modL ); + + x = obj.concatFeats( obj.reshape2timeSeriesFeatVec( rm ), obj.reshape2timeSeriesFeatVec( mod ) ); + fprintf( ':' ); + end + %% ---------------------------------------------------------------- + + function outputDeps = getTSfeatureInternOutputDependencies( obj ) + outputDeps.compressor = obj.compressor; + classInfo = metaclass( obj ); + [classname1, classname2] = strtok( classInfo.Name, '.' ); + if isempty( classname2 ), outputDeps.featureProc = classname1; + else outputDeps.featureProc = classname2(2:end); end + outputDeps.v = 1; + end + %% ---------------------------------------------------------------- + + end + + %% -------------------------------------------------------------------- + methods (Static) + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/TimeSeriesFeatureCreator.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/TimeSeriesFeatureCreator.m new file mode 100644 index 0000000..73f4242 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+FeatureCreators/TimeSeriesFeatureCreator.m @@ -0,0 +1,98 @@ +classdef TimeSeriesFeatureCreator < FeatureCreators.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = private) + targetFsHz; + end + + %% ----------------------------------------------------------------------------------- + methods (Abstract) + outputDeps = getTSfeatureInternOutputDependencies( obj ) + x = constructTSvector( obj ) % has to return a cell, first item the feature vector, + % second item the features description. + end + + %% -------------------------------------------------------------------- + methods (Access = public) + + function obj = TimeSeriesFeatureCreator( targetFsHz ) + obj = obj@FeatureCreators.Base(); + obj.targetFsHz = targetFsHz; + end + %% ---------------------------------------------------------------- + + function x = constructVector( obj ) + x = obj.constructTSvector(); + T = size( x{1}, 1 ); + bas = obj.blockAnnotations; + aFields = fieldnames( bas ); + isSequenceAnnotation = cellfun( @(af)(... + isstruct( bas.(af) ) && isfield( bas.(af), 't' ) ... + ), aFields ); + sequenceAfields = aFields(isSequenceAnnotation); + for jj = 1 : numel( sequenceAfields ) + fprintf( '.' ); + seqAname = sequenceAfields{jj}; + annot = bas.(seqAname); + if ~isstruct( annot.t ) % time series + if length( annot.t ) == size( annot.(seqAname), 1 ) + if iscell( annot.(seqAname) ) + asc_sz = size( annot.(seqAname) ); + asc_num = cell2mat( annot.(seqAname) ); + asc_num = interp1( annot.t, asc_num, (1:T)'/obj.targetFsHz, 'pchip' ); + annot.(seqAname) = mat2cell( asc_num, ones( 1, size( asc_num, 1 ) ), repmat( size( asc_num, 2 )/asc_sz(2), 1, asc_sz(2) ) ); + else + annot.(seqAname) = interp1( annot.t, annot.(seqAname), (1:T)/obj.targetFsHz, 'pchip' ); + end + annot.t = 1:T; + else + error( 'unexpected annotations sequence structure' ); + end + elseif all( isfield( annot.t, {'onset','offset'} ) ) % event series + if isequal( size( annot.t.onset ), size( annot.t.offset ) ) && ... + length( annot.t.onset ) == size( annot.(seqAname), 1 ) + annot.t.onset = round( annot.t.onset * obj.targetFsHz ); + annot.t.onset = min( [annot.t.onset;repmat( T, size( annot.t.onset ) )], [], 1 ); + annot.t.offset = round( annot.t.offset * obj.targetFsHz ); + annot.t.offset = min( [annot.t.offset;repmat( T, size( annot.t.offset ) )], [], 1 ); + else + error( 'unexpected annotations sequence structure' ); + end + else + error( 'unexpected annotations sequence structure' ); + end + bas.(seqAname) = annot; + end + obj.blockAnnotations = bas; + fprintf( ';' ); + end + %% ---------------------------------------------------------------- + + function outputDeps = getFeatureInternOutputDependencies( obj ) + outputDeps.tsFeatureProc = obj.getTSfeatureInternOutputDependencies(); + outputDeps.targetFsHz = obj.targetFsHz; + outputDeps.v = 1; + end + %% ---------------------------------------------------------------- + + end + + %% -------------------------------------------------------------------- + methods (Static) + + function dataBlockResampled = resampleDataBlock( dataBlock, srcFsHz, targetFsHz, targetNt ) + [nT, ~] = size(dataBlock); + srcTs = 0 : 1 / srcFsHz : (nT-1) / srcFsHz; + targetTs = 0 : 1 / targetFsHz : srcTs(end); + nTargetTsMissing = targetNt - numel( targetTs ); + % pchip interpolation... + dataBlockResampled = interp1( srcTs, dataBlock, targetTs, 'pchip' ); + % ...with 'last-datapoint' extrapolation. + dataBlockResampled(end+1:end+nTargetTsMissing,:) = ... + repmat( dataBlockResampled(end,:), nTargetTsMissing, 1 ); + end + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/BAC_NS_NPP_Weighter.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/BAC_NS_NPP_Weighter.m new file mode 100644 index 0000000..b4d962f --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/BAC_NS_NPP_Weighter.m @@ -0,0 +1,55 @@ +classdef BAC_NS_NPP_Weighter < ImportanceWeighters.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) +% labelWeights; + end + + %% -------------------------------------------------------------------- + methods + + function obj = BAC_NS_NPP_Weighter( labelWeights ) + obj = obj@ImportanceWeighters.Base(); +% if nargin >= 1 +% obj.labelWeights = labelWeights; +% end + end + % ----------------------------------------------------------------- + + function [importanceWeights] = getImportanceWeights( obj, sampleIds ) + importanceWeights = ones( size( sampleIds ) ); + y = obj.data(:,'y'); + y = y(sampleIds,:); + assert( size( y, 2 ) == 1 ); + ba = obj.data(:,'blockAnnotations'); + ba = ba(sampleIds); + ba_ns = cat( 1, ba.nActivePointSrcs ); + ba_pp = cat( 1, ba.posPresent ); + clear ba; + y_ = y .* (ba_ns+1) .* (1 + ~ba_pp * 9); + y_unique = unique( y_ ); + for ii = 1 : numel( y_unique ) + y_unique_ii_lidxs = y_ == y_unique(ii); + lw = numel( sampleIds ) / sum( y_unique_ii_lidxs ); + if y_unique(ii) > 0 + lw = lw * 2; % because their is p vs (npp+nnp) + end + importanceWeights(y_unique_ii_lidxs) = lw; + end + importanceWeights = importanceWeights / min( importanceWeights ); + obj.verboseOutput = '\nWeighting samples of \n'; + for ii = 1 : numel( y_unique ) + trueLabel = unique( y(y_==y_unique(ii)) ); + labelWeight = unique( importanceWeights(y_==y_unique(ii)) ); + obj.verboseOutput = sprintf( ['%s' ... + ' class %d (%d) with %f\n'], ... + obj.verboseOutput, ... + y_unique(ii), trueLabel, labelWeight ); + end + end + % ----------------------------------------------------------------- + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/BAC_Weighter.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/BAC_Weighter.m new file mode 100644 index 0000000..5fde490 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/BAC_Weighter.m @@ -0,0 +1,52 @@ +classdef BAC_Weighter < ImportanceWeighters.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + labelWeights; + end + + %% -------------------------------------------------------------------- + methods + + function obj = BAC_Weighter( labelWeights ) + obj = obj@ImportanceWeighters.Base(); + if nargin >= 1 + obj.labelWeights = labelWeights; + end + end + % ----------------------------------------------------------------- + + function [importanceWeights] = getImportanceWeights( obj, sampleIds ) + importanceWeights = ones( size( sampleIds ) ); + y = obj.data(:,'y'); + y = y(sampleIds,:); + for cc = 1 : size( y, 2 ) + labels = unique( y(:,cc) ); + lw = obj.labelWeights; + if isempty( lw ) + lw = ones( size( labels ) ); + elseif numel( lw ) ~= numel( labels ) + error( 'AMLTTP:usage', 'number of label weights must equal number of unique labels' ); + end + for ii = 1 : numel( labels ) + y_label_ii = y(:,cc) == labels(ii); + labelWeight(ii) = lw(ii) * numel( sampleIds ) / sum( y_label_ii ); %#ok + importanceWeights(y_label_ii,cc) = labelWeight(ii); + end + end + importanceWeights = mean( importanceWeights, 2 ); + importanceWeights = importanceWeights / min( importanceWeights ); + obj.verboseOutput = '\nWeighting samples of \n'; + for ii = 1 : numel( labels ) + obj.verboseOutput = sprintf( ['%s' ... + ' class %d with %f\n'], ... + obj.verboseOutput, ... + labels(ii), labelWeight(ii) ); + end + end + % ----------------------------------------------------------------- + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/Base.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/Base.m new file mode 100644 index 0000000..ceebfe3 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/Base.m @@ -0,0 +1,25 @@ +classdef (Abstract) Base < handle + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + data; + verboseOutput = ''; + end + + %% -------------------------------------------------------------------- + methods + + function obj = connectData( obj, data ) + obj.data = data; + end + % ----------------------------------------------------------------- + + end + + %% -------------------------------------------------------------------- + methods (Abstract) + [importanceWeights] = getImportanceWeights( obj, sampleIds ) + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/IgnorantWeighter.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/IgnorantWeighter.m new file mode 100644 index 0000000..02362c8 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ImportanceWeighters/IgnorantWeighter.m @@ -0,0 +1,23 @@ +classdef IgnorantWeighter < ImportanceWeighters.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + end + + %% -------------------------------------------------------------------- + methods + + function obj = IgnorantWeighter() + obj = obj@ImportanceWeighters.Base(); + end + % ----------------------------------------------------------------- + + function [importanceWeights] = getImportanceWeights( ~, sampleIds ) + importanceWeights = ones( size( sampleIds ) ); + end + % ----------------------------------------------------------------- + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmDistributionLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmDistributionLabeler.m index 1f5b3b8..3da154c 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmDistributionLabeler.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmDistributionLabeler.m @@ -32,12 +32,14 @@ end %% ------------------------------------------------------------------------------- - function y = labelEnergeticBlock( obj, blockAnnotations ) + function [y, ysi] = labelEnergeticBlock( obj, blockAnnotations ) srcAzms = blockAnnotations.srcAzms(obj.sourceIds,:); srcAzmIdxs = LabelCreators.AzmDistributionLabeler.azimToIndex( ... srcAzms, obj.angularResolution, obj.nAngles ); y = zeros( 1, obj.nAngles ); y(srcAzmIdxs) = 1; + [~,srcAzmIdxsOrder] = sort( srcAzmIdxs ); + ysi = {srcAzmIdxsOrder}; % TODO what if sourceIds ~= ':'? end end @@ -47,7 +49,7 @@ %% ----------------------------------------------------------------------------------- function outputDeps = getLabelInternOutputDependencies( obj ) outputDeps.angularResolution = obj.angularResolution; - outputDeps.v = 1; + outputDeps.v = 2; end end %% ----------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmLabeler.m index 22aabe0..5befb8b 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmLabeler.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/AzmLabeler.m @@ -24,8 +24,9 @@ end %% ------------------------------------------------------------------------------- - function y = labelEnergeticBlock( obj, blockAnnotations ) + function [y,ysi] = labelEnergeticBlock( obj, blockAnnotations ) y = blockAnnotations.srcAzms(obj.sourceIds); + ysi = {obj.sourceIds}; end %% ------------------------------------------------------------------------------- @@ -35,7 +36,7 @@ methods (Access = protected) function outputDeps = getLabelInternOutputDependencies( obj ) - outputDeps.v = 1; + outputDeps.v = 2; end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/Base.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/Base.m index f0ac03f..19c677c 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/Base.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/Base.m @@ -3,6 +3,7 @@ %% ----------------------------------------------------------------------------------- properties (SetAccess = protected) y; + ysi; x; blockAnnotations; labelBlockSize_s; @@ -13,7 +14,7 @@ %% ----------------------------------------------------------------------------------- methods (Abstract, Access = protected) outputDeps = getLabelInternOutputDependencies( obj ) - y = label( obj, annotations ) + [y, ysi] = label( obj, annotations ) end %% ----------------------------------------------------------------------------------- @@ -23,10 +24,13 @@ obj = obj@Core.IdProcInterface(); ip = inputParser; ip.addOptional( 'labelBlockSize_s', [] ); - ip.addOptional( 'removeUnclearBlocks', true ); + ip.addOptional( 'removeUnclearBlocks', 'block-wise' ); % 'false','block-wise','time-wise' ip.parse( varargin{:} ); obj.labelBlockSize_s = ip.Results.labelBlockSize_s; obj.removeUnclearBlocks = ip.Results.removeUnclearBlocks; + if ~any( strcmpi( obj.removeUnclearBlocks, {'false','block-wise','time-wise'} ) ) + error( 'AMLTTP:usage:unsupportedOptionSetting', 'use one of ''false'',''block-wise'',''time-wise''.' ); + end if isempty( obj.labelBlockSize_s ) obj.labelBlockSize_auto = true; else @@ -39,12 +43,13 @@ function process( obj, wavFilepath ) obj.inputProc.sceneId = obj.sceneId; in = obj.loadInputData( wavFilepath, 'blockAnnotations' ); obj.y = []; + obj.ysi = {}; for blockAnnotation = in.blockAnnotations' if obj.labelBlockSize_auto obj.labelBlockSize_s = ... blockAnnotation.blockOffset - blockAnnotation.blockOnset; end - obj.y(end+1,:) = obj.label( blockAnnotation ); + [obj.y(end+1,:), obj.ysi{end+1}] = obj.label( blockAnnotation ); if obj.labelBlockSize_auto obj.labelBlockSize_s = []; end @@ -56,14 +61,15 @@ function process( obj, wavFilepath ) % override of DataProcs.IdProcInterface's method function [out, outFilepath] = loadProcessedData( obj, wavFilepath, varargin ) [tmpOut, outFilepath] = loadProcessedData@Core.IdProcInterface( ... - obj, wavFilepath, 'y' ); + obj, wavFilepath, 'y', 'ysi' ); obj.y = tmpOut.y; + obj.ysi = tmpOut.ysi; obj.inputProc.sceneId = obj.sceneId; - if nargin < 3 || (any( strcmpi( 'x', varargin ) ) && any( strcmpi( 'a', varargin ) )) + if nargin < 3 || (any( strcmpi( 'x', varargin ) ) && (any( strcmpi( 'a', varargin ) ) || strcmpi( obj.removeUnclearBlocks, 'time-wise' ))) inData = obj.loadInputData( wavFilepath, 'x', 'blockAnnotations' ); obj.x = inData.x; obj.blockAnnotations = inData.blockAnnotations; - elseif any( strcmpi( 'a', varargin ) ) + elseif any( strcmpi( 'a', varargin ) ) || strcmpi( obj.removeUnclearBlocks, 'time-wise' ) inData = obj.loadInputData( wavFilepath, 'blockAnnotations' ); obj.blockAnnotations = inData.blockAnnotations; elseif any( strcmpi( 'x', varargin ) ) @@ -76,7 +82,7 @@ function process( obj, wavFilepath ) % override of Core.IdProcInterface's method function out = saveOutput( obj, wavFilepath ) - out = obj.getOutput( 'y' ); + out = obj.getOutput( 'y', 'ysi', 'noRemoveNanBlocks' ); obj.save( wavFilepath, out ); end %% ------------------------------------------------------------------------------- @@ -84,6 +90,7 @@ function process( obj, wavFilepath ) % override of DataProcs.IdProcInterface's method function save( obj, wavFilepath, ~ ) out.y = obj.y; + out.ysi = obj.ysi; save@Core.IdProcInterface( obj, wavFilepath, out ); end %% ------------------------------------------------------------------------------- @@ -94,7 +101,7 @@ function save( obj, wavFilepath, ~ ) methods (Access = protected) function outputDeps = getInternOutputDependencies( obj ) - outputDeps.v = 1; + outputDeps.v = 2; outputDeps.labelBlockSize = obj.labelBlockSize_s; outputDeps.labelBlockSize_auto = obj.labelBlockSize_auto; outputDeps.labelProc = obj.getLabelInternOutputDependencies(); @@ -103,23 +110,32 @@ function save( obj, wavFilepath, ~ ) function out = getOutput( obj, varargin ) out.y = obj.y; - out.bIdxs = 1 : numel( out.y ); + out.bIdxs = 1 : size( out.y, 1 ); + removeNanBlocks = strcmpi( obj.removeUnclearBlocks, {'block-wise','time-wise'} ); + if ~any( removeNanBlocks ) || any( strcmpi( 'noRemoveNanBlocks', varargin ) ) + removeNanBlocks_lidx = []; + else + removeNanBlocks_lidx = any(isnan(out.y),2); + if removeNanBlocks(2) + [~,~,sameTimeIdxs] = unique( [obj.blockAnnotations.blockOffset] ); + nanTimeIdxs = sameTimeIdxs(removeNanBlocks_lidx); + removeNanBlocks_lidx = ismember( sameTimeIdxs, nanTimeIdxs ); + end + end if nargin < 2 || any( strcmpi( 'x', varargin ) ) out.x = obj.x; - if obj.removeUnclearBlocks - out.x(any(isnan(out.y),2),:) = []; - end + out.x(removeNanBlocks_lidx,:,:) = []; end if nargin < 2 || any( strcmpi( 'a', varargin ) ) out.a = obj.blockAnnotations; - if obj.removeUnclearBlocks - out.a(any(isnan(out.y),2)) = []; - end + out.a(removeNanBlocks_lidx) = []; end - if obj.removeUnclearBlocks - out.bIdxs(any(isnan(out.y),2)) = []; - out.y(any(isnan(out.y),2),:) = []; + if nargin < 2 || any( strcmpi( 'ysi', varargin ) ) + out.ysi = obj.ysi; + out.ysi(removeNanBlocks_lidx) = []; end + out.bIdxs(removeNanBlocks_lidx) = []; + out.y(removeNanBlocks_lidx,:,:) = []; end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/EnergyDependentLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/EnergyDependentLabeler.m index f5d1b87..6b6aff0 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/EnergyDependentLabeler.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/EnergyDependentLabeler.m @@ -8,7 +8,7 @@ %% ----------------------------------------------------------------------------------- methods (Abstract) - y = labelEnergeticBlock( obj, blockAnnotations ) + [y, ysi] = labelEnergeticBlock( obj, blockAnnotations ) end %% ----------------------------------------------------------------------------------- @@ -35,17 +35,18 @@ outputDeps = getInternOutputDependencies@LabelCreators.Base( obj ); outputDeps.sourcesMinEnergy = obj.sourcesMinEnergy; outputDeps.sourceIds = obj.sourceIds; - outputDeps.v = 1; + outputDeps.v = 2; end %% ------------------------------------------------------------------------------- - function y = label( obj, blockAnnotations ) + function [y, ysi] = label( obj, blockAnnotations ) rejectBlock = LabelCreators.EnergyDependentLabeler.isEnergyTooLow( ... blockAnnotations, obj.sourceIds, obj.sourcesMinEnergy ); if rejectBlock y = NaN; + ysi = {}; else - y = obj.labelEnergeticBlock( blockAnnotations ); + [y, ysi] = obj.labelEnergeticBlock( blockAnnotations ); end end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/IdAzmDistributionLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/IdAzmDistributionLabeler.m index 28cd518..0b26d6e 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/IdAzmDistributionLabeler.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/IdAzmDistributionLabeler.m @@ -39,7 +39,7 @@ %% ----------------------------------------------------------------------------------- methods (Access = protected) - function y = label( obj, blockAnnotations ) + function [y,ysi] = label( obj, blockAnnotations ) [activeTypes, ~, activeSrcIdxs] = getActiveTypes( obj, blockAnnotations ); if ~isempty(obj.nrgSrcsFilter) srcAzms = blockAnnotations.srcAzms(obj.nrgSrcsFilter, :); @@ -61,6 +61,7 @@ end end y = reshape(y, 1, numel(obj.types) * (obj.nAzimuthBins + 1)); + ysi = {}; end %% ----------------------------------------------------------------------------------- @@ -68,7 +69,7 @@ outputDeps = getLabelInternOutputDependencies@LabelCreators.MultiEventTypeLabeler(obj); outputDeps.nAzimuthBins = obj.nAzimuthBins; outputDeps.angularResolution = obj.angularResolution; - outputDeps.v = 2; + outputDeps.v = 3; end end %% ----------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiEventTypeLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiEventTypeLabeler.m index dc9b77e..b2037c7 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiEventTypeLabeler.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiEventTypeLabeler.m @@ -7,12 +7,17 @@ types; negOut; srcPrioMethod; + segIdTargetSrcFilter; srcTypeFilterOut; nrgSrcsFilter; fileFilterOut; sourcesMinEnergy; end + %% ----------------------------------------------------------------------------------- + properties (Access = public) + end + %% ----------------------------------------------------------------------------------- methods (Abstract) end @@ -25,10 +30,11 @@ ip.addOptional( 'minBlockToEventRatio', 0.75 ); ip.addOptional( 'maxNegBlockToEventRatio', 0 ); ip.addOptional( 'labelBlockSize_s', [] ); - ip.addOptional( 'removeUnclearBlocks', true ); + ip.addOptional( 'removeUnclearBlocks', 'block-wise' ); ip.addOptional( 'types', {{'Type1'},{'Type2'}} ); ip.addOptional( 'negOut', 'rest' ); % rest, none ip.addOptional( 'srcPrioMethod', 'order' ); % energy, order, time + ip.addOptional( 'segIdTargetSrcFilter', [] ); % e.g. [1,1;3,2]: throw away time-aggregate blocks with type 1 on other than src 1 and type 2 on other than src 3 ip.addOptional( 'srcTypeFilterOut', [] ); % e.g. [2,1;3,2]: throw away type 1 blocks from src 2 and type 2 blocks from src 3 ip.addOptional( 'nrgSrcsFilter', [] ); % idxs of srcs to be account for block-filtering based on too low energy. If empty, do not use ip.addOptional( 'fileFilterOut', {} ); % blocks containing these files get filtered out @@ -42,10 +48,12 @@ obj.types = ip.Results.types; obj.negOut = ip.Results.negOut; obj.srcPrioMethod = ip.Results.srcPrioMethod; + obj.segIdTargetSrcFilter = ip.Results.segIdTargetSrcFilter; obj.srcTypeFilterOut = ip.Results.srcTypeFilterOut; obj.nrgSrcsFilter = ip.Results.nrgSrcsFilter; obj.sourcesMinEnergy = ip.Results.sourcesMinEnergy; obj.fileFilterOut = sort( ip.Results.fileFilterOut ); + obj.procName = [obj.procName '(' strcat( obj.types{1}{:} ) ')']; end %% ------------------------------------------------------------------------------- @@ -63,8 +71,9 @@ outputDeps.nrgSrcsFilter = obj.nrgSrcsFilter; outputDeps.sourcesMinEnergy = obj.sourcesMinEnergy; outputDeps.srcTypeFilterOut = sortrows( obj.srcTypeFilterOut ); + outputDeps.segIdTargetSrcFilter = sortrows( obj.segIdTargetSrcFilter ); outputDeps.fileFilterOut = obj.fileFilterOut; - outputDeps.v = 7; + outputDeps.v = 9; end %% ------------------------------------------------------------------------------- @@ -73,13 +82,14 @@ end %% ------------------------------------------------------------------------------- - function y = label( obj, blockAnnotations ) + function [y, ysi] = label( obj, blockAnnotations ) [activeTypes, relBlockEventOverlap, srcIdxs] = obj.getActiveTypes( blockAnnotations ); [maxPosRelOverlap,maxTimeTypeIdx] = max( relBlockEventOverlap ); + ysi = {}; if any( activeTypes ) switch obj.srcPrioMethod case 'energy' - eSrcs = cellfun( @mean, blockAnnotations.srcEnergy(:,:) ); % mean over channels + eSrcs = cellfun( @mean, blockAnnotations.globalSrcEnergy ); % mean over channels for ii = 1 : numel( activeTypes ) if activeTypes(ii) eTypes(ii) = 1/sum( 1./eSrcs([srcIdxs{ii}]) ); @@ -97,6 +107,7 @@ 'Use ''energy'' or ''order''.'], obj.srcPrioMethod ); end y = labelTypeIdx; + ysi = srcIdxs(y); elseif strcmp( obj.negOut, 'rest' ) && ... (maxPosRelOverlap <= obj.maxNegBlockToEventRatio) y = -1; @@ -104,6 +115,20 @@ y = NaN; return; end + if ~isempty( obj.segIdTargetSrcFilter ) + for ii = 1 : size( obj.segIdTargetSrcFilter, 1 ) + srcf = obj.segIdTargetSrcFilter(ii,1); + typef = obj.segIdTargetSrcFilter(ii,2); + srcfAzm = obj.lastConfig{obj.sceneId}.preceding.preceding.preceding.preceding.preceding.sceneCfg.sources(srcf).azimuth; + if isa( srcfAzm, 'SceneConfig.ValGen' ) + srcfAzm = srcfAzm.val; + end + if activeTypes(typef) && (any( abs( blockAnnotations.srcAzms(srcIdxs{typef}) - srcfAzm ) >= 0.1 ) || any( abs( blockAnnotations.globalNrjOffsets(srcIdxs{typef}) ) >= 0.1 )) + y = NaN; + return; + end + end + end for ii = 1 : size( obj.srcTypeFilterOut, 1 ) srcfo = obj.srcTypeFilterOut(ii,1); typefo = obj.srcTypeFilterOut(ii,2); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiEventTypeTimeSeriesLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiEventTypeTimeSeriesLabeler.m new file mode 100644 index 0000000..fef54fc --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiEventTypeTimeSeriesLabeler.m @@ -0,0 +1,177 @@ +classdef MultiEventTypeTimeSeriesLabeler < LabelCreators.TimeSeriesLabelCreator + % class for multi-class labeling blocks by event + %% ----------------------------------------------------------------------------------- + properties (SetAccess = protected) + types; + negOut; + srcPrioMethod; + segIdTargetSrcFilter = []; + srcTypeFilterOut; + fileFilterOut = {}; + end + + %% ----------------------------------------------------------------------------------- + properties (Access = public) + end + + %% ----------------------------------------------------------------------------------- + methods (Abstract) + end + + %% ----------------------------------------------------------------------------------- + methods + + function obj = MultiEventTypeTimeSeriesLabeler( varargin ) + ip = inputParser; + ip.addOptional( 'removeUnclearBlocks', 'sequence-wise' ); + ip.addOptional( 'types', {{'Type1'},{'Type2'}} ); + ip.addOptional( 'negOut', 'rest' ); % rest, none + ip.addOptional( 'srcPrioMethod', 'order' ); % energy, order +% ip.addOptional( 'segIdTargetSrcFilter', [] ); % e.g. [1,1;3,2]: throw away time-aggregate blocks with type 1 on other than src 1 and type 2 on other than src 3 + ip.addOptional( 'srcTypeFilterOut', [] ); % e.g. [2,1;3,2]: throw away type 1 blocks from src 2 and type 2 blocks from src 3 +% ip.addOptional( 'fileFilterOut', {} ); % blocks containing these files get filtered out + ip.parse( varargin{:} ); + obj = obj@LabelCreators.TimeSeriesLabelCreator( 'removeUnclearBlocks', ... + ip.Results.removeUnclearBlocks ); + obj.types = ip.Results.types; + obj.negOut = ip.Results.negOut; + obj.srcPrioMethod = ip.Results.srcPrioMethod; +% obj.segIdTargetSrcFilter = ip.Results.segIdTargetSrcFilter; + obj.srcTypeFilterOut = ip.Results.srcTypeFilterOut; +% obj.fileFilterOut = sort( ip.Results.fileFilterOut ); + obj.procName = [obj.procName '(' strcat( obj.types{1}{:} ) ')']; + end + %% ------------------------------------------------------------------------------- + + end + + %% ----------------------------------------------------------------------------------- + methods (Access = protected) + + function outputDeps = getLabelInternOutputDependencies( obj ) + outputDeps.types = obj.types; + outputDeps.negOut = obj.negOut; + outputDeps.srcPrioMethod = obj.srcPrioMethod; + outputDeps.srcTypeFilterOut = sortrows( obj.srcTypeFilterOut ); + outputDeps.segIdTargetSrcFilter = sortrows( obj.segIdTargetSrcFilter ); + outputDeps.fileFilterOut = obj.fileFilterOut; + outputDeps.v = 1; + end + %% ------------------------------------------------------------------------------- + + function eit = eventIsType( obj, typeIdx, type ) + eit = any( strcmp( type, obj.types{typeIdx} ) ); + end + %% ------------------------------------------------------------------------------- + + function [y, ysi] = label( obj, blockAnnotations ) + [activeTypes, srcIdxs] = obj.getActiveTypes( blockAnnotations ); + y = zeros( size( activeTypes, 1 ), 1 ); + ysi = cell( size( activeTypes, 1 ), 1 ); + if any( activeTypes(:) ) + switch obj.srcPrioMethod + case 'energy' + error( 'AMLTTP:notImplemented', 'energy PrioMethod for time-series not implemented yet' ); + for ss = 1 : size( blockAnnotations.globalSrcEnergy.globalSrcEnergy, 2 ) + eSrcs(:,ss) = mean( cell2mat( ... + blockAnnotations.globalSrcEnergy.globalSrcEnergy(:,ss) ), 2 ); + end + eTypes = eSrcs .* activeTypes; + [~,labelTypeIdxs] = max( eTypes, [], 2 ); + y = labelTypeIdxs .* double( any( activeTypes, 2 ) ); + case 'order' + labelTypeIdxs = activeTypes .* repmat( 1:size( activeTypes, 2), size( activeTypes, 1), 1 ); + y = min( labelTypeIdxs, [], 2 ); + otherwise + error( 'AMLTTP:unknownOptionValue', ['%s: unknown option value.'... + 'Use ''energy'' or ''order''.'], obj.srcPrioMethod ); + end + ysi_ = y; + ysi_(ysi_==0) = 1; + ysi = srcIdxs(sub2ind( size( srcIdxs ), 1:size( srcIdxs, 1 ), ysi_' )); + end + if strcmp( obj.negOut, 'rest' ) + y(y==0) = -1; + else + y(y==0) = NaN; + return; + end +% if ~isempty( obj.segIdTargetSrcFilter ) +% for ii = 1 : size( obj.segIdTargetSrcFilter, 1 ) +% srcf = obj.segIdTargetSrcFilter(ii,1); +% typef = obj.segIdTargetSrcFilter(ii,2); +% srcfAzm = obj.lastConfig{obj.sceneId}.preceding.preceding.preceding.preceding.preceding.sceneCfg.sources(srcf).azimuth; +% if isa( srcfAzm, 'SceneConfig.ValGen' ) +% srcfAzm = srcfAzm.val; +% end +% if activeTypes(typef) && any( abs( blockAnnotations.srcAzms(srcIdxs{typef}) - srcfAzm ) >= 0.1 ) +% y = NaN; +% return; +% end +% end +% end + for ii = 1 : size( obj.srcTypeFilterOut, 1 ) + srcfo = obj.srcTypeFilterOut(ii,1); + typefo = obj.srcTypeFilterOut(ii,2); + fo_lidxs = activeTypes(:,typefo) ... + & cellfun( @(si)(any( si == srcfo )), srcIdxs(:,typefo) ); + y(fo_lidxs) = NaN; + end +% for ii = 1 : numel( obj.fileFilterOut ) +% if any( strcmpi( obj.fileFilterOut{ii}, blockAnnotations.srcFile.srcFile(:,1) ) ) +% y = NaN; +% return; +% end +% end + end + %% ------------------------------------------------------------------------------- + + function [activeTypes, srcIdxs] = getActiveTypes( obj, blockAnnotations ) + ts = blockAnnotations.globalSrcEnergy.t; + activeTypes = zeros( numel( ts ), numel( obj.types ) ); + srcIdxs = cell( numel( ts ), numel( obj.types ) ); + eventOnsets = blockAnnotations.srcType.t.onset; + eventOffsets = blockAnnotations.srcType.t.offset; + for tt = 1 : numel( obj.types ) + eventsAreType = cellfun( @(ba)(obj.eventIsType( tt, ba )), ... + blockAnnotations.srcType.srcType(:,1) ); + srcIdxs_tt = [blockAnnotations.srcType.srcType{eventsAreType,2}]; + eventOnOffs_tt = [eventOnsets(eventsAreType)',eventOffsets(eventsAreType)']; + eventOnOffs_tt = eventOnOffs_tt - ts(1) + 1; + if ~isempty( eventOnOffs_tt ) + for ii = 1:2 + eventOnOffs_tt(:,ii) = max( ... + [zeros( size(eventOnOffs_tt, 1), 1 ), ... + eventOnOffs_tt(:,ii)], [], 2 ); + eventOnOffs_tt(:,ii) = min( ... + [repmat( numel( ts ), size(eventOnOffs_tt, 1), 1 ), ... + eventOnOffs_tt(:,ii)], [], 2 ); + end + else + eventOnOffs_tt = []; + end + for jj = 1 : size( eventOnOffs_tt, 1 ) + event_jj_idxs = eventOnOffs_tt(jj,1) : eventOnOffs_tt(jj,2); + activeTypes(event_jj_idxs,tt) = 1; + srcIdxs(event_jj_idxs,tt) = cellfun( @(a,b)([a,b]), ... + srcIdxs(event_jj_idxs,tt), ... + repmat( {srcIdxs_tt(jj)}, numel( event_jj_idxs ), 1 ), ... + 'UniformOutput', false ); + end + end + end + %% ------------------------------------------------------------------------------- + + end + %% ----------------------------------------------------------------------------------- + + methods (Static) + + %% ------------------------------------------------------------------------------- + + end + +end + + + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiLabeler.m index 1b51402..f58d9ba 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiLabeler.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/MultiLabeler.m @@ -23,11 +23,14 @@ %% ----------------------------------------------------------------------------------- methods (Access = protected) - function y = label( obj, blockAnnotations ) + function [y,ysi] = label( obj, blockAnnotations ) y = []; + ysi = {}; for ii = 1 : numel( obj.individualLabelers ) obj.individualLabelers{ii}.labelBlockSize_s = obj.labelBlockSize_s; - y = [y, obj.individualLabelers{ii}.label( blockAnnotations )]; + [yii,ysiii] = obj.individualLabelers{ii}.label( blockAnnotations ); + y = [y, yii]; %#ok + ysi = [ysi, ysiii]; %#ok end end %% ------------------------------------------------------------------------------- @@ -38,7 +41,7 @@ outputDeps.(outDepName) = ... obj.individualLabelers{ii}.getLabelInternOutputDependencies; end - outputDeps.v = 1; + outputDeps.v = 2; end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/NumberOfSourcesLabeler.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/NumberOfSourcesLabeler.m index f988e33..4f97118 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/NumberOfSourcesLabeler.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/NumberOfSourcesLabeler.m @@ -29,14 +29,16 @@ function outputDeps = getLabelInternOutputDependencies( obj ) outputDeps.srcMinEnergy = obj.srcMinEnergy; - outputDeps.v = 1; + outputDeps.v = 2; end %% ------------------------------------------------------------------------------- - function y = label( obj, blockAnnotations ) + function [y,ysi] = label( obj, blockAnnotations ) pointSrcIdxs = ~isnan( blockAnnotations.srcAzms ) ; - srcsBlockEnergies = cellfun( @mean, blockAnnotations.srcEnergy(pointSrcIdxs) ); - y = sum( srcsBlockEnergies > obj.srcMinEnergy ); + srcsBlockEnergies = cellfun( @mean, blockAnnotations.globalSrcEnergy(pointSrcIdxs) ); + activeSources = srcsBlockEnergies > obj.srcMinEnergy; + y = sum( activeSources ); + ysi = {find( activeSources )}; end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/TimeSeriesLabelCreator.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/TimeSeriesLabelCreator.m new file mode 100644 index 0000000..9aedd38 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+LabelCreators/TimeSeriesLabelCreator.m @@ -0,0 +1,88 @@ +classdef TimeSeriesLabelCreator < LabelCreators.Base + % Base Abstract base class for labeling blocks + %% ----------------------------------------------------------------------------------- + properties (SetAccess = protected) + end + + %% ----------------------------------------------------------------------------------- + methods (Abstract, Access = protected) + outputDeps = getLabelInternOutputDependencies( obj ) + [y, ysi] = label( obj, annotations ) + end + + %% ----------------------------------------------------------------------------------- + methods + + function obj = TimeSeriesLabelCreator( varargin ) + ip = inputParser; + ip.addOptional( 'removeUnclearBlocks', 'sequence-wise' ); % 'sequence-wise', 'time-wise', 'false' + ip.parse( varargin{:} ); + obj = obj@LabelCreators.Base(); + obj.removeUnclearBlocks = ip.Results.removeUnclearBlocks; + if ~any( strcmpi( obj.removeUnclearBlocks, {'false','sequence-wise','time-wise'} ) ) + error( 'AMLTTP:usage:unsupportedOptionSetting', 'use one of ''false'',''block-wise'',''time-wise''.' ); + end + end + %% ------------------------------------------------------------------------------- + + end + + %% ----------------------------------------------------------------------------------- + methods (Access = protected) + + % override of LabelCreators.Base's method + function outputDeps = getInternOutputDependencies( obj ) + outputDeps.base = getInternOutputDependencies@LabelCreators.Base( obj ); + outputDeps.ts_v = 1; + outputDeps.labelProc = obj.getLabelInternOutputDependencies(); + end + %% ------------------------------------------------------------------------------- + + % override of LabelCreators.Base's method + function out = getOutput( obj, varargin ) + out.y = obj.y; + out.bIdxs = 1 : size( out.y, 1 ); + removeNanBlocks = strcmpi( obj.removeUnclearBlocks, ... + {'sequence-wise','time-wise'} ); + if ~any( removeNanBlocks ) || any( strcmpi( 'noRemoveNanBlocks', varargin ) ) + removeNanBlocks_lidx = []; + else + error( 'AMLTTP:notImplemented', 'data removal for time-series not implemented yet' ); + removeNanBlocks_lidx = any( isnan( out.y ), 3 ); + if removeNanBlocks(2) + error( 'AMLTTP:notImplemented', 'time-wise data removal for time-series not implemented yet' ); + [~,~,sameTimeIdxs] = unique( [obj.blockAnnotations.blockOffset] ); + nanTimeIdxs = sameTimeIdxs(removeNanBlocks_lidx); + removeNanBlocks_lidx = ismember( sameTimeIdxs, nanTimeIdxs ); + end + end + if nargin < 2 || any( strcmpi( 'x', varargin ) ) + out.x = obj.x; + out.x(removeNanBlocks_lidx,:,:) = []; + end + if nargin < 2 || any( strcmpi( 'a', varargin ) ) + out.a = obj.blockAnnotations; + out.a(removeNanBlocks_lidx) = []; + end + if nargin < 2 || any( strcmpi( 'ysi', varargin ) ) + out.ysi = obj.ysi; + out.ysi(removeNanBlocks_lidx) = []; + end + out.bIdxs(removeNanBlocks_lidx) = []; + out.y(removeNanBlocks_lidx,:,:) = []; + end + %% ------------------------------------------------------------------------------- + + end + %% ----------------------------------------------------------------------------------- + + methods (Static) + + %% ------------------------------------------------------------------------------- + + end + +end + + + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGMMmodelSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGMMmodelSelectTrainer.m index 37087fe..f3528a7 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGMMmodelSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGMMmodelSelectTrainer.m @@ -40,7 +40,7 @@ function run( obj ) end %% ---------------------------------------------------------------- - function buildModel( obj, ~, ~ ) + function buildModel( obj, ~, ~, ~ ) comps = obj.nComp; thrs = obj.thr; for nt=1:numel(thrs) diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGmmNetTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGmmNetTrainer.m index b47f177..b5513f9 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGmmNetTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BGmmNetTrainer.m @@ -29,7 +29,7 @@ end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) obj.model = Models.BGmmNetModel(); xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); gmmOpts.nComp = obj.nComp; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BMFATrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BMFATrainer.m index b306bda..e08ac3f 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BMFATrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/BMFATrainer.m @@ -25,7 +25,7 @@ end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) obj.model = Models.BMFAModel(); xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); gmmOpts.mfaK = obj.nComp; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/Base.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/Base.m index db40d65..e537592 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/Base.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/Base.m @@ -1,4 +1,4 @@ -classdef (Abstract) Base < handle +classdef (Abstract) Base < handle & Parameterized %% -------------------------------------------------------------------- properties (SetAccess = protected) @@ -9,11 +9,32 @@ properties (SetAccess = {?ModelTrainers.Base, ?Parameterized}) performanceMeasure; maxDataSize; + dataSelector; + importanceWeighter; end %% -------------------------------------------------------------------- methods + function obj = Base( varargin ) + pds{1} = struct( 'name', 'performanceMeasure', ... + 'default', @PerformanceMeasures.BAC2, ... + 'valFun', @(x)(isa( x, 'function_handle' )), ... + 'setCallback', @(ob, n, o)(ob.setPerformanceMeasure( n )) ); + pds{2} = struct( 'name', 'maxDataSize', ... + 'default', inf, ... + 'valFun', @(x)(isinf(x) || (rem(x,1) == 0 && x > 0)) ); + pds{3} = struct( 'name', 'dataSelector', ... + 'default', DataSelectors.IgnorantSelector(), ... + 'valFun', @(x)(isa( x, 'DataSelectors.Base') ) ); + pds{4} = struct( 'name', 'importanceWeighter', ... + 'default', ImportanceWeighters.IgnorantWeighter(), ... + 'valFun', @(x)(isa( x, 'ImportanceWeighters.Base') ) ); + obj = obj@Parameterized( pds ); + obj.setParameters( true, varargin{:} ); + end + %% ---------------------------------------------------------------- + function setData( obj, trainSet, testSet ) obj.trainSet = trainSet; if ~exist( 'testSet', 'var' ), testSet = []; end @@ -58,49 +79,52 @@ function setPerformanceMeasure( obj, newPerformanceMeasure ) %% ------------------------------------------------------------------------------- function performance = getPerformance( obj, getDatapointInfo ) - if nargin < 2, getDatapointInfo = 'noInfo'; end + if nargin < 2, getDatapointInfo = false; end verboseFprintf( obj, 'Applying model to test set...\n' ); model = obj.getModel(); model.verbose( obj.verbose ); performance = Models.Base.getPerformance( ... model, obj.testSet, obj.performanceMeasure, ... - obj.maxDataSize, true, getDatapointInfo ); + obj.maxDataSize, obj.dataSelector, obj.importanceWeighter, getDatapointInfo ); end %% ---------------------------------------------------------------- function run( obj ) - [x,y] = obj.getPermutedTrainingData(); + obj.dataSelector.connectData( obj.trainSet ); + obj.importanceWeighter.connectData( obj.trainSet ); + [x,y,sampleIds] = obj.getPermutedTrainingData(); nanXidxs = any( isnan( x ), 2 ); infXidxs = any( isinf( x ), 2 ); if any( nanXidxs ) || any( infXidxs ) warning( 'There are NaNs or INFs in the data -- throwing those vectors away!' ); x(nanXidxs | infXidxs,:) = []; y(nanXidxs | infXidxs,:) = []; + sampleIds(nanXidxs | infXidxs) = []; end - if numel( y ) > obj.maxDataSize - if ModelTrainers.Base.balMaxData - throwoutIdxs = ModelTrainers.Base.getBalThrowoutIdxs( y, obj.maxDataSize ); - x(throwoutIdxs,:) = []; - y(throwoutIdxs,:) = []; - else - x(obj.maxDataSize+1:end,:) = []; - y(obj.maxDataSize+1:end,:) = []; - end + if size( y, 1 ) > obj.maxDataSize + selectFilter = obj.dataSelector.getDataSelection( sampleIds, obj.maxDataSize ); + verboseFprintf( obj, obj.dataSelector.verboseOutput ); + x = x(selectFilter,:); + y = y(selectFilter,:); + sampleIds = sampleIds(selectFilter); end - obj.buildModel( x, y ); + iw = obj.importanceWeighter.getImportanceWeights( sampleIds ); + verboseFprintf( obj, obj.importanceWeighter.verboseOutput ); + obj.buildModel( x, y, iw ); end %% ---------------------------------------------------------------- - function [x,y] = getPermutedTrainingData( obj ) + function [x,y,permutationIdxs] = getPermutedTrainingData( obj ) x = obj.trainSet(:,'x'); if isempty( x ) warning( 'There is no data to train the model.' ); y = []; + permutationIdxs = []; return; else y = obj.trainSet(:,'y'); end - % apply the mask + % apply feature mask, if set fmask = ModelTrainers.Base.featureMask; if ~isempty( fmask ) p_feat = size( x, 2 ); @@ -109,7 +133,7 @@ function run( obj ) x = x(:,fmask); end % permute data - permutationIdxs = randperm( length( y ) ); + permutationIdxs = randperm( size( y, 1 ) )'; x = x(permutationIdxs,:); y = y(permutationIdxs,:); end @@ -120,7 +144,7 @@ function run( obj ) %% -------------------------------------------------------------------- methods (Abstract) - buildModel( obj, x, y ) + buildModel( obj, x, y, iw ) end %% -------------------------------------------------------------------- @@ -130,35 +154,7 @@ function run( obj ) %% -------------------------------------------------------------------- methods (Static) - - function throwoutIdxs = getBalThrowoutIdxs( y, maxDataSize ) - labels = unique( y ); - nPerLabel = arrayfun( @(l)(sum( l == y )), labels ); - [~, labelOrder] = sort( nPerLabel ); - nLabels = numel( labels ); - nRemaining = maxDataSize; - throwoutIdxs = []; - for ii = labelOrder' - nKeep = min( int32( nRemaining/nLabels ), nPerLabel(ii) ); - nRemaining = nRemaining - nKeep; - nLabels = nLabels - 1; - lIdxs = find( y == labels(ii) ); - lIdxs = lIdxs(randperm(nPerLabel(ii))); - throwoutIdxs = [throwoutIdxs; lIdxs(nKeep+1:end)]; - end - end - function b = balMaxData( setNewValue, newValue ) - persistent balMaxD; - if isempty( balMaxD ) - balMaxD = false; - end - if nargin > 0 && setNewValue - balMaxD = newValue; - end - b = balMaxD; - end - function fm = featureMask( setNewMask, newmask ) % Set/Reset the featureMask and return it. % featureMask() reset the featurMask diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/CVtrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/CVtrainer.m index 30098c8..865249f 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/CVtrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/CVtrainer.m @@ -39,7 +39,7 @@ function run( obj ) end %% ---------------------------------------------------------------- - function buildModel( obj, ~, ~ ) + function buildModel( obj, ~, ~, ~ ) obj.trainer.setPerformanceMeasure( obj.performanceMeasure ); obj.createFolds(); obj.foldsPerformance = ones( obj.nFolds, 1 ); @@ -50,7 +50,7 @@ function buildModel( obj, ~, ~ ) obj.trainer.run(); obj.models{ff} = obj.trainer.getModel(); obj.foldsPerformance(ff) = double( obj.trainer.getPerformance() ); - verboseFprintf( obj, 'Done. Performance = %f\n', obj.foldsPerformance(ff) ); + verboseFprintf( obj, 'Done. Performance = %f\n\n', obj.foldsPerformance(ff) ); maxPossiblePerf = mean( obj.foldsPerformance ); if (ff < obj.nFolds) && (maxPossiblePerf <= obj.abortPerfMin) break; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GMMmodelSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GMMmodelSelectTrainer.m index 0282517..c8765d0 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GMMmodelSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GMMmodelSelectTrainer.m @@ -40,7 +40,7 @@ function run( obj ) end %% ---------------------------------------------------------------- - function buildModel( obj, ~, ~ ) + function buildModel( obj, ~, ~, ~ ) comps = obj.nComp; thrs = obj.thr; for nt=1:numel(thrs) diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupLambdaSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupLambdaSelectTrainer.m index 7b7f206..e87acf8 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupLambdaSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupLambdaSelectTrainer.m @@ -51,7 +51,7 @@ function run(self) end %% BUILD MODEL - function buildModel(self, ~, ~) + function buildModel(self, ~, ~, ~) verboseFprintf(self, '\nRun on full trainSet...\n'); % run core trainer once to determine the lambda path self.trainer_core = ModelTrainers.GlmGroupTrainer( ... diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupTrainer.m index 981603a..5f71405 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmGroupTrainer.m @@ -41,7 +41,7 @@ end %% BUILD MODEL - function buildModel(self, x, y) + function buildModel(self, x, y, iw) % init self.model = Models.GlmGroupModel(); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetLambdaSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetLambdaSelectTrainer.m index 842978a..3fd34a3 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetLambdaSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetLambdaSelectTrainer.m @@ -14,24 +14,16 @@ family; nLambda; % number of lambdas on the regularization path cvFolds; % no. of folds for cross validation - labelWeights; end %% -------------------------------------------------------------------- methods function obj = GlmNetLambdaSelectTrainer( varargin ) - pds{1} = struct( 'name', 'performanceMeasure', ... - 'default', @PerformanceMeasures.BAC2, ... - 'valFun', @(x)(isa( x, 'function_handle' )), ... - 'setCallback', @(ob, n, o)(ob.setPerformanceMeasure( n )) ); - pds{2} = struct( 'name', 'maxDataSize', ... - 'default', inf, ... - 'valFun', @(x)(isinf(x) || (rem(x,1) == 0 && x > 0)) ); - pds{3} = struct( 'name', 'alpha', ... + pds{1} = struct( 'name', 'alpha', ... 'default', 1, ... 'valFun', @(x)(isfloat(x) && x >= 0 && x <= 1.0) ); - pds{4} = struct( 'name', 'family', ... + pds{2} = struct( 'name', 'family', ... 'default', 'binomial', ... 'valFun', @(x)(ischar(x) && any(strcmpi(x, ... {'binomial',... @@ -39,16 +31,14 @@ 'multinomialGrouped',... 'gaussian',... 'poisson'}))) ); - pds{5} = struct( 'name', 'nLambda', ... + pds{3} = struct( 'name', 'nLambda', ... 'default', 100, ... 'valFun', @(x)(rem(x,1) == 0 && x >= 0) ); - pds{6} = struct( 'name', 'cvFolds', ... + pds{4} = struct( 'name', 'cvFolds', ... 'default', 10, ... 'valFun', @(x)(rem(x,1) == 0 && x >= 0) ); - pds{7} = struct( 'name', 'labelWeights', ... - 'default', [], ... - 'valFun', @(x)(isempty(x) || isfloat(x)) ); obj = obj@Parameterized( pds ); + obj = obj@ModelTrainers.Base( varargin{:} ); obj.setParameters( true, varargin{:} ); end %% ------------------------------------------------------------------------------- @@ -58,15 +48,16 @@ function run( obj ) end %% ---------------------------------------------------------------- - function buildModel( obj, ~, ~ ) + function buildModel( obj, ~, ~, ~ ) verboseFprintf( obj, '\nRun on full trainSet...\n' ); obj.coreTrainer = ModelTrainers.GlmNetTrainer( ... 'performanceMeasure', obj.performanceMeasure, ... 'maxDataSize', obj.maxDataSize, ... + 'dataSelector', obj.dataSelector, ... 'alpha', obj.alpha, ... 'family', obj.family, ... 'nLambda', obj.nLambda, ... - 'labelWeights', obj.labelWeights ); + 'importanceWeighter', obj.importanceWeighter ); obj.coreTrainer.setData( obj.trainSet, obj.testSet ); obj.coreTrainer.run(); obj.fullSetModel = obj.coreTrainer.getModel(); @@ -83,10 +74,10 @@ function buildModel( obj, ~, ~ ) lPerfs = zeros( numel( lambdas ), numel( cvModels ) ); for ii = 1 : numel( cvModels ) cvModels{ii}.setLambda( [] ); - thisFoldPerfs = Models.Base.getPerformance( ... - cvModels{ii}, obj.cvTrainer.folds{ii}, ... - obj.performanceMeasure ); - lPerfs(1:numel(thisFoldPerfs),ii) = thisFoldPerfs; + foldPerfs = Models.Base.getPerformance( ... + cvModels{ii}, obj.cvTrainer.folds{ii}, obj.performanceMeasure, ... + obj.maxDataSize, obj.dataSelector, obj.importanceWeighter, false ); + lPerfs(1:numel(foldPerfs),ii) = foldPerfs; verboseFprintf( obj, '.' ); end obj.fullSetModel.lPerfsMean = nanMean( lPerfs, 2 ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetModelSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetModelSelectTrainer.m index 1f6894c..7e7a525 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetModelSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetModelSelectTrainer.m @@ -16,11 +16,9 @@ obj = obj@ModelTrainers.HpsTrainer( varargin{:} ); obj.setParameters( true, ... 'buildCoreTrainer', @GlmNetLambdaSelectTrainer, ... - 'hpsCoreTrainerParams', {'cvFolds', 2,}, ... - varargin{:} ); - obj.setParameters( false, ... - 'finalCoreTrainerParams', ... - {'cvFolds', 2,} ); + 'hpsCoreTrainerParams', {'cvFolds', 2,}, ... + varargin{:} ); + obj.setParameters( false, 'finalCoreTrainerParams', {'cvFolds', 2,} ); end %% ------------------------------------------------------------------------------- @@ -41,21 +39,20 @@ %% ------------------------------------------------------------------------------- function refinedHpsTrainer = refineGridTrainer( obj, hps ) - refinedHpsTrainer = GlmNetModelSelectTrainer( 'makeProbModel', obj.makeProbModel, ... - 'buildCoreTrainer', obj.buildCoreTrainer, ... - 'hpsCoreTrainerParams', obj.hpsCoreTrainerParams, ... - 'finalCoreTrainerParams', obj.finalCoreTrainerParams, ... - 'hpsMaxDataSize', obj.hpsMaxDataSize, ... - 'hpsRefineStages', obj.hpsRefineStages, ... - 'hpsSearchBudget', obj.hpsSearchBudget, ... - 'hpsCvFolds', obj.hpsCvFolds, ... - 'hpsMethod', obj.hpsMethod, ... - 'performanceMeasure', obj.performanceMeasure ); + refinedHpsTrainer = GlmNetModelSelectTrainer( ... + 'buildCoreTrainer', obj.buildCoreTrainer, ... + 'hpsCoreTrainerParams', obj.hpsCoreTrainerParams, ... + 'finalCoreTrainerParams', obj.finalCoreTrainerParams, ... + 'hpsMaxDataSize', obj.hpsMaxDataSize, ... + 'hpsRefineStages', obj.hpsRefineStages, ... + 'hpsSearchBudget', obj.hpsSearchBudget, ... + 'hpsCvFolds', obj.hpsCvFolds, ... + 'hpsMethod', obj.hpsMethod, ... + 'performanceMeasure', obj.performanceMeasure ); best3LogMean = @(fn)(mean( log10( [hps.params(end-2:end).(fn)] ) )); aRefinedRange = 10.^getCenteredHalfRange( ... - log10(obj.hpsAlphaRange), best3LogMean('alpha') ); - refinedHpsTrainer.setParameters( false, ... - 'hpsAlphaRange', aRefinedRange ); + log10(obj.hpsAlphaRange), best3LogMean('alpha') ); + refinedHpsTrainer.setParameters( false, 'hpsAlphaRange', aRefinedRange ); end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetTrainer.m index a890292..a43c8cd 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GlmNetTrainer.m @@ -7,21 +7,16 @@ family; nLambda; lambda; - labelWeights; end %% -------------------------------------------------------------------- methods function obj = GlmNetTrainer( varargin ) - pds{1} = struct( 'name', 'performanceMeasure', ... - 'default', @PerformanceMeasures.BAC2, ... - 'valFun', @(x)(isa( x, 'function_handle' )), ... - 'setCallback', @(ob, n, o)(ob.setPerformanceMeasure( n )) ); - pds{2} = struct( 'name', 'alpha', ... + pds{1} = struct( 'name', 'alpha', ... 'default', 1, ... 'valFun', @(x)(isfloat(x) && x >= 0 && x <= 1.0) ); - pds{3} = struct( 'name', 'family', ... + pds{2} = struct( 'name', 'family', ... 'default', 'binomial', ... 'valFun', @(x)(ischar(x) && any(strcmpi(x, ... {'binomial',... @@ -29,25 +24,20 @@ 'multinomialGrouped',... 'gaussian',... 'poisson'}))) ); - pds{4} = struct( 'name', 'nLambda', ... + pds{3} = struct( 'name', 'nLambda', ... 'default', 100, ... 'valFun', @(x)(rem(x,1) == 0 && x >= 0) ); - pds{5} = struct( 'name', 'lambda', ... - 'default', [], ... - 'valFun', @(x)(isempty(x) || isfloat(x)) ); - pds{6} = struct( 'name', 'maxDataSize', ... - 'default', inf, ... - 'valFun', @(x)(isinf(x) || (rem(x,1) == 0 && x > 0)) ); - pds{7} = struct( 'name', 'labelWeights', ... + pds{4} = struct( 'name', 'lambda', ... 'default', [], ... 'valFun', @(x)(isempty(x) || isfloat(x)) ); obj = obj@Parameterized( pds ); + obj = obj@ModelTrainers.Base( varargin{:} ); obj.setParameters( true, varargin{:} ); end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) - glmOpts.weights = obj.setDataWeights( y ); + function buildModel( obj, x, y, iw ) + glmOpts.weights = iw; obj.model = Models.GlmNetModel(); x(isnan(x)) = 0; x(isinf(x)) = 0; @@ -64,7 +54,7 @@ function buildModel( obj, x, y ) else family = obj.family; end - verboseFprintf( obj, 'GlmNet training with alpha=%f\n', glmOpts.alpha ); + verboseFprintf( obj, '\nGlmNet training with alpha=%f\n', glmOpts.alpha ); verboseFprintf( obj, '\tsize(x) = %dx%d\n', size(xScaled,1), size(xScaled,2) ); obj.model.model = glmnet( xScaled, y, family, glmOpts ); verboseFprintf( obj, '\n' ); @@ -81,24 +71,6 @@ function buildModel( obj, x, y ) end %% ---------------------------------------------------------------- - function wp = setDataWeights( obj, y ) - wp = ones( size(y) ); - for cc = 1 : size( y, 2 ) - labels = unique( y(:,cc) ); - lw = obj.labelWeights; - if numel( lw ) ~= numel( labels ) - lw = ones( size( labels ) ); - end - for ii = 1 : numel( labels ) - labelShare = sum( y(:,cc) == labels(ii) ) / size( y, 1 ); - labelWeight = lw(ii) / labelShare; - wp(y(:,cc)==labels(ii),cc) = labelWeight; - end - end - wp = mean( wp, 2 ); - end - %% ---------------------------------------------------------------- - end end \ No newline at end of file diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GmmNetTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GmmNetTrainer.m index 92c85dd..9253335 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GmmNetTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/GmmNetTrainer.m @@ -29,7 +29,7 @@ end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) obj.model = Models.GmmNetModel(); xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); gmmOpts.nComp = obj.nComp; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/HpsTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/HpsTrainer.m index f1639ef..51a87c5 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/HpsTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/HpsTrainer.m @@ -22,40 +22,37 @@ methods function obj = HpsTrainer( varargin ) - pds{1} = struct( 'name', 'performanceMeasure', ... - 'default', @PerformanceMeasures.BAC2, ... - 'valFun', @(x)(isa( x, 'function_handle' )), ... - 'setCallback', @(ob, n, o)(ob.setPerformanceMeasure( n )) ); - pds{2} = struct( 'name', 'buildCoreTrainer', ... + pds{1} = struct( 'name', 'buildCoreTrainer', ... 'default', [], ... 'valFun', @(x)(~isempty( x ) && ... isa( x, 'function_handle' )) ); - pds{3} = struct( 'name', 'hpsCoreTrainerParams', ... + pds{2} = struct( 'name', 'hpsCoreTrainerParams', ... 'default', {{}}, ... 'valFun', @(x)(iscell( x )) ); - pds{4} = struct( 'name', 'finalCoreTrainerParams', ... + pds{3} = struct( 'name', 'finalCoreTrainerParams', ... 'default', {{}}, ... 'valFun', @(x)(iscell( x )) ); - pds{5} = struct( 'name', 'hpsMaxDataSize', ... + pds{4} = struct( 'name', 'hpsMaxDataSize', ... 'default', inf, ... 'valFun', @(x)(isinf(x) || (rem(x,1) == 0 && x > 0)) ); - pds{6} = struct( 'name', 'hpsRefineStages', ... + pds{5} = struct( 'name', 'hpsRefineStages', ... 'default', 1, ... 'valFun', @(x)(rem(x,1) == 0 && x >= 0) ); - pds{7} = struct( 'name', 'hpsSearchBudget', ... + pds{6} = struct( 'name', 'hpsSearchBudget', ... 'default', 8, ... 'valFun', @(x)(rem(x,1) == 0 && x >= 0) ); - pds{8} = struct( 'name', 'hpsCvFolds', ... + pds{7} = struct( 'name', 'hpsCvFolds', ... 'default', 4, ... 'valFun', @(x)(rem(x,1) == 0 && x >= 0) ); - pds{9} = struct( 'name', 'hpsMethod', ... + pds{8} = struct( 'name', 'hpsMethod', ... 'default', 'grid', ... 'valFun', @(x)(... ischar(x) && any(strcmpi(x, {'grid','random'}))) ); - pds{10} = struct( 'name', 'finalMaxDataSize', ... + pds{9} = struct( 'name', 'finalMaxDataSize', ... 'default', inf, ... 'valFun', @(x)(isinf(x) || (rem(x,1) == 0 && x > 0)) ); obj = obj@Parameterized( pds ); + obj = obj@ModelTrainers.Base( varargin{:} ); obj.setParameters( true, varargin{:} ); end %% ------------------------------------------------------------------------------- @@ -65,7 +62,7 @@ function run( obj ) end %% ---------------------------------------------------------------- - function buildModel( obj, ~, ~ ) + function buildModel( obj, ~, ~, ~ ) obj.coreTrainer = obj.buildCoreTrainer(); obj.createHpsTrainer(); hps.params = obj.determineHyperparameterSets(); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/LoadModelNoopTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/LoadModelNoopTrainer.m index 85de982..cb7f5fb 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/LoadModelNoopTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/LoadModelNoopTrainer.m @@ -10,23 +10,17 @@ methods function obj = LoadModelNoopTrainer( modelPath, varargin ) - pds{1} = struct( 'name', 'performanceMeasure', ... - 'default', @PerformanceMeasures.BAC2, ... - 'valFun', @(x)(isa( x, 'function_handle' )), ... - 'setCallback', @(ob, n, o)(ob.setPerformanceMeasure( n )) ); - pds{2} = struct( 'name', 'modelParams', ... + pds{1} = struct( 'name', 'modelParams', ... 'default', struct(), ... 'valFun', @(x)(isstruct( x )) ); - pds{3} = struct( 'name', 'maxDataSize', ... - 'default', inf, ... - 'valFun', @(x)(isinf(x) || (rem(x,1) == 0 && x > 0)) ); obj = obj@Parameterized( pds ); + obj = obj@ModelTrainers.Base( varargin{:} ); obj.setParameters( true, varargin{:} ); obj.modelPath = modelPath; end %% ---------------------------------------------------------------- - function buildModel( ~, ~, ~ ) + function buildModel( ~, ~, ~, ~ ) % noop end %% ---------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFTrainer.m index b36c6a6..d50dc9d 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFTrainer.m @@ -25,7 +25,7 @@ end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) obj.model = Models.MbfModel(); xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); gmmOpts.nComp = obj.nComp; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFmodelSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFmodelSelectTrainer.m index 0d115a2..19b8d41 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFmodelSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MBFmodelSelectTrainer.m @@ -40,7 +40,7 @@ function run( obj ) end %% ---------------------------------------------------------------- - function buildModel( obj, ~, ~ ) + function buildModel( obj, ~, ~, ~ ) comps = obj.nComp; thrs = obj.thr; for nt=1:numel(thrs) diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFATrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFATrainer.m index 0395dcf..8cb1cc2 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFATrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFATrainer.m @@ -21,7 +21,7 @@ end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) obj.model = Models.MFAModel(); xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); gmmOpts.mfaK = 10;%0.5*size(xScaled,2); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFAmodelSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFAmodelSelectTrainer.m index 3332897..6210aac 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFAmodelSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MFAmodelSelectTrainer.m @@ -40,7 +40,7 @@ function run( obj ) end %% ---------------------------------------------------------------- - function buildModel( obj, ~, ~ ) + function buildModel( obj, ~, ~, ~ ) comps = obj.nComp; nDims = obj.nDim; for nt=1:numel(nDims) diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MbfNetTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MbfNetTrainer.m index d9b8c22..962dfac 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MbfNetTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MbfNetTrainer.m @@ -29,7 +29,7 @@ end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) % glmOpts.weights = obj.setDataWeights( y ); obj.model = Models.MbfNetModel(); xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MfaNetTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MfaNetTrainer.m index 3e73e5b..3cc57d0 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MfaNetTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/MfaNetTrainer.m @@ -29,7 +29,7 @@ end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) obj.model = Models.MfaNetModel(); xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); mbfOpts.nComp = obj.nComp; diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMmodelSelectTrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMmodelSelectTrainer.m index 6897456..dfd5490 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMmodelSelectTrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMmodelSelectTrainer.m @@ -32,11 +32,10 @@ obj = obj@ModelTrainers.HpsTrainer( varargin{:} ); obj.setParameters( true, ... 'buildCoreTrainer', @ModelTrainers.SVMtrainer, ... - 'hpsCoreTrainerParams', {'makeProbModel', false}, ... - varargin{:} ); - obj.setParameters( false, ... - 'finalCoreTrainerParams', ... - {'makeProbModel', obj.makeProbModel} ); + 'hpsCoreTrainerParams', {'makeProbModel', false}, ... + varargin{:} ); + obj.setParameters( false, 'finalCoreTrainerParams', ... + {'makeProbModel', obj.makeProbModel} ); end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMtrainer.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMtrainer.m index c1333fb..6b86647 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMtrainer.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/SVMtrainer.m @@ -14,39 +14,39 @@ methods function obj = SVMtrainer( varargin ) - pds{1} = struct( 'name', 'performanceMeasure', ... - 'default', @PerformanceMeasures.BAC2, ... - 'valFun', @(x)(isa( x, 'function_handle' )), ... - 'setCallback', @(ob, n, o)(ob.setPerformanceMeasure( n )) ); - pds{2} = struct( 'name', 'epsilon', ... + pds{1} = struct( 'name', 'epsilon', ... 'default', 0.001, ... 'valFun', @(x)(isfloat(x) && x > 0) ); - pds{3} = struct( 'name', 'kernel', ... + pds{2} = struct( 'name', 'kernel', ... 'default', 0, ... 'valFun', @(x)(rem(x,1) == 0 && all(x == 0 | x == 2)) ); - pds{4} = struct( 'name', 'c', ... + pds{3} = struct( 'name', 'c', ... 'default', 1, ... 'valFun', @(x)(isfloat(x) && x > 0) ); - pds{5} = struct( 'name', 'gamma', ... + pds{4} = struct( 'name', 'gamma', ... 'default', 0.1, ... 'valFun', @(x)(isfloat(x) && x > 0) ); - pds{6} = struct( 'name', 'maxDataSize', ... - 'default', inf, ... - 'valFun', @(x)(isinf(x) || (rem(x,1) == 0 && x > 0)) ); - pds{7} = struct( 'name', 'makeProbModel', ... + pds{5} = struct( 'name', 'makeProbModel', ... 'default', false, ... 'valFun', @islogical ); obj = obj@Parameterized( pds ); + obj = obj@ModelTrainers.Base( varargin{:} ); obj.setParameters( true, varargin{:} ); end %% ---------------------------------------------------------------- - function buildModel( obj, x, y ) + function buildModel( obj, x, y, iw ) + if ~all( iw ) + warning( 'AMLTTP:usage:unsupported', ... + ['SVmtrainer can''t use individual sample importance weights '... + 'produced bei ImportanceWeighter. '... + 'Instead, class-wide weights will be used.'] ); + end [x, y, cp] = obj.prepareData( x, y ); obj.model = Models.SVMmodel(); obj.model.useProbModel = obj.makeProbModel; xScaled = obj.model.scale2zeroMeanUnitVar( x, 'saveScalingFactors' ); - m = ceil( prod( size( x ) ) * 8 / (1024 * 1024) ); + m = ceil( numel( x ) * 8 / (1024 * 1024) ); m = min( 2*m, 2000 ); svmParamStrScheme = '-t %d -g %e -c %e -w-1 1 -w1 %e -e %e -m %d -b %d -h 0'; svmParamStr = sprintf( svmParamStrScheme, ... diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/featureSelectionPCA2.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/featureSelectionPCA2.m deleted file mode 100644 index a8436dc..0000000 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+ModelTrainers/featureSelectionPCA2.m +++ /dev/null @@ -1,21 +0,0 @@ -function [idFeature] = featureSelectionPCA2(x,thr) -if nargin<2 - thr = 0.9; -end -[PC,~, e] = princomp(x); -en = e/sum(e); -area = 0; -i = 1; -if thr~=1 - while area maxDataSize - if balMaxData - throwoutIdxs = ModelTrainers.Base.getBalThrowoutIdxs( yTrue, maxDataSize ); - else - throwoutIdxs = randperm(numel( yTrue ) ); - throwoutIdxs(1:maxDataSize) = []; - end - x(throwoutIdxs,:) = []; - yTrue(throwoutIdxs,:) = []; + if size( yTrue, 1 ) > maxDataSize + selectFilter = dataSelector.getDataSelection( sampleIds, maxDataSize ); + verboseFprintf( model, dataSelector.verboseOutput ); + x = x(selectFilter,:); + yTrue = yTrue(selectFilter,:); + sampleIds = sampleIds(selectFilter); end - if strcmpi( getDatapointInfo, 'datapointInfo' ) + iw = importanceWeighter.getImportanceWeights( sampleIds ); + verboseFprintf( model, importanceWeighter.verboseOutput ); + if getDatapointInfo dpi.fileIdxs = testSet(:,'pointwiseFileIdxs'); - dpi.fileIdxs(throwoutIdxs) = []; + dpi.fileIdxs = dpi.fileIdxs(sampleIds); ufidxs = unique( dpi.fileIdxs ); dpi.blockAnnotsCacheFiles(ufidxs) = testSet(ufidxs,'blockAnnotsCacheFile'); dpi.fileNames(ufidxs) = testSet(ufidxs,'fileName'); dpi.bIdxs = testSet(:,'bIdxs'); - dpi.bIdxs(throwoutIdxs) = []; + dpi.bIdxs = dpi.bIdxs(sampleIds); dpi.bacfIdxs = testSet(:,'bacfIdxs'); - dpi.bacfIdxs(throwoutIdxs) = []; - dpiarg = {dpi}; + dpi.bacfIdxs = dpi.bacfIdxs(sampleIds); else - dpiarg = {}; + dpi = struct.empty; end if isempty( x ), error( 'There is no data to test the model.' ); end yModel = model.applyModel( x ); for ii = 1 : size( yModel, 2 ) - perf(ii) = perfMeasure( yTrue, yModel(:,ii), dpiarg{:} ); + perf(ii) = perfMeasure( yTrue, yModel(:,ii), iw, dpi, testSet ); end end %% ---------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC.m index 874892d..0a65125 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC.m @@ -14,13 +14,8 @@ %% -------------------------------------------------------------------- methods - function obj = BAC( yTrue, yPred, datapointInfo ) - if nargin < 3 - dpiarg = {}; - else - dpiarg = {datapointInfo}; - end - obj = obj@PerformanceMeasures.Base( yTrue, yPred, dpiarg{:} ); + function obj = BAC( yTrue, yPred, varargin ) + obj = obj@PerformanceMeasures.Base( yTrue, yPred, varargin{:} ); end % ----------------------------------------------------------------- @@ -49,16 +44,15 @@ end % ----------------------------------------------------------------- - function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, dpi ) + function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, iw, dpi, ~ ) tps = yTrue == 1 & yPred > 0; tns = yTrue == -1 & yPred < 0; fps = yTrue == -1 & yPred > 0; fns = yTrue == 1 & yPred < 0; - if nargin < 4 - dpi = struct.empty; - else + if ~isempty( dpi ) dpi.yTrue = yTrue; dpi.yPred = yPred; + dpi.iw = iw; end obj.tp = sum( tps ); obj.tn = sum( tns ); @@ -82,38 +76,6 @@ performance = 0.5 * obj.sensitivity + 0.5 * obj.specificity; end % ----------------------------------------------------------------- - - function [dpiext, compiled] = makeDatapointInfoStats( obj, fieldname, compiledPerfField ) - if isempty( obj.datapointInfo ), dpiext = []; return; end - if ~isfield( obj.datapointInfo, fieldname ) - error( '%s is not a field of datapointInfo', fieldname ); - end - if nargin < 3, compiledPerfField = 'performance'; end - uniqueDpiFieldElems = unique( obj.datapointInfo.(fieldname) ); - for ii = 1 : numel( uniqueDpiFieldElems ) - if iscell( uniqueDpiFieldElems ) - udfe = uniqueDpiFieldElems{ii}; - udfeIdxs = strcmp( obj.datapointInfo.(fieldname), ... - udfe ); - else - udfe = uniqueDpiFieldElems(ii); - udfeIdxs = obj.datapointInfo.(fieldname) == udfe; - end - for fn = fieldnames( obj.datapointInfo )' - if any( size( obj.datapointInfo.(fn{1}) ) ~= size( udfeIdxs ) ) - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1}); - continue - end - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1})(udfeIdxs); - end - dpiext(ii) = PerformanceMeasures.BAC( iiDatapointInfo.yTrue, ... - iiDatapointInfo.yPred,... - iiDatapointInfo ); - compiled{ii,1} = udfe; - compiled{ii,2} = dpiext(ii).(compiledPerfField); - end - end - % ----------------------------------------------------------------- end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC2.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC2.m index 2f50cc2..5086e2f 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC2.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC2.m @@ -8,13 +8,8 @@ %% -------------------------------------------------------------------- methods - function obj = BAC2( yTrue, yPred, datapointInfo ) - if nargin < 3 - dpiarg = {}; - else - dpiarg = {datapointInfo}; - end - obj = obj@PerformanceMeasures.BAC( yTrue, yPred, dpiarg{:} ); + function obj = BAC2( yTrue, yPred, varargin ) + obj = obj@PerformanceMeasures.BAC( yTrue, yPred, varargin{:} ); end % ----------------------------------------------------------------- @@ -43,51 +38,14 @@ end % ----------------------------------------------------------------- - function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, dpi ) - if nargin < 4 - dpiarg = {}; - else - dpiarg = {dpi}; - end + function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, iw, dpi, ~ ) [obj, performance, dpi] = ... - calcPerformance@PerformanceMeasures.BAC( obj, yTrue, yPred, dpiarg{:} ); + calcPerformance@PerformanceMeasures.BAC( obj, yTrue, yPred, iw, dpi, [] ); obj.bac = performance; performance = 1 - (((1 - obj.sensitivity)^2 + (1 - obj.specificity)^2) / 2)^0.5; end % ----------------------------------------------------------------- - function [dpiext, compiled] = makeDatapointInfoStats( obj, fieldname, compiledPerfField ) - if isempty( obj.datapointInfo ), dpiext = []; return; end - if ~isfield( obj.datapointInfo, fieldname ) - error( '%s is not a field of datapointInfo', fieldname ); - end - if nargin < 3, compiledPerfField = 'performance'; end - uniqueDpiFieldElems = unique( obj.datapointInfo.(fieldname) ); - for ii = 1 : numel( uniqueDpiFieldElems ) - if iscell( uniqueDpiFieldElems ) - udfe = uniqueDpiFieldElems{ii}; - udfeIdxs = strcmp( obj.datapointInfo.(fieldname), ... - udfe ); - else - udfe = uniqueDpiFieldElems(ii); - udfeIdxs = obj.datapointInfo.(fieldname) == udfe; - end - for fn = fieldnames( obj.datapointInfo )' - if any( size( obj.datapointInfo.(fn{1}) ) ~= size( udfeIdxs ) ) - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1}); - continue - end - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1})(udfeIdxs); - end - dpiext(ii) = PerformanceMeasures.BAC2( iiDatapointInfo.yTrue, ... - iiDatapointInfo.yPred,... - iiDatapointInfo ); - compiled{ii,1} = udfe; - compiled{ii,2} = dpiext(ii).(compiledPerfField); - end - end - % ----------------------------------------------------------------- - end end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC_BAextended.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC_BAextended.m new file mode 100644 index 0000000..873dd32 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/BAC_BAextended.m @@ -0,0 +1,211 @@ +classdef BAC_BAextended < PerformanceMeasures.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + tp; + fp; + tn; + fn; + sensitivity; + specificity; + acc; + resc_b; + resc_t; + resc_t2; + end + + %% -------------------------------------------------------------------- + methods + + function obj = BAC_BAextended( yTrue, yPred, varargin ) + obj = obj@PerformanceMeasures.Base( yTrue, yPred, varargin{:} ); + end + % ----------------------------------------------------------------- + + function po = strapOffDpi( obj ) + po = strapOffDpi@PerformanceMeasures.Base( obj ); + po.resc_b = []; + po.resc_t = []; + po.resc_t2 = []; + end + % ----------------------------------------------------------------- + + function b = eqPm( obj, otherPm ) + b = obj.performance == otherPm.performance; + end + % ----------------------------------------------------------------- + + function b = gtPm( obj, otherPm ) + b = obj.performance > otherPm.performance; + end + % ----------------------------------------------------------------- + + function d = double( obj ) + for ii = 1 : size( obj, 2 ) + d(ii) = double( obj(ii).performance ); + end + end + % ----------------------------------------------------------------- + + function s = char( obj ) + if numel( obj ) > 1 + warning( 'only returning first object''s performance' ); + end + s = num2str( obj(1).performance ); + end + % ----------------------------------------------------------------- + + function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, ~, ~, testSetIdData ) + tps = yTrue == 1 & yPred > 0; + tns = yTrue == -1 & yPred < 0; + fps = yTrue == -1 & yPred > 0; + fns = yTrue == 1 & yPred < 0; + dpi = struct.empty; + obj.tp = sum( tps ); + obj.tn = sum( tns ); + obj.fp = sum( fps ); + obj.fn = sum( fns ); + tp_fn = sum( yTrue == 1 ); + tn_fp = sum( yTrue == -1 ); + if tp_fn == 0; + warning( 'No positive true label.' ); + obj.sensitivity = nan; + else + obj.sensitivity = obj.tp / tp_fn; + end + if tn_fp == 0; + warning( 'No negative true label.' ); + obj.specificity = nan; + else + obj.specificity = obj.tn / tn_fp; + end + obj.acc = (obj.tp + obj.tn) / (tp_fn + tn_fp); + performance = 0.5 * obj.sensitivity + 0.5 * obj.specificity; + obj = obj.analyzeBAextended( yTrue, yPred, testSetIdData ); + end + % ----------------------------------------------------------------- + + function obj = analyzeBAextended( obj, yTrue, yPred, testSetIdData ) + fprintf( 'analyzing BA-extended' ); + obj.resc_b = RescSparse( 'uint32', 'uint8' ); + obj.resc_t = RescSparse( 'uint32', 'uint8' ); + obj.resc_t2 = RescSparse( 'uint32', 'uint8' ); + bapis = cell( numel( testSetIdData.data ), 1 ); + agBapis = cell( numel( testSetIdData.data ), 1 ); + asgns = cell( numel( testSetIdData.data ), 1 ); + agAsgns = cell( numel( testSetIdData.data ), 1 ); + agBapis2 = cell( numel( testSetIdData.data ), 1 ); + agAsgns2 = cell( numel( testSetIdData.data ), 1 ); + blockAnnotsCacheFiles = testSetIdData(:,'blockAnnotsCacheFile'); + [bacfClassIdxs,bacfci_ic] = PerformanceMeasures.BAC_BAextended.getFileIds( blockAnnotsCacheFiles ); + sampleFileIdxs = testSetIdData(:,'pointwiseFileIdxs'); + for ii = 1 : numel( testSetIdData.data ) + scp.classIdx = nan; + scp.dd = nan; + scp.fileClassId = bacfClassIdxs(ii); + scp.fileId = sum( bacfci_ic(1:ii) == bacfci_ic(ii) ); + blockAnnotations_ii = testSetIdData(ii,'blockAnnotations'); + yt_ii = yTrue(sampleFileIdxs==ii,:); + yp_ii = yPred(sampleFileIdxs==ii,:); + bacfIdxs_ii = testSetIdData(ii,'bacfIdxs'); + for jj = 1 : numel( blockAnnotsCacheFiles{ii} ) + scp.id = jj; + blockAnnotations = blockAnnotations_ii(bacfIdxs_ii==jj); + yt = yt_ii(bacfIdxs_ii==jj); + yp = yp_ii(bacfIdxs_ii==jj); + if isempty( blockAnnotations ), continue; end + [bapis{ii,jj},agBapis{ii,jj},agBapis2{ii,jj},... + asgns{ii,jj},agAsgns{ii,jj},agAsgns2{ii,jj}] = ... + PerformanceMeasures.BAC_BAextended.produceBapisAsgns( ... + yt, yp, blockAnnotations,... + scp ); %#ok<*PROPLC> + end + end + asgns = PerformanceMeasures.BAC_BAextended.catAsgns( asgns ); + obj.resc_b = addDpiToResc( obj.resc_b, asgns, cat( 1, bapis{:} ) ); + fprintf( ':' ); + agAsgns = PerformanceMeasures.BAC_BAextended.catAsgns( agAsgns ); + obj.resc_t = addDpiToResc( obj.resc_t, agAsgns, cat( 1, agBapis{:} ) ); + fprintf( ':' ); + agAsgns2 = PerformanceMeasures.BAC_BAextended.catAsgns( agAsgns2 ); + obj.resc_t2 = addDpiToResc( obj.resc_t2, agAsgns2, cat( 1, agBapis2{:} ) ); + fprintf( ';' ); + fprintf( '\n' ); + end + % ----------------------------------------------------------------- + + end + + %% -------------------------------------------------------------------- + methods (Static) + + function asgns = catAsgns( asgns ) + asgns = cat( 1, asgns{:} ); + asgns = {cat( 1, asgns{:,1} ), cat( 1, asgns{:,2} ), ... + cat( 1, asgns{:,3} ), cat( 1, asgns{:,4} )}; + end + % ----------------------------------------------------------------- + + function [bacfClassIdxs,bacfci_ic] = getFileIds( blockAnnotsCacheFiles ) + bacfiles = cellfun( @(x)(applyIfNempty(x,@(c)(c{1}))), blockAnnotsCacheFiles, 'UniformOutput', false )'; + [~,bacfiles] = cellfun( @(x)(applyIfNempty(x,@fileparts)), bacfiles, 'UniformOutput', false ); + [~,bacfClasses] = cellfun( @(c)( strtok(c,'.') ), bacfiles, 'UniformOutput', false ); + [bacfClasses,~] = cellfun( @(c)( strtok(c,'.') ), bacfClasses, 'UniformOutput', false ); + niClasses = {{'alarm'},{'baby'},{'femaleSpeech'},{'fire'},{'crash'},{'dog'},... + {'engine'},{'footsteps'},{'knock'},{'phone'},{'piano'},... + {'maleSpeech'},{'femaleScream','maleScream'},{'general'}}; + bacfClassIdxs = cellfun( ... + @(x)( find( cellfun( @(c)(any( strcmpi( x, c ) )), niClasses ) ) ), ... + bacfClasses, 'UniformOutput', false ); + bacfClassIdxs(cellfun(@isempty,bacfClassIdxs)) = {nan}; + bacfClassIdxs = cell2mat( bacfClassIdxs ); + [~,~,bacfci_ic] = unique( bacfClassIdxs ); + end + % ----------------------------------------------------------------- + + function [pis,agPis,agPis2,asg,agAsg,agAsg2] = produceBapisAsgns( ... + yt, yp, blockAnnotations, scp ) + [blockAnnotations, yt, yp, sameTimeIdxs] = findSameTimeBlocks( blockAnnotations, yt, yp ); + [bap, asg] = extractBAparams( blockAnnotations, scp, yp, yt ); + pis = baParams2bapIdxs( bap ); + fprintf( '.' ); + if isfield( blockAnnotations, 'estAzm' ) % is segId + usti = unique( sameTimeIdxs )'; + agBap = bap; + agBap(numel( usti )+1:end,:) = []; + agBap(:,2:10) = deal( nanRescStruct ); + agYt = yt; + agYt(numel( usti )+1:end,:) = []; + agYp = yp; + agYp(numel( usti )+1:end,:) = []; + maxc = 0; + for bb = 1 : numel( usti ) + stibb = sameTimeIdxs==usti(bb); + sumStibb = sum( stibb ); + maxc = max( maxc, sumStibb ); + agBap(bb,1:sumStibb) = bap(stibb); + agYt(bb,1:sumStibb) = yt(stibb); + agYp(bb,1:sumStibb) = yp(stibb); + end + agBap(:,maxc+1:end) = []; + [agBap2, agAsgn2] = aggregateBlockAnnotations2( agBap, agYp, agYt ); + [agBap, agAsgn] = aggregateBlockAnnotations( agBap, agYp, agYt ); + agAsg = mat2cell( agAsgn, size( agAsgn, 1 ), [1,1,1,1] ); + agPis = baParams2bapIdxs( agBap ); + agAsg2 = mat2cell( agAsgn2, size( agAsgn2, 1 ), [1,1,1,1] ); + agPis2 = baParams2bapIdxs( agBap2 ); + fprintf( ',' ); + else + agPis = []; + agPis2 = []; + agAsg = []; + agAsg2 = []; + end + end + % ----------------------------------------------------------------- + + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Base.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Base.m index 0767607..c98e663 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Base.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Base.m @@ -9,18 +9,18 @@ %% -------------------------------------------------------------------- methods - function obj = Base( yTrue, yPred, datapointInfo ) - if nargin < 2 - error( ['Subclass of PerformanceMeasures.Base must call superconstructor ',... - 'and pass yTrue and yPred.'] ); - end + function obj = Base( yTrue, yPred, iw, datapointInfo, testSetIdData ) if nargin < 3 - dpiarg = {}; - else - dpiarg = {datapointInfo}; + iw = ones( size( yTrue ) ); + end + if nargin < 4 + datapointInfo = struct.empty; + end + if nargin < 5 + testSetIdData = []; end [obj, obj.performance, obj.datapointInfo] = ... - obj.calcPerformance( yTrue, yPred, dpiarg{:} ); + obj.calcPerformance( yTrue, yPred, iw, datapointInfo, testSetIdData ); end % ----------------------------------------------------------------- @@ -31,7 +31,7 @@ function po = strapOffDpi( obj ) po = obj; - po.datapointInfo = []; + po.datapointInfo = struct.empty; end % ----------------------------------------------------------------- @@ -77,11 +77,58 @@ function disp( obj ) end % ----------------------------------------------------------------- + function [blockAnnotations, yp, yt] = getBacfDpi( obj, bacfIdx, bacfSubidx ) + allDpi = obj.datapointInfo; + currentFileDpiIdxs = find( allDpi.fileIdxs == bacfIdx ); + currentFileBacfSubIdxs = allDpi.bacfIdxs(currentFileDpiIdxs); + currentBacfDpiIdxs = currentFileDpiIdxs(currentFileBacfSubIdxs == bacfSubidx); + currentBacfUsedIdxs = allDpi.bIdxs(currentBacfDpiIdxs); + bacfile = load( allDpi.blockAnnotsCacheFiles{bacfIdx}{bacfSubidx}, 'blockAnnotations'); + blockAnnotations = bacfile.blockAnnotations(currentBacfUsedIdxs); + yp = allDpi.yPred(currentBacfDpiIdxs); + yt = allDpi.yTrue(currentBacfDpiIdxs); + end + % ----------------------------------------------------------------- + + function [dpiext, compiled] = makeDatapointInfoStats( obj, fieldname, compiledPerfField ) + if isempty( obj.datapointInfo ), dpiext = []; return; end + if ~isfield( obj.datapointInfo, fieldname ) + error( '%s is not a field of datapointInfo', fieldname ); + end + if nargin < 3, compiledPerfField = 'performance'; end + uniqueDpiFieldElems = unique( obj.datapointInfo.(fieldname) ); + for ii = 1 : numel( uniqueDpiFieldElems ) + if iscell( uniqueDpiFieldElems ) + udfe = uniqueDpiFieldElems{ii}; + udfeIdxs = strcmp( obj.datapointInfo.(fieldname), ... + udfe ); + else + udfe = uniqueDpiFieldElems(ii); + udfeIdxs = obj.datapointInfo.(fieldname) == udfe; + end + for fn = fieldnames( obj.datapointInfo )' + if any( size( obj.datapointInfo.(fn{1}) ) ~= size( udfeIdxs ) ) + iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1}); + continue + end + iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1})(udfeIdxs); + end + classInfo = metaclass( obj ); + dpiext(ii) = feval( classInfo.Name, iiDatapointInfo.yTrue, ... + iiDatapointInfo.yPred,... + iiDatapointInfo.iw,... + iiDatapointInfo ); + compiled{ii,1} = udfe; + compiled{ii,2} = dpiext(ii).(compiledPerfField); + end + end + % ----------------------------------------------------------------- + end %% -------------------------------------------------------------------- methods (Abstract) - [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, dpiarg ) + [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, iw, dpi, testSetIdData ) b = eqPm( obj, otherPm ) b = gtPm( obj, otherPm ) s = char( obj ) diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Fscore.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Fscore.m index 00942ee..ec59fc6 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Fscore.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/Fscore.m @@ -14,13 +14,8 @@ %% -------------------------------------------------------------------- methods - function obj = Fscore( yTrue, yPred, datapointInfo ) - if nargin < 3 - dpiarg = {}; - else - dpiarg = {datapointInfo}; - end - obj = obj@PerformanceMeasures.Base( yTrue, yPred, dpiarg{:} ); + function obj = Fscore( yTrue, yPred, varargin ) + obj = obj@PerformanceMeasures.Base( yTrue, yPred, varargin{:} ); end % ----------------------------------------------------------------- @@ -49,16 +44,15 @@ end % ----------------------------------------------------------------- - function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, dpi ) + function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, iw, dpi, ~ ) tps = yTrue == 1 & yPred > 0; tns = yTrue == -1 & yPred < 0; fps = yTrue == -1 & yPred > 0; fns = yTrue == 1 & yPred < 0; - if nargin < 4 - dpi = struct.empty; - else + if ~isempty( dpi ) dpi.yTrue = yTrue; dpi.yPred = yPred; + dpi.iw = iw; end obj.tp = sum( tps ); obj.tn = sum( tns ); @@ -84,38 +78,6 @@ end % ----------------------------------------------------------------- - function [dpiext, compiled] = makeDatapointInfoStats( obj, fieldname, compiledPerfField ) - if isempty( obj.datapointInfo ), dpiext = []; return; end - if ~isfield( obj.datapointInfo, fieldname ) - error( '%s is not a field of datapointInfo', fieldname ); - end - if nargin < 3, compiledPerfField = 'performance'; end - uniqueDpiFieldElems = unique( obj.datapointInfo.(fieldname) ); - for ii = 1 : numel( uniqueDpiFieldElems ) - if iscell( uniqueDpiFieldElems ) - udfe = uniqueDpiFieldElems{ii}; - udfeIdxs = strcmp( obj.datapointInfo.(fieldname), ... - udfe ); - else - udfe = uniqueDpiFieldElems(ii); - udfeIdxs = obj.datapointInfo.(fieldname) == udfe; - end - for fn = fieldnames( obj.datapointInfo )' - if any( size( obj.datapointInfo.(fn{1}) ) ~= size( udfeIdxs ) ) - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1}); - continue - end - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1})(udfeIdxs); - end - dpiext(ii) = PerformanceMeasures.BAC( iiDatapointInfo.yTrue, ... - iiDatapointInfo.yPred,... - iiDatapointInfo ); - compiled{ii,1} = udfe; - compiled{ii,2} = dpiext(ii).(compiledPerfField); - end - end - % ----------------------------------------------------------------- - end end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/ImportanceWeightedSquareBalancedAccuracy.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/ImportanceWeightedSquareBalancedAccuracy.m new file mode 100644 index 0000000..89e4f60 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/ImportanceWeightedSquareBalancedAccuracy.m @@ -0,0 +1,85 @@ +classdef ImportanceWeightedSquareBalancedAccuracy < PerformanceMeasures.Base + + %% -------------------------------------------------------------------- + properties (SetAccess = protected) + tp; + fp; + tn; + fn; + sensitivity; + specificity; + acc; + bac; + end + + %% -------------------------------------------------------------------- + methods + + function obj = ImportanceWeightedSquareBalancedAccuracy( yTrue, yPred, varargin ) + obj = obj@PerformanceMeasures.Base( yTrue, yPred, varargin{:} ); + end + % ----------------------------------------------------------------- + + function b = eqPm( obj, otherPm ) + b = obj.performance == otherPm.performance; + end + % ----------------------------------------------------------------- + + function b = gtPm( obj, otherPm ) + b = obj.performance > otherPm.performance; + end + % ----------------------------------------------------------------- + + function d = double( obj ) + for ii = 1 : size( obj, 2 ) + d(ii) = double( obj(ii).performance ); + end + end + % ----------------------------------------------------------------- + + function s = char( obj ) + if numel( obj ) > 1 + warning( 'only returning first object''s performance' ); + end + s = num2str( obj(1).performance ); + end + % ----------------------------------------------------------------- + + function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, iw, dpi, ~ ) + if ~isempty( dpi ) + dpi.yTrue = yTrue; + dpi.yPred = yPred; + dpi.iw = iw; + end + tps = iw .* (yTrue == 1 & yPred > 0); + tns = iw .* (yTrue == -1 & yPred < 0); + fps = iw .* (yTrue == -1 & yPred > 0); + fns = iw .* (yTrue == 1 & yPred < 0); + tp_fn = sum( [tps(:);fns(:)], 1 ); + tn_fp = sum( [tns(:);fps(:)], 1 ); + obj.tp = sum( tps ); + obj.tn = sum( tns ); + obj.fp = sum( fps ); + obj.fn = sum( fns ); + if tp_fn == 0; + warning( 'No positive true label.' ); + obj.sensitivity = nan; + else + obj.sensitivity = obj.tp / tp_fn; + end + if tn_fp == 0; + warning( 'No negative true label.' ); + obj.specificity = nan; + else + obj.specificity = obj.tn / tn_fp; + end + obj.acc = (obj.tp + obj.tn) / (tp_fn + tn_fp); + obj.bac = 0.5 * obj.sensitivity + 0.5 * obj.specificity; + performance = 1 - (((1 - obj.sensitivity)^2 + (1 - obj.specificity)^2) / 2)^0.5; + end + % ----------------------------------------------------------------- + + end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/MultinomialBAC.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/MultinomialBAC.m index 0106fb3..1b15389 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/MultinomialBAC.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/MultinomialBAC.m @@ -10,13 +10,8 @@ %% -------------------------------------------------------------------- methods - function obj = MultinomialBAC( yTrue, yPred, datapointInfo ) - if nargin < 3 - dpiarg = {}; - else - dpiarg = {datapointInfo}; - end - obj = obj@PerformanceMeasures.Base( yTrue, yPred, dpiarg{:} ); + function obj = MultinomialBAC( yTrue, yPred, varargin ) + obj = obj@PerformanceMeasures.Base( yTrue, yPred, varargin{:} ); end % ----------------------------------------------------------------- @@ -45,7 +40,7 @@ end % ----------------------------------------------------------------- - function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, dpi ) + function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, iw, dpi, ~ ) labels = unique( [yTrue;yPred] ); n_acc = 0; for tt = 1 : numel( labels ) @@ -61,11 +56,10 @@ end n_acc = n_acc + obj.confusionMatrix(tt,tt); end - if nargin < 4 - dpi = struct.empty; - else + if ~isempty( dpi ) dpi.yTrue = yTrue; dpi.yPred = yPred; + dpi.iw = iw; end obj.acc = n_acc / sum( sum( obj.confusionMatrix ) ); performance = sum( obj.sens(~isnan(obj.sens)) ) / ... @@ -73,38 +67,6 @@ end % ----------------------------------------------------------------- - function [dpiext, compiled] = makeDatapointInfoStats( obj, fieldname, compiledPerfField ) - if isempty( obj.datapointInfo ), dpiext = []; return; end - if ~isfield( obj.datapointInfo, fieldname ) - error( '%s is not a field of datapointInfo', fieldname ); - end - if nargin < 3, compiledPerfField = 'performance'; end - uniqueDpiFieldElems = unique( obj.datapointInfo.(fieldname) ); - for ii = 1 : numel( uniqueDpiFieldElems ) - if iscell( uniqueDpiFieldElems ) - udfe = uniqueDpiFieldElems{ii}; - udfeIdxs = strcmp( obj.datapointInfo.(fieldname), ... - udfe ); - else - udfe = uniqueDpiFieldElems(ii); - udfeIdxs = obj.datapointInfo.(fieldname) == udfe; - end - for fn = fieldnames( obj.datapointInfo )' - if any( size( obj.datapointInfo.(fn{1}) ) ~= size( udfeIdxs ) ) - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1}); - continue - end - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1})(udfeIdxs); - end - dpiext(ii) = PerformanceMeasures.BAC( iiDatapointInfo.yTrue, ... - iiDatapointInfo.yPred,... - iiDatapointInfo ); - compiled{ii,1} = udfe; - compiled{ii,2} = dpiext(ii).(compiledPerfField); - end - end - % ----------------------------------------------------------------- - end end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/NSE.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/NSE.m index 6906115..0776b8d 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/NSE.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+PerformanceMeasures/NSE.m @@ -8,13 +8,8 @@ %% -------------------------------------------------------------------- methods - function obj = NSE( yTrue, yPred, datapointInfo ) - if nargin < 3 - dpiarg = {}; - else - dpiarg = {datapointInfo}; - end - obj = obj@PerformanceMeasures.Base( yTrue, yPred, dpiarg{:} ); + function obj = NSE( yTrue, yPred, varargin ) + obj = obj@PerformanceMeasures.Base( yTrue, yPred, varargin{:} ); end % ----------------------------------------------------------------- @@ -43,16 +38,15 @@ end % ----------------------------------------------------------------- - function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, dpi ) + function [obj, performance, dpi] = calcPerformance( obj, yTrue, yPred, iw, dpi, ~ ) e = yTrue - yPred; se = e.^2; performance = - mean( se ); obj.mae = mean( abs( e ) ); - if nargin < 4 - dpi = struct.empty; - else + if ~isempty( dpi ) dpi.yTrue = yTrue; dpi.yPred = yPred; + dpi.iw = iw; end end % ----------------------------------------------------------------- @@ -85,38 +79,6 @@ performance = sum( sens(~isnan(sens)) ) / numel( sens(~isnan(sens)) ); end % ----------------------------------------------------------------- - - function [dpiext, compiled] = makeDatapointInfoStats( obj, fieldname, compiledPerfField ) - if isempty( obj.datapointInfo ), dpiext = []; return; end - if ~isfield( obj.datapointInfo, fieldname ) - error( '%s is not a field of datapointInfo', fieldname ); - end - if nargin < 3, compiledPerfField = 'performance'; end - uniqueDpiFieldElems = unique( obj.datapointInfo.(fieldname) ); - for ii = 1 : numel( uniqueDpiFieldElems ) - if iscell( uniqueDpiFieldElems ) - udfe = uniqueDpiFieldElems{ii}; - udfeIdxs = strcmp( obj.datapointInfo.(fieldname), ... - udfe ); - else - udfe = uniqueDpiFieldElems(ii); - udfeIdxs = obj.datapointInfo.(fieldname) == udfe; - end - for fn = fieldnames( obj.datapointInfo )' - if any( size( obj.datapointInfo.(fn{1}) ) ~= size( udfeIdxs ) ) - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1}); - continue - end - iiDatapointInfo.(fn{1}) = obj.datapointInfo.(fn{1})(udfeIdxs); - end - dpiext(ii) = PerformanceMeasures.BAC( iiDatapointInfo.yTrue, ... - iiDatapointInfo.yPred,... - iiDatapointInfo ); - compiled{ii,1} = udfe; - compiled{ii,2} = dpiext(ii).(compiledPerfField); - end - end - % ----------------------------------------------------------------- end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/BRIRsource.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/BRIRsource.m index d00cf12..ba2e687 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/BRIRsource.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/BRIRsource.m @@ -1,12 +1,18 @@ classdef BRIRsource < SceneConfig.SourceBase & Parameterized %% ----------------------------------------------------------------------------------- + properties brirFName; speakerId; end - %% ----------------------------------------------------------------------------------- + + properties (SetAccess = protected) + azimuth; + end +%% ----------------------------------------------------------------------------------- + methods function obj = BRIRsource( brirFName, varargin ) @@ -27,7 +33,24 @@ strcmp( obj1.brirFName(f1SepIdxs(end-1):end), obj2.brirFName(f2SepIdxs(end-1):end) ); end %% ------------------------------------------------------------------------------- - + + function calcAzimuth( obj, brirHeadOrientIdx ) + brirSofa = SOFAload( db.getFile( obj.brirFName ), 'nodata' ); + headOrientIdx = ceil( brirHeadOrientIdx * size( brirSofa.ListenerView, 1 )); + headOrientation = SOFAconvertCoordinates( ... + brirSofa.ListenerView(headOrientIdx,:),'cartesian','spherical' ); + if isempty( obj.speakerId ) + sid = 1; + else + sid = obj.speakerId; + end + brirSrcPos = SOFAconvertCoordinates( ... + brirSofa.EmitterPosition(sid,:) - brirSofa.ListenerPosition, ... + 'cartesian','spherical' ); + obj.azimuth = brirSrcPos(1) - headOrientation(1); + end + %% ------------------------------------------------------------------------------- end + %% ----------------------------------------------------------------------------------- end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/SceneConfiguration.m b/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/SceneConfiguration.m index 8c0a498..9589095 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/SceneConfiguration.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/+SceneConfig/SceneConfiguration.m @@ -4,7 +4,7 @@ properties (SetAccess = protected) sources; SNRs; - snrRefs; + snrRefs; % TODO: there should only be one scene-wide SNRref. loopSrcs; % 'no','self','randomSeq' room; brirHeadOrientIdx; @@ -54,6 +54,11 @@ function addRoom( obj, room ) function setBRIRheadOrientation( obj, brirHeadOrientIdx ) obj.brirHeadOrientIdx = brirHeadOrientIdx; + for ii = 1:numel( obj.sources ) + if isa( obj.sources(ii), 'SceneConfig.BRIRsource' ) + obj.sources(ii).calcAzimuth( brirHeadOrientIdx ); + end + end end %% ------------------------------------------------------------------------------- diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/TwoEarsIdTrainPipe.m b/AuditoryMachineLearningTrainingTestingPipeline/src/TwoEarsIdTrainPipe.m index 27d55af..69ea41b 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/TwoEarsIdTrainPipe.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/TwoEarsIdTrainPipe.m @@ -64,6 +64,10 @@ function init( obj, sceneCfgs, varargin ) ip.addOptional( 'hrir', ... 'impulse_responses/qu_kemar_anechoic/QU_KEMAR_anechoic_3m.sofa' ); ip.addOptional( 'sceneCfgDataUseRatio', 1 ); + ip.addOptional( 'sceneCfgPrioDataUseRatio', 1 ); + ip.addOptional( 'selectPrioClass', [] ); + ip.addOptional( 'dataSelector', DataSelectors.IgnorantSelector() ); + ip.addOptional( 'loadBlockAnnotations', false ); ip.addOptional( 'gatherFeaturesProc', true ); ip.addOptional( 'stopAfterProc', inf ); ip.addOptional( 'fs', 44100 ); @@ -103,8 +107,10 @@ function init( obj, sceneCfgs, varargin ) multiCfgProcs{end+1} = ... DataProcs.MultiSceneCfgsIdProcWrapper( binSim, obj.labelCreator ); if ip.Results.gatherFeaturesProc - gatherFeaturesProc = DataProcs.GatherFeaturesProc(); - gatherFeaturesProc.setSceneCfgDataUseRatio( ip.Results.sceneCfgDataUseRatio ); + gatherFeaturesProc = DataProcs.GatherFeaturesProc( ip.Results.loadBlockAnnotations ); + gatherFeaturesProc.setSceneCfgDataUseRatio( ... + ip.Results.sceneCfgDataUseRatio, ip.Results.dataSelector, ... + ip.Results.sceneCfgPrioDataUseRatio, ip.Results.selectPrioClass ); multiCfgProcs{end+1} = DataProcs.MultiSceneCfgsIdProcWrapper( ... binSim, gatherFeaturesProc ); end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/postproc/breakDownPerformanceDep.m b/AuditoryMachineLearningTrainingTestingPipeline/src/postproc/breakDownPerformanceDep.m new file mode 100644 index 0000000..75686b6 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/postproc/breakDownPerformanceDep.m @@ -0,0 +1,20 @@ +function [bac,sens,spec] = breakDownPerformanceDep( counts, vars ) + +countsSummarizedDown = summarizeDown( counts, [vars, ndims( counts )] ); +dimidxs = [ndims(countsSummarizedDown), 1 : ndims( countsSummarizedDown ) - 1]; +countsSummarizedDown = permute( countsSummarizedDown, dimidxs ); + +dimidxs = size( countsSummarizedDown ); +dimidxs(1) = []; +if numel( dimidxs ) == 1, dimidxs(end+1) = 1; end +tp = reshape( squeeze( countsSummarizedDown(1,:) ), dimidxs ); +tn = reshape( squeeze( countsSummarizedDown(2,:) ), dimidxs ); +fp = reshape( squeeze( countsSummarizedDown(3,:) ), dimidxs ); +fn = reshape( squeeze( countsSummarizedDown(4,:) ), dimidxs ); + +sens = tp./(tp+fn); +spec = tn./(tn+fp); + +bac = 0.5*sens + 0.5*spec; + +end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/postproc/breakDownPerformanceDepClassAvg.m b/AuditoryMachineLearningTrainingTestingPipeline/src/postproc/breakDownPerformanceDepClassAvg.m new file mode 100644 index 0000000..39262c9 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/postproc/breakDownPerformanceDepClassAvg.m @@ -0,0 +1,14 @@ +function [bac,sens,spec,bacstd,sensstd,specstd] = breakDownPerformanceDepClassAvg( counts, classVar, vars ) + +vars = sort( [classVar, vars] ); +[bac,sens,spec] = breakDownPerformanceDep( counts, vars ); +classVarNew = find( vars == classVar ); + +bacstd = squeeze( nanStd( bac, classVarNew ) ); +sensstd = squeeze( nanStd( sens, classVarNew ) ); +specstd = squeeze( nanStd( spec, classVarNew ) ); +bac = squeeze( nanMean( bac, classVarNew ) ); +sens = squeeze( nanMean( sens, classVarNew ) ); +spec = squeeze( nanMean( spec, classVarNew ) ); + +end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/JidoRecInterface.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/JidoRecInterface.m index e00bef6..2669e31 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/JidoRecInterface.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/JidoRecInterface.m @@ -100,8 +100,10 @@ function configureAudioStreamServer(obj, sampleRate, frameSize, ... % Sclaing factor estimated empirically % earSignals = [audioBuffer.left ./ (2^31); ... % 0.7612 .* (audioBuffer.right ./ (2^31))]'; - earSignals = [audioBuffer.left * obj.normFactor; ... - audioBuffer.right * obj.normFactor]'; + earSignals = [audioBuffer.left ./ (2^31); ... + audioBuffer.right ./ (2^31)]'; +% earSignals = [audioBuffer.left * obj.normFactor; ... +% audioBuffer.right * obj.normFactor]'; % Get default buffer size of the audio stream server bufferSize = size(earSignals, 1); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/RescSparse.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/RescSparse.m new file mode 100644 index 0000000..d57ebc4 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/RescSparse.m @@ -0,0 +1,518 @@ +classdef RescSparse + % class for results count sparse matrices + %% ----------------------------------------------------------------------------------- + properties (SetAccess = public) + dataConvert; + dataIdxsConvert; + data; + dataIdxs; + id; + dataInitialize; + dataAdd; + end + + %% ----------------------------------------------------------------------------------- + methods + + function obj = RescSparse( datatype, dataidxstype, dataInitialize, dataAdd ) + if nargin < 1 || isempty( datatype ) + datatype = 'double'; + end + obj = obj.setDataType( datatype ); + if nargin < 2 || isempty( dataidxstype ) + dataidxstype = 'double'; + end + switch dataidxstype + case 'double' + obj.dataIdxsConvert = @double; + case 'single' + obj.dataIdxsConvert = @single; + case 'int64' + obj.dataIdxsConvert = @int64; + case 'int32' + obj.dataIdxsConvert = @int32; + case 'int16' + obj.dataIdxsConvert = @int16; + case 'int8' + obj.dataIdxsConvert = @int8; + case 'uint64' + obj.dataIdxsConvert = @uint64; + case 'uint32' + obj.dataIdxsConvert = @uint32; + case 'uint16' + obj.dataIdxsConvert = @uint16; + case 'uint8' + obj.dataIdxsConvert = @uint8; + case 'logical' + obj.dataIdxsConvert = @logical; + otherwise + obj.dataIdxsConvert = @double; + end + if nargin < 3 || isempty( dataInitialize ) + dataInitialize = obj.dataConvert( 0 ); + end + obj.dataInitialize = dataInitialize; + if nargin < 4 || isempty( dataAdd ) + dataAdd = @(a,b)(a+b); + end + obj.dataAdd = dataAdd; + obj.data = obj.dataConvert( zeros( 0 ) ); + obj.dataIdxs = obj.dataIdxsConvert( zeros( 0 ) ); + obj.id = []; + end + %% ------------------------------------------------------------------------------- + function obj = setDataType( obj, newDataType ) + switch newDataType + case 'double' + obj.dataConvert = @double; + case 'single' + obj.dataConvert = @single; + case 'int64' + obj.dataConvert = @int64; + case 'int32' + obj.dataConvert = @int32; + case 'int16' + obj.dataConvert = @int16; + case 'int8' + obj.dataConvert = @int8; + case 'uint64' + obj.dataConvert = @uint64; + case 'uint32' + obj.dataConvert = @uint32; + case 'uint16' + obj.dataConvert = @uint16; + case 'uint8' + obj.dataConvert = @uint8; + case 'logical' + obj.dataConvert = @logical; + otherwise + obj.dataConvert = @double; + end + end + %% ------------------------------------------------------------------------------- + + function ecpy = emptyCopy( obj ) + ecpy = RescSparse(); + ecpy.dataConvert = obj.dataConvert; + ecpy.dataIdxsConvert = obj.dataIdxsConvert; + ecpy.dataAdd = obj.dataAdd; + ecpy.dataInitialize = obj.dataInitialize; + end + %% ------------------------------------------------------------------------------- + + function value = get( obj, idxs ) + value = 0; + if size( idxs, 2 ) < size( obj.dataIdxs, 2 ) + error( 'AMLTTP:usage:unexpected', 'idxs dimensions too small.' ); + end + if size( idxs, 2 ) > size( obj.dataIdxs, 2 ) + error( 'AMLTTP:usage:unexpected', 'idxs dimensions too big.' ); + end + rowIdxEq = obj.rowSearch( idxs ); + if rowIdxEq ~= 0 + value = obj.data(rowIdxEq,:); + end + end + %% ------------------------------------------------------------------------------- + + function [data,dataIdxs] = getRowIndexed( obj, rowIdxs ) + if max( rowIdxs ) > size( obj.dataIdxs, 1 ) + error( 'AMLTTP:usage:unexpected', 'max rowIdxs too big.' ); + end + data = obj.data(rowIdxs,:); + if nargout > 1 + dataIdxs = obj.dataIdxs(rowIdxs,:); + end + end + %% ------------------------------------------------------------------------------- + + function rowIdxs = getRowIdxs( obj, idxsMask ) + if isempty( obj.dataIdxs ) + rowIdxs = []; + return; + end + if size( idxsMask, 2 ) ~= size( obj.dataIdxs, 2 ) + error( 'AMLTTP:usage:unexpected', 'idxsMask dimensions wrong.' ); + end + dataIdxsMask = true( size( obj.dataIdxs, 1 ), sum( cellfun( @(c)(~ischar( c ) ), idxsMask ) ) ); + jj = 0; + for ii = 1 : size( obj.dataIdxs, 2 ) + if ischar( idxsMask{ii} ) && idxsMask{ii} == ':', continue; end + jj = jj + 1; + if jj == 1 + dataIdxsMask(:,jj) = idxsMask{ii}( obj.dataIdxs(:,ii) ); + else + tmp = dataIdxsMask(:,jj-1); + dataIdxsMask(tmp,jj) = idxsMask{ii}( obj.dataIdxs(tmp,ii) ); + end + end + rowIdxsMask = all( dataIdxsMask, 2 ); + rowIdxs = find( rowIdxsMask ); + end + %% ------------------------------------------------------------------------------- + + function obj = deleteData( obj, rowIdxs ) + if max( rowIdxs ) > size( obj.dataIdxs, 1 ) + error( 'AMLTTP:usage:unexpected', 'max rowIdxs too big.' ); + end + obj.data(rowIdxs,:) = []; + obj.dataIdxs(rowIdxs,:) = []; + end + %% ------------------------------------------------------------------------------- + + function obj = filter( obj, varargin ) + obj = obj.deleteData( obj.getRowIdxs( ... + getIdxMask( size( obj.dataIdxs, 2 ), varargin{:} ) ) ); + end + %% ------------------------------------------------------------------------------- + + function [obj,incidxs,insidxs] = addData( obj, idxs, data, areIdxsPresorted ) + if isempty( idxs ) + incidxs = []; insidxs = []; + return; + end + idxs = obj.dataIdxsConvert( idxs ); + data = obj.dataConvert( data ); + if size( idxs, 2 ) < size( obj.dataIdxs, 2 ) + error( 'AMLTTP:usage:unexpected', 'idxs dimensions too small.' ); + end + if size( idxs, 2 ) > size( obj.dataIdxs, 2 ) + if isempty( obj.dataIdxs ) + obj.dataIdxs = obj.dataIdxsConvert( zeros( 0, size( idxs, 2 ) ) ); + else + obj.dataIdxs(:,size( obj.dataIdxs, 2 )+1:size( idxs, 2 )) = obj.dataIdxsConvert( 1 ); + end + end + rowIdxEq = zeros( size( idxs, 1 ), 1 ); + rowIdxGt = zeros( size( idxs, 1 ), 1 ); + iis = [1, size( idxs, 1 ), 2:size( idxs, 1 )-1]; + for ii = iis + if (nargin >=4) && areIdxsPresorted && (ii==2) && (rowIdxGt(1)==rowIdxGt(end)) && (all( rowIdxEq == 0 )) + rowIdxGt(:) = rowIdxGt(1); + break; + end + [rowIdxEq(ii),~,rowIdxGt(ii)] = obj.rowSearch( idxs(ii,:) ); + if rowIdxEq(ii) ~= 0 + obj.data(rowIdxEq(ii),:) = obj.dataAdd( obj.data(rowIdxEq(ii),:), data(ii,:) ); + end + end + rowIdxGt(rowIdxEq ~= 0) = []; + idxs(rowIdxEq ~= 0,:) = []; + data(rowIdxEq ~= 0,:) = []; + [rigtidxs,order] = sortrows( [rowIdxGt,double( idxs )] ); + insidxs = rigtidxs(:,1); + incidxs = sort( [insidxs; (1:size( obj.dataIdxs, 1 ))'] ); + obj.dataIdxs(end+1,:) = obj.dataIdxsConvert( 0 ); + obj.data(end+1,:) = obj.dataConvert( 0 ); + obj.dataIdxs = obj.dataIdxs(incidxs,:); + obj.data = obj.data(incidxs,:); + insidxs = insidxs + (0:numel( insidxs )-1)'; + obj.dataIdxs(insidxs,:) = rigtidxs(:,2:end); + obj.data(insidxs,:) = data(order,:); + end + %% ------------------------------------------------------------------------------- + + function [rowIdxEq,rowIdxLt,rowIdxGt] = rowSearch( obj, idxs, preRowIdxGt ) +% if numel( idxs ) ~= size( obj.dataIdxs, 2 ) +% error( 'AMLTTP:implementation:unexpected', 'This should not have happened.' ); +% end + rowIdxEq = 0; + rowIdxLt = 0; + if nargin < 3 || isempty( preRowIdxGt ) + preRowIdxGt = size( obj.dataIdxs, 1 ); + end + rowIdxGt = preRowIdxGt + 1; + ni = size( idxs, 2 ); + while rowIdxGt - rowIdxLt > 1 + mRowIdx = floor( 0.5*rowIdxLt + 0.5*rowIdxGt ); + idxAreEq = 1; idxAisltB = 0; idxAisgtB = 0; + for ii = 1 : ni + if idxs(ii) < obj.dataIdxs(mRowIdx,ii) + idxAisltB = 1; + idxAreEq = 0; + break; + elseif idxs(ii) > obj.dataIdxs(mRowIdx,ii) + idxAisgtB = 1; + idxAreEq = 0; + break; + end + end + if idxAreEq + rowIdxEq = mRowIdx; + rowIdxLt = mRowIdx - 1; + rowIdxGt = mRowIdx + 1; + break; + elseif idxAisltB + rowIdxGt = mRowIdx; + elseif idxAisgtB + rowIdxLt = mRowIdx; + end + end + end + %% ------------------------------------------------------------------------------- + + function obj = partJoin( obj, otherObj, keepMask, overrideMask ) + obj = obj.deleteData( obj.getRowIdxs( overrideMask ) ); + otherObj = otherObj.deleteData( otherObj.getRowIdxs( keepMask ) ); + obj = obj.addData( otherObj.dataIdxs, otherObj.data, true ); + end + %% ------------------------------------------------------------------------------- + + function [summedResc,summedDataOrigin] = summarizeDown( obj, keepDims, rowIdxs, idxReplaceMask, fun, sdoPrior, intraGroupNorm ) + summedResc = obj; + if nargin < 2 + return; + end + if isempty( keepDims ) && (nargin < 4 || isempty( idxReplaceMask )) + return; + end + if nargin < 3 || (ischar( rowIdxs ) && (rowIdxs == ':')) %isempty( rowIdxs ) + rowIdxs = 1 : size( summedResc.dataIdxs, 1 ); + end + if nargin >= 4 && ~isempty( idxReplaceMask ) + keepDims = 1:size( obj.dataIdxs, 2 ); + if size( idxReplaceMask, 2 ) ~= size( summedResc.dataIdxs, 2 ) + error( 'AMLTTP:usage:unexpected', 'idxsMask dimensions wrong.' ); + end + for ii = 1 : size( summedResc.dataIdxs, 2 ) + if isempty( idxReplaceMask{ii} ), continue; end + summedResc.dataIdxs(:,ii) = idxReplaceMask{ii}( summedResc.dataIdxs(:,ii) ); + end + end + if nargin < 5 + fun = []; + end + if (nargin < 6 || isempty( sdoPrior )) && nargout > 1 + clear sdoPrior; + sdoPrior(:,1) = mat2cell( summedResc.dataIdxs(rowIdxs,:), ones( size( summedResc.dataIdxs(rowIdxs,:), 1 ), 1 ) ); + sdoPrior(:,2) = mat2cell( summedResc.data(rowIdxs,:), ones( size( summedResc.data(rowIdxs,:), 1 ), 1 ) ); + end + [keepDimsUniqueIdxs,~,ic] = unique( summedResc.dataIdxs(rowIdxs,keepDims), 'rows' ); + summedData = accumarray( ic, summedResc.data(rowIdxs,:), [], fun ); + if nargout > 1 + if nargin >= 7 && intraGroupNorm + ignFactor = arrayfun( @(a,b)(sum( cellfun( @sum, sdoPrior(ic==b,2) ) ) / (sum( a{:} ) * numel( sdoPrior(ic==b,2) ))), sdoPrior(:,2), ic ); + else + ignFactor = ones( size( ic ) ); + end + [summedDataOrigin(:,1), summedDataOrigin(:,2)] = splitapply( @(x,a)(deal({cell2mat( x(:,1) )},{cell2mat( arrayfun( @(c,b)(c{:}*b), x(:,2), a, 'UniformOutput', false ) )})), sdoPrior, ignFactor, ic ); + end + summedResc.dataIdxs = keepDimsUniqueIdxs; + summedResc.data = summedData; + if ~isempty( obj.id ) + idxDescr = fieldnames( obj.id ); + idxDescr = idxDescr(keepDims); + summedResc.id = ... + cell2struct( num2cell( 1:numel( idxDescr ) )', idxDescr ); + end + end + %% ------------------------------------------------------------------------------- + + function robj = resample( obj, depIdx, rIdx, resample_weights, conditions ) + if nargin >= 5 && ~isempty( conditions ) + useIdxs = obj.getRowIdxs( conditions ); + else + useIdxs = ':'; + end + drIdxs = obj.dataIdxs( useIdxs,[depIdx,rIdx] ); + drIdxs = mat2cell( drIdxs, size( drIdxs, 1 ), ones( 1, size( drIdxs, 2 ) ) ); + drIdxs = sub2ind( size( resample_weights ), drIdxs{:} ); + w = resample_weights( drIdxs ); + w(isnan(w)) = 1; + robj = obj; + robj.data(useIdxs,:) = robj.data(useIdxs,:) .* w; + robj.dataIdxs(robj.data==0,:) = []; + robj.data(robj.data==0,:) = []; + end + %% ------------------------------------------------------------------------------- + + function dist = idxDistribution( obj, depIdx, defIdx ) + sobj = obj.summarizeDown( [defIdx, depIdx] ); + defIdx = 1 : numel( defIdx ); + depIdx = defIdx(end) + 1; + maxDepIdx = max( sobj.dataIdxs(:,depIdx) ); + maxDefIdxs = cell( 1, numel( defIdx ) ); + for ii = 1 : numel( defIdx ) + maxDefIdxs{ii} = max( sobj.dataIdxs(:,defIdx(ii)) ); + end + dist = nan( maxDefIdxs{:}, maxDepIdx ); + [defUniqueIdxs,~,ic] = unique( sobj.dataIdxs(:,defIdx), 'rows' ); + for ii = 1 : size( defUniqueIdxs, 1 ) + dui_ii = (ic == ii); + depIdxs_dui = sobj.dataIdxs(dui_ii,depIdx); + depData_dui = sobj.data(dui_ii,:); + dui = num2cell( defUniqueIdxs(ii,:) ); + dist(dui{:},depIdxs_dui) = depData_dui; + end + end + %% ------------------------------------------------------------------------------- + + function [obj, sdo] = combineFun_legacy( obj, fun, cdim, argIdxs, cidx, newDataType, sdo ) + if nargin > 5 && ~isempty( newDataType ) + obj = obj.setDataType( newDataType ); + end + nargs = numel( argIdxs ); + diDim = size( obj.dataIdxs, 2 ); + rowIdxs = cell( 1, nargs ); + for ii = 1 : nargs + idxsMask = repmat( {':'}, 1, diDim ); + idxsMask{cdim} = @(x)(x == argIdxs(ii)); + rowIdxs{ii} = obj.getRowIdxs( idxsMask ); + end + nri = cellfun( @numel, rowIdxs ); + newDataIdxs = zeros( sum( nri ), diDim ); + newData = zeros( sum( nri ), 1 ); + newSdo = cell( size( newData, 1 ), 2 ); + ndiIdx = 1; + riIdx = ones( size( nri ) ); + curDataIdxs = zeros( nargs, diDim ); + iis = 1 : nargs; + args = cell( 1, nargs ); + oldProgress = 0; + while any( riIdx <= nri ) + progress = int8( 100 * sum( riIdx ) / sum( nri ) ); + if progress > oldProgress + fprintf( '.' ); + oldProgress = progress; + end + for ii = iis + if riIdx(ii) > nri(ii) + curDataIdxs(ii,:) = inf( 1, diDim ); + else + curDataIdxs(ii,:) = obj.dataIdxs(rowIdxs{ii}(riIdx(ii)),:); + curDataIdxs(ii,cdim) = cidx; + end + end + me = 1 : nargs; + for cc = 1 : diDim + m = min( curDataIdxs(me,cc), [], 1 ); + me = me(curDataIdxs(me,cc) == m); + if numel( me ) == 1, break; end; + end + newDataIdxs(ndiIdx,:) = curDataIdxs(me(1),:); + iis = []; + for ii = 1 : nargs + if any( ii == me ) + if nargout > 1 + newSdo{ndiIdx,1} = cat( 1, newSdo{ndiIdx,1}, sdo{rowIdxs{ii}(riIdx(ii)),1} ); + newSdo{ndiIdx,2} = cat( 1, newSdo{ndiIdx,2}, sdo{rowIdxs{ii}(riIdx(ii)),2} ); + end + args{ii} = obj.data(rowIdxs{ii}(riIdx(ii))); + riIdx(ii) = riIdx(ii) + 1; + iis(end+1) = ii; + else + args{ii} = obj.dataConvert( 0 ); + end + end + newData(ndiIdx) = fun(args{:}); + ndiIdx = ndiIdx + 1; + end + newDataIdxs(ndiIdx:end,:) = []; + newData(ndiIdx:end) = []; + newSdo(ndiIdx:end,:) = []; + delIdxs = unique( cat( 1, rowIdxs{:} ) ); + obj = obj.deleteData( delIdxs ); + [obj,incidxs,insidxs] = obj.addData( newDataIdxs, newData ); + if nargout > 1 + sdo(delIdxs,:) = []; + sdo(end+1,:) = {[],[]}; + sdo = sdo(incidxs,:); + sdo(insidxs,:) = newSdo; + end + fprintf( '\n' ); + end + %% ------------------------------------------------------------------------------- + + function [obj, sdo] = combineFun( obj, fun, cdim, argIdxs, cidx, newDataType, sdo ) + if nargin > 5 && ~isempty( newDataType ) + obj = obj.setDataType( newDataType ); + end + nargs = numel( argIdxs ); + diDim = size( obj.dataIdxs, 2 ); + dDim = size( obj.data, 2 ); + argRowIdxs = cell( 1, nargs ); + argDataIdxs = cell( 1, nargs ); + argData = cell( 1, nargs ); + argGroups = cell( 1, nargs ); + for ii = 1 : nargs + idxsMask = repmat( {':'}, 1, diDim ); + idxsMask{cdim} = @(x)(x == argIdxs(ii)); + argRowIdxs{ii} = obj.getRowIdxs( idxsMask ); + argDataIdxs{ii} = obj.dataIdxs(argRowIdxs{ii},:); + argData{ii} = obj.data(argRowIdxs{ii},:); + argGroups{ii} = repmat( ii, size( argRowIdxs{ii}, 1 ), 1 ); + end + if all( cellfun( @isempty, argRowIdxs ) ) + return; + end + obj.dataIdxs(cat( 1, argRowIdxs{:} ),:) = []; + obj.data(cat( 1, argRowIdxs{:} ),:) = []; + funnedDataIdxs = cat( 1, argDataIdxs{:} ); + funnedDataIdxs(:,cdim) = cidx; + [funnedDataIdxs,~,ic] = unique( funnedDataIdxs, 'rows' ); + funnedData = obj.dataConvert( zeros( size( funnedDataIdxs, 1 ), dDim ) ); + argGroups = cat( 1, argGroups{:} ); + icGrouped = splitapply( @(x)({x}), ic, argGroups ); + argData2beFunned = repmat( {funnedData}, 1, nargs ); + for ii = 1 : nargs + argData2beFunned{ii}(icGrouped{ii},:) = argData{ii}; + end + funnedData = fun( argData2beFunned{:} ); + obj = obj.addData( funnedDataIdxs, funnedData, true ); + end + %% ------------------------------------------------------------------------------- + + function [mat, sdomat] = resc2mat( obj, ridx2midx, rowIdxs, sdo ) + if nargin < 2 || isempty( ridx2midx ) + ridx2midx = repmat( {@(idx)(idx)}, 1, size( obj.dataIdxs, 2 ) ); + end + if nargin < 3 || isempty( rowIdxs ) + rowIdxs = 1 : size( obj.dataIdxs, 1 ); + end + midxs = obj.dataIdxs(rowIdxs,:); + for ii = 1 : size( obj.dataIdxs, 2 ) + midxs(:,ii) = ridx2midx{ii}( midxs(:,ii) ); + end + maxMidxs = num2cell( max( midxs, [], 1 ) ); + minMidxs = int16( min( midxs, [], 1 ) ); + mat(maxMidxs{:}) = 0; + if nargout > 1 + sdomat{maxMidxs{:},2} = {}; + end + midxs = midxs + repmat( uint8(max(0, 1 - minMidxs)), size( midxs, 1 ), 1 ); + midxs = num2cell( midxs ); + for ii = 1 : size( midxs, 1 ) + mat(midxs{ii,:}) = obj.data(rowIdxs(ii),:); + if nargout > 1 + sdomat(midxs{ii,:},:) = sdo(rowIdxs(ii),:); + end + end + end + %% ------------------------------------------------------------------------------- + + end + %% ----------------------------------------------------------------------------------- + + methods (Static) + + function [idxAreEq,idxAisltB,idxAisgtB] = idxsCmp( idxsA, idxsB ) + idxAreEq = 0; idxAisltB = 0; idxAisgtB = 0; +% if numel( idxsA ) ~= numel( idxsB ) +% error( 'AMLTTP:implementation:unexpected', 'This should not have happened.' ); +% end + for ii = 1 : size( idxsA, 2 ) + if idxsA(ii) < idxsB(ii) + idxAisltB = 1; + return; + elseif idxsA(ii) > idxsB(ii) + idxAisgtB = 1; + return; + end + end + idxAreEq = 1; + end + %% ------------------------------------------------------------------------------- + +end + +end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/addDpiToResc.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/addDpiToResc.m new file mode 100644 index 0000000..7f8e3d3 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/addDpiToResc.m @@ -0,0 +1,37 @@ +function resc = addDpiToResc( resc, assignments, bapi ) + +if isempty( bapi ), return; end + +ci = zeros( numel( bapi ), 1 ); +for aa = 1:4 % 1: TP, 2: TN, 3: FP, 4: FN + ci = ci + aa*[assignments{aa}]; +end + +bapiFields = fieldnames( bapi ); +bapiFields = [{'counts'}; bapiFields]; +if isfield( resc, 'id' ) && ~isempty( resc.id ) + if numel( bapiFields ) ~= numel( fieldnames( resc.id ) ) || ... + ~all( strcmpi( bapiFields, fieldnames( resc.id ) ) ) + error( 'AMLTTP:apiUsage', 'existing RESC structure differs from BAPI to be added' ); + end +else + resc.id.counts = 1; +end + +C = zeros( numel( bapi ), numel( bapiFields ) ); +C(:,1) = ci; +for ii = 2 : numel( bapiFields ) + if isfield( resc.id, bapiFields{ii} ) + ii_ = resc.id.(bapiFields{ii}); + else + ii_ = ii; + end + C(:,ii_) = cat( 1, bapi.(bapiFields{ii}) ); + resc.id.(bapiFields{ii}) = ii_; +end + +[C,~,ic] = unique( C, 'rows' ); +paramFactor = arrayfun( @(x)(sum( x == ic )), 1:size( C, 1 ) ); +resc = resc.addData( C, paramFactor', true ); + +end \ No newline at end of file diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/addSnrs.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/addSnrs.m new file mode 100644 index 0000000..25ec46b --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/addSnrs.m @@ -0,0 +1,8 @@ +function snr = addSnrs( snrs ) + +snr = 0; +for ii = 1 : numel( snrs ) + snr = snr + 10^(-snrs(ii)/10); +end +snr = 1 / snr; +snr = 10 * log10( snr ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/aggregateBlockAnnotations.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/aggregateBlockAnnotations.m new file mode 100644 index 0000000..f926d38 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/aggregateBlockAnnotations.m @@ -0,0 +1,67 @@ +function [ag, asgn] = aggregateBlockAnnotations( bap, yp, yt ) + +[ytIdxR,ytIdxC] = find( yt > 0 ); +assert( numel( unique( ytIdxR ) ) == numel( ytIdxR ) ); % because I defined it in my test scripts: target sounds only on src1 +isyt = false( size( bap, 1 ), 1 ); +isyt(ytIdxR) = true; +isyp = any( yp > 0, 2 ); + +asgn(:,1) = isyp & isyt; +asgn(:,2) = ~isyp & ~isyt; +asgn(:,3) = isyp & ~isyt; +asgn(:,4) = ~isyp & isyt; + +ag = bap(:,1); +[ag.nAct_segStream] = deal( nan ); + +% tmp = reshape( [bap.multiSrcsAttributability], size( bap ) ); +% tmp = num2cell( nanMean( tmp, 2 ) ); +% [ag.multiSrcsAttributability] = tmp{:}; + +if sum( isyt ) > 0 +ytIdxs = sub2ind( size( yt ), ytIdxR, ytIdxC ); +[ag(isyt).curSnr] = bap(ytIdxs).curSnr; +[ag(isyt).curNrj] = bap(ytIdxs).curNrj; +[ag(isyt).curNrjOthers] = bap(ytIdxs).curNrjOthers; +[ag(isyt).curSnr_db] = bap(ytIdxs).curSnr_db; +[ag(isyt).curNrj_db] = bap(ytIdxs).curNrj_db; +[ag(isyt).curNrjOthers_db] = bap(ytIdxs).curNrjOthers_db; +[ag(isyt).curSnr2] = bap(ytIdxs).curSnr2; +[ag(isyt).azmErr] = bap(ytIdxs).azmErr; +[ag(isyt).dist2bisector] = bap(ytIdxs).dist2bisector; +[ag(isyt).blockClass] = bap(ytIdxs).blockClass; +[ag(isyt).gtAzm] = bap(ytIdxs).gtAzm; +[ag(isyt).estAzm] = bap(ytIdxs).estAzm; +end + +if sum( ~isyt ) > 0 +tmp = reshape( double( [bap(~isyt,:).curSnr] ), size( bap(~isyt,:) ) ); +[~,maxCurSnrIdx] = max( tmp, [], 2 ); +nIdxs = sub2ind( size( yt ), find( ~isyt ), maxCurSnrIdx ); +[ag(~isyt).curSnr] = bap(nIdxs).curSnr; +[ag(~isyt).curNrj] = bap(nIdxs).curNrj; +[ag(~isyt).curNrjOthers] = bap(nIdxs).curNrjOthers; +tmp = reshape( double( [bap(~isyt,:).curSnr_db] ), size( bap(~isyt,:) ) ); +[~,maxCurSnrIdx] = max( tmp, [], 2 ); +nIdxs = sub2ind( size( yt ), find( ~isyt ), maxCurSnrIdx ); +[ag(~isyt).curSnr_db] = bap(nIdxs).curSnr_db; +[ag(~isyt).curNrj_db] = bap(nIdxs).curNrj_db; +[ag(~isyt).curNrjOthers_db] = bap(nIdxs).curNrjOthers_db; +tmp = reshape( double( [bap(~isyt,:).curSnr2] ), size( bap(~isyt,:) ) ); +[~,maxCurSnrIdx] = max( tmp, [], 2 ); +nIdxs = sub2ind( size( yt ), find( ~isyt ), maxCurSnrIdx ); +[ag(~isyt).curSnr2] = bap(nIdxs).curSnr2; +[ag(~isyt).dist2bisector] = bap(nIdxs).dist2bisector; +[ag(~isyt).blockClass] = bap(nIdxs).blockClass; +[ag(~isyt).gtAzm] = bap(nIdxs).gtAzm; +[ag(~isyt).estAzm] = bap(nIdxs).estAzm; +[ag(~isyt).azmErr] = deal( nan ); +end + +end + +function v = nanIfEmpty( v ) +if isempty( v ) + v = nan; +end +end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/aggregateBlockAnnotations2.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/aggregateBlockAnnotations2.m new file mode 100644 index 0000000..a0a2d68 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/aggregateBlockAnnotations2.m @@ -0,0 +1,141 @@ +function [ag, asgn] = aggregateBlockAnnotations2( bap, yp, yt ) + +ag = bap; +validBaps = ~isnan( arrayfun( @(ax)(ax.scpId), bap ) ); + +isyt = yt > 0; +[ytIdxR,ytIdxC] = find( isyt ); +assert( numel( unique( ytIdxR ) ) == numel( ytIdxR ) ); % because I defined it in my test scripts: target sounds only on src1 +isytR = any( isyt, 2 ); +isyp = yp > 0; +isypR = any( isyp, 2 ); +istpR = isytR & isypR; +tpIdxR = ytIdxR(istpR(ytIdxR)); +tpIdxC = ytIdxC(istpR(ytIdxR)); +tpIdx = sub2ind( size( yt ), tpIdxR, tpIdxC ); + +%% compute dist2bisector + +selfIdx = 1 : numel( bap ); +nonemptyBaps = validBaps & ~isnan( arrayfun( @(ax)(ax.gtAzm), bap ) ); +selfIdx = selfIdx(nonemptyBaps(selfIdx)); +[selfIdxR,selfIdxC] = ind2sub( size( bap ), selfIdx ); +otherIdxs = arrayfun( ... + @(r,c)(sub2ind( size( bap ), repmat( r, 1, size( bap, 2 )-1 ), [1:c-1 c+1:size( bap, 2 )] )), ... + selfIdxR, selfIdxC, 'UniformOutput', false ); +otherIdxs = cellfun( @(c)(c(nonemptyBaps(c))), otherIdxs, 'UniformOutput', false ); + +selfGtAzms = wrapTo180( [bap(selfIdx).gtAzm] ); +selfGtAzms = sign(selfGtAzms).*abs(abs(abs(selfGtAzms)-90)-90); % mirror to frontal hemisphere +otherGtAzms = cellfun( @(c)(wrapTo180( [bap(c).gtAzm] )), otherIdxs, 'UniformOutput', false ); +otherGtAzms = cellfun( @(c)(sign(c).*abs(abs(abs(c)-90)-90)), otherGtAzms, 'UniformOutput', false ); +bisectAzms = cellfun( @(s,o)(s + ( o - s )/2), num2cell( selfGtAzms ), otherGtAzms, 'UniformOutput', false ); +spreads = cellfun( @(s,o)(max(eps,abs( o - s ))), num2cell( selfGtAzms ), otherGtAzms, 'UniformOutput', false ); +bisectNormAzms = cellfun( @(b,s)((s - 2*abs( b ))./s), bisectAzms, spreads, 'UniformOutput', false ); +isBnaNeg = cellfun( @(c)(c < 0), bisectNormAzms, 'UniformOutput', false ); +bisectNormAzmsNeg = cellfun( @(b,s)((abs(b)-s/2)./(90-s/2)), ... + bisectAzms, spreads, 'UniformOutput', false ); +bisectNormAzms = cellfun( @(bp,bn,isn)(nansum( [-isn.*bn;(~isn).*bp;repmat(isempty(bp),1,max(1,numel(bp)))], 1 )), ... + bisectNormAzms, bisectNormAzmsNeg, isBnaNeg, 'UniformOutput', false ); +otherSnrs = cellfun( @(c)([bap(c).curSnr2]), otherIdxs, 'UniformOutput', false ); +otherSnrs = cellfun( @(c)(c - max(c)), otherSnrs, 'UniformOutput', false ); +otherSnrNorms = cellfun( @(c)(max(0,1./abs(c-1).^0.2 - 0.4.*abs(c)./100)), otherSnrs, 'UniformOutput', false ); +otherSnrNorms(cellfun(@isempty,otherSnrNorms)) = {1}; + +dist2bisector = cellfun( @(b,s)(double(b)*double(s)'/sum(double(s))), bisectNormAzms, otherSnrNorms, 'UniformOutput', false ); +[ag(selfIdx).dist2bisector] = dist2bisector{:}; + +%% assign tp (and following fp,fn,tn) per time instead of per block + +istp_ = false( size( ag ) ); +if ~isempty( tpIdxR ) + tp_gtAzms = [bap(tpIdx).gtAzm]; + assert( all( ~isnan( tp_gtAzms ) ) ); + azmErrs = arrayfun( @(x)(x.estAzm), bap(tpIdxR,:) ) - repmat( tp_gtAzms', 1, size( bap, 2 ) ); + azmErrs = abs( wrapTo180( azmErrs ) ); + azmErrs(~isyp(tpIdxR,:)) = nan; + [tpAzmErr,tpIdxC_] = min( azmErrs, [], 2 ); + tpAzmErr2 = mean( azmErrs, 2 ); + tpIdx_ = sub2ind( size( ag ), tpIdxR, tpIdxC_ ); + istp_(tpIdx_) = true; +end + +isfp_ = isyp & ~istp_; + +isfnR = isytR & ~isypR; +isfn_ = repmat( isfnR, 1, size( isyt, 2 ) ) & isyt; + +istn_ = ~isfn_ & ~isyp & validBaps; + +%% assign case-insensitive baParams changes + +% [ag.nAct_segStream] = deal( nan ); + +%% assign case-sensitive baParams changes + +if ~isempty( tpIdxR ) + acell = num2cell( tpAzmErr ); + [ag(tpIdx_).azmErr] = acell{:}; + acell = num2cell( tpAzmErr2 ); + [ag(tpIdx_).azmErr2] = acell{:}; + acell = num2cell( [bap(tpIdx).curSnr] ); + a2cell = num2cell( [bap(tpIdx_).curSnr] ); + [ag(tpIdx).curSnr] = a2cell{:}; + [ag(tpIdx_).curSnr] = acell{:}; + acell = num2cell( [bap(tpIdx).curNrj] ); + a2cell = num2cell( [bap(tpIdx_).curNrj] ); + [ag(tpIdx).curNrj] = a2cell{:}; + [ag(tpIdx_).curNrj] = acell{:}; + acell = num2cell( [bap(tpIdx).curNrjOthers] ); + a2cell = num2cell( [bap(tpIdx_).curNrjOthers] ); + [ag(tpIdx).curNrjOthers] = a2cell{:}; + [ag(tpIdx_).curNrjOthers] = acell{:}; + acell = num2cell( [bap(tpIdx).curSnr_db] ); + a2cell = num2cell( [bap(tpIdx_).curSnr_db] ); + [ag(tpIdx).curSnr_db] = a2cell{:}; + [ag(tpIdx_).curSnr_db] = acell{:}; + acell = num2cell( [bap(tpIdx).curNrj_db] ); + a2cell = num2cell( [bap(tpIdx_).curNrj_db] ); + [ag(tpIdx).curNrj_db] = a2cell{:}; + [ag(tpIdx_).curNrj_db] = acell{:}; + acell = num2cell( [bap(tpIdx).curNrjOthers_db] ); + a2cell = num2cell( [bap(tpIdx_).curNrjOthers_db] ); + [ag(tpIdx).curNrjOthers_db] = a2cell{:}; + [ag(tpIdx_).curNrjOthers_db] = acell{:}; + acell_curSnr2 = num2cell( [bap(tpIdx).curSnr2] ); + a2cell = num2cell( [bap(tpIdx_).curSnr2] ); + [ag(tpIdx).curSnr2] = a2cell{:}; + [ag(tpIdx_).curSnr2] = acell_curSnr2{:}; + acell = num2cell( [ag(tpIdx).dist2bisector] ); + acell2 = num2cell( [ag(tpIdx_).dist2bisector] ); + [ag(tpIdx).dist2bisector] = acell2{:}; + [ag(tpIdx_).dist2bisector] = acell{:}; + acell = num2cell( [bap(tpIdx).blockClass] ); + acell2 = num2cell( [bap(tpIdx_).blockClass] ); + [ag(tpIdx).blockClass] = acell2{:}; + [ag(tpIdx_).blockClass] = acell{:}; + acell = num2cell( [bap(tpIdx).gtAzm] ); + acell2 = num2cell( [bap(tpIdx_).gtAzm] ); + [ag(tpIdx).gtAzm] = acell2{:}; + [ag(tpIdx_).gtAzm] = acell{:}; + acell = num2cell( [bap(tpIdx).estAzm] ); + acell2 = num2cell( [bap(tpIdx_).estAzm] ); + [ag(tpIdx).estAzm] = acell2{:}; + [ag(tpIdx_).estAzm] = acell{:}; +end + +[ag(isytR,:).posPresent] = deal( 1 ); +[ag(~isytR,:).posPresent] = deal( 0 ); +acell_curSnr2 = repmat( num2cell( [bap(isyt).curSnr2] )', 1, size( ag, 2 ) ); +[ag(isytR,:).posSnr] = acell_curSnr2{:}; + +%% reshape assignments and aggregate baParams + +asgn(:,1) = istp_(validBaps); +asgn(:,2) = istn_(validBaps); +asgn(:,3) = isfp_(validBaps); +asgn(:,4) = isfn_(validBaps); +ag = ag(validBaps); + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/applyIfNempty.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/applyIfNempty.m new file mode 100644 index 0000000..aa2b578 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/applyIfNempty.m @@ -0,0 +1,10 @@ +function varargout = applyIfNempty( x, fun ) + +if isempty( x ) + varargout(1:nargout) = cell( 1, nargout ); + return; +end + +[varargout{1:nargout}] = fun( x ); + +end \ No newline at end of file diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/baParams2bapIdxs.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/baParams2bapIdxs.m new file mode 100644 index 0000000..2a5d890 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/baParams2bapIdxs.m @@ -0,0 +1,57 @@ +function baParamIdxs = baParams2bapIdxs( baParams ) + +emptyBapi = nanRescStruct; +baParamIdxs = repmat( emptyBapi, numel( baParams ), 1); + +tmp = num2cell( nan2inf( [baParams.classIdx] ) ); +[baParamIdxs.classIdx] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.dd] ) ); +[baParamIdxs.dd] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.nAct] + 1 ) ); +[baParamIdxs.nAct] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.curSnr]+35)/5 ) + 1 ) ); +[baParamIdxs.curSnr] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.curSnr_db]+35)/5 ) + 1 ) ); +[baParamIdxs.curSnr_db] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.curSnr2]+35)/5 ) + 1 ) ); +[baParamIdxs.curSnr2] = tmp{:}; +tmp = num2cell( nan2inf( round( [baParams.azmErr]/5 ) + 1 ) ); +[baParamIdxs.azmErr] = tmp{:}; +tmp = num2cell( nan2inf( round( [baParams.azmErr2]/5 ) + 1 ) ); +[baParamIdxs.azmErr2] = tmp{:}; +tmp = num2cell( nan2inf( round( (wrapTo180([baParams.gtAzm])+180)/5 ) + 1 ) ); +[baParamIdxs.gtAzm] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.nEstErr] + 4 ) ); +[baParamIdxs.nEstErr] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.nAct_segStream] + 1 ) ); +[baParamIdxs.nAct_segStream] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.curNrj]+35)/5 ) + 1 ) ); +[baParamIdxs.curNrj] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.curNrj_db]+35)/5 ) + 1 ) ); +[baParamIdxs.curNrj_db] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.curNrjOthers]+35)/5 ) + 1 ) ); +[baParamIdxs.curNrjOthers] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.curNrjOthers_db]+35)/5 ) + 1 ) ); +[baParamIdxs.curNrjOthers_db] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.scpId] ) ); +[baParamIdxs.scpId] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.fileId] ) ); +[baParamIdxs.fileId] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.fileClassId] ) ); +[baParamIdxs.fileClassId] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.posPresent] + 1 ) ); +[baParamIdxs.posPresent] = tmp{:}; +tmp = num2cell( nan2inf( round( ([baParams.posSnr]+35)/5 ) + 1 ) ); +[baParamIdxs.posSnr] = tmp{:}; +tmp = num2cell( nan2inf( [baParams.blockClass] ) ); +[baParamIdxs.blockClass] = tmp{:}; +tmp = num2cell( nan2inf( ([baParams.dist2bisector]+1)*10 + 1 ) ); +[baParamIdxs.dist2bisector] = tmp{:}; + +baParamIdxs = rmfield( baParamIdxs, 'estAzm' ); + +end + +function v = nan2inf( v ) +v(isnan( v ))= inf; +end \ No newline at end of file diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/createTFfigure.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/createTFfigure.m new file mode 100644 index 0000000..3b08a4d --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/createTFfigure.m @@ -0,0 +1,36 @@ +function createTFfigure( tfData, cm, clim ) + +if nargin < 2 || isempty( cm ), cm = 'parula'; end +if nargin < 3 || isempty( clim ) + cl{1} = 'CLimMode'; + cl{2} = 'auto'; +else + cl{1} = 'CLimMode'; + cl{2} = 'manual'; + cl{3} = 'CLim'; + cl{4} = clim; +end + +figure1 = figure; +colormap( cm ); + +axes1 = axes( 'Parent', figure1 ); +hold( axes1, 'on' ); + +image( tfData, 'Parent', axes1, 'CDataMapping', 'scaled' ); + +xlabel('t/s'); +ylabel('f/Hz'); +xlim(axes1,[0.5 100.5]); +ylim(axes1,[0.5 32.5]); + +box(axes1,'on'); + +set( axes1, 'FontSize', 12, 'Layer', 'top',... + 'XTick', [10,20,30,40,50,60,70,80,90,100], ... + 'XTickLabel', {'0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9','1'}, ... + 'YTick', [1 8 16 24 32], 'YTickLabel', {'80','242','572','1131','8000'}, ... + cl{:} ); + +%colorbar( 'peer', axes1 ); + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/createTrainTestSplitFlists.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/createTrainTestSplitFlists.m index 0c2a7ea..fc2284f 100644 --- a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/createTrainTestSplitFlists.m +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/createTrainTestSplitFlists.m @@ -2,15 +2,15 @@ function createTrainTestSplitFlists( inputFlist, outputName, baseDir, nFolds, on if nargin < 5, oneFoldForTrain = false; end; -allData = core.IdentTrainPipeData(); -allData.loadWavFileList( inputFlist ); +allData = Core.IdentTrainPipeData(); +allData.loadFileList( inputFlist ); folds = allData.splitInPermutedStratifiedFolds( nFolds ); for ff = 1 : nFolds foldsIdx = 1 : nFolds; foldsIdx(ff) = []; - foldCombi = core.IdentTrainPipeData.combineData( folds{foldsIdx} ); + foldCombi = Core.IdentTrainPipeData.combineData( folds{foldsIdx} ); if oneFoldForTrain combiTStr = 'Test'; oneTStr = 'Train'; @@ -18,7 +18,7 @@ function createTrainTestSplitFlists( inputFlist, outputName, baseDir, nFolds, on combiTStr = 'Train'; oneTStr = 'Test'; end - foldCombi.saveDataFList( [outputName '_' combiTStr 'Set_' int2str(ff) '.flist'], baseDir ); - folds{ff}.saveDataFList( [outputName '_' oneTStr 'Set_' int2str(ff) '.flist'], baseDir ); + foldCombi.saveFList( [outputName '_' combiTStr 'Set_' int2str(ff) '.flist'], baseDir ); + folds{ff}.saveFList( [outputName '_' oneTStr 'Set_' int2str(ff) '.flist'], baseDir ); end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/extractBAparams.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/extractBAparams.m new file mode 100644 index 0000000..9ce6d48 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/extractBAparams.m @@ -0,0 +1,127 @@ +function [baParams, asgn] = extractBAparams( blockAnnotations, scp, yp, yt ) + +asgn{1} = ((yp == yt) & (yp > 0)); +asgn{2} = ((yp == yt) & (yp < 0)); +asgn{3} = ((yp ~= yt) & (yp > 0)); +asgn{4} = ((yp ~= yt) & (yp < 0)); + +emptyBap = nanRescStruct; +emptyBap.scpId = scp.id; +emptyBap.fileId = scp.fileId; +emptyBap.fileClassId = scp.fileClassId; +emptyBap.classIdx = scp.classIdx; +emptyBap.dd = scp.dd; +baParams = repmat( emptyBap, numel(yt), 1); +isSegId = isfield( blockAnnotations, 'estAzm' ); + +isP = yt > 0; + +baNsa = [blockAnnotations.nActivePointSrcs]'; +tmp = num2cell( baNsa ); +[baParams.nAct] = tmp{:}; +baPp = [blockAnnotations.posPresent]'; +tmp = num2cell( baPp ); +[baParams.posPresent] = tmp{:}; +if sum( baPp ) > 0 + baPs = [blockAnnotations.posSnr]'; + tmp = num2cell( baPs ); + [baParams(logical(baPp)).posSnr] = tmp{:}; +end +if isSegId + baSrcAzms = {blockAnnotations.srcAzms}'; + tmp = num2cell( [blockAnnotations.nSrcs_estimationError] ); + [baParams.nEstErr] = tmp{:}; + tmp = num2cell( min( cellfun( @numel, baSrcAzms ), baNsa ) ); + [baParams.nAct_segStream] = tmp{:}; + estAzm = [blockAnnotations.estAzm]'; + srcAzmP_ = cellfun( @(x)(x(1)), baSrcAzms(isP) ); + tmp = num2cell( srcAzmP_ ); + [baParams(isP).gtAzm] = tmp{:}; + tmp = num2cell( estAzm ); + [baParams.estAzm] = tmp{:}; + azmErr = num2cell( abs( wrapTo180( srcAzmP_ - estAzm(isP) ) ) ); + [baParams(isP).azmErr] = azmErr{:}; + [baParams(isP).azmErr2] = azmErr{:}; +end + +baSrcSnr = {blockAnnotations.srcSNRactive}'; +baSrcSnr_db = {blockAnnotations.srcSNR_db}'; +baSrcSnr2 = {blockAnnotations.srcSNR2}'; +baSrcNrj = {blockAnnotations.nrj}'; +baSrcNrj_db = {blockAnnotations.nrj_db}'; +baSrcNrjOthers = {blockAnnotations.nrjOthers}'; +baSrcNrjOthers_db = {blockAnnotations.nrjOthers_db}'; + +if any( isP ) +% if is positive, the first src in the stream is the positive one, because +% of the restriction of positives to the first source in a scene config +curSnrP_ = cellfun( @(x)(x{1}), baSrcSnr(isP) ); +curSnrP = num2cell( min( max( curSnrP_, -35 ), 35 ) ); +[baParams(isP).curSnr] = curSnrP{:}; +curNrjoP_ = cellfun( @(x)(x{1}), baSrcNrjOthers(isP) ); +curNrjoP = num2cell( min( max( curNrjoP_, -35 ), 35 ) ); +[baParams(isP).curNrjOthers] = curNrjoP{:}; +curNrjP_ = cellfun( @(x)(x{1}), baSrcNrj(isP) ); +curNrjP = num2cell( min( max( curNrjP_, -35 ), 35 ) ); +[baParams(isP).curNrj] = curNrjP{:}; +curSnr_dbP_ = cellfun( @(x)(x{1}), baSrcSnr_db(isP) ); +curSnr_dbP = num2cell( min( max( curSnr_dbP_, -35 ), 35 ) ); +[baParams(isP).curSnr_db] = curSnr_dbP{:}; +curNrjo_dbP_ = cellfun( @(x)(x{1}), baSrcNrjOthers_db(isP) ); +curNrjo_dbP = num2cell( min( max( curNrjo_dbP_, -35 ), 35 ) ); +[baParams(isP).curNrjOthers_db] = curNrjo_dbP{:}; +curNrj_dbP_ = cellfun( @(x)(x{1}), baSrcNrj_db(isP) ); +curNrj_dbP = num2cell( min( max( curNrj_dbP_, -35 ), 35 ) ); +[baParams(isP).curNrj_db] = curNrj_dbP{:}; +curSnr2P_ = cellfun( @(x)(x{1}), baSrcSnr2(isP) ); +curSnr2P = num2cell( min( max( curSnr2P_, -35 ), 35 ) ); +[baParams(isP).curSnr2] = curSnr2P{:}; +[baParams(isP).posSnr] = curSnr2P{:}; +end + +nCond = (yt < 0) & (~isSegId | ~cellfun( @isempty, baSrcSnr )); + +if any( nCond ) +% if is negative, the most dominant src in the stream is selected +[~,curSnr2NmaxIdx] = cellfun( @(x)(max( cell2mat( x ) )), baSrcSnr2(nCond), 'UniformOutput', false ); +curSnr2N_ = cellfun( @(x,x2)(x{x2}), baSrcSnr2(nCond), curSnr2NmaxIdx ); +curSnr2N = num2cell( min( max( curSnr2N_, -35 ), 35 ) ); +[baParams(nCond).curSnr2] = curSnr2N{:}; +curNrjoN_ = cellfun( @(x,x2)(x{x2}), baSrcNrjOthers(nCond), curSnr2NmaxIdx ); +curNrjoN = num2cell( min( max( curNrjoN_, -35 ), 35 ) ); +[baParams(nCond).curNrjOthers] = curNrjoN{:}; +curNrjN_ = cellfun( @(x,x2)(x{x2}), baSrcNrj(nCond), curSnr2NmaxIdx ); +curNrjN = num2cell( min( max( curNrjN_, -35 ), 35 ) ); +[baParams(nCond).curNrj] = curNrjN{:}; +curSnrN_ = cellfun( @(x,x2)(x{x2}), baSrcSnr(nCond), curSnr2NmaxIdx ); +curSnrN = num2cell( min( max( curSnrN_, -35 ), 35 ) ); +[baParams(nCond).curSnr] = curSnrN{:}; +curNrjo_dbN_ = cellfun( @(x,x2)(x{x2}), baSrcNrjOthers_db(nCond), curSnr2NmaxIdx ); +curNrjo_dbN = num2cell( min( max( curNrjo_dbN_, -35 ), 35 ) ); +[baParams(nCond).curNrjOthers_db] = curNrjo_dbN{:}; +curNrj_dbN_ = cellfun( @(x,x2)(x{x2}), baSrcNrj_db(nCond), curSnr2NmaxIdx ); +curNrj_dbN = num2cell( min( max( curNrj_dbN_, -35 ), 35 ) ); +[baParams(nCond).curNrj_db] = curNrj_dbN{:}; +curSnr_dbN_ = cellfun( @(x,x2)(x{x2}), baSrcSnr_db(nCond), curSnr2NmaxIdx ); +curSnr_dbN = num2cell( min( max( curSnr_dbN_, -35 ), 35 ) ); +[baParams(nCond).curSnr_db] = curSnr_dbN{:}; +curAzmN_ = cellfun( @(x,x2)(x(x2)), baSrcAzms(nCond), curSnr2NmaxIdx ); +curAzmN = num2cell( curAzmN_ ); +[baParams(nCond).gtAzm] = curAzmN{:}; +end + +bafiles = cellfun( @(c)(c.srcFile), {blockAnnotations.srcFile}', 'UniformOutput', false ); +nonemptybaf = ~cellfun( @isempty, bafiles ); +bafiles = cellfun( @(c)(applyIfNempty( c, @(x)(x{1}) )), bafiles(nonemptybaf), 'UniformOutput', false ); +bafClasses = cellfun( @fileparts, bafiles, 'UniformOutput', false ); +[~,bafClasses] = cellfun( @fileparts, bafClasses, 'UniformOutput', false ); +niClasses = {{'alarm'},{'baby'},{'femaleSpeech'},{'fire'},{'crash'},{'dog'},... + {'engine'},{'footsteps'},{'knock'},{'phone'},{'piano'},... + {'maleSpeech'},{'femaleScream','maleScream'},{'general'}}; +bafClassIdxs = cellfun( ... + @(x)( find( cellfun( @(c)(any( strcmpi( x, c ) )), niClasses ) ) ), bafClasses, ... + 'UniformOutput', false ); +[baParams(nonemptybaf).blockClass] = bafClassIdxs{:}; + + +end \ No newline at end of file diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/findSameTimeBlocks.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/findSameTimeBlocks.m new file mode 100644 index 0000000..12f4ec8 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/findSameTimeBlocks.m @@ -0,0 +1,12 @@ +function [blockAnnotations,yt,yp,sameTimeIdxs] = findSameTimeBlocks( blockAnnotations,yt,yp ) + +[~,~,sameTimeIdxs] = unique( [blockAnnotations.blockOffset] ); +for bb = 1 : max( sameTimeIdxs ) + [blockAnnotations(sameTimeIdxs==bb).allGtAzms] = deal( [blockAnnotations(sameTimeIdxs==bb).srcAzms] ); + if any( yt(sameTimeIdxs==bb) == 1 ) + [blockAnnotations(sameTimeIdxs==bb).posSnr] = deal( blockAnnotations(sameTimeIdxs==bb & yt==1).srcSNR2{1} ); + end +end + +end + diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/nanRescStruct.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/nanRescStruct.m new file mode 100644 index 0000000..0a42eb7 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/nanRescStruct.m @@ -0,0 +1,29 @@ +function nrs = nanRescStruct() + +sdef = {'classIdx',nan,... + 'dd',nan,... + 'nAct',nan,... + 'curSnr',nan,... + 'curSnr_db',nan,... + 'curSnr2',nan,... + 'dist2bisector',nan,... + 'azmErr',nan,... + 'azmErr2',nan,... + 'nEstErr',nan,... + 'nAct_segStream',nan,... + 'curNrj',nan, ... + 'curNrj_db',nan, ... + 'curNrjOthers',nan, ... + 'curNrjOthers_db',nan, ... + 'scpId', nan,... + 'fileId', nan,... + 'fileClassId', nan,... + 'blockClass', nan,... + 'gtAzm',nan,... + 'estAzm',nan,... + 'posPresent',nan,... + 'posSnr',nan,... + }; +nrs = struct( sdef{:} ); + +end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/pathInsert.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/pathInsert.m new file mode 100644 index 0000000..6077faf --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/pathInsert.m @@ -0,0 +1,4 @@ +function p = pathInsert( p, pinsert, level ) + +pSeps = strfind( p, filesep ); +p = cleanPathFromRelativeRefs( [p(1:pSeps(end+level)) pinsert '/' p(pSeps(end+level)+1:end)] ); diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/shrinkCacheFileSizes.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/shrinkCacheFileSizes.m new file mode 100644 index 0000000..296ddb0 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/shrinkCacheFileSizes.m @@ -0,0 +1,56 @@ +function shrinkCacheFileSizes( cacheDir, cd_idxs ) + +cfgFolders = dir( cacheDir ); +cfgFolders = cfgFolders([cfgFolders.isdir]); +cfgFolders(1:2) = []; + +nextProgressOutStep_ii = 10; +if nargin < 2 || isempty( cd_idxs ), cd_idxs = 1 : numel( cfgFolders ); end; +for ii = cd_idxs + if ii > numel( cfgFolders ), break; end + if round( ii*100/numel( cfgFolders ) ) >= nextProgressOutStep_ii + fprintf( ':' ); + nextProgressOutStep_ii = nextProgressOutStep_ii + 10; + end + cacheFiles = dir( [cacheDir filesep cfgFolders(ii).name filesep '*.wav.mat'] ); + nextProgressOutStep_jj = 10; + for jj = 1 : numel( cacheFiles ) + if round( jj*100/numel( cacheFiles ) ) >= nextProgressOutStep_jj + fprintf( '.' ); + nextProgressOutStep_jj = nextProgressOutStep_jj + 10; + end + cacheContent = load( [cacheDir filesep cfgFolders(ii).name filesep cacheFiles(jj).name] ); + cacheContentFields = fieldnames( cacheContent ); + isChanged = false; + for kk = 1 : numel( cacheContentFields ) + if strcmpi( cacheContentFields{kk}, 'blockAnnotations' ) + baFields = fieldnames( cacheContent.blockAnnotations ); + for bb = 1 : numel( baFields ) + if ~isstruct( cacheContent.blockAnnotations.(baFields{bb}) ) + continue; + end + if ~isfield( cacheContent.blockAnnotations.(baFields{bb}), baFields{bb} ) + continue; + end + if iscell( cacheContent.blockAnnotations.(baFields{bb}).(baFields{bb}) ) ... + && mean( cellfun( @numel, cacheContent.blockAnnotations.(baFields{bb}).(baFields{bb})(:) ) ) == 1 ... + && std( cellfun( @numel, cacheContent.blockAnnotations.(baFields{bb}).(baFields{bb})(:) ) ) == 0 + cacheContent.blockAnnotations.(baFields{bb}).(baFields{bb}) = ... + cell2mat( cacheContent.blockAnnotations.(baFields{bb}).(baFields{bb}) ); + isChanged = true; + end + end + end + if ~isa( cacheContent.(cacheContentFields{kk}), 'double' ), continue; end + cacheContent.(cacheContentFields{kk}) = single( cacheContent.(cacheContentFields{kk}) ); + isChanged = true; + end + if isChanged + save( [cacheDir filesep cfgFolders(ii).name filesep cacheFiles(jj).name], '-struct', 'cacheContent' ); + end + end +end + +fprintf( '\n' ); + +end diff --git a/AuditoryMachineLearningTrainingTestingPipeline/src/tools/summarizeDown.m b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/summarizeDown.m new file mode 100644 index 0000000..b2eb6c5 --- /dev/null +++ b/AuditoryMachineLearningTrainingTestingPipeline/src/tools/summarizeDown.m @@ -0,0 +1,13 @@ +function xsummed = summarizeDown( x, leaveVariables, doNotSqueeze ) + dims = 1 : ndims( x ); + dims(leaveVariables) = []; + dims = flip( dims ); + xsummed = x; + for dd = dims + xsummed = sum( xsummed, dd ); + end + if nargin < 3 || ~doNotSqueeze + xsummed = squeeze( xsummed ); + end +end + diff --git a/BinauralSimulator/src/+db/clearTmp.m b/BinauralSimulator/src/+db/clearTmp.m index fc82c19..3cd7cf0 100644 --- a/BinauralSimulator/src/+db/clearTmp.m +++ b/BinauralSimulator/src/+db/clearTmp.m @@ -3,10 +3,14 @@ function clearTmp() dir_path = db.tmp(); -dirData = dir(dir_path); -dirIndex = [dirData.isdir]; -for idx=find(dirIndex) - if dirData(idx).name(1) ~= '.' - rmdir(fullfile(dir_path, dirData(idx).name), 's') +for dirData = dir(dir_path).' + if dirData.isdir + if dirData.name(1) ~= '.' + rmdir(fullfile(dir_path, dirData.name), 's') + end + else + if dirData.name(1) ~= '.' || strcmp(dirData.name, '.dir.flist') + delete(fullfile(dir_path, dirData.name)); + end end end diff --git a/BinauralSimulator/src/+db/downloadFile.m b/BinauralSimulator/src/+db/downloadFile.m index 3f1c5ff..cdf4a8f 100644 --- a/BinauralSimulator/src/+db/downloadFile.m +++ b/BinauralSimulator/src/+db/downloadFile.m @@ -38,8 +38,9 @@ if nargin == 3 && bVerbose fprintf('Downloading file %s\n', url); end -[~, status] = urlwrite(url, outfile); -if ~status +try + websave(outfile, url, weboptions('Timeout',Inf)); +catch warning('Download failed (url=%s), trying alternative database...', url); % try with alternative URL of database url = [dbalturl, '/', filename]; @@ -47,8 +48,9 @@ if nargin == 3 && bVerbose fprintf('Downloading file %s\n', url); end - [~, status] = urlwrite(url, outfile); - if ~status + try + websave(outfile, url, weboptions('Timeout',Inf)); + catch error('Download also failed with alternative database (url=%s)', url); end end diff --git a/BinauralSimulator/src/+db/url.m b/BinauralSimulator/src/+db/url.m index 4d676e3..2151100 100644 --- a/BinauralSimulator/src/+db/url.m +++ b/BinauralSimulator/src/+db/url.m @@ -23,7 +23,7 @@ if exist('url','var') CachedURL=url; elseif isempty(CachedURL) - CachedURL= 'https://dev.qu.tu-berlin.de/projects/twoears-getdata/repository/raw/'; + CachedURL= 'https://avtshare01.rz.tu-ilmenau.de/two-ears/database/'; end url=CachedURL; @@ -31,6 +31,6 @@ if exist('alturl','var') CachedAltURL=alturl; elseif isempty(CachedAltURL) - CachedAltURL= 'https://dev.qu.tu-berlin.de/projects/twoears-database/repository/revisions/master/raw/'; + CachedAltURL= 'https://dev.qu.tu-berlin.de/projects/twoears-getdata/repository/raw/'; end alturl=CachedAltURL; \ No newline at end of file diff --git a/BlackboardSystem/src/blackboard_core/BlackboardSystem.m b/BlackboardSystem/src/blackboard_core/BlackboardSystem.m index df98173..c265855 100644 --- a/BlackboardSystem/src/blackboard_core/BlackboardSystem.m +++ b/BlackboardSystem/src/blackboard_core/BlackboardSystem.m @@ -11,6 +11,7 @@ locVis; % for visualising localisation afeVis; % for visualising AFE genderVis; % for visualising gender recognition + emDetVis; % for visualising emergency detection end methods @@ -46,6 +47,10 @@ function setLocVis(obj, locVis) obj.locVis = locVis; end + function setEmDetVis(obj, emDetVis) + obj.emDetVis = emDetVis; + end + function setAfeVis(obj, afeVis) obj.afeVis = afeVis; end diff --git a/BlackboardSystem/src/blackboard_data/LocationHypothesis.m b/BlackboardSystem/src/blackboard_data/LocationHypothesis.m index 65445e4..2d3fb71 100644 --- a/BlackboardSystem/src/blackboard_data/LocationHypothesis.m +++ b/BlackboardSystem/src/blackboard_data/LocationHypothesis.m @@ -2,8 +2,8 @@ % class LocationHypothesis represents the source location properties (SetAccess = private) - sourcesPosteriors; % Posterior distribution of source azimuths, relative to head orientation - sourceAzimuths; % Relative azimuths corresponding to sourcesPosteriors + sourcesDistribution; % Posterior distribution of source azimuths, relative to head orientation + azimuths ; % Relative azimuths corresponding to sourcesPosteriors headOrientation; % Head orientation angle azimuth; % Most likely source azimuth @@ -12,8 +12,8 @@ methods function obj = LocationHypothesis(headOrientation, sourceAzimuths, sourcesPosteriors) - obj.sourceAzimuths = sourceAzimuths; - obj.sourcesPosteriors = sourcesPosteriors; + obj.azimuths = sourceAzimuths; + obj.sourcesDistribution = sourcesPosteriors; obj.headOrientation = wrapTo360(headOrientation); [posterior,idx] = max(sourcesPosteriors); diff --git a/BlackboardSystem/src/knowledge_sources/DnnLocationCaffeKS.m b/BlackboardSystem/src/knowledge_sources/DnnLocationCaffeKS.m index 55a1b18..08e0a3d 100644 --- a/BlackboardSystem/src/knowledge_sources/DnnLocationCaffeKS.m +++ b/BlackboardSystem/src/knowledge_sources/DnnLocationCaffeKS.m @@ -43,21 +43,10 @@ azRes = 5; end nChannels = 32; + commonParams = getCommonAFEParams(); param = genParStruct(... - 'fb_type', 'gammatone', ... - 'fb_lowFreqHz', defaultFreqRange(1), ... - 'fb_highFreqHz', defaultFreqRange(2), ... - 'fb_nChannels', nChannels, ... - 'ihc_method', 'halfwave', ... - 'ild_wSizeSec', 20E-3, ... - 'ild_hSizeSec', 10E-3, ... - 'rm_wSizeSec', 20E-3, ... - 'rm_hSizeSec', 10E-3, ... - 'rm_scaling', 'power', ... - 'rm_decaySec', 8E-3, ... - 'cc_wSizeSec', 20E-3, ... - 'cc_hSizeSec', 10E-3, ... - 'cc_wname', 'hann'); + commonParams{:}, ... + 'fb_nChannels', nChannels); requests{1}.name = 'crosscorrelation'; requests{1}.params = param; requests{2}.name = 'ild'; diff --git a/BlackboardSystem/src/knowledge_sources/DnnLocationKS.m b/BlackboardSystem/src/knowledge_sources/DnnLocationKS.m index 784145f..746e31c 100644 --- a/BlackboardSystem/src/knowledge_sources/DnnLocationKS.m +++ b/BlackboardSystem/src/knowledge_sources/DnnLocationKS.m @@ -40,23 +40,11 @@ azRes = 5; end nChannels = 32; - param = genParStruct(... - 'pp_bBinauralRMS', true, ... - 'fb_type', 'gammatone', ... - 'fb_lowFreqHz', defaultFreqRange(1), ... - 'fb_highFreqHz', defaultFreqRange(2), ... - 'fb_nChannels', nChannels, ... - 'ihc_method', 'halfwave', ... - 'ild_wSizeSec', 20E-3, ... - 'ild_hSizeSec', 10E-3, ... - 'rm_wSizeSec', 20E-3, ... - 'rm_hSizeSec', 10E-3, ... - 'rm_scaling', 'magnitude', ... - 'rm_decaySec', 8E-3, ... - 'rm_wname', 'hann', ... - 'cc_wSizeSec', 20E-3, ... - 'cc_hSizeSec', 10E-3, ... - 'cc_wname', 'hann'); + commonParams = getCommonAFEParams(); + param = genParStruct( ... + commonParams{:}, ... + 'pp_bNormalizeRMS', false, ... + 'fb_nChannels', nChannels); requests{1}.name = 'crosscorrelation'; requests{1}.params = param; requests{2}.name = 'ild'; @@ -217,4 +205,4 @@ function execute(obj) end end -% vim: set sw=4 ts=4 et tw=90 cc=+1: \ No newline at end of file +% vim: set sw=4 ts=4 et tw=90 cc=+1: diff --git a/BlackboardSystem/src/knowledge_sources/EmergencyDetectionKS.m b/BlackboardSystem/src/knowledge_sources/EmergencyDetectionKS.m new file mode 100644 index 0000000..2e373c9 --- /dev/null +++ b/BlackboardSystem/src/knowledge_sources/EmergencyDetectionKS.m @@ -0,0 +1,103 @@ +classdef EmergencyDetectionKS < AbstractKS + % EmergencyDetectionKS Checks for an emergency situation by evaluating + % the output of source identification. + + properties (SetAccess = private) + accumulatedIdProbs = zeros(3, 1); + smoothingFactor + forgettingFactor = 0.9; + emergencyThreshold + emergencyProbability = 0; + isEmergencyDetected = false; + end + + methods + function obj = EmergencyDetectionKS(varargin) + obj = obj@AbstractKS(); + obj.invocationMaxFrequency_Hz = inf; + + defaultSmoothingFactor = 0.25; + defaultEmergencyThreshold = 0.5; + + p = inputParser(); + p.addOptional('SmoothingFactor', defaultSmoothingFactor, ... + @(x) validateattributes(x, {'numeric'}, {'scalar', ... + 'real', '>=', 0, '<=', 1})); + p.addOptional('EmergencyThreshold', ... + defaultEmergencyThreshold, @(x) validateattributes(x, ... + {'numeric'}, {'scalar', 'real', '>', 0, '<', 1})); + p.parse(varargin{:}); + + obj.smoothingFactor = p.Results.SmoothingFactor; + obj.emergencyThreshold = p.Results.EmergencyThreshold; + end + + function setEmergencyThreshold(obj, emergencyThreshold) + obj.emergencyThreshold = emergencyThreshold; + end + + function setSmoothingFactor(obj, smoothingFactor) + obj.smoothingFactor = smoothingFactor; + end + + function [bExecute, bWait] = canExecute(obj) + sndTimeIdx = sort(cell2mat(keys(obj.blackboard.data))); + bExecute = isfield(obj.blackboard.data(sndTimeIdx(end)), ... + 'singleBlockObjectHypotheses'); + bWait = false; + end + + function execute(obj) + singleBlockObjHyp = obj.blackboard.getData( ... + 'singleBlockObjectHypotheses', obj.trigger.tmIdx).data; + + numHyps = length(singleBlockObjHyp); + + for idx = 1 : numHyps + % Get class label and detection probability. + hypLabel = singleBlockObjHyp(idx).label; + detProb = singleBlockObjHyp(idx).p; + + if detProb >= 0.5 + switch hypLabel + case 'fire' + obj.accumulatedIdProbs(1) = ... + obj.smoothingFactor * obj.accumulatedIdProbs(1) + ... + (1 - obj.smoothingFactor) * singleBlockObjHyp(idx).p; + case 'alarm' + obj.accumulatedIdProbs(2) = ... + obj.smoothingFactor * obj.accumulatedIdProbs(2) + ... + (1 - obj.smoothingFactor) * singleBlockObjHyp(idx).p; + case 'femaleScreammaleScream' + obj.accumulatedIdProbs(3) = ... + obj.smoothingFactor * obj.accumulatedIdProbs(3) + ... + (1 - obj.smoothingFactor) * singleBlockObjHyp(idx).p; + end + end + end + + obj.accumulatedIdProbs(1) = ... + obj.forgettingFactor * obj.accumulatedIdProbs(1); + obj.accumulatedIdProbs(2) = ... + obj.forgettingFactor * obj.accumulatedIdProbs(2); + obj.accumulatedIdProbs(3) = ... + obj.forgettingFactor * obj.accumulatedIdProbs(3); + + obj.emergencyProbability = ( ... + obj.accumulatedIdProbs(1) + ... + 5 * obj.accumulatedIdProbs(2) + ... + 4 * obj.accumulatedIdProbs(3)) / 10; + + if obj.emergencyProbability >= obj.emergencyThreshold + obj.isEmergencyDetected = true; + end + end + + function visualise(obj) + if ~isempty(obj.blackboardSystem.emDetVis) + obj.blackboardSystem.emDetVis.draw( ... + obj.emergencyProbability, obj.isEmergencyDetected ); + end + end + end +end diff --git a/BlackboardSystem/src/knowledge_sources/FullBodyRotationKS.m b/BlackboardSystem/src/knowledge_sources/FullBodyRotationKS.m index 03c5940..d34687c 100644 --- a/BlackboardSystem/src/knowledge_sources/FullBodyRotationKS.m +++ b/BlackboardSystem/src/knowledge_sources/FullBodyRotationKS.m @@ -48,8 +48,8 @@ function execute(obj) % Compute circular mean. meanSourceDirection = atan2d( ... - mean(ploc.sourcesPosteriors .* sind(wrapTo180(ploc.sourceAzimuths))), ... - mean(ploc.sourcesPosteriors .* cosd(wrapTo180(ploc.sourceAzimuths)))); + mean(ploc.sourcesPosteriors .* sind(wrapTo180(ploc.azimuths))), ... + mean(ploc.sourcesPosteriors .* cosd(wrapTo180(ploc.azimuths)))); % % Compute entropy of posterior distribution. % distEntropy = sum(ploc.sourcesPosteriors .* ... diff --git a/BlackboardSystem/src/knowledge_sources/HeadRotationKS.m b/BlackboardSystem/src/knowledge_sources/HeadRotationKS.m index b6c68c7..4fb14b6 100644 --- a/BlackboardSystem/src/knowledge_sources/HeadRotationKS.m +++ b/BlackboardSystem/src/knowledge_sources/HeadRotationKS.m @@ -41,9 +41,9 @@ function execute(obj) % Get the most likely source direction ploc = obj.blackboard.getData( ... 'locationHypothesis', obj.trigger.tmIdx).data; - [post,idx] = max(ploc.sourcesPosteriors); + [post,idx] = max(ploc.sourcesDistribution); % confHyp.azimuths are relative to the current head orientation - azSrc = wrapTo180(ploc.sourceAzimuths(idx)); + azSrc = wrapTo180(ploc.azimuths(idx)); % We want to turn the head toward the most likely source % direction, but if not a strong source, make a random rotation diff --git a/BlackboardSystem/src/knowledge_sources/IntegrateFullstreamIdentitiesKS.m b/BlackboardSystem/src/knowledge_sources/IntegrateFullstreamIdentitiesKS.m index d05a516..75744b4 100644 --- a/BlackboardSystem/src/knowledge_sources/IntegrateFullstreamIdentitiesKS.m +++ b/BlackboardSystem/src/knowledge_sources/IntegrateFullstreamIdentitiesKS.m @@ -94,9 +94,13 @@ function execute(obj) maxedObjects = obj.onlyAllowNobjectsPerLocation( objects ); % create objectHypotheses for oo = 1 : numel( maxedObjects.labels ) + newProb = maxedObjects.ps(oo); + if newProb > 1 + newProb = 1; + end objectHyp = IdentityHypothesis( ... maxedObjects.labels{oo}, ... - maxedObjects.ps(oo), ... + newProb, ... maxedObjects.ds(oo), ... idloc(1).concernsBlocksize_s, ... nan ); diff --git a/BlackboardSystem/src/knowledge_sources/IntegrateSegregatedIdentitiesKS.m b/BlackboardSystem/src/knowledge_sources/IntegrateSegregatedIdentitiesKS.m index 9eb5a12..8df4c64 100644 --- a/BlackboardSystem/src/knowledge_sources/IntegrateSegregatedIdentitiesKS.m +++ b/BlackboardSystem/src/knowledge_sources/IntegrateSegregatedIdentitiesKS.m @@ -143,10 +143,14 @@ function execute(obj) for ll = 1 : numel( locMaxedObjects ) locObjs = locMaxedObjects{ll}; for oo = 1 : numel( locObjs.labels ) + newProb = locObjs.ps(oo) / obj.maxnpdf; + if newProb > 1 + newProb = 1; + end objectHyp = SingleBlockObjectHypothesis( ... locObjs.labels{oo}, ... locObjs.loc, ... - locObjs.ps(oo) / obj.maxnpdf, ... + newProb, ... locObjs.ds(oo), ... idloc(1).concernsBlocksize_s ); obj.blackboard.addData( 'singleBlockObjectHypotheses', ... diff --git a/BlackboardSystem/src/knowledge_sources/LocalisationDecisionKS.m b/BlackboardSystem/src/knowledge_sources/LocalisationDecisionKS.m index a3f7e4e..5eafdb8 100644 --- a/BlackboardSystem/src/knowledge_sources/LocalisationDecisionKS.m +++ b/BlackboardSystem/src/knowledge_sources/LocalisationDecisionKS.m @@ -62,14 +62,14 @@ function execute(obj) prevHyp = obj.blackboard.getData( ... 'locationHypothesis', obj.prevTimeIdx).data; headRotation = wrapTo180(aziHyp.headOrientation-prevHyp.headOrientation); - prevPost = prevHyp.sourcesPosteriors; + prevPost = prevHyp.sourcesDistribution; currPost = aziHyp.sourcesDistribution; if sum(currPost > obj.postThreshold) > 0 if headRotation ~= 0 % Only if the new location hypothesis contains strong % directional sources, do the removal [prevPost,currPost] = removeFrontBackConfusion(... - prevHyp.sourceAzimuths, prevPost, ... + prevHyp.azimuths, prevPost, ... currPost, headRotation); % Changed int16 to round here, which seems to cause problem % with circshift in the next line @@ -126,7 +126,7 @@ function execute(obj) if obj.bSolveConfusion % Generates location hypotheses if posterior distribution > threshold % Assume a confusion when more than 1 valid location - if sum(ploc.sourcesPosteriors > obj.postThreshold) > 1 ... + if sum(ploc.sourcesDistribution > obj.postThreshold) > 1 ... || (ploc.relativeAzimuth > 150 && ploc.relativeAzimuth < 210) bRotateHead = true; end @@ -147,7 +147,7 @@ function visualise(obj) ploc = obj.blackboard.getData( ... 'locationHypothesis', obj.trigger.tmIdx).data; obj.blackboardSystem.locVis.setPosteriors(... - ploc.sourceAzimuths+ploc.headOrientation, ploc.sourcesPosteriors); + ploc.azimuths+ploc.headOrientation, ploc.sourcesDistribution); end end end diff --git a/BlackboardSystem/src/knowledge_sources/RobotNavigationKS.m b/BlackboardSystem/src/knowledge_sources/RobotNavigationKS.m new file mode 100644 index 0000000..527dd62 --- /dev/null +++ b/BlackboardSystem/src/knowledge_sources/RobotNavigationKS.m @@ -0,0 +1,100 @@ +classdef RobotNavigationKS < AbstractKS + % RobotNavigationKS + % + + properties (SetAccess = private) + movingScheduled = false; + robot + targetSource = []; + robotPositions = [ + %95, 97.5; % Outside kitchen + 91, 98; % Kitchen + %91, 102; % Bed room + 87, 102]; % Living room + idlocIdx = 0; + end + + methods + function obj = RobotNavigationKS(robot, targetSource) + + obj = obj@AbstractKS(); + obj.invocationMaxFrequency_Hz = inf; + obj.robot = robot; + + if exist('targetSource', 'var') + obj.targetSource = targetSource; + else + obj.targetSource = []; + end + + obj.invocationMaxFrequency_Hz = 1; + end + + function setTargetSource(obj, targetSource) + obj.targetSource = targetSource; + end + + function [bExecute, bWait] = canExecute(obj) + bWait = false; + bExecute = false; + hyp = obj.blackboard.getLastData('singleBlockObjectHypotheses'); + if ~isempty(hyp) + if ~isempty(obj.targetSource) + idloc = hyp.data; + idx = strcmp({idloc(:).label}, obj.targetSource); + if max(idx) > 0 && any(cell2mat({idloc(idx).p}) == 0.7) + bExecute = true; + obj.idlocIdx = argmax(cell2mat({idloc.p})); + end + end + end + end + + function execute(obj) + + + % Robot is not moving + % Let us get it to move + hyp = obj.blackboard.getLastData('singleBlockObjectHypotheses'); + + idloc = hyp.data; + + % Now we have identified the target source. We want to + % move the robot towards the source + + % idloc(idx).loc is source location relative to head + targetLocBase = idloc(obj.idlocIdx).loc + obj.robot.getCurrentHeadOrientation; + [posX, posY, theta] = obj.robot.getCurrentRobotPosition; + nRobotPositions = size(obj.robotPositions,1); + + relativeAngles = zeros(nRobotPositions, 1); + for idx = 1 : nRobotPositions + currentTargetPos = obj.robotPositions(idx, :); + relativeAngles(idx) = atan2d(currentTargetPos(2) - posY, ... + currentTargetPos(1) - posX); + end + + distances = 1 - cosd(relativeAngles - targetLocBase + theta*180/pi); + [~, bestPosIdx] = min(distances); + + % Check if the target position is less than 1 metre away from + % the current position and stay put if yes + distMetres = sqrt((posX-obj.robotPositions(bestPosIdx,1))^2 + (posY-obj.robotPositions(bestPosIdx,2))^2); + if distMetres > 1 + % Need to work out which angle to move to + obj.robot.moveRobot(obj.robotPositions(bestPosIdx,1), obj.robotPositions(bestPosIdx,2), theta, 'absolute'); + end + %obj.movingScheduled = true; + + notify(obj, 'KsFiredEvent', BlackboardEventData( obj.trigger.tmIdx )); + end + + % Visualisation + function visualise(obj) + + end + + end +end + +% vim: set sw=4 ts=4 et tw=90 cc=+1: diff --git a/BlackboardSystem/src/knowledge_sources/StreamSegregationKS.m b/BlackboardSystem/src/knowledge_sources/StreamSegregationKS.m index 13eed10..4cdffc8 100644 --- a/BlackboardSystem/src/knowledge_sources/StreamSegregationKS.m +++ b/BlackboardSystem/src/knowledge_sources/StreamSegregationKS.m @@ -109,7 +109,8 @@ function execute(obj) end else locHypos = obj.blackboard.getLastData( 'locationHypothesis' ); - if isempty( locHypos ) + isLocHyp = ~isempty( locHypos ); + if ~isLocHyp locHypos = obj.blackboard.getLastData( 'sourcesAzimuthsDistributionHypotheses' ); end assert( numel( locHypos.data ) == 1 ); @@ -123,11 +124,8 @@ function execute(obj) % segregating into 0 streams seems pointless end refAzm = zeros( 1, numAzimuths ); - if isfield( locData, 'sourcesPosteriors' ) - posteriors = locData.sourcesPosteriors; - else - posteriors = locData.sourcesDistribution; - end + + posteriors = locData.sourcesDistribution; [locPeaks, locPeaksIdxs] = findpeaks( ... [posteriors(end) ... posteriors(:)' ... @@ -139,13 +137,8 @@ function execute(obj) [~, locPeaksSortedAzmIdxs] = sort( locPeaks, 'descend' ); locSortedAzmIdxs = locPeaksIdxs(locPeaksSortedAzmIdxs); for azimuthIdx = 1 : numAzimuths - if isfield( locData, 'sourceAzimuths' ) - refAzm(azimuthIdx) = wrapTo180( ... - locData.sourceAzimuths(locSortedAzmIdxs(azimuthIdx)) ); - else - refAzm(azimuthIdx) = wrapTo180( ... - locData.azimuths(locSortedAzmIdxs(azimuthIdx)) ); - end + refAzm(azimuthIdx) = wrapTo180( ... + locData.azimuths(locSortedAzmIdxs(azimuthIdx)) ); end end likelihoods = zeros( size(itds, 1), size(itds, 2), numAzimuths ); @@ -177,4 +170,4 @@ function execute(obj) BlackboardEventData( obj.trigger.tmIdx ) ); end end -end \ No newline at end of file +end diff --git a/BlackboardSystem/src/tools/VisualiserIdentityLocalisation.m b/BlackboardSystem/src/tools/VisualiserIdentityLocalisation.m index 0dda29a..c664c82 100644 --- a/BlackboardSystem/src/tools/VisualiserIdentityLocalisation.m +++ b/BlackboardSystem/src/tools/VisualiserIdentityLocalisation.m @@ -26,7 +26,7 @@ properties (SetAccess = private) ksColourMap = containers.Map; % identity colour map idRadiusMap = containers.Map; % identity radius map - radiusList = -200:40:200; + radiusList = -250:35:200; colourList = [0.4660 0.6740 0.1880 0.8500 0.3250 0.0980 0.0000 0.4470 0.7410 @@ -42,11 +42,11 @@ NumPosteriors = 72 drawHandle HeadHandle + TextHandle MarkerHandle + MarkerTextHandle MarkerHandles BarHandle - TextHandle - TextHandle2 TextHandles idTextHandles Hue = 50 % hue in HSV space, default is orangey yellow @@ -72,6 +72,8 @@ end function reset(obj) + obj.radiusIndex = 1; + obj.colourIndex = 1; axes(obj.drawHandle); obj.ksColourMap = containers.Map; @@ -84,16 +86,16 @@ function reset(obj) % add a grid if required if (obj.SHOW_GRID) - for i=1:obj.NUM_GRID_LINES - angle_degrees = wrapTo180((i-1)*360/obj.NUM_GRID_LINES); + for ii=1:obj.NUM_GRID_LINES + angle_degrees = wrapTo180((ii-1)*360/obj.NUM_GRID_LINES); angle_rad = -2*pi*angle_degrees/360; sn = sin(angle_rad); cs = cos(angle_rad); plot([obj.INNER_RADIUS*sn 520*sn],[obj.INNER_RADIUS*cs 520*cs],'Color',c); text(560*sn,560*cs,num2str(angle_degrees),'HorizontalAlignment','Center','Color',[0.7 0.7 0.7]); end % circles - for i=0:4 - r=obj.INNER_RADIUS+i*(500-obj.INNER_RADIUS)/4; + for ii=0:4 + r=obj.INNER_RADIUS+ii*(500-obj.INNER_RADIUS)/4; plot(r*sin(linspace(0,2*pi,50)),r*cos(linspace(0,2*pi,50)),'Color',c); end end @@ -111,49 +113,50 @@ function reset(obj) axis off; box on; + % add probability bars + obj.Posteriors = zeros(1,obj.NumPosteriors); + obj.Angles = 0:(360/obj.NumPosteriors):359; + obj.HeadRotationDegrees = 0; + [x,y] = deal(zeros(1,4)); + for ii=1:obj.NumPosteriors + angle_rad1 = -2*pi*(ii-1.5)/obj.NumPosteriors; + angle_rad2 = -2*pi*(ii-0.5)/obj.NumPosteriors; + + sn = sin(angle_rad1); cs = cos(angle_rad1); + x(1) = obj.INNER_RADIUS*sn; + y(1) = obj.INNER_RADIUS*cs; + x(2) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(ii))*sn; + y(2) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(ii))*cs; + sn = sin(angle_rad2); cs = cos(angle_rad2); + x(3) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(ii))*sn; + y(3) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(ii))*cs; + x(4) = obj.INNER_RADIUS*sn; + y(4) = obj.INNER_RADIUS*cs; + % obj.BarHandle(i) = plot([x1 x2],[y1 y2],'Color',[1.0 0.6471 0],'LineWidth',obj.LINE_WIDTH); + obj.BarHandle(ii) = patch('XData',x,'YData',y,'LineStyle','none'); + end + % add markers y2 = obj.MARKER_RADIUS; y1 = obj.INNER_RADIUS; col = [1 1 1]; - obj.MarkerHandle(1) = plot([y1 y2], [y1 y2], 'Color', col, 'LineStyle', '--'); - obj.MarkerHandle(2) = fill(15*sin(-linspace(0,2*pi,30)),y2+15*cos(-linspace(0,2*pi,30)),col,'linestyle','none'); - + + for ii = 1:4 + obj.MarkerHandle(ii) = fill(15*sin(-linspace(0,2*pi,30)),y2+15*cos(-linspace(0,2*pi,30)),col,'EdgeColor',col); + obj.MarkerTextHandle(ii) = text(y1,y2, '', 'Color', col, 'FontSize', 12); + end obj.TextHandle = text(y1,y2, '', 'Color', col, 'FontSize', 12); - obj.TextHandle2 = text(y1,y2, '', 'Color', col, 'FontSize', 12); for ii=1:55 obj.MarkerHandles(ii) = fill(15*sin(-linspace(0,2*pi,30)), ... y2+15*cos(-linspace(0,2*pi,30)), ... col,'linestyle','none'); - obj.TextHandles(ii) = text(y1,y2, '', 'Color', col, 'FontSize', 14); + obj.TextHandles(ii) = text(y1,y2, '', 'Color', col, 'FontSize', 11); end for ii=1:13 - obj.idTextHandles(ii) = text(y1,y2, '', 'Color', col, 'FontSize', 12); - end - - % add probability bars - obj.Posteriors = zeros(1,obj.NumPosteriors); - obj.Angles = 0:(360/obj.NumPosteriors):359; - obj.HeadRotationDegrees = 0; - [x,y] = deal(zeros(1,4)); - for i=1:obj.NumPosteriors - angle_rad1 = -2*pi*(i-1.5)/obj.NumPosteriors; - angle_rad2 = -2*pi*(i-0.5)/obj.NumPosteriors; - - sn = sin(angle_rad1); cs = cos(angle_rad1); - x(1) = obj.INNER_RADIUS*sn; - y(1) = obj.INNER_RADIUS*cs; - x(2) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(i))*sn; - y(2) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(i))*cs; - sn = sin(angle_rad2); cs = cos(angle_rad2); - x(3) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(i))*sn; - y(3) = (obj.INNER_RADIUS+obj.OUTER_RADIUS*obj.Posteriors(i))*cs; - x(4) = obj.INNER_RADIUS*sn; - y(4) = obj.INNER_RADIUS*cs; - % obj.BarHandle(i) = plot([x1 x2],[y1 y2],'Color',[1.0 0.6471 0],'LineWidth',obj.LINE_WIDTH); - obj.BarHandle(i) = patch('XData',x,'YData',y,'LineStyle','none'); + obj.idTextHandles(ii) = text(y1, y2, '', 'Color', col, 'FontSize', 11); end hold off; drawnow; @@ -196,21 +199,25 @@ function reset(obj) end end - function obj = plotMarkerAtAngle(obj,angle,hue) - if nargin < 3 + function obj = plotMarkerAtAngle(obj,idx,angle,str,hue) + if nargin < 5 hue = 100; end sn = sin(-2*pi*angle/360); cs = cos(-2*pi*angle/360); - x2 = obj.MARKER_RADIUS * sn; - y2 = obj.MARKER_RADIUS * cs; - x1 = obj.INNER_RADIUS * sn; - y1 = obj.INNER_RADIUS * cs; - col = hsv2rgb(hue/360,0.9,0.6); - set(obj.MarkerHandle(1), 'Color', col, 'XData', [x1 x2], 'YData', [y1 y2]); - set(obj.MarkerHandle(2), 'FaceColor', col, ... - 'XData', x2+15*sin(-linspace(0,2*pi,30)), ... - 'YData', y2+15*cos(-linspace(0,2*pi,30))); + x2 = (obj.MARKER_RADIUS-20) * sn; + y2 = (obj.MARKER_RADIUS-20) * cs; + col = obj.getIdentityColor(str); + set(obj.MarkerHandle(idx), 'EdgeColor', col, ... + 'XData', x2+15*sin(-linspace(0,2*pi,30)), ... + 'YData', y2+15*cos(-linspace(0,2*pi,30))); + r = (obj.MARKER_RADIUS+20); + str = VisualiserIdentityLocalisation.getShortName(str); + set(obj.MarkerTextHandle(idx), ... + 'Color', col, ... + 'Position', [r*sn, r*cs], ... + 'String', str, ... + 'rotation', angle); end function obj = plotMarkerIdxAtAngle(obj,... @@ -231,7 +238,12 @@ function reset(obj) function obj = plotTextIdxAtAngle(obj, ... idx, label, angle, radiusDelta, color) - angle = angle - 4; + + if angle > 90 && angle <= 270 + angle = angle + 4; + else + angle = angle - 4; + end sn = sin(-2*pi*angle/360); cs = cos(-2*pi*angle/360); %radiusDelta = obj.getIdentityRadius(label); @@ -239,8 +251,11 @@ function reset(obj) radiusInner = obj.INNER_RADIUS + radiusDelta; x2 = radius * sn; y2 = radius * cs; - x1 = radiusInner * sn; - y1 = radiusInner * cs; + + if angle > 90 && angle <= 270 + angle = angle + 180; % let text appear upright + end + set(obj.TextHandles(idx), ... 'Color', color, ... % obj.getIdentityColor(label), ... 'Position', [x2, y2], ... @@ -250,8 +265,8 @@ function reset(obj) function obj = plotIdTextIdxAtAngle(obj, ... idx, label, prob) - x2 = 450; - y2 = 590 + obj.getIdentityRadius(label); + x2 = 480; + y2 = 570 + obj.getIdentityRadius(label); set(obj.idTextHandles(idx), ... 'Color', obj.getIdentityColor(label), ... 'Position', [x2, y2], ... @@ -326,7 +341,7 @@ function draw(obj) color = [0.9 0.9 0.9]; % first clear handles - for ii=1:55 + for ii=1:numel(obj.MarkerHandles) set(obj.MarkerHandles(ii), 'FaceColor', color, ... 'XData', 15*sin(-linspace(0,2*pi,30)), ... 'YData', 15*cos(-linspace(0,2*pi,30))); @@ -339,15 +354,17 @@ function draw(obj) % populate with new info for idx = 1:numel(labels) if ds{idx} == 1 - radius = obj.getIdentityRadius(labels{idx}); - color = obj.getIdentityColor(labels{idx}); + label = VisualiserIdentityLocalisation.getShortName(labels{idx}); + radius = obj.getIdentityRadius(label); + color = obj.getIdentityColor(label); + theta = locs{idx}+obj.HeadRotationDegrees; obj.plotTextIdxAtAngle(idx, ... - sprintf('%s (%.0f%%)', labels{idx}, probs{idx}*100), ... - locs{idx}+obj.HeadRotationDegrees, radius, color); + sprintf('%.0f%% %s', probs{idx}*100, label), ... + theta, radius, color); obj.plotMarkerIdxAtAngle(idx, ... - locs{idx}+obj.HeadRotationDegrees, ... + theta, ... probs{idx}, ... color,... radius); @@ -362,7 +379,7 @@ function draw(obj) y1 = obj.INNER_RADIUS; color = [0.9 0.9 0.9]; - for ii=1:13 + for ii=1:numel(obj.idTextHandles) set(obj.idTextHandles(ii), ... 'Color', color, ... 'Position', [y1, y2], ... @@ -372,7 +389,7 @@ function draw(obj) for idx = 1:numel(labels) % if ds{idx} == 1 obj.plotIdTextIdxAtAngle(idx, ... - labels{idx}, ... + VisualiserIdentityLocalisation.getShortName(labels{idx}), ... probs{idx}); % end end @@ -387,7 +404,7 @@ function draw(obj) x2 = radius * sn; y2 = radius * cs; if ischar(numSrcs) - str = sprintf('Attended to "%s" source', numSrcs); + str = ''; %sprintf('Attended to "%s" source', numSrcs); elseif numSrcs > 1 str = sprintf('%d sources', numSrcs); else @@ -400,4 +417,18 @@ function draw(obj) 'String', str); end end + + methods (Static) + function newName = getShortName(label) + if strcmp(label, 'maleSpeech') + newName = 'male'; + elseif strcmp(label, 'femaleSpeech') + newName = 'female'; + elseif strcmp(label, 'femaleScreammaleScream') + newName = 'scream'; + else + newName = label; + end + end + end end diff --git a/BlackboardSystem/src/tools/VisualiserLocalisation.m b/BlackboardSystem/src/tools/VisualiserLocalisation.m index acb16a9..37625cb 100644 --- a/BlackboardSystem/src/tools/VisualiserLocalisation.m +++ b/BlackboardSystem/src/tools/VisualiserLocalisation.m @@ -126,8 +126,8 @@ function reset(obj) end end - function obj = plotMarkerAtAngle(obj,angle,hue) - if nargin < 3 + function obj = plotMarkerAtAngle(obj,angle,str,hue) + if nargin < 4 hue = 100; end sn = sin(-2*pi*angle/360); @@ -141,9 +141,7 @@ function reset(obj) set(obj.MarkerHandle(2), 'FaceColor', col, ... 'XData', x2+15*sin(-linspace(0,2*pi,30)), ... 'YData', y2+15*cos(-linspace(0,2*pi,30))); - %plot([x1 x2], [y1 y2], 'Color', col, 'LineStyle', '--'); - %fill(x2+15*sin(-linspace(0,2*pi,30)),y2+15*cos(-linspace(0,2*pi,30)),col,'linestyle','none'); - %text(x,y,str,'HorizontalAlignment','Center','VerticalAlignment','Middle','fontsize',18,'Color',[1 1 1]); + text(x,y,str,'HorizontalAlignment','Center','VerticalAlignment','Middle','fontsize',18,'Color',[1 1 1]); end function obj = setHeadRotation(obj,val) diff --git a/BlackboardSystem/src/tools/getCommonAFEParams.m b/BlackboardSystem/src/tools/getCommonAFEParams.m new file mode 100644 index 0000000..3fa2c53 --- /dev/null +++ b/BlackboardSystem/src/tools/getCommonAFEParams.m @@ -0,0 +1,21 @@ +function commonParams = getCommonAFEParams( ) + commonParams = {... + 'pp_bNormalizeRMS', true, ... % default is 0 + 'pp_intTimeSecRMS', 4, ... % default is 500E-3 + 'pp_bBinauralRMS', true, ... % default is true + 'fb_type', 'gammatone', ... + 'fb_lowFreqHz', 80, ... + 'fb_highFreqHz', 8000, ... + 'ihc_method', 'halfwave', ... % stream segr. uses 'dau' + 'ild_wSizeSec', 20E-3, ... % DnnLoc uses 20E-3, stream segr. uses 25E-3 + 'ild_hSizeSec', 10E-3, ... + 'rm_wSizeSec', 20E-3, ... % DnnLoc uses 20E-3, identification uses 25E-3 + 'rm_hSizeSec', 10E-3, ... % DO NOT CHANGE -- important for gabor filters + 'rm_scaling', 'power', ... % DnnLoc uses power, identification uses magnitude + 'rm_decaySec', 8E-3, ... + 'cc_wSizeSec', 20E-3, ... % dnnLocKs uses 20E-3, stream segr. uses 25E-3 + 'cc_hSizeSec', 10E-3, ... % dnnLocKs uses 10E-3, stream segr. uses 10E-2 + 'cc_wname', 'hann' ..., + 'cc_maxDelaySec', 1.1E-3,... % default is 1.1E-3, stream segregation will use 1.1E-3 as well + }; +end diff --git a/RoboticPlatform/bass-genom3/bassStruct.idl b/RoboticPlatform/bass-genom3/bassStruct.idl index 07329bc..480daf0 100644 --- a/RoboticPlatform/bass-genom3/bassStruct.idl +++ b/RoboticPlatform/bass-genom3/bassStruct.idl @@ -32,6 +32,12 @@ #define BASSSTRUCT_IDL module binaudio { + + struct timestamp{ + unsigned long sec; + unsigned long nsec; + }; + struct portStruct { unsigned long sampleRate; unsigned long nChunksOnPort; @@ -39,6 +45,7 @@ module binaudio { unsigned long long lastFrameIndex; sequence left; sequence right; + timestamp stamp; }; }; diff --git a/RoboticPlatform/bass-genom3/codels/Ports.c b/RoboticPlatform/bass-genom3/codels/Ports.c index 0d08f75..cdb2ee3 100644 --- a/RoboticPlatform/bass-genom3/codels/Ports.c +++ b/RoboticPlatform/bass-genom3/codels/Ports.c @@ -29,6 +29,7 @@ */ #include +#include #include "Ports.h" #include "bass_c_types.h" @@ -73,6 +74,7 @@ int publishPort(const bass_Audio *Audio, bass_captureStruct *cap, uint32_t fop; /* total amount of Frames On the Port */ uint32_t bps; /* amout of Bytes Per Sample */ int pos, ii; + struct timeval timeNow; data = Audio->data(self); fpc = data->nFramesPerChunk; @@ -93,6 +95,13 @@ int publishPort(const bass_Audio *Audio, bass_captureStruct *cap, } data->lastFrameIndex += fpc; + + // http://docs.ros.org/indigo/api/rostime/html/src_2time_8cpp_source.html + // Lines 108 to 111. + gettimeofday(&timeNow, NULL); + data->stamp.sec = timeNow.tv_sec; + data->stamp.nsec = timeNow.tv_usec*1000; + Audio->write(self); return 0; } diff --git a/Tools/misc/compressAndScale.m b/Tools/misc/compressAndScale.m index c0a06fa..18fc062 100644 --- a/Tools/misc/compressAndScale.m +++ b/Tools/misc/compressAndScale.m @@ -27,14 +27,20 @@ compressor = 1; end if nargin < 3 - scalor = @(x)(0.5); + scalor = 0.5; end if nargin < 4 dim = 0; end if dim == 0 d = sign(d) .* abs(d).^compressor; - dScalor = scalor( d(:) ); + if isnumeric( scalor ) + dScalor = scalor; + elseif isa( scalor, 'function_handle' ) + dScalor = scalor( d(:) ); + else + error( 'AMLTTP:unsupportedUsage', 'scalor has to be a number or a function handle to a function that produces a number.' ); + end if isnan( dScalor ), scale = 1; else scale = 0.5 / dScalor; end; d = d .* repmat( scale, size( d ) ); diff --git a/Tools/misc/isequalDeepCompare.m b/Tools/misc/isequalDeepCompare.m index aef4244..2ef7f61 100644 --- a/Tools/misc/isequalDeepCompare.m +++ b/Tools/misc/isequalDeepCompare.m @@ -1,6 +1,15 @@ function eq = isequalDeepCompare( a, b ) -if ~isequal( class( a ), class( b ) ), eq = false; return; end +if isnumeric( a ) && islogical( b ) + b = feval( class( a ), b ); +elseif isnumeric( b ) && islogical( a ) + a = feval( class( b ), a ); +end + +if ~isequal( class( a ), class( b ) ) + eq = false; + return; +end if isa( a, 'struct' ) na = numel( a ); @@ -8,7 +17,10 @@ if na ~= nb, eq = false; return; end sortedFieldnamesA = sort( fieldnames( a ) ); sortedFieldnamesB = sort( fieldnames( b ) ); - if ~isequal( sortedFieldnamesA, sortedFieldnamesB ), eq = false; return; end + if ~isequal( sortedFieldnamesA, sortedFieldnamesB ) + eq = false; + return; + end for nn = 1 : na for ff = 1 : length( sortedFieldnamesA ) if ~isequalDeepCompare( a(nn).(sortedFieldnamesA{ff}), ... @@ -22,7 +34,10 @@ nb = numel( b ); if na ~= nb, eq = false; return; end for nn = 1 : na - if ~isequalDeepCompare( a{nn}, b{nn} ), eq = false; return; end + if ~isequalDeepCompare( a{nn}, b{nn} ) + eq = false; + return; + end end % elseif isobject( a ) % only compares public (readable) properties % na = numel( a ); @@ -40,7 +55,10 @@ % end % end else - if ~isequal( a, b ), eq = false; return; end + if ~isequal( a, b ) + eq = false; + return; + end end eq = true; diff --git a/Tools/misc/lMoments.m b/Tools/misc/lMoments.m index 1b9dac7..d9e66c4 100644 --- a/Tools/misc/lMoments.m +++ b/Tools/misc/lMoments.m @@ -11,13 +11,17 @@ selectNl = nL; nL = max( nL ); -[rows cols] = size(X); -if cols == 1 X = X'; end -n = length(X); +[nrows,ncols] = size(X); +if ncols == 1 + X = X'; + n = nrows; +else + n = ncols; +end X = sort(X); b = zeros(1,nL-1); l = zeros(1,nL-1); -b0 = mean(X); +b0 = mean(X,2); for r = 1:nL-1 Num = prod(repmat(r+1:n,r,1)-repmat([1:r]',1,n-r),1); @@ -26,10 +30,11 @@ end tB = [b0 b]'; -B = tB(length(tB):-1:1); +lenB = size(tB,1); +B = tB(lenB:-1:1); for i = 1:nL-1 - Spc = zeros(length(B)-(i+1),1); + Spc = zeros(lenB-(i+1),1); Coeff = [Spc ; legendreShiftPoly(i)]; l(i) = sum((Coeff.*B),1); end diff --git a/examples/first_steps/setting_up_an_acoustic_scene/binaural_renderer.xml b/examples/first_steps/setting_up_an_acoustic_scene/binaural_renderer.xml deleted file mode 100644 index 71b01a6..0000000 --- a/examples/first_steps/setting_up_an_acoustic_scene/binaural_renderer.xml +++ /dev/null @@ -1,18 +0,0 @@ - - - - - - - diff --git a/examples/first_steps/setting_up_an_acoustic_scene/brs_renderer.xml b/examples/first_steps/setting_up_an_acoustic_scene/brs_renderer.xml deleted file mode 100644 index 804728f..0000000 --- a/examples/first_steps/setting_up_an_acoustic_scene/brs_renderer.xml +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - diff --git a/examples/first_steps/setting_up_an_auditory_model/binaural_renderer.xml b/examples/first_steps/setting_up_an_auditory_model/binaural_renderer.xml deleted file mode 100644 index 269a50a..0000000 --- a/examples/first_steps/setting_up_an_auditory_model/binaural_renderer.xml +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - diff --git a/examples/first_steps/setting_up_an_auditory_model/blackboard.xml b/examples/first_steps/setting_up_an_auditory_model/blackboard.xml deleted file mode 100644 index b832887..0000000 --- a/examples/first_steps/setting_up_an_auditory_model/blackboard.xml +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - 1 - - - robotConnect - - - - scheduler - dataConnect - - - dataConnect - loc - - - loc - dec - - - dec - rot - - -