From 84058735cc741a5e81dfe3e162994d25730ca3e0 Mon Sep 17 00:00:00 2001 From: Pablo Zinemanas Date: Tue, 11 May 2021 11:29:14 +0200 Subject: [PATCH] Fix tensorflow2 compatibility in visualization tools #26 --- visualization/app.py | 8 +- visualization/callbacks.py | 208 +------------------------------------ visualization/index.py | 6 +- visualization/layout.py | 31 ------ 4 files changed, 13 insertions(+), 240 deletions(-) diff --git a/visualization/app.py b/visualization/app.py index 836b2ff..93ea46a 100644 --- a/visualization/app.py +++ b/visualization/app.py @@ -2,5 +2,9 @@ import dash_bootstrap_components as dbc # Define app -app = dash.Dash(__name__, external_stylesheets=[dbc.themes.SPACELAB], - suppress_callback_exceptions=True) +app = dash.Dash( + __name__, + external_stylesheets=[dbc.themes.SPACELAB], + suppress_callback_exceptions=False +) + diff --git a/visualization/callbacks.py b/visualization/callbacks.py index a1d8907..c63bdb2 100644 --- a/visualization/callbacks.py +++ b/visualization/callbacks.py @@ -25,7 +25,7 @@ import numpy as np import ast import soundfile as sf -from tensorflow import get_default_graph +from tensorflow.compat.v1 import get_default_graph from sklearn.decomposition import PCA import base64 @@ -97,7 +97,7 @@ def click_on_plot2d(clickData, x_select, y_select): ) # Input('output_select', 'value')], def update_plot2D(samples_per_class, x_select, y_select, - active_tab, fold_ix, model_path, dataset_ix, sr): + active_tab, fold_ix, model_path, dataset_ix, sr): global X global X_pca global Y @@ -432,10 +432,10 @@ def select_model(model_ix): model_name = options_models[model_ix]['label'] model_class = get_available_models()[model_name] default_arguments = get_default_args_of_function(model_class.__init__) - delete = ['model', 'model_path', 'n_classes', - 'n_frames_cnn', 'n_freq_cnn'] + delete = ['model', 'model_path', 'n_classes', 'n_frames_cnn', 'n_freq_cnn', 'n_frames', 'n_freq', 'n_freqs'] for key in delete: - default_arguments.pop(key) + if key in default_arguments: + default_arguments.pop(key) if model_name in params['models']: params_model = params['models'][model_name]['model_arguments'] for key in params_model.keys(): @@ -766,7 +766,6 @@ def start_training(status, fold_ix, normalizer, model_path, else: raise dash.exceptions.PreventUpdate - # return [False, "", 'success', ""] @app.callback( @@ -781,14 +780,6 @@ def manage_button_train(n_intervals, status): button_train = "Train model" return button_train - -# @app.callback( -# [Output("plot2D_eval", "figure"), -# Output("accuracy", "children")], -# [Input("tabs", "active_tab")], -# [State('fold_name', 'value'), -# State('model_path', 'value')] -# ) @app.callback( [Output("results", "children"), Output("figure_metrics", "figure")], @@ -837,86 +828,10 @@ def evaluate_model(n_clicks, fold_ix, model_path): msg = "Accuracy in fold %s is %1.2f" % (fold_name, accuracy) return [msg, figure_metrics] - # fold_name = dataset.fold_list[fold_ix] - # exp_folder_fold = conv_path(os.path.join(model_path, fold_name)) - # scaler_path = os.path.join(exp_folder_fold, 'scaler.pickle') - # scaler = load_pickle(scaler_path) - - # data_generator_test = DataGenerator( - # dataset, feature_extractor, folds=['fold_name'], - # batch_size=params['train']['batch_size'], - # shuffle=True, train=False, scaler=scaler - # ) - # X, Y = data_generator_train.get_data() - - # with graph.as_default(): - # model_container.load_model_weights(exp_folder_fold) - # model_embeddings = model_container.cut_network(-2) - # X_emb = model_embeddings.predict(X) - - # pca = PCA(n_components=4) - # pca.fit(X_emb) - - # X_test, Y_test = data_generator_test.get_data_for_testing(fold_name) - # file_names_test = data_generator_test.features_file_list - # X_test = scaler.transform(X_test) - # with graph.as_default(): - # results = model_container.evaluate((X_test, Y_test)) - - # # TODO: Convert to one sequence per file. - - # with graph.as_default(): - # X_emb = model_embeddings.predict(X_test) - - # predictions = np.zeros_like(Y_test) - # for j in range(len(predictions)): - # predictions[j] = np.mean(results['predictions'][j], axis=0) - # Y_test[j] = results['annotations'][j][0] - - # X_pca_test = pca.transform(X_emb) - - # figure2d = generate_figure2D_eval( - # X_pca_test, predictions, Y_test, dataset.label_list - # ) - # return [ - # figure2d, - # "Accuracy in fold %s is %f" % (fold_name, results['accuracy'])] - return ['Pa'] # raise dash.exceptions.PreventUpdate -@app.callback( - [Output('plot_mel_eval', 'figure'), - Output('audio-player-eval', 'overrideProps'), - Output('predicted', 'children')], - [Input('plot2D_eval', 'selectedData')]) -def click_on_plot2d_eval(clickData): - if clickData is None: - figure_mel = generate_figure_mel(X_test[0]) - return [figure_mel, {'autoPlay': False, 'src': ''}, ""] - else: - point = np.array([clickData['points'][0]['x'], - clickData['points'][0]['y']]) - distances_to_data = np.sum( - np.power(X_pca_test[:, [0, 1]] - point, 2), axis=-1) - min_distance_index = np.argmin(distances_to_data) - audio_file = data_generator_train.convert_features_path_to_audio_path( - file_names_test[min_distance_index]) - audio_data, sr = sf.read(audio_file) - figure_mel = generate_figure_mel(X_test[min_distance_index]) - - class_ix = np.argmax(Y_test[min_distance_index]) - pred_ix = np.argmax(predictions[min_distance_index]) - predicted_text = "%s predicted as %s" % (dataset.label_list[class_ix], - dataset.label_list[pred_ix]) - return [ - figure_mel, - {'autoPlay': True, 'src': encode_audio(audio_data, sr)}, - predicted_text - ] - - @app.callback( [Output('plot_features', 'figure'), Output('audio-player-demo', 'overrideProps'), @@ -1008,116 +923,3 @@ def generate_demo(n_clicks, list_of_contents, fold_ix, figure_features, {'autoPlay': False, 'src': ""}, "" ] - -# @app.callback( -# [Output('plot_features', 'figure'), -# Output('audio-player-demo', 'overrideProps'), -# Output('demo_file_label', 'children')], -# [#Input("tabs", "active_tab"), -# Input("btn_run_demo", "n_clicks"), -# Input('upload-data', 'contents')], -# [State('fold_name', 'value'), -# State('model_path', 'value'), -# State('plot2D_eval', 'selectedData'), -# State('upload-data', 'filename'), -# State('upload-data', 'last_modified'), -# State('sr', 'value')]) -# def generate_demo(n_clicks, list_of_contents, fold_ix, #active_tab, -# model_path, selectedData, -# list_of_names, list_of_dates, sr): -# ctx = dash.callback_context -# button_id = ctx.triggered[0]['prop_id'].split('.')[0] -# print('demo', button_id) -# if button_id == 'btn_run_demo-data': -# n_files = len(data_generator_test.features_file_list) - -# ix = np.random.randint(n_files) - -# fold_name = dataset.fold_list[fold_ix] -# exp_folder_fold = conv_path(os.path.join(model_path, fold_name)) - -# X_features, Y_file = data_generator_test.get_data_of_file(ix) - -# with graph.as_default(): -# model_container.load_model_weights(exp_folder_fold) -# Y_features = model_container.model.predict(X_features) - -# fig_demo = generate_figure_features( -# X_features, Y_features, dataset.label_list) - -# features_file = data_generator_test.features_file_list[ix] -# audio_file = data_generator_test.convert_features_path_to_audio_path( -# features_file, sr=sr) -# audio_data, sr = sf.read(audio_file) - -# class_ix = np.argmax(Y_file[0]) -# file_label = dataset.label_list[class_ix] -# return [ -# fig_demo, -# {'autoPlay': False, 'src': encode_audio(audio_data, sr)}, -# file_label -# ] - -# if button_id == 'upload-data': -# fold_name = dataset.fold_list[fold_ix] -# exp_folder_fold = conv_path(os.path.join(model_path, fold_name)) -# scaler_path = os.path.join(exp_folder_fold, 'scaler.pickle') -# scaler = load_pickle(scaler_path) - -# filename = conv_path('upload.wav') -# data = list_of_contents.encode("utf8").split(b";base64,")[1] -# with open(filename, "wb") as fp: -# fp.write(base64.decodebytes(data)) - -# X_feat = feature_extractor.calculate(filename) -# X_feat = scaler.transform(X_feat) -# with graph.as_default(): -# Y_t = model_container.model.predict(X_feat) - -# label_list = dataset.label_list -# figure_features = generate_figure_features(X_feat, Y_t, label_list) -# return [ -# figure_features, -# {'autoPlay': False, 'src': list_of_contents}, "" -# ] - -# if active_tab == 'tab_demo': -# fold_name = dataset.fold_list[fold_ix] -# if (button_id == 'tabs') and (selectedData is not None): -# point = np.array([selectedData['points'][0]['x'], -# selectedData['points'][0]['y']]) -# distances_to_data = np.sum( -# np.power(X_pca_test[:, [0, 1]] - point, 2), axis=-1) -# ix = np.argmin(distances_to_data) -# else: -# ix = np.random.randint(len(data_generator.data[fold_name]['X'])) - -# exp_folder_fold = conv_path(os.path.join(model_path, fold_name)) -# scaler_path = os.path.join(exp_folder_fold, 'scaler.pickle') -# scaler = load_pickle(scaler_path) - -# X_features = data_generator.data[fold_name]['X'][ix] -# X_features = scaler.transform(X_features) - -# with graph.as_default(): -# model_container.load_model_weights(exp_folder_fold) -# Y_features = model_container.model.predict(X_features) - -# fig_demo = generate_figure_features( -# X_features, Y_features, dataset.label_list) - -# features_file = dataset.file_lists[fold_name][ix] -# audio_file = data_generator.convert_features_path_to_audio_path( -# features_file) -# audio_data, sr = sf.read(audio_file) - -# Y_file = data_generator.data[fold_name]['Y'][ix][0] -# class_ix = np.argmax(Y_file) -# file_label = dataset.label_list[class_ix] -# return [ -# fig_demo, -# {'autoPlay': False, 'src': encode_audio(audio_data, sr)}, -# file_label -# ] - -# raise dash.exceptions.PreventUpdate diff --git a/visualization/index.py b/visualization/index.py index 9f0a056..26131e1 100644 --- a/visualization/index.py +++ b/visualization/index.py @@ -1,13 +1,11 @@ from .app import app from .layout import layout -from . import callbacks - -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Define layout app.layout = layout +from . import callbacks + # Run the app if (__name__ == '__main__'): app.run_server(debug=False) diff --git a/visualization/layout.py b/visualization/layout.py index a61cf29..d5ac20b 100644 --- a/visualization/layout.py +++ b/visualization/layout.py @@ -471,37 +471,6 @@ ), ]) -# Plot 2D graph -# plot2D_eval = dcc.Graph(id='plot2D_eval', figure=figure2D, -# style={"height": "100%", "width": "100%"}) - -# # Plot mel-spectrogram -# plot_mel_eval = dcc.Graph( -# id="plot_mel_eval", -# figure=figure_mel, -# style={"width": "90%", "display": "inline-block", 'float': 'left'} -# ) - -# # Audio controls -# audio_player_eval = dash_audio_components.DashAudioComponents( -# id='audio-player-eval', src="", autoPlay=False, controls=True -# ) - -# Define Tab Evaluation (4) -# tab_evaluation = html.Div([ -# dbc.Row( -# [ -# dbc.Col(html.Div([plot2D_eval]), width=8), -# dbc.Col([ -# dbc.Row([plot_mel_eval], align='center'), -# dbc.Row([audio_player_eval], align='center'), -# dbc.Row([html.Div("", id="accuracy")], align='center'), -# dbc.Row([html.Div("", id="predicted")], align='center') -# ], width=4), -# ] -# ), -# ]) - figure_metrics = generate_figure_metrics([], []) plot_metrics = dcc.Graph( id="figure_metrics",