Skip to content

Commit

Permalink
Fix tensorflow2 compatibility in visualization tools #26
Browse files Browse the repository at this point in the history
  • Loading branch information
pzinemanas committed May 11, 2021
1 parent 2bbc781 commit 8405873
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 240 deletions.
8 changes: 6 additions & 2 deletions visualization/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
import dash_bootstrap_components as dbc

# Define app
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.SPACELAB],
suppress_callback_exceptions=True)
app = dash.Dash(
__name__,
external_stylesheets=[dbc.themes.SPACELAB],
suppress_callback_exceptions=False
)

208 changes: 5 additions & 203 deletions visualization/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np
import ast
import soundfile as sf
from tensorflow import get_default_graph
from tensorflow.compat.v1 import get_default_graph
from sklearn.decomposition import PCA
import base64

Expand Down Expand Up @@ -97,7 +97,7 @@ def click_on_plot2d(clickData, x_select, y_select):
)
# Input('output_select', 'value')],
def update_plot2D(samples_per_class, x_select, y_select,
active_tab, fold_ix, model_path, dataset_ix, sr):
active_tab, fold_ix, model_path, dataset_ix, sr):
global X
global X_pca
global Y
Expand Down Expand Up @@ -432,10 +432,10 @@ def select_model(model_ix):
model_name = options_models[model_ix]['label']
model_class = get_available_models()[model_name]
default_arguments = get_default_args_of_function(model_class.__init__)
delete = ['model', 'model_path', 'n_classes',
'n_frames_cnn', 'n_freq_cnn']
delete = ['model', 'model_path', 'n_classes', 'n_frames_cnn', 'n_freq_cnn', 'n_frames', 'n_freq', 'n_freqs']
for key in delete:
default_arguments.pop(key)
if key in default_arguments:
default_arguments.pop(key)
if model_name in params['models']:
params_model = params['models'][model_name]['model_arguments']
for key in params_model.keys():
Expand Down Expand Up @@ -766,7 +766,6 @@ def start_training(status, fold_ix, normalizer, model_path,

else:
raise dash.exceptions.PreventUpdate
# return [False, "", 'success', ""]


@app.callback(
Expand All @@ -781,14 +780,6 @@ def manage_button_train(n_intervals, status):
button_train = "Train model"
return button_train


# @app.callback(
# [Output("plot2D_eval", "figure"),
# Output("accuracy", "children")],
# [Input("tabs", "active_tab")],
# [State('fold_name', 'value'),
# State('model_path', 'value')]
# )
@app.callback(
[Output("results", "children"),
Output("figure_metrics", "figure")],
Expand Down Expand Up @@ -837,86 +828,10 @@ def evaluate_model(n_clicks, fold_ix, model_path):
msg = "Accuracy in fold %s is %1.2f" % (fold_name, accuracy)
return [msg, figure_metrics]

# fold_name = dataset.fold_list[fold_ix]
# exp_folder_fold = conv_path(os.path.join(model_path, fold_name))
# scaler_path = os.path.join(exp_folder_fold, 'scaler.pickle')
# scaler = load_pickle(scaler_path)

# data_generator_test = DataGenerator(
# dataset, feature_extractor, folds=['fold_name'],
# batch_size=params['train']['batch_size'],
# shuffle=True, train=False, scaler=scaler
# )
# X, Y = data_generator_train.get_data()

# with graph.as_default():
# model_container.load_model_weights(exp_folder_fold)
# model_embeddings = model_container.cut_network(-2)
# X_emb = model_embeddings.predict(X)

# pca = PCA(n_components=4)
# pca.fit(X_emb)

# X_test, Y_test = data_generator_test.get_data_for_testing(fold_name)
# file_names_test = data_generator_test.features_file_list
# X_test = scaler.transform(X_test)
# with graph.as_default():
# results = model_container.evaluate((X_test, Y_test))

# # TODO: Convert to one sequence per file.

# with graph.as_default():
# X_emb = model_embeddings.predict(X_test)

# predictions = np.zeros_like(Y_test)
# for j in range(len(predictions)):
# predictions[j] = np.mean(results['predictions'][j], axis=0)
# Y_test[j] = results['annotations'][j][0]

# X_pca_test = pca.transform(X_emb)

# figure2d = generate_figure2D_eval(
# X_pca_test, predictions, Y_test, dataset.label_list
# )
# return [
# figure2d,
# "Accuracy in fold %s is %f" % (fold_name, results['accuracy'])]

return ['Pa']
# raise dash.exceptions.PreventUpdate


@app.callback(
[Output('plot_mel_eval', 'figure'),
Output('audio-player-eval', 'overrideProps'),
Output('predicted', 'children')],
[Input('plot2D_eval', 'selectedData')])
def click_on_plot2d_eval(clickData):
if clickData is None:
figure_mel = generate_figure_mel(X_test[0])
return [figure_mel, {'autoPlay': False, 'src': ''}, ""]
else:
point = np.array([clickData['points'][0]['x'],
clickData['points'][0]['y']])
distances_to_data = np.sum(
np.power(X_pca_test[:, [0, 1]] - point, 2), axis=-1)
min_distance_index = np.argmin(distances_to_data)
audio_file = data_generator_train.convert_features_path_to_audio_path(
file_names_test[min_distance_index])
audio_data, sr = sf.read(audio_file)
figure_mel = generate_figure_mel(X_test[min_distance_index])

class_ix = np.argmax(Y_test[min_distance_index])
pred_ix = np.argmax(predictions[min_distance_index])
predicted_text = "%s predicted as %s" % (dataset.label_list[class_ix],
dataset.label_list[pred_ix])
return [
figure_mel,
{'autoPlay': True, 'src': encode_audio(audio_data, sr)},
predicted_text
]


@app.callback(
[Output('plot_features', 'figure'),
Output('audio-player-demo', 'overrideProps'),
Expand Down Expand Up @@ -1008,116 +923,3 @@ def generate_demo(n_clicks, list_of_contents, fold_ix,
figure_features,
{'autoPlay': False, 'src': ""}, ""
]

# @app.callback(
# [Output('plot_features', 'figure'),
# Output('audio-player-demo', 'overrideProps'),
# Output('demo_file_label', 'children')],
# [#Input("tabs", "active_tab"),
# Input("btn_run_demo", "n_clicks"),
# Input('upload-data', 'contents')],
# [State('fold_name', 'value'),
# State('model_path', 'value'),
# State('plot2D_eval', 'selectedData'),
# State('upload-data', 'filename'),
# State('upload-data', 'last_modified'),
# State('sr', 'value')])
# def generate_demo(n_clicks, list_of_contents, fold_ix, #active_tab,
# model_path, selectedData,
# list_of_names, list_of_dates, sr):
# ctx = dash.callback_context
# button_id = ctx.triggered[0]['prop_id'].split('.')[0]
# print('demo', button_id)
# if button_id == 'btn_run_demo-data':
# n_files = len(data_generator_test.features_file_list)

# ix = np.random.randint(n_files)

# fold_name = dataset.fold_list[fold_ix]
# exp_folder_fold = conv_path(os.path.join(model_path, fold_name))

# X_features, Y_file = data_generator_test.get_data_of_file(ix)

# with graph.as_default():
# model_container.load_model_weights(exp_folder_fold)
# Y_features = model_container.model.predict(X_features)

# fig_demo = generate_figure_features(
# X_features, Y_features, dataset.label_list)

# features_file = data_generator_test.features_file_list[ix]
# audio_file = data_generator_test.convert_features_path_to_audio_path(
# features_file, sr=sr)
# audio_data, sr = sf.read(audio_file)

# class_ix = np.argmax(Y_file[0])
# file_label = dataset.label_list[class_ix]
# return [
# fig_demo,
# {'autoPlay': False, 'src': encode_audio(audio_data, sr)},
# file_label
# ]

# if button_id == 'upload-data':
# fold_name = dataset.fold_list[fold_ix]
# exp_folder_fold = conv_path(os.path.join(model_path, fold_name))
# scaler_path = os.path.join(exp_folder_fold, 'scaler.pickle')
# scaler = load_pickle(scaler_path)

# filename = conv_path('upload.wav')
# data = list_of_contents.encode("utf8").split(b";base64,")[1]
# with open(filename, "wb") as fp:
# fp.write(base64.decodebytes(data))

# X_feat = feature_extractor.calculate(filename)
# X_feat = scaler.transform(X_feat)
# with graph.as_default():
# Y_t = model_container.model.predict(X_feat)

# label_list = dataset.label_list
# figure_features = generate_figure_features(X_feat, Y_t, label_list)
# return [
# figure_features,
# {'autoPlay': False, 'src': list_of_contents}, ""
# ]

# if active_tab == 'tab_demo':
# fold_name = dataset.fold_list[fold_ix]
# if (button_id == 'tabs') and (selectedData is not None):
# point = np.array([selectedData['points'][0]['x'],
# selectedData['points'][0]['y']])
# distances_to_data = np.sum(
# np.power(X_pca_test[:, [0, 1]] - point, 2), axis=-1)
# ix = np.argmin(distances_to_data)
# else:
# ix = np.random.randint(len(data_generator.data[fold_name]['X']))

# exp_folder_fold = conv_path(os.path.join(model_path, fold_name))
# scaler_path = os.path.join(exp_folder_fold, 'scaler.pickle')
# scaler = load_pickle(scaler_path)

# X_features = data_generator.data[fold_name]['X'][ix]
# X_features = scaler.transform(X_features)

# with graph.as_default():
# model_container.load_model_weights(exp_folder_fold)
# Y_features = model_container.model.predict(X_features)

# fig_demo = generate_figure_features(
# X_features, Y_features, dataset.label_list)

# features_file = dataset.file_lists[fold_name][ix]
# audio_file = data_generator.convert_features_path_to_audio_path(
# features_file)
# audio_data, sr = sf.read(audio_file)

# Y_file = data_generator.data[fold_name]['Y'][ix][0]
# class_ix = np.argmax(Y_file)
# file_label = dataset.label_list[class_ix]
# return [
# fig_demo,
# {'autoPlay': False, 'src': encode_audio(audio_data, sr)},
# file_label
# ]

# raise dash.exceptions.PreventUpdate
6 changes: 2 additions & 4 deletions visualization/index.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from .app import app
from .layout import layout
from . import callbacks

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Define layout
app.layout = layout

from . import callbacks

# Run the app
if (__name__ == '__main__'):
app.run_server(debug=False)
31 changes: 0 additions & 31 deletions visualization/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,37 +471,6 @@
),
])

# Plot 2D graph
# plot2D_eval = dcc.Graph(id='plot2D_eval', figure=figure2D,
# style={"height": "100%", "width": "100%"})

# # Plot mel-spectrogram
# plot_mel_eval = dcc.Graph(
# id="plot_mel_eval",
# figure=figure_mel,
# style={"width": "90%", "display": "inline-block", 'float': 'left'}
# )

# # Audio controls
# audio_player_eval = dash_audio_components.DashAudioComponents(
# id='audio-player-eval', src="", autoPlay=False, controls=True
# )

# Define Tab Evaluation (4)
# tab_evaluation = html.Div([
# dbc.Row(
# [
# dbc.Col(html.Div([plot2D_eval]), width=8),
# dbc.Col([
# dbc.Row([plot_mel_eval], align='center'),
# dbc.Row([audio_player_eval], align='center'),
# dbc.Row([html.Div("", id="accuracy")], align='center'),
# dbc.Row([html.Div("", id="predicted")], align='center')
# ], width=4),
# ]
# ),
# ])

figure_metrics = generate_figure_metrics([], [])
plot_metrics = dcc.Graph(
id="figure_metrics",
Expand Down

0 comments on commit 8405873

Please sign in to comment.