From d7a7a0ceb8405b8919316a6fde87416f50b4505f Mon Sep 17 00:00:00 2001 From: robin Date: Thu, 26 Jan 2023 16:59:58 +0100 Subject: [PATCH] Update Vedo and SSD versions. --- .../FC/Environment/ArmadilloInteractive.py | 80 ++++++++-------- examples/demos/Armadillo/FC/prediction.py | 3 +- examples/demos/Armadillo/UNet/prediction.py | 3 +- .../Beam/FC/Environment/BeamInteractive.py | 76 ++++++++------- examples/demos/Beam/FC/prediction.py | 3 +- examples/demos/Beam/UNet/prediction.py | 3 +- .../Liver/FC/Environment/LiverInteractive.py | 74 +++++++-------- examples/demos/Liver/FC/prediction.py | 3 +- examples/demos/Liver/UNet/prediction.py | 3 +- examples/features/dataGeneration_multi.py | 8 +- examples/features/dataGeneration_single.py | 4 +- examples/features/gradientDescent.py | 3 +- examples/features/onlineTraining.py | 3 +- examples/features/prediction.py | 3 +- src/Core/AsyncSocket/AbstractEnvironment.py | 13 ++- src/Core/AsyncSocket/TcpIpClient.py | 13 ++- src/Core/AsyncSocket/TcpIpServer.py | 16 ++++ src/Core/Environment/BaseEnvironment.py | 31 ++++--- src/Core/Environment/BaseEnvironmentConfig.py | 7 +- src/Core/Manager/EnvironmentManager.py | 62 +++++-------- src/Core/Network/BaseNetworkConfig.py | 4 +- src/Core/Visualization/VedoFactory.py | 67 ------------- src/Core/Visualization/VedoVisualizer.py | 93 ------------------- src/Core/Visualization/__init__.py | 0 24 files changed, 211 insertions(+), 364 deletions(-) delete mode 100644 src/Core/Visualization/VedoFactory.py delete mode 100644 src/Core/Visualization/VedoVisualizer.py delete mode 100644 src/Core/Visualization/__init__.py diff --git a/examples/demos/Armadillo/FC/Environment/ArmadilloInteractive.py b/examples/demos/Armadillo/FC/Environment/ArmadilloInteractive.py index ba9e071c..531fc298 100644 --- a/examples/demos/Armadillo/FC/Environment/ArmadilloInteractive.py +++ b/examples/demos/Armadillo/FC/Environment/ArmadilloInteractive.py @@ -50,6 +50,8 @@ def __init__(self, self.selected = None self.interactive_window = True self.mouse_factor = 10 * p_model.scale + self.key_on = False + self.click_on = False # Force fields self.arrows = None @@ -89,7 +91,7 @@ def create(self): self.areas[-1].append(i) # Create sphere at initial state self.spheres_init.append(self.mesh_coarse.points(self.areas[-1]).mean(axis=0)) - self.spheres.append(self.sphere(self.spheres_init[-1])) + self.spheres.append(self.sphere(self.spheres_init[-1]).alpha(0.5)) # Define fixed plane mesh_y = self.mesh.points()[:, 1] @@ -105,16 +107,13 @@ def create(self): self.plotter.add(self.mesh) self.plotter.add(Plane(pos=plane_origin, normal=[0, 1, 0], s=(10 * p_model.scale, 10 * p_model.scale), c='darkred', alpha=0.2)) - self.plotter.add(Text2D("Press 'Alt' to interact with the object.\n" - "Left click to select a sphere.\n" - "Right click to unselect a sphere.", s=0.75)) + self.plotter.add(Text2D("Press 'b' to interact with the spheres / with the environment.\n" + "Left click to select / unselect a sphere.", s=0.75)) # Add callbacks - self.plotter.addCallback('KeyPress', self.key_press) - self.plotter.addCallback('KeyRelease', self.key_release) - self.plotter.addCallback('LeftButtonPress', self.left_button_press) - self.plotter.addCallback('RightButtonPress', self.right_button_press) - self.plotter.addCallback('MouseMove', self.mouse_move) + self.plotter.add_callback('KeyPress', self.key_press) + self.plotter.add_callback('LeftButtonPress', self.left_button_press) + self.plotter.add_callback('MouseMove', self.mouse_move) async def step(self): @@ -127,42 +126,39 @@ async def step(self): def key_press(self, evt): # Only react to an 'Alt' press - if 'alt' in evt.keyPressed.lower(): - # Switch from environment to object interaction - self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyle3D()) - self.interactive_window = False - - def key_release(self, evt): - - # Only react with an 'Alt' release - if 'alt' in evt.keyPressed.lower(): - # Switch from object to environment interaction - self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) - self.interactive_window = True - self.selected = None - # Reset all - self.update_mesh() - self.update_arrows() - self.update_spheres() + if 'b' in evt.keyPressed.lower(): + self.key_on = not self.key_on + if self.key_on: + # Switch from environment to object interaction + self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyle3D()) + self.interactive_window = False + self.update_spheres() + else: + # Switch from object to environment interaction + self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) + self.interactive_window = True + self.selected = None + # Reset all + self.update_mesh() + self.update_arrows() + self.update_spheres() def left_button_press(self, evt): # Select a sphere only in object interaction mode if not self.interactive_window: - # Pick a unique sphere - if evt.actor in self.spheres: - self.selected = self.spheres.index(evt.actor) - self.update_spheres(center=self.spheres_init[self.selected]) - - def right_button_press(self, evt): - - # Unselect a sphere only in object interaction mode - if not self.interactive_window: - self.selected = None - # Reset all - self.update_mesh() - self.update_arrows() - self.update_spheres() + self.click_on = not self.click_on + if self.click_on: + # Pick a unique sphere + if evt.actor in self.spheres: + self.selected = self.spheres.index(evt.actor) + self.update_spheres(center=self.spheres_init[self.selected]) + else: + self.selected = None + # Reset all + self.update_mesh() + self.update_arrows() + self.update_spheres() def mouse_move(self, evt): @@ -215,7 +211,9 @@ def update_spheres(self, center=None): # Remove actual spheres self.plotter.remove(*self.spheres) # If no center provided, reset all the spheres - if center is None: + if self.interactive_window: + self.spheres = [self.sphere(c).alpha(0.5) for c in self.spheres_init] + elif center is None: self.spheres = [self.sphere(c) for c in self.spheres_init] # Otherwise, update the selected cell else: diff --git a/examples/demos/Armadillo/FC/prediction.py b/examples/demos/Armadillo/FC/prediction.py index 65d40446..8ee02bfb 100644 --- a/examples/demos/Armadillo/FC/prediction.py +++ b/examples/demos/Armadillo/FC/prediction.py @@ -11,7 +11,6 @@ from DeepPhysX.Core.Pipelines.BasePrediction import BasePrediction from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.FC.FCConfig import FCConfig # Session related imports @@ -26,7 +25,7 @@ def launch_runner(): # Environment config environment_config = BaseEnvironmentConfig(environment_class=Armadillo, - visualizer=VedoVisualizer) + visualizer='vedo') # FC config nb_hidden_layers = 3 diff --git a/examples/demos/Armadillo/UNet/prediction.py b/examples/demos/Armadillo/UNet/prediction.py index 9752f206..f66c544e 100644 --- a/examples/demos/Armadillo/UNet/prediction.py +++ b/examples/demos/Armadillo/UNet/prediction.py @@ -11,7 +11,6 @@ from DeepPhysX.Core.Pipelines.BasePrediction import BasePrediction from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.UNet.UNetConfig import UNetConfig # Session related imports @@ -26,7 +25,7 @@ def launch_runner(): # Environment config environment_config = BaseEnvironmentConfig(environment_class=Armadillo, - visualizer=VedoVisualizer) + visualizer='vedo') # UNet config network_config = UNetConfig(input_size=grid_resolution, diff --git a/examples/demos/Beam/FC/Environment/BeamInteractive.py b/examples/demos/Beam/FC/Environment/BeamInteractive.py index 14f9a069..7eac3b4d 100644 --- a/examples/demos/Beam/FC/Environment/BeamInteractive.py +++ b/examples/demos/Beam/FC/Environment/BeamInteractive.py @@ -45,6 +45,8 @@ def __init__(self, self.selected = None self.interactive_window = True self.mouse_factor = 10 + self.key_on = False + self.click_on = False # Force fields self.arrows = None @@ -87,7 +89,7 @@ def create(self): for z in (np.min(self.mesh_init[:, 2]), np.max(self.mesh_init[:, 2])): center = [x, y, z] self.spheres_init.append(self.mesh_init.copy().tolist().index(center)) - self.spheres.append(self.sphere(self.mesh.points()[self.spheres_init[-1]])) + self.spheres.append(self.sphere(self.mesh.points()[self.spheres_init[-1]]).alpha(0.5)) x_min, x_max = x - sx / 4, x + sx / 4 y_min, y_max = y - sy / 2, y + sy / 2 z_min, z_max = z - sz / 2, z + sz / 2 @@ -101,7 +103,7 @@ def create(self): center = [sx / 2, sy / 2, sz / 2] center[i] = a self.spheres_init.append(self.mesh_init.copy().tolist().index(center)) - self.spheres.append(self.sphere(self.mesh.points()[self.spheres_init[-1]])) + self.spheres.append(self.sphere(self.mesh.points()[self.spheres_init[-1]]).alpha(0.5)) other = [0, 1, 2] other.remove(i) o0_min, o0_max = np.min(self.mesh_init[:, other[0]]), np.max(self.mesh_init[:, other[0]]) @@ -116,15 +118,12 @@ def create(self): self.plotter.add(*self.spheres) self.plotter.add(self.mesh) self.plotter.add(Plane(pos=[0., 0., 0.], normal=[1, 0, 0], s=(20, 20), c='darkred', alpha=0.2)) - self.plotter.add(Text2D("Press 'Alt' to interact with the object.\n" - "Left click to select a sphere.\n" - "Right click to unselect a sphere.", s=0.75)) + self.plotter.add(Text2D("Press 'b' to interact with the spheres / with the environment.\n" + "Left click to select / unselect a sphere.", s=0.75)) # Add callbacks self.plotter.addCallback('KeyPress', self.key_press) - self.plotter.addCallback('KeyRelease', self.key_release) self.plotter.addCallback('LeftButtonPress', self.left_button_press) - self.plotter.addCallback('RightButtonPress', self.right_button_press) self.plotter.addCallback('MouseMove', self.mouse_move) async def step(self): @@ -138,42 +137,39 @@ async def step(self): def key_press(self, evt): # Only react to an 'Alt' press - if 'alt' in evt.keyPressed.lower(): - # Switch from environment to object interaction - self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyle3D()) - self.interactive_window = False - - def key_release(self, evt): - - # Only react with an 'Alt' release - if 'alt' in evt.keyPressed.lower(): - # Switch from object to environment interaction - self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) - self.interactive_window = True - self.selected = None - # Reset all - self.update_mesh() - self.update_arrows() - self.update_spheres() + if 'b' in evt.keyPressed.lower(): + self.key_on = not self.key_on + if self.key_on: + # Switch from environment to object interaction + self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyle3D()) + self.interactive_window = False + self.update_spheres() + else: + # Switch from object to environment interaction + self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) + self.interactive_window = True + self.selected = None + # Reset all + self.update_mesh() + self.update_arrows() + self.update_spheres() def left_button_press(self, evt): # Select a sphere only in object interaction mode if not self.interactive_window: - # Pick a unique sphere - if evt.actor in self.spheres: - self.selected = self.spheres.index(evt.actor) - self.update_spheres(center=self.spheres_init[self.selected]) - - def right_button_press(self, evt): - - # Unselect a sphere only in object interaction mode - if not self.interactive_window: - self.selected = None - # Reset all - self.update_mesh() - self.update_arrows() - self.update_spheres() + self.click_on = not self.click_on + if self.click_on: + # Pick a unique sphere + if evt.actor in self.spheres: + self.selected = self.spheres.index(evt.actor) + self.update_spheres(center=self.spheres_init[self.selected]) + else: + self.selected = None + # Reset all + self.update_mesh() + self.update_arrows() + self.update_spheres() def mouse_move(self, evt): @@ -224,7 +220,9 @@ def update_spheres(self, center=None): # Remove actual spheres self.plotter.remove(*self.spheres) # If no center provided, reset all the spheres - if center is None: + if self.interactive_window: + self.spheres = [self.sphere(self.mesh_init[c]).alpha(0.5) for c in self.spheres_init] + elif center is None: self.spheres = [self.sphere(self.mesh_init[c]) for c in self.spheres_init] # Otherwise, update the selected cell else: diff --git a/examples/demos/Beam/FC/prediction.py b/examples/demos/Beam/FC/prediction.py index c0fc1a70..477ac104 100644 --- a/examples/demos/Beam/FC/prediction.py +++ b/examples/demos/Beam/FC/prediction.py @@ -11,7 +11,6 @@ from DeepPhysX.Core.Pipelines.BasePrediction import BasePrediction from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.FC.FCConfig import FCConfig # Session related imports @@ -26,7 +25,7 @@ def launch_runner(): # Environment config environment_config = BaseEnvironmentConfig(environment_class=Beam, - visualizer=VedoVisualizer) + visualizer='vedo') # FC config nb_hidden_layers = 3 diff --git a/examples/demos/Beam/UNet/prediction.py b/examples/demos/Beam/UNet/prediction.py index 8f194c29..fef275db 100644 --- a/examples/demos/Beam/UNet/prediction.py +++ b/examples/demos/Beam/UNet/prediction.py @@ -11,7 +11,6 @@ from DeepPhysX.Core.Pipelines.BasePrediction import BasePrediction from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.UNet.UNetConfig import UNetConfig # Session related imports @@ -26,7 +25,7 @@ def launch_runner(): # Environment config environment_config = BaseEnvironmentConfig(environment_class=Beam, - visualizer=VedoVisualizer) + visualizer='vedo') # UNet config network_config = UNetConfig(input_size=grid_resolution, diff --git a/examples/demos/Liver/FC/Environment/LiverInteractive.py b/examples/demos/Liver/FC/Environment/LiverInteractive.py index 81fed433..66cc1013 100644 --- a/examples/demos/Liver/FC/Environment/LiverInteractive.py +++ b/examples/demos/Liver/FC/Environment/LiverInteractive.py @@ -50,6 +50,8 @@ def __init__(self, self.selected = None self.interactive_window = True self.mouse_factor = 0.1 + self.key_on = False + self.click_on = False # Force fields self.arrows = None @@ -85,7 +87,7 @@ def create(self): for center in self.mesh_coarse.points()[self.spheres_init]: self.areas.append(np.argwhere(np.sum(np.power(self.mesh_coarse.points() - center, 2), axis=1) <= np.power(radius, 2))) - self.spheres.append(self.sphere(center)) + self.spheres.append(self.sphere(center).alpha(0.5)) # Define boundaries origin, corner = np.array(p_model.boundaries[:3]), np.array(p_model.boundaries[3:]) @@ -98,15 +100,12 @@ def create(self): self.plotter.add(*self.spheres) self.plotter.add(box) self.plotter.add(self.mesh) - self.plotter.add(Text2D("Press 'Alt' to interact with the object.\n" - "Left click to select a sphere.\n" - "Right click to unselect a sphere.", s=0.75)) + self.plotter.add(Text2D("Press 'b' to interact with the spheres / with the environment.\n" + "Left click to select / unselect a sphere.", s=0.75)) # Add callbacks self.plotter.addCallback('KeyPress', self.key_press) - self.plotter.addCallback('KeyRelease', self.key_release) self.plotter.addCallback('LeftButtonPress', self.left_button_press) - self.plotter.addCallback('RightButtonPress', self.right_button_press) self.plotter.addCallback('MouseMove', self.mouse_move) async def step(self): @@ -120,42 +119,39 @@ async def step(self): def key_press(self, evt): # Only react to an 'Alt' press - if 'alt' in evt.keyPressed.lower(): - # Switch from environment to object interaction - self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyle3D()) - self.interactive_window = False - - def key_release(self, evt): - - # Only react with an 'Alt' release - if 'alt' in evt.keyPressed.lower(): - # Switch from object to environment interaction - self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) - self.interactive_window = True - self.selected = None - # Reset all - self.update_mesh() - self.update_arrows() - self.update_spheres() + if 'b' in evt.keyPressed.lower(): + self.key_on = not self.key_on + if self.key_on: + # Switch from environment to object interaction + self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyle3D()) + self.interactive_window = False + self.update_spheres() + else: + # Switch from object to environment interaction + self.plotter.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera()) + self.interactive_window = True + self.selected = None + # Reset all + self.update_mesh() + self.update_arrows() + self.update_spheres() def left_button_press(self, evt): # Select a sphere only in object interaction mode if not self.interactive_window: - # Pick a unique sphere - if evt.actor in self.spheres: - self.selected = self.spheres.index(evt.actor) - self.update_spheres(center=self.mesh_coarse.points()[self.spheres_init[self.selected]]) - - def right_button_press(self, evt): - - # Unselect a sphere only in object interaction mode - if not self.interactive_window: - self.selected = None - # Reset all - self.update_mesh() - self.update_arrows() - self.update_spheres() + self.click_on = not self.click_on + if self.click_on: + # Pick a unique sphere + if evt.actor in self.spheres: + self.selected = self.spheres.index(evt.actor) + self.update_spheres(center=self.mesh_coarse.points()[self.spheres_init[self.selected]]) + else: + self.selected = None + # Reset all + self.update_mesh() + self.update_arrows() + self.update_spheres() def mouse_move(self, evt): @@ -207,7 +203,9 @@ def update_spheres(self, center=None): # Remove actual spheres self.plotter.remove(*self.spheres) # If no center provided, reset all the spheres - if center is None: + if self.interactive_window: + self.spheres = [self.sphere(c).alpha(0.5) for c in self.mesh_coarse.points()[self.spheres_init]] + elif center is None: self.spheres = [self.sphere(c) for c in self.mesh_coarse.points()[self.spheres_init]] # Otherwise, update the selected cell else: diff --git a/examples/demos/Liver/FC/prediction.py b/examples/demos/Liver/FC/prediction.py index 83b3a02e..cafaa391 100644 --- a/examples/demos/Liver/FC/prediction.py +++ b/examples/demos/Liver/FC/prediction.py @@ -11,7 +11,6 @@ from DeepPhysX.Core.Pipelines.BasePrediction import BasePrediction from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.FC.FCConfig import FCConfig # Session related imports @@ -26,7 +25,7 @@ def launch_runner(): # Environment config environment_config = BaseEnvironmentConfig(environment_class=Liver, - visualizer=VedoVisualizer, + visualizer='vedo', env_kwargs={'nb_forces': 3}) # FC config diff --git a/examples/demos/Liver/UNet/prediction.py b/examples/demos/Liver/UNet/prediction.py index 9d937001..1bd28629 100644 --- a/examples/demos/Liver/UNet/prediction.py +++ b/examples/demos/Liver/UNet/prediction.py @@ -11,7 +11,6 @@ from DeepPhysX.Core.Pipelines.BasePrediction import BasePrediction from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.UNet.UNetConfig import UNetConfig # Session related imports @@ -26,7 +25,7 @@ def launch_runner(): # Environment config environment_config = BaseEnvironmentConfig(environment_class=Liver, - visualizer=VedoVisualizer, + visualizer='vedo', env_kwargs={'nb_forces': 3}) # UNet config diff --git a/examples/features/dataGeneration_multi.py b/examples/features/dataGeneration_multi.py index e9e98994..0c18500c 100644 --- a/examples/features/dataGeneration_multi.py +++ b/examples/features/dataGeneration_multi.py @@ -12,7 +12,7 @@ from DeepPhysX.Core.Pipelines.BaseDataGeneration import BaseDataGeneration from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer +# from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer # Session related imports from Environment import MeanEnvironment @@ -26,7 +26,7 @@ def launch_data_generation(use_tcp_ip): # Environment configuration environment_config = BaseEnvironmentConfig(environment_class=MeanEnvironment, - visualizer=VedoVisualizer, + visualizer='vedo', as_tcp_ip_client=use_tcp_ip, number_of_thread=5, env_kwargs={'constant': False, @@ -52,10 +52,10 @@ def launch_data_generation(use_tcp_ip): if __name__ == '__main__': # Run single process - single_process_time = launch_data_generation(use_tcp_ip=False) + # single_process_time = launch_data_generation(use_tcp_ip=False) # Run multiprocess multi_process_time = launch_data_generation(use_tcp_ip=True) # Show results print(f"\nSINGLE PROCESS VS MULTIPROCESS" - f"\n Single process elapsed time: {round(single_process_time, 2)}s" + # f"\n Single process elapsed time: {round(single_process_time, 2)}s" f"\n Multiprocess elapsed time: {round(multi_process_time, 2)}s") diff --git a/examples/features/dataGeneration_single.py b/examples/features/dataGeneration_single.py index c9c0295b..f308c400 100644 --- a/examples/features/dataGeneration_single.py +++ b/examples/features/dataGeneration_single.py @@ -7,7 +7,7 @@ from DeepPhysX.Core.Pipelines.BaseDataGeneration import BaseDataGeneration from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer +# from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer # Session related imports from Environment import MeanEnvironment @@ -21,7 +21,7 @@ def launch_data_generation(): # Environment configuration environment_config = BaseEnvironmentConfig(environment_class=MeanEnvironment, - visualizer=VedoVisualizer, + visualizer='vedo', as_tcp_ip_client=False, env_kwargs={'constant': False, 'data_size': (nb_points, dimension), diff --git a/examples/features/gradientDescent.py b/examples/features/gradientDescent.py index 23e8d613..62467663 100644 --- a/examples/features/gradientDescent.py +++ b/examples/features/gradientDescent.py @@ -12,7 +12,6 @@ from DeepPhysX.Core.Pipelines.BaseTraining import BaseTraining from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.FC.FCConfig import FCConfig # Session related imports @@ -27,7 +26,7 @@ def launch_training(): # Environment configuration environment_config = BaseEnvironmentConfig(environment_class=MeanEnvironment, - visualizer=VedoVisualizer, + visualizer='vedo', as_tcp_ip_client=False, env_kwargs={'constant': True, 'data_size': [nb_points, dimension], diff --git a/examples/features/onlineTraining.py b/examples/features/onlineTraining.py index 780849ec..841bdd47 100644 --- a/examples/features/onlineTraining.py +++ b/examples/features/onlineTraining.py @@ -11,7 +11,6 @@ from DeepPhysX.Core.Pipelines.BaseTraining import BaseTraining from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig from DeepPhysX.Core.Database.BaseDatabaseConfig import BaseDatabaseConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.FC.FCConfig import FCConfig # Session imports @@ -26,7 +25,7 @@ def launch_training(): # Environment configuration environment_config = BaseEnvironmentConfig(environment_class=MeanEnvironment, - visualizer=VedoVisualizer, + visualizer='vedo', as_tcp_ip_client=True, number_of_thread=5, env_kwargs={'constant': False, diff --git a/examples/features/prediction.py b/examples/features/prediction.py index 7fd8adf3..49dce1c0 100644 --- a/examples/features/prediction.py +++ b/examples/features/prediction.py @@ -9,7 +9,6 @@ # DeepPhysX related imports from DeepPhysX.Core.Pipelines.BasePrediction import BasePrediction from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer from DeepPhysX.Torch.FC.FCConfig import FCConfig @@ -25,7 +24,7 @@ def launch_prediction(session): # Environment configuration environment_config = BaseEnvironmentConfig(environment_class=MeanEnvironment, - visualizer=VedoVisualizer, + visualizer='open3d', env_kwargs={'constant': False, 'data_size': [nb_points, dimension], 'delay': True, diff --git a/src/Core/AsyncSocket/AbstractEnvironment.py b/src/Core/AsyncSocket/AbstractEnvironment.py index 229f2de7..2c2a85d1 100644 --- a/src/Core/AsyncSocket/AbstractEnvironment.py +++ b/src/Core/AsyncSocket/AbstractEnvironment.py @@ -1,7 +1,7 @@ from typing import Optional, Dict, Any, Union, Tuple from numpy import ndarray -from SSD.Core.Storage.Database import Database +from SSD.Core.Rendering.UserAPI import UserAPI, Database from DeepPhysX.Core.Database.DatabaseHandler import DatabaseHandler @@ -11,7 +11,8 @@ class AbstractEnvironment: def __init__(self, as_tcp_ip_client: bool = True, instance_id: int = 1, - instance_nb: int = 1): + instance_nb: int = 1, + **kwargs): """ AbstractEnvironment sets the Environment API for TcpIpClient. Do not use AbstractEnvironment to implement an Environment, use BaseEnvironment instead. @@ -43,6 +44,9 @@ def __init__(self, self.sample_training: Optional[Dict[str, Any]] = None self.sample_additional: Optional[Dict[str, Any]] = None + # Visualization Factory + self.factory: Optional[UserAPI] = None + ########################################################################################## ########################################################################################## # Environment initialization # @@ -88,7 +92,10 @@ def close(self) -> None: def get_database_handler(self) -> DatabaseHandler: raise NotImplementedError - def _create_visualization(self, visualization_db: Union[Database, Tuple[str, str]]) -> None: + def _create_visualization(self, visualization_db: Union[Database, Tuple[str, str]], produce_data: bool) -> None: + raise NotImplementedError + + def _connect_visualization(self) -> None: raise NotImplementedError def _send_training_data(self) -> None: diff --git a/src/Core/AsyncSocket/TcpIpClient.py b/src/Core/AsyncSocket/TcpIpClient.py index c46a9fa2..8c1a67da 100644 --- a/src/Core/AsyncSocket/TcpIpClient.py +++ b/src/Core/AsyncSocket/TcpIpClient.py @@ -5,6 +5,8 @@ from asyncio import run as async_run from numpy import ndarray +from SSD.Core.Storage.Database import Database + from DeepPhysX.Core.AsyncSocket.TcpIpObject import TcpIpObject from DeepPhysX.Core.AsyncSocket.AbstractEnvironment import AbstractEnvironment @@ -97,8 +99,9 @@ async def __initialize(self) -> None: self.environment.init() self.environment.init_database() if visualization_db is not None: - self.environment._create_visualization(visualization_db=visualization_db) - self.environment.init_visualization() + db = Database(database_dir=visualization_db[0], + database_name=visualization_db[1]).load() + self.environment._create_visualization(visualization_db=db) # Initialization done await self.send_data(data_to_send='done', loop=loop, receiver=self.sock) @@ -107,6 +110,10 @@ async def __initialize(self) -> None: _ = await self.receive_data(loop=loop, sender=self.sock) self.environment.get_database_handler().load() + # Connect to Visualizer + if visualization_db is not None: + self.environment._connect_visualization() + ########################################################################################## ########################################################################################## # Running Client # @@ -157,6 +164,8 @@ async def __close(self) -> None: self.environment.close() except NotImplementedError: pass + if self.environment.factory is not None: + self.environment.factory.close() # Confirm exit command to the server loop = get_event_loop() await self.send_command_exit(loop=loop, receiver=self.sock) diff --git a/src/Core/AsyncSocket/TcpIpServer.py b/src/Core/AsyncSocket/TcpIpServer.py index e8b3f350..033a0449 100644 --- a/src/Core/AsyncSocket/TcpIpServer.py +++ b/src/Core/AsyncSocket/TcpIpServer.py @@ -187,6 +187,22 @@ async def __initialize(self, for client_id, client in self.clients: await self.send_data(data_to_send='sync', loop=loop, receiver=client) + def connect_visualization(self) -> None: + """ + Connect the Factories of the Clients to the Visualizer. + """ + + async_run(self.__connect_visualization()) + + async def __connect_visualization(self): + """ + Connect the Factories of the Clients to the Visualizer. + """ + + loop = get_event_loop() + for _, client in self.clients: + await self.send_data(data_to_send='conn', loop=loop, receiver=client) + ########################################################################################## ########################################################################################## # Data: produce batch & dispatch batch # diff --git a/src/Core/Environment/BaseEnvironment.py b/src/Core/Environment/BaseEnvironment.py index 41bc90b6..9d3603fc 100644 --- a/src/Core/Environment/BaseEnvironment.py +++ b/src/Core/Environment/BaseEnvironment.py @@ -3,8 +3,8 @@ from os.path import isfile, join from SSD.Core.Storage.Database import Database +from SSD.Core.Rendering.UserAPI import UserAPI -from DeepPhysX.Core.Visualization.VedoFactory import VedoFactory from DeepPhysX.Core.AsyncSocket.AbstractEnvironment import AbstractEnvironment from DeepPhysX.Core.Database.DatabaseHandler import DatabaseHandler @@ -43,9 +43,6 @@ def __init__(self, # Connect the Environment to the data Database self.__database_handler = DatabaseHandler(on_init_handler=self.__database_handler_init) - # Connect the Factory to the visualization Database - self.factory: Optional[VedoFactory] = None - ########################################################################################## ########################################################################################## # Environment initialization # @@ -289,9 +286,6 @@ def update_visualisation(self) -> None: """ if self.factory is not None: - # If Environment is a TcpIpClient, request to the Server - if self.as_tcp_ip_client: - self.tcp_ip_client.request_update_visualization() self.factory.render() def _get_prediction(self): @@ -329,18 +323,29 @@ def __database_handler_init(self): self.__database_handler.load() def _create_visualization(self, - visualization_db: Union[Database, Tuple[str, str]]) -> None: + visualization_db: Union[Database, Tuple[str, str]], + produce_data: bool = True) -> None: """ Create a Factory for the Environment. """ if type(visualization_db) == list: - self.factory = VedoFactory(database_path=visualization_db, - idx_instance=self.instance_id, - remote=True) + self.factory = UserAPI(database_dir=visualization_db[0], + database_name=visualization_db[1], + idx_instance=self.instance_id, + non_storing=not produce_data) else: - self.factory = VedoFactory(database=visualization_db, - idx_instance=self.instance_id) + self.factory = UserAPI(database=visualization_db, + idx_instance=self.instance_id, + non_storing=not produce_data) + self.init_visualization() + + def _connect_visualization(self) -> None: + """ + Connect the Factory to the Visualizer. + """ + + self.factory.connect_visualizer() def _send_training_data(self) -> List[int]: """ diff --git a/src/Core/Environment/BaseEnvironmentConfig.py b/src/Core/Environment/BaseEnvironmentConfig.py index 94d840a9..64651e4b 100644 --- a/src/Core/Environment/BaseEnvironmentConfig.py +++ b/src/Core/Environment/BaseEnvironmentConfig.py @@ -7,7 +7,6 @@ from DeepPhysX.Core.AsyncSocket.TcpIpServer import TcpIpServer from DeepPhysX.Core.Environment.BaseEnvironment import BaseEnvironment -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer class BaseEnvironmentConfig: @@ -23,7 +22,7 @@ def __init__(self, load_samples: bool = False, only_first_epoch: bool = True, always_produce: bool = False, - visualizer: Optional[Type[VedoVisualizer]] = None, + visualizer: Optional[str] = None, record_wrong_samples: bool = False, env_kwargs: Optional[Dict[str, Any]] = None): """ @@ -41,7 +40,7 @@ def __init__(self, :param only_first_epoch: If True, data will always be created from environment. If False, data will be created from the environment during the first epoch and then re-used from the Dataset. :param always_produce: If True, data will always be produced in Environment(s). - :param visualizer: Class of the Visualizer to use. + :param visualizer: Backend of the Visualizer to use. :param record_wrong_samples: If True, wrong samples are recorded through Visualizer. :param env_kwargs: Additional arguments to pass to the Environment. """ @@ -90,7 +89,7 @@ def __init__(self, self.env_kwargs: Dict[str, Any] = {} if env_kwargs is None else env_kwargs # Visualizer variables - self.visualizer: Optional[Type[VedoVisualizer]] = visualizer + self.visualizer: Optional[str] = visualizer self.record_wrong_samples: bool = record_wrong_samples def create_server(self, diff --git a/src/Core/Manager/EnvironmentManager.py b/src/Core/Manager/EnvironmentManager.py index 1e85c39f..6cbcb126 100644 --- a/src/Core/Manager/EnvironmentManager.py +++ b/src/Core/Manager/EnvironmentManager.py @@ -4,7 +4,7 @@ from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig, TcpIpServer, BaseEnvironment from DeepPhysX.Core.Database.DatabaseHandler import DatabaseHandler -from DeepPhysX.Core.Visualization.VedoVisualizer import VedoVisualizer +from SSD.Core.Rendering.Visualizer import Visualizer, Database class EnvironmentManager: @@ -44,12 +44,10 @@ def __init__(self, # Create a Visualizer to provide the visualization Database force_local = pipeline == 'prediction' - self.visualizer: Optional[VedoVisualizer] = None + visualizer_db: Optional[Database] = None if environment_config.visualizer is not None: - self.visualizer = environment_config.visualizer(database_dir=join(session, 'dataset'), - database_name='Visualization', - remote=environment_config.as_tcp_ip_client and not force_local, - record=produce_data) + visualizer_db = Database(database_dir=join(session, 'dataset'), + database_name='Visualization').new() # Create a single Environment or a TcpIpServer self.number_of_thread: int = 1 if force_local else environment_config.number_of_thread @@ -57,10 +55,19 @@ def __init__(self, self.environment: Optional[BaseEnvironment] = None # Create Server if environment_config.as_tcp_ip_client and not force_local: + if visualizer_db is not None: + visualizer_db.create_table(table_name='Temp') self.server = environment_config.create_server(environment_manager=self, batch_size=batch_size, - visualization_db=None if self.visualizer is None else - self.visualizer.get_path()) + visualization_db=None if visualizer_db is None else + visualizer_db.get_path()) + if visualizer_db is not None: + visualizer_db.remove_table(table_name='Temp') + Visualizer.launch(backend=environment_config.visualizer, + database_dir=join(session, 'dataset'), + database_name=visualizer_db.get_path()[1], + nb_clients=environment_config.number_of_thread) + self.server.connect_visualization() # Create Environment else: self.environment = environment_config.create_environment() @@ -69,21 +76,19 @@ def __init__(self, self.environment.create() self.environment.init() self.environment.init_database() - if self.visualizer is not None: - self.environment._create_visualization(visualization_db=self.visualizer.get_database()) - self.environment.init_visualization() + if visualizer_db is not None: + self.environment._create_visualization(visualization_db=visualizer_db, + produce_data=produce_data) + Visualizer.launch(backend=environment_config.visualizer, + database_dir=join(session, 'dataset'), + database_name=visualizer_db.get_path()[1]) + self.environment._connect_visualization() # Define whether methods are used for environment or server self.get_database_handler = self.__get_server_db_handler if self.server else self.__get_environment_db_handler self.get_data = self.__get_data_from_server if self.server else self.__get_data_from_environment self.dispatch_batch = self.__dispatch_batch_to_server if self.server else self.__dispatch_batch_to_environment - # Init the Visualizer once Environments are initialized - if self.visualizer is not None: - if len(self.visualizer.get_database().get_tables()) == 1: - self.visualizer.get_database().load() - self.visualizer.init_visualizer() - ########################################################################################## ########################################################################################## # DatabaseHandler management # @@ -207,23 +212,6 @@ def __dispatch_batch_to_environment(self, save_data=save_data, request_prediction=request_prediction) - ########################################################################################## - ########################################################################################## - # Requests management # - ########################################################################################## - ########################################################################################## - - def update_visualizer(self, - instance: int) -> None: - """ - Update the Visualizer. - - :param instance: Index of the Environment render to update. - """ - - if self.visualizer is not None: - self.visualizer.render_instance(instance) - ########################################################################################## ########################################################################################## # Manager behavior # @@ -241,12 +229,10 @@ def close(self) -> None: # Environment case if self.environment: + if self.environment.factory is not None: + self.environment.factory.close() self.environment.close() - # Visualizer - if self.visualizer: - self.visualizer.close() - def __str__(self) -> str: description = "\n" diff --git a/src/Core/Network/BaseNetworkConfig.py b/src/Core/Network/BaseNetworkConfig.py index efde937a..d47df632 100644 --- a/src/Core/Network/BaseNetworkConfig.py +++ b/src/Core/Network/BaseNetworkConfig.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Type from os.path import isdir -from numpy import typeDict +from numpy import sctypeDict from DeepPhysX.Core.Network.BaseNetwork import BaseNetwork from DeepPhysX.Core.Network.BaseOptimization import BaseOptimization @@ -70,7 +70,7 @@ def __init__(self, raise TypeError( f"[{self.__class__.__name__}] Wrong 'save each epoch' type: bool required, get {type(save_each_epoch)}") # Check data type - if data_type not in typeDict: + if data_type not in sctypeDict: raise ValueError( f"[{self.__class__.__name__}] The following data type is not a numpy type: {data_type}") diff --git a/src/Core/Visualization/VedoFactory.py b/src/Core/Visualization/VedoFactory.py deleted file mode 100644 index c524be5f..00000000 --- a/src/Core/Visualization/VedoFactory.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Optional, Tuple, List, Dict - -from SSD.Core.Rendering.VedoFactory import VedoFactory as _VedoFactory -from SSD.Core.Rendering.VedoFactory import Database, VedoTable - - -class VedoFactory(_VedoFactory): - - def __init__(self, - database: Optional[Database] = None, - database_path: Optional[Tuple[str, str]] = None, - database_dir: str = '', - database_name: Optional[str] = None, - remove_existing: bool = False, - idx_instance: int = 0, - remote: bool = False): - """ - A Factory to manage objects to render and save in the Database. - User interface to create and update Vedo objects. - - :param database: Database to connect to. - :param database_path: Path to the Database to connect to. - :param database_dir: Directory which contains the Database file (used if 'database' is not defined). - :param database_name: Name of the Database to connect to (used if 'database' is not defined). - :param remove_existing: If True, overwrite a Database with the same path. - :param idx_instance: If several Factories must be created, specify the index of the Factory. - :param remote: If True, the Visualizer will treat the Factories as remote. - """ - - # Define Database - if database is not None: - self.__database: Database = database - elif database_path is not None: - self.__database: Database = Database(database_dir=database_path[0], - database_name=database_path[1]).load() - elif database_name is not None: - self.__database: Database = Database(database_dir=database_dir, - database_name=database_name).new(remove_existing=remove_existing) - else: - raise ValueError("Both 'database' and 'database_name' are not defined.") - - # Information about all Tables - self.__tables: List[VedoTable] = [] - self.__current_id: int = 0 - self.__idx: int = idx_instance - self.__update: Dict[int, bool] = {} - self.__path = database_path - - # ExchangeTable to synchronize Factory and Visualizer - if not remote: - self.__database.register_pre_save_signal(table_name='Sync', - handler=self.__sync_visualizer, - name=f'Factory_{self.__idx}') - self.remote = remote - - def render(self): - """ - Render the current state of Actors in the Plotter. - """ - - if not self.remote: - _VedoFactory.render(self) - - else: - # Reset al the update flags - for i, updated in self.__update.items(): - self.__update[i] = False diff --git a/src/Core/Visualization/VedoVisualizer.py b/src/Core/Visualization/VedoVisualizer.py deleted file mode 100644 index d4187f1d..00000000 --- a/src/Core/Visualization/VedoVisualizer.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Optional, Dict, Tuple, Union - -from SSD.Core.Rendering.VedoVisualizer import VedoVisualizer as _VedoVisualizer -from SSD.Core.Rendering.VedoVisualizer import Database, VedoActor, Plotter - - -class VedoVisualizer(_VedoVisualizer): - - def __init__(self, - database: Optional[Database] = None, - database_dir: str = '', - database_name: Optional[str] = None, - remove_existing: bool = False, - offscreen: bool = False, - remote: bool = False, - record: bool = True): - """ - Manage the creation, update and rendering of Vedo Actors. - - :param database: Database to connect to. - :param database_dir: Directory which contains the Database file (used if 'database' is not defined). - :param database_name: Name of the Database (used if 'database' is not defined). - :param remove_existing: If True, overwrite a Database with the same path. - :param offscreen: If True, visual data will be saved but not rendered. - :param remote: If True, the Visualizer will treat the Factories as remote. - :param record: If True, the visualization Database is saved in memory. - """ - - # Define Database - if database is not None: - self.__database: Database = database - elif database_name is not None: - self.__database: Database = Database(database_dir=database_dir, - database_name=database_name).new(remove_existing=remove_existing) - else: - raise ValueError("Both 'database' and 'database_name' are not defined.") - - # Information about all Factories / Actors - self.__actors: Dict[int, Dict[Tuple[int, int], VedoActor]] = {} - self.__all_actors: Dict[Tuple[int, int], VedoActor] = {} - self.__plotter: Optional[Plotter] = None - self.__offscreen: bool = offscreen - self.step: Union[int, Dict[int, int]] = {} if remote else 0 - - self.__database.create_table(table_name='Sync', - storing_table=False, - fields=('step', str)) - - if not remote: - self.__database.register_post_save_signal(table_name='Sync', - handler=self.__sync_visualizer) - self.record = record - - def render_instance(self, instance: int): - """ - Render the current state of Actors managed by a certain Factory in the Plotter. - - :param instance: Index of the Environment render to update. - """ - - # 1. Update Factories steps - if instance not in self.step: - self.step[instance] = 0 - self.step[instance] += 1 - - # 2. Retrieve visual data and update Actors (one Table per Actor) - table_names = self.__database.get_tables() - table_names.remove('Sync') - table_names = [table_name for table_name in table_names if table_name.split('_')[1] == str(instance)] - for table_name in table_names: - # Get the current step line in the Table - data_dict = self.__database.get_line(table_name=table_name, - line_id=self.step[instance]) - # If the id of line is correct, the Actor was updated - if data_dict.pop('id') == self.step[instance]: - self.update_instance(table_name=table_name, - data_dict=data_dict) - # Otherwise, the actor was not updated, then add an empty line - else: - self.__database.add_data(table_name=table_name, - data={}) - - # 3. Render Plotter if offscreen is False - if not self.__offscreen: - self.__plotter.render() - - def close(self): - """ - Launch the closing procedure of the Visualizer. - """ - - if not self.record: - self.__database.close(erase_file=True) diff --git a/src/Core/Visualization/__init__.py b/src/Core/Visualization/__init__.py deleted file mode 100644 index e69de29b..00000000